In [108]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from os import listdir
from os.path import isfile, join
from dataloader import create_data_loaders, collate_fn, CustomDataset

In [114]:
# create RNN
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        # rnn
        self.lstm = nn.LSTM(input_size = input_size, hidden_size = hidden_size, 
                            batch_first = True)

        # define fully connected layers
        self.linear = nn.Linear(hidden_size, num_classes, bias = False)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # return output and last hidden state
        x, (h, c) = self.lstm(x)
        
        # flatten and fully-connected output layer
        x = x.view(-1, self.lstm.hidden_size)
        x = self.softmax(self.linear(x))
        
        return x

net = RNN(512, 200, 6)
print(net)

RNN(
  (lstm): LSTM(512, 200, batch_first=True)
  (linear): Linear(in_features=200, out_features=6, bias=False)
  (softmax): Softmax(dim=1)
)


In [85]:
# have a look at 4_conv ex1 for inspiration
# def one_hot_encode_labels(labels, num_classes, encoding):
#     labels = np.array([*labels])

#     one_hot = np.zeros((labels.size, num_classes))
    
#     for idx, label in enumerate(labels):
#         one_hot[idx, encoding[label]] = 1
    
#     return one_hot

# def load_data(split, protein, encoding):
#     if split == "train":
#         data_train = []
        
#         for cv in ["cv0", "cv1", "cv2"]:
#             path = f"encoder_proteins/{cv}"
            
#             onlyfiles = [f for f in listdir(path) if isfile(join(path, f))]

#             load_split = np.load(path, allow_pickle = True).item()
    
#     data = load_split["data"]
#     labels = load_split["labels"]
    
#     num_classes = 6
#     labels = one_hot_encode_labels(labels, num_classes, encoding)
    
#     return torch.from_numpy(data), torch.from_numpy(labels)

# encoding = {'I': 0, 'O':1 , 'P': 2, 'S': 3, 'M':4, 'B': 5}

# x_train, targets_train = load_data("cv0", "Q55210.npy", encoding)
# x_train, targets_train

(tensor([[-0.0085,  0.2434, -0.0823,  ...,  0.4764,  0.0955, -0.5988],
         [-1.2657, -0.4229, -0.6991,  ...,  0.2343, -0.0476, -0.1440],
         [-1.2906, -0.4980, -0.7681,  ...,  0.1442, -0.0278,  0.1105],
         ...,
         [ 0.4188,  0.3088,  0.5084,  ...,  0.4166,  0.1377,  0.3491],
         [ 1.0614,  0.7385,  0.5483,  ...,  0.1471,  0.4259,  0.1575],
         [ 0.4986,  0.5775, -0.1428,  ...,  0.2616,  0.0052,  0.1141]]),
 tensor([[0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0., 0.],
         ...,
         [0., 1., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0.]], dtype=torch.float64))

In [98]:
epochs = 1
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr = 1e-3)

training_loss, validation_loss = [], []
batch_size = 1
seq_length = x_train.shape[0]
feat_length = x_train.shape[1]

train_set = TensorDataset(x_train, targets_train)
train_loader = DataLoader(train_set, batch_size = x_train.shape[0], shuffle = True, drop_last = False)

for i in range(epochs):
    epoch_training_loss = 0
    epoch_validation_loss = 0
    
    net.train()
    for inputs, targets in train_loader:
        inputs = inputs.reshape(batch_size, seq_length, feat_length)
        
        output = net(inputs)
        
        loss = criterion(output, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update loss
        epoch_training_loss += loss.detach().numpy()
    
    training_loss.append(epoch_training_loss/len(train_set))
    
    if i % 10 == 0:
        print(f'Epoch {i}, training loss: {training_loss[-1]}')


ValueError: Expected input batch_size (340) to match target batch_size (170).

In [None]:
# Specify the input size (L * 512) and number of classes (6)
input_size = L * 512
num_classes = 6
