In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn
import torch.optim as optim

torch.manual_seed(1)

import numpy as np
import matplotlib.pyplot as plt
import sys

from Bio import SeqIO
from datetime import datetime

from torch.utils import data
#from data_generator import data_generator
from data_generator import Dataset
from lstm import LSTM_model
from lstm import LSTMCell

In [2]:
def gen_alphabet(mod_val):
    alphabet = "abcdefghijklmnopqrstuvwxyz"
    index = mod_val % 26
    return alphabet[index:] + alphabet[:index]

test = [gen_alphabet(i) for i in range(32)]
print(test)

['abcdefghijklmnopqrstuvwxyz', 'bcdefghijklmnopqrstuvwxyza', 'cdefghijklmnopqrstuvwxyzab', 'defghijklmnopqrstuvwxyzabc', 'efghijklmnopqrstuvwxyzabcd', 'fghijklmnopqrstuvwxyzabcde', 'ghijklmnopqrstuvwxyzabcdef', 'hijklmnopqrstuvwxyzabcdefg', 'ijklmnopqrstuvwxyzabcdefgh', 'jklmnopqrstuvwxyzabcdefghi', 'klmnopqrstuvwxyzabcdefghij', 'lmnopqrstuvwxyzabcdefghijk', 'mnopqrstuvwxyzabcdefghijkl', 'nopqrstuvwxyzabcdefghijklm', 'opqrstuvwxyzabcdefghijklmn', 'pqrstuvwxyzabcdefghijklmno', 'qrstuvwxyzabcdefghijklmnop', 'rstuvwxyzabcdefghijklmnopq', 'stuvwxyzabcdefghijklmnopqr', 'tuvwxyzabcdefghijklmnopqrs', 'uvwxyzabcdefghijklmnopqrst', 'vwxyzabcdefghijklmnopqrstu', 'wxyzabcdefghijklmnopqrstuv', 'xyzabcdefghijklmnopqrstuvw', 'yzabcdefghijklmnopqrstuvwx', 'zabcdefghijklmnopqrstuvwxy', 'abcdefghijklmnopqrstuvwxyz', 'bcdefghijklmnopqrstuvwxyza', 'cdefghijklmnopqrstuvwxyzab', 'defghijklmnopqrstuvwxyzabc', 'efghijklmnopqrstuvwxyzabcd', 'fghijklmnopqrstuvwxyzabcde']


In [3]:
class alpha_set(data.Dataset):
    def __gen_acid_dict__(self, acids):
        acid_dict = {}
        for i, elem in enumerate(acids):
            temp = torch.zeros(len(acids))
            temp[i] = 1
            acid_dict[elem] = temp
        return acid_dict
    
    def __init__(self, acids, length, num_seqs):
        self.max_seq_len = length
        self.acids = acids
        self.acid_dict = self.__gen_acid_dict__(acids)
        self.data = [gen_alphabet(i) for i in range(num_seqs)]

    def __prepare_seq__(self, seq):
        valid_elems = min(len(seq), self.max_seq_len)
        seq = str(seq).ljust(self.max_seq_len+1, '-')
        temp_seq = [self.acid_dict[x] for x in seq]
        tensor_seq = torch.stack(temp_seq[:-1]).float()
        #valid_elems = torch.Tensor([elem != '-' for elem in seq[:-1]])

        # Labels consisting of the raw tensor
        # labels_seq = torch.stack(temp_seq[1:]).long()

        # Label consisting of last element
        # labels_seq = temp_seq[-1].long()

        # Labels consisting of the index of correct class
        labels_seq = torch.argmax(torch.stack(temp_seq[1:]), dim=1).long()

        #print(labels_seq.size())
        #print(tensor_seq.size())
        #labels_seq = torch.transpose(labels_seq, 0, 1)
        #tensor_seq = torch.transpose(tensor_seq, 0, 1)
        #print("Seq shape:", tensor_seq[1:].size())
        return tensor_seq, labels_seq, valid_elems

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.__prepare_seq__(self.data[index])
    

# Use Cuda if available
use_cuda = torch.cuda.is_available() and False
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

batch_size = 8

max_seq_len = 50
alphabet = "abcdefghijklmnopqrstuvwxyz-"
alpha_dataset = alpha_set(alphabet, 64, 2000)
alpha_generator = data.DataLoader(alpha_dataset, batch_size=batch_size, shuffle=True)
loss_function = nn.CrossEntropyLoss(reduction="sum").to(processor)
lstm = LSTM_model(len(alphabet), 100, 1, max_seq_len, 0).to(processor)
optimiser = optim.SGD(lstm.parameters(), lr=1e-3, momentum=0.9, nesterov=True)

for i, (seq, label, valid) in enumerate(alpha_generator):
    seq = seq.to(processor)
    label = label.to(processor)
    valid = valid.to(processor)

    seq = seq.transpose(0,1)
    #label = label.transpose(0,1)

    #label = label.squeeze(0)
    if i == 0:
        print("Input:\t", seq.size())
        print("Labels:\t", label.size())
        print("Valid:\t", valid)

    
    lstm.zero_grad()

    seq = rnn.pack_padded_sequence(seq, valid, enforce_sorted=False)
    #label = rnn.pack_padded_sequence(label, valid, enforce_sorted=False, batch_first=True)
    
    out, hidden = lstm(seq)
    #out = out.squeeze(1)
    
    if i == 0:
        print("Output:\t", out.size())
        print("Hidden:\t")#, hidden.size())
    #out = out.transpose(1, 2)
    
    loss = 0
    for j in range(out.size()[1]):
        narrowed_out = torch.narrow(torch.narrow(out, 1, j, 1).squeeze(1), 0, 0, valid[j])
        #print(narrowed_out.size())
        #print(torch.argmax(narrowed_out, dim=1))
        #print(label[j].size())
        narrowed_label = torch.narrow(label[j], 0, 0, valid[j])
        
        loss += loss_function(narrowed_out, narrowed_label)

    #loss /= out.size()[1]
    loss.backward()
    optimiser.step()

print("Finished Training")

Using GPU: False
Input:	 torch.Size([64, 8, 27])
Labels:	 torch.Size([8, 64])
Valid:	 tensor([26, 26, 26, 26, 26, 26, 26, 26])
Output:	 torch.Size([50, 8, 27])
Hidden:	


Finished Training


In [4]:
with torch.no_grad():
    for i, (seq, label, valid) in enumerate(alpha_generator):
        lstm.eval()
        seq = seq.to(processor)
        label = label.to(processor)
        valid = valid.to(processor)
        
        seq = seq.transpose(0,1)
        #label = label.transpose(0,1)

        seq = rnn.pack_padded_sequence(seq, valid, enforce_sorted=False)
        
        out, hidden = lstm(seq)

        out = out.transpose(0,1)

        correct = 0
        for j in range(batch_size):
            truths = [1 if pred == truth else 0 for pred, truth in zip(torch.argmax(out[j], dim=1), label[j])]
            print(len(truths))
            correct += sum(truths)

        

        #if i > 4:
        break
accuracy = correct/(batch_size * max_seq_len * (i+1))
print("Test Accuracy: {0:.3f}%".format(accuracy*100))

def print_preds(preds):
    for i, seq in enumerate(preds):
        print("Sequence {}".format(i))
        indexes = torch.argmax(seq, dim=1)
        print(indexes)
        #print("".join(ret_val))
        
print_preds(out)

50
50
50
50
50
50
50
50
Test Accuracy: 98.000%
Sequence 0
tensor([22, 23, 24, 25,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13,
        14, 15, 16, 17, 18, 19, 20, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26])
Sequence 1
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26])
Sequence 2
tensor([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
        20, 21, 22, 23, 24, 25, 26,  1, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26])
Sequence 3
tensor([14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,  0,  1,  2,  3,  4,  5,
         6,  7,  8,  9, 10, 11, 26, 13, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26])
Sequence 