In [None]:
import json
import torch
import sklearn
import numpy as np
import torch.nn as nn

from torch.utils import data

# Our code
from data_generator import Dataset
from lstm import LSTM_model

In [None]:
train_file = "stab_data/stability_train.json"
valid_file = "stab_data/stability_valid.json"
test_file = "stab_data/stability_test.json"

In [None]:
class stab_dataset(Dataset):
    def __init__(self, filename, max_seq_len, output_type="onehot", acids="ACDEFGHIKLMNPQRSTVWY-"):
        elem_list = []
        label_list = []
        self.acids = acids
        self.output_type = output_type
        self.acid_dict, self.int_acid_dict = self.__gen_acid_dict__(acids)
        self.max_seq_len = max_seq_len
        self.get_prot_class=True
        # Loading the entire input file into memory
        for i, elem in enumerate(json.load(open(filename))):
            seq = elem["primary"].upper()
            if self.__is_legal_seq__(seq):
                elem_list.append(seq)
                label_list.append(elem["stability_score"])
        '''
        for i, elem in enumerate(SeqIO.parse(filename, "fasta")):
            if self.__is_legal_seq__(elem.seq.upper()):
                elem_list.append(elem.seq.upper())
                if get_prot_class:
                    label_list.append(prot_class_re.search(elem.description).group(1))
        '''
        
        self.data = elem_list
        self.prot_labels = label_list

class lin_reg(nn.Module):
    def __init__(self, input_size, output_size):
        super(lin_reg, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        #x = x.long()
        out = self.linear(x).long()
        return out

# Use Cuda if available
use_cuda = torch.cuda.is_available() and True
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

max_seq_len = 500
batch_size = 64
hidden_dim = 512
embed_size = 30
hidden_layers = 1

scope_file = "scope_data_40.fasta"
acids = "ACDEFGHIKLMNOPQRSTUVWY-"

stab_train = stab_dataset(train_file, 500, output_type="embed", acids=acids)
stab_valid = stab_dataset(valid_file, 500, output_type="embed", acids=acids)
stab_test = stab_dataset(test_file, 500, output_type="embed", acids=acids)

train_generator = data.DataLoader(stab_train, batch_size=batch_size, shuffle=True, num_workers=8)
valid_generator = data.DataLoader(stab_valid, batch_size=batch_size, shuffle=True, num_workers=8)
test_generator = data.DataLoader(stab_test, batch_size=batch_size, shuffle=True, num_workers=8)

model = LSTM_model(len(acids), embed_size, hidden_dim, hidden_layers, max_seq_len, batch_size, processor, dropout=0).to(processor)

model.load_state_dict(torch.load("best_lstm.pth"))

top_model = lin_reg(hidden_dim, 1).to(processor)

In [None]:
for param in top_model.parameters():
    param.requires_grad=True
loss_function = nn.MSELoss() 
optimiser = torch.optim.SGD(top_model.parameters(), lr=0.001)
    
for i, (batch, labels, valid_elems, scores) in enumerate(train_generator):
    #with torch.no_grad():
    model.eval()
    
    batch = batch.to(processor)
    labels = labels.to(processor)
    scores = scores[0].to(processor)
    
    
    # Transposing from (batch x seq x feature_size) to (seq x batch x feature_size)
    batch = batch.transpose(0,1)

    print("Batch:", batch.grad_fn)
    out, hidden = model(batch, valid_elems)
    print("Out:", out.grad_fn)
    print("Hidden:", hidden.grad_fn)
    reduced_hidden = torch.mean(hidden, dim=0)
    print("Mean:", reduced_hidden.grad_fn)

    top_model.zero_grad()
    top_model.train()
    pred = top_model(reduced_hidden)
    loss = loss_function(pred, scores.view(-1,1)).detach()
    print("Loss:", loss)
    loss.backward()
    optimiser.step()
    
    break