In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn
import torch.optim as optim


import numpy as np
import matplotlib.pyplot as plt
import sys

from Bio import SeqIO
from datetime import datetime

from torch.utils import data
#from data_generator import data_generator
from data_generator import Dataset
from lstm import LSTM_model
from lstm import LSTMCell
from alpha_set import alpha_set
from print_seq import print_seq

In [2]:
def gen_alphabet(mod_val):
    alphabet = "abcdefghijklmnopqrstuvwxyz"
    index = mod_val % 26
    return alphabet[index:] + alphabet[:index]

test = [gen_alphabet(i) for i in range(32)]
print(test)

['abcdefghijklmnopqrstuvwxyz', 'bcdefghijklmnopqrstuvwxyza', 'cdefghijklmnopqrstuvwxyzab', 'defghijklmnopqrstuvwxyzabc', 'efghijklmnopqrstuvwxyzabcd', 'fghijklmnopqrstuvwxyzabcde', 'ghijklmnopqrstuvwxyzabcdef', 'hijklmnopqrstuvwxyzabcdefg', 'ijklmnopqrstuvwxyzabcdefgh', 'jklmnopqrstuvwxyzabcdefghi', 'klmnopqrstuvwxyzabcdefghij', 'lmnopqrstuvwxyzabcdefghijk', 'mnopqrstuvwxyzabcdefghijkl', 'nopqrstuvwxyzabcdefghijklm', 'opqrstuvwxyzabcdefghijklmn', 'pqrstuvwxyzabcdefghijklmno', 'qrstuvwxyzabcdefghijklmnop', 'rstuvwxyzabcdefghijklmnopq', 'stuvwxyzabcdefghijklmnopqr', 'tuvwxyzabcdefghijklmnopqrs', 'uvwxyzabcdefghijklmnopqrst', 'vwxyzabcdefghijklmnopqrstu', 'wxyzabcdefghijklmnopqrstuv', 'xyzabcdefghijklmnopqrstuvw', 'yzabcdefghijklmnopqrstuvwx', 'zabcdefghijklmnopqrstuvwxy', 'abcdefghijklmnopqrstuvwxyz', 'bcdefghijklmnopqrstuvwxyza', 'cdefghijklmnopqrstuvwxyzab', 'defghijklmnopqrstuvwxyzabc', 'efghijklmnopqrstuvwxyzabcd', 'fghijklmnopqrstuvwxyzabcde']


In [3]:
min_loss = float("inf")
# Use Cuda if available
use_cuda = torch.cuda.is_available() and True
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

batch_size = 32

max_seq_len = 2000
alphabet = "abcdefghijklmnopqrstuvwxyz-"
alpha_dataset = alpha_set(alphabet, max_seq_len, 3200)
alpha_generator = data.DataLoader(alpha_dataset, batch_size=batch_size, shuffle=True)
loss_function = nn.CrossEntropyLoss(reduction="sum").to(processor)
lstm = LSTM_model(len(alphabet), 400, 1, max_seq_len, batch_size, processor).to(processor)
optimiser = optim.SGD(lstm.parameters(), lr=1e-3, momentum=0.9, nesterov=True)

for i, (seq, label, valid) in enumerate(alpha_generator):
    seq = seq.to(processor)
    label = label.to(processor)
    valid = valid.to(processor)
    lstm.train()
    seq = seq.transpose(0,1)
    #label = label.transpose(0,1)

    #label = label.squeeze(0)
    if i == 0:
        print("Input:\t", seq.size())
        print("Labels:\t", label.size())
        print("Valid:\t", valid)

    
    lstm.zero_grad()

    seq = rnn.pack_padded_sequence(seq, valid, enforce_sorted=False)
    label = rnn.pack_padded_sequence(label, valid, enforce_sorted=False, batch_first=True)
    
    out, hidden = lstm(seq)
    #out = out.squeeze(1)

    out = rnn.pack_padded_sequence(out, valid, enforce_sorted=False)
    if i == 0:
        print("Output:\t", out.data.size())
        print("Hidden:\t", hidden.data.size())
    #out = out.transpose(1, 2)
    
    loss = loss_function(out.data, label.data)
    if loss.item() < min_loss:
        min_loss = loss.item()
    '''
    for j in range(out.size()[1]):
        narrowed_out = torch.narrow(torch.narrow(out, 1, j, 1).squeeze(1), 0, 0, valid[j])
        #print(narrowed_out.size())
        #print(torch.argmax(narrowed_out, dim=1))
        #print(label[j].size())
        narrowed_label = torch.narrow(label[j], 0, 0, valid[j])
        
        loss += loss_function(narrowed_out, narrowed_label)
    '''
    #loss /= out.size()[1]
    loss.backward()
    optimiser.step()
print(min_loss)
print("Finished Training")

Using GPU: True


Input:	 torch.Size([2000, 32, 27])
Labels:	 torch.Size([32, 2000])
Valid:	 tensor([26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26],
       device='cuda:0')
Output:	 torch.Size([832, 27])
Hidden:	 torch.Size([832, 400])


126.07563018798828
Finished Training


In [12]:
with torch.no_grad():
    correct = 0
    for i, (seq, label, valid) in enumerate(alpha_generator):
        lstm.eval()
        test = seq
        seq = seq.to(processor)
        label = label.to(processor)
        valid = valid.to(processor)
        
        seq = seq.transpose(0,1)
        #label = label.transpose(0,1)

        seq = rnn.pack_padded_sequence(seq, valid, enforce_sorted=False)
        
        out, hidden = lstm(seq)

        out = out.transpose(0,1)

        for j in range(batch_size):
            preds = torch.argmax(out[j], dim=1)[:valid[j]]
            actual = label[j][:valid[j]]
            truths = [1 if pred == truth else 0 for pred, truth in zip(preds, actual)]
            correct += sum(truths)
        accuracy = correct/(torch.sum(valid).item())
        break
    
print("Test Accuracy: {0:.3f}%".format(accuracy*100))

print(out.size())
print(test.size())
print(valid.size())
print_seq(out[0].view(1,max_seq_len,27), valid, alphabet)
print_seq(test[0].view(1,max_seq_len,27), valid, alphabet)

Test Accuracy: 95.913%
torch.Size([32, 2000, 27])
torch.Size([32, 2000, 27])
torch.Size([32])
Sequence 0
klmnopqrstuvwxyzabcdefghij
Sequence 0
jklmnopqrstuvwxyzabcdefghi
