In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

import numpy as np
import matplotlib.pyplot as plt
import sys

from Bio import SeqIO
from datetime import datetime

from torch.utils import data
#from data_generator import data_generator
from data_generator import Dataset
from lstm import LSTM_model
from lstm import LSTMCell
from time import sleep

import gc
#from google.colab import drive

In [None]:
print(len(gc.get_objects()))
acids = "ACDEFGHIKLMNOPQRSTUVWY-"
large_file = "uniref50.fasta"
small_file = "100k_rows.fasta"
test_file = "test.fasta"

max_seq_len = 2000

# Good sizes: 16/700 or 32/400 on laptop
# 32/1500 on desktop
batch_size = 32
hidden_dim = 200

hidden_layers = 1

# Use Cuda if available
use_cuda = torch.cuda.is_available() and True
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True


# Initialising for training
model = LSTMCell(len(acids), hidden_dim, hidden_layers, max_seq_len).to(processor)
loss_function = nn.CrossEntropyLoss(reduction="mean").to(processor)
optimiser = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True)

# Initialising data generator
dataset = Dataset(small_file, max_seq_len, acids=acids)
base_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=16)
#print(dataset.__len__())

In [None]:
# Initialising some variables for use in training
batches = 500
time_diff = 0
no_improv = 0
min_loss = float("inf")
loss_list = []

# Initialising the hidden layers
#hn, cn = model.init_hidden(batch_size)
#hn = hn.to(processor)
#cn = cn.to(processor)

# Main training loop
for i, (batch, labels, valid_elems) in enumerate(base_generator):

    # Keeping track of stuff
    start_time = datetime.now()
    sys.stdout.write("\rBatch: {0}. Min loss: {1:.5f}. Estimated time left: {2}. Batches since improvement: {3}.".format(i+1, min_loss, time_diff*(batches - i), no_improv))

    # Resetting gradients
    model.zero_grad()

    # Putting data on gpu 
    batch = batch.to(processor)
    labels = labels.to(processor)
    hn, cn = model.init_hidden(batch.size()[0])
    #hn = hn.to(processor)
    #cn = cn.to(processor)

    
    
    # Transposing from (batch x seq x feature_size) to (seq x batch x feature_size)
    batch = torch.transpose(batch, 0, 1)
    labels = torch.transpose(labels, 0, 1)

    # New empty output tensor to contain outputs for all elements in the sequences
    output = torch.Tensor.new_empty(batch, batch.size())
    
    

    # Filling the output tensor
    for j in range(batch.size()[0]):
        output[j], (hn, cn) = model(batch[j], (hn.detach(),cn.detach()))
        #output_mask = j <= valid_elems
        #reduced_output, reduced_labels = output[output_mask], labels[j][output_mask]
        #if reduced_output.size()[0] == 0:
        #    break
        #print(reduced_output, reduced_labels)
        #loss = loss_function(reduced_output, reduced_labels)
        #cum_loss += loss
        #loss_list.append(loss)
        #retain = j < batch.size()[0] - 1
        #loss.backward(retain_graph=retain)

    # Backpropping only through the non-padded parts
    loss = 0
    for j in range(batch.size()[1]):
        # Narrowing the output into parts without padding
        narrowed_output = torch.narrow(torch.narrow(output, 1, 0, 1), 0, 0, valid_elems[j])
        narrowed_labels = torch.narrow(torch.narrow(labels, 1, 0, 1), 0, 0, valid_elems[j])
        
        #print(narrowed_output.size())
        #print(narrowed_labels.size())
        #narrowed_output = torch.transpose(narrowed_output, 0, 1)
        #narrowed_labels = torch.transpose(narrowed_labels, 0, 1)

        print(narrowed_output.size())
        print(narrowed_labels.size())
        loss += loss_function(narrowed_output, narrowed_labels)
        
        #loss_list.append(loss.item())
        #loss.backward()
        #optimiser.step()
    loss /= batch_size
    
    if loss < min_loss:
        torch.save(model.state_dict(), "model.pth")
        min_loss = loss
        no_improv = 0
    else:
        no_improv += 1
        
    loss_list.append(loss.item())
    loss.backward()
    optimiser.step()
    
    #if cum_loss < min_loss:
    #    min_loss = cum_loss.item()
    #loss_list.append(cum_loss.item())
    #loss.backward()#retain_graph=retain)
    #optimiser.step()
    
    #print(batch.size())
    #print(labels.size())

    # For tracking progress
    end_time = datetime.now()
    time_diff = end_time - start_time
    
    # Breaking when it's run through the given number of batches
    if i+1 >= batches:
        break

In [None]:
# Plotting the loss through the epochs
plt.plot(loss_list)
#plt.scatter(range(len(loss_list)), loss_list)
plt.title("Loss plotted through each batch")
plt.ylabel("Loss")
plt.xlabel("Batch number")
#plt.yscale("log")
#plt.xscale("log")
plt.savefig("loss_log.png")

In [None]:
# Predicting
with torch.no_grad():
    # Loading saved model from file
    model.load_state_dict(torch.load("model.pth"))
    model.eval()

    # Initalising data generator
    dataset = Dataset(small_file, max_seq_len, acids=acids)
    base_generator = data.DataLoader(dataset, batch_size=8, shuffle=False, num_workers=16)

    # For loops is easiest way to get an element from
    # the generator even though we only loop once
    for batch, _, valid_elems in base_generator:

        # Putting 
        batch = batch.to(processor)
        #labels = labels.to(processor)
        hn = hn.to(processor)
        cn = cn.to(processor)
        #print(hn)

        # Transposing from (batch x seq x feature_size) to (seq x batch x feature_size)
        batch = torch.transpose(batch, 0,1)
        labels = torch.transpose(labels, 0, 1)

        output_list = [[] for _ in valid_elems]
        for j in range(batch.size()[0]):
            #print(seq_elem.size())
            #print(batch.size())
            #print(hn.size())
            #print(cn.size())
            output, (hn, cn) = model(batch[j], (hn, cn))
            #print(output.size())
            output_mask = j < valid_elems
            reduced_output = output[output_mask]
            #print(reduced_output.size())
            if reduced_output.size()[0] == 0:
                break
            k = 0
            #print(output_mask)

            for i in range(valid_elems.size()[0]):
                if output_mask[i]:
                    output_list[i].append(reduced_output[k])
                    k += 1

        output_list = [torch.stack(elem) for elem in output_list]
        break
print(output_list[0].size())

def print_preds(preds):
    for i, seq in enumerate(preds):
        print("Sequence {}".format(i))
        indexes = torch.argmax(seq, dim=1)
        ret_val = [acids[x] for x in indexes]
        print("".join(ret_val))
        
print_preds(output_list)