In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

import numpy as np
import matplotlib.pyplot as plt
import sys

from Bio import SeqIO
from datetime import datetime

from torch.utils import data
#from data_generator import data_generator
from data_generator import Dataset
from lstm import LSTM_model
from lstm import LSTMCell
from time import sleep

In [None]:
def gen_alphabet(mod_val):
    alphabet = "abcdefghijklmnopqrstuvwxyz"
    index = mod_val % 26
    return alphabet[index:] + alphabet[:index]

test = [gen_alphabet(i) for i in range(32)]
print(test)

In [None]:
class alpha_set(data.Dataset):
    def __gen_acid_dict__(self, acids):
        acid_dict = {}
        for i, elem in enumerate(acids):
            temp = torch.zeros(len(acids))
            temp[i] = 1
            acid_dict[elem] = temp
        return acid_dict
    
    def __init__(self, acids, length, num_seqs):
        self.max_seq_len = length
        self.acids = acids
        self.acid_dict = self.__gen_acid_dict__(acids)
        self.data = [gen_alphabet(i) for i in range(num_seqs)]

    def __prepare_seq__(self, seq):
        valid_elems = min(len(seq)+1, self.max_seq_len)
        seq = str(seq).ljust(self.max_seq_len+1, '-')
        temp_seq = [self.acid_dict[x] for x in seq]
        tensor_seq = torch.stack(temp_seq[:-1]).float()
        #valid_elems = torch.Tensor([elem != '-' for elem in seq[:-1]])

        # Labels consisting of the raw tensor
        # labels_seq = torch.stack(temp_seq[1:]).long()

        # Label consisting of last element
        # labels_seq = temp_seq[-1].long()

        # Labels consisting of the index of correct class
        labels_seq = torch.argmax(torch.stack(temp_seq[1:]), dim=1).long()

        #print(labels_seq.size())
        #print(tensor_seq.size())
        #labels_seq = torch.transpose(labels_seq, 0, 1)
        #tensor_seq = torch.transpose(tensor_seq, 0, 1)
        #print("Seq shape:", tensor_seq[1:].size())
        return tensor_seq, labels_seq, valid_elems

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.__prepare_seq__(self.data[index])

# Use Cuda if available
use_cuda = torch.cuda.is_available() and True
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

batch_size = 1
alphabet = "abcdefghijklmnopqrstuvwxyz-"
alpha_dataset = alpha_set(alphabet, 64, 10000)
alpha_generator = data.DataLoader(alpha_dataset, batch_size=batch_size, shuffle=True)
loss_function = nn.CrossEntropyLoss(reduction="sum").to(processor)
lstm = LSTM_model(len(alphabet), 100, 1, 27).to(processor)
optimiser = optim.SGD(lstm.parameters(), lr=1e-3, momentum=0.9, nesterov=True)

for i, (seq, label, valid) in enumerate(alpha_generator):
    seq = seq.to(processor)
    label = label.to(processor)
    valid = valid.to(processor)
    
    seq = seq.transpose(0,1)
    #label = label.transpose(0,1)
    label = label.squeeze(0)
    if i == 0:
        print("Input:\t", seq.size())
        print("Labels:\t", label.size())

    
    lstm.zero_grad()

    out, hidden = lstm(seq)
    out = out.squeeze(1)
    
    if i == 0:
        print("Output:\t", out.size())
        print("Hidden:\t", hidden.size())
    #out = out.transpose(1, 2)
    
    loss = loss_function(out, label)
    loss.backward()
    optimiser.step()

print("Finished Training")

In [None]:
with torch.no_grad():
    for i, (seq, label, valid) in enumerate(alpha_generator):
        seq = seq.to(processor)
        label = label.to(processor)
        valid = valid.to(processor)
        
        seq = seq.transpose(0,1)
        #label = label.transpose(0,1)
        label = label.squeeze(0)
        
        out, hidden = lstm(seq)

        print(out.size())
        #print(seq)
        print("Predictions:", torch.argmax(out, dim=2).transpose(0,1))
        print("Labels:", label)
        #print(torch.argmax)

        if i > 4:
            break

In [None]:
# alternatively, we can do the entire sequence all at once.
# the first value returned by LSTM is all of the hidden states throughout
# the sequence. the second is just the most recent hidden state
# (compare the last slice of "out" with "hidden" below, they are the same)
# The reason for this is that:
# "out" will give you access to all hidden states in the sequence
# "hidden" will allow you to continue the sequence and backpropagate,
# by passing it as an argument  to the lstm at a later time
# Add the extra 2nd dimension
inputs = [torch.randn(1, 3) for _ in range(5)]
lstm = nn.LSTM(3,3)
inputs = torch.cat(inputs).view(len(inputs), 1, -1)

print(inputs.size())

hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3))  # clean out hidden state

print(hidden[0].size(), hidden[1].size())

out, hidden = lstm(inputs, hidden)
print(out.size())
print(out, "\n")
print(hidden)