In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

import numpy as np
import matplotlib.pyplot as plt
import sys

from Bio import SeqIO
from datetime import datetime

from torch.utils import data
#from data_generator import data_generator
from data_generator import Dataset
from lstm import LSTM_model
from lstm import LSTMCell
from time import sleep

#from google.colab import drive

In [None]:
filename = "100k_rows.fasta"
dataset = Dataset(filename, 2000)
batch_size = 32
base_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)

print(dataset.__getitem__(0)[0].size())
print(dataset.__getitem__(0)[1].size())

In [None]:
acids = "ACDEFGHIKLMNOPQRSTUVWY-"
loss_list = []
large_file = "uniref50.fasta"
small_file = "100k_rows.fasta"
test_file = "test.fasta"

max_seq_len = 2000

# Good sizes: 16/700 or 32/400 on laptop
# 32/1500 on desktop
batch_size = 16
hidden_dim = 500

hidden_layers = 1

# Use Cuda if available
use_cuda = torch.cuda.is_available() and True
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

#model = LSTM_model(len(acids), hidden_dim, hidden_layers, max_seq_len).to(processor)
model = LSTMCell(len(acids), hidden_dim, hidden_layers, max_seq_len).to(processor)
#model = test_LSTM(len(acids), hidden_dim, hidden_layers).to(processor)
loss_function = nn.CrossEntropyLoss().to(processor)
#loss_function = nn.NLLLoss().to(processor)
optimiser = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, nesterov=True)

dataset = Dataset(small_file, max_seq_len, acids=acids)
base_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=16)
print(dataset.__len__())

In [None]:
batches = 100
time_diff = 0
no_improv = 0
min_loss = float("inf")


for i, (batch, labels) in enumerate(base_generator):
    start_time = datetime.now()
    
    sys.stdout.write("\rBatch: {0}. Min loss: {1:.5f}. Estimated time left: {2}. Batches since improvement: {3}.".format(i+1, min_loss, time_diff*(batches - i), no_improv))
    
    model.zero_grad()

    batch = batch.to(processor)
    labels = labels.to(processor)
    hn, cn = model.init_hidden(batch.size()[0])
    hn = hn.to(processor)
    cn = cn.to(processor)


    # Transposing from (batch x seq x feature_size) to (seq x batch x feature_size)
    batch = torch.transpose(batch, 0, 1)
    labels = torch.transpose(labels, 0, 1)
    loss = 0
    #print(batch.size())
    for j in range(batch.size()[0]):
        #print(batch.size())
        #seq_elem = seq_elem.to(processor)
        #label_elem = label_elem.to(processor)
        #print("Seq:", seq_elem.size())
        #print("Label:", label_elem.size())
        output, (hn, cn) = model(batch[j], (hn,cn))
        l = loss_function(output, labels[j])
        loss += l
        
    loss_list.append(loss)
    loss.backward()
    optimiser.step()
        

    #print(batch.size())
    #print(labels.size())
    
    end_time = datetime.now()
    time_diff = end_time - start_time
    
    
    if i+1 >= batches:
        break

In [None]:
#batches = dataset.__len__()
batches = 150
time_diff = 0
no_improv = 0
min_loss = float("inf")


for i in range(10):
    for i, (batch, labels) in enumerate(base_generator):
        start_time = datetime.now()

        sys.stdout.write("\rBatch: {0}. Min loss: {1:.5f}. Estimated time left: {2}. batches since improvement: {3}.".format(i+1, min_loss, time_diff*(batches - i), no_improv))
    
        model.zero_grad()

        batch = batch.to(processor)
        labels = labels.to(processor)

        output, (hn, cn) = model(batch)
        #print("Output size:", output.size())
        #print("labels size:", labels.size())
        #print("Hidden state:", hn.size())
        #print("Cell state:", cn.size())
        loss = loss_function(output, labels)
        loss_list.append(loss)
        loss.backward()
        optimiser.step()

        if (loss.item() < min_loss):
            no_improv= 0
            min_loss = loss.item()
            checkpoint = {'model': model,
                          'state_dict': model.state_dict(),
                          'optimiser' : optimiser,
                          'optim_state' : optimiser.state_dict()}
        else:
            no_improv += 1
            #torch.save(checkpoint, 'checkpoint.pth')

        end_time = datetime.now()
        time_diff = end_time - start_time

        if i+1 >= batches:
            break

In [None]:
# Plotting the loss through the epochs
plt.plot(loss_list)
#plt.scatter(range(len(loss_list)), loss_list)
plt.title("Loss plotted through each batch")
plt.ylabel("Loss")
plt.xlabel("Batch number")
#plt.yscale("log")
#plt.xscale("log")
plt.savefig("loss_log.png")

In [None]:
def convert(output):
    letter = torch.argmax(output)
    return dataset.acids[letter]

#def 

with torch.no_grad():
    dataset = Dataset(small_file, max_seq_len, acids=acids)
    base_generator = data.DataLoader(dataset, batch_size=8, shuffle=False, num_workers=16)
    
    for batch, _ in base_generator:
        batch = batch.to(processor)
        #labels = labels.to(processor)
        hn, cn = model.init_hidden(batch.size()[0])
        hn = hn.to(processor)
        cn = cn.to(processor)
        #print(hn)

        # Transposing from (batch x seq x feature_size) to (seq x batch x feature_size)
        batch = torch.transpose(batch, 0,1)
        #labels = torch.transpose(labels, 0, 1)

        #print(batch.size())
        output_list = []
        for seq_elem in batch:
            #print(seq_elem.size())
            output, (hn, cn) = model(seq_elem, (hn, cn))
            output_list.append(output)
            #print(output.size())
        test = torch.stack(output_list)
        break

    print(test.size())
    print(len(output_list))
    print(output_list[0].size())
    print(convert(output_list[0][0]))
    print(convert(output_list[1][0]))
    print(convert(output_list[2][0]))
    print(convert(output_list[3][0]))
    print(convert(output_list[4][0]))
    print(convert(output_list[5][0]))