In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt
import sys

from Bio import SeqIO
from datetime import datetime

from torch.utils import data

from sklearn.manifold import TSNE

# Our code
from data_generator import Dataset
from lstm import LSTM_model
from lstm import LSTMCell
from alpha_set import alpha_set
from print_seq import print_seq

In [None]:
#acids = "abcdefghijklmnopqrstuvwxyz-"
acids = "ACDEFGHIKLMNOPQRSTUVWY-"
large_file = "uniref50.fasta"
small_file = "100k_rows.fasta"
test_file = "test.fasta"

#max_seq_len = 50
#max_seq_len = 2000
max_seq_len = 500

loss_list = []

# Good sizes: 64/512 on laptop
# 32/1500 on desktop
batch_size = 64
hidden_dim = 512

embed_size = 30
hidden_layers = 1

# Use Cuda if available
use_cuda = torch.cuda.is_available() and True
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True


# Initialising data generator
dataset = Dataset(small_file, max_seq_len, output_type="embed", acids=acids)
#dataset = alpha_set(acids, max_seq_len, 3200)
base_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)

# Initialising for training
model = LSTM_model(len(acids), embed_size, hidden_dim, hidden_layers, max_seq_len, batch_size, processor, dropout=0).to(processor)
loss_function = nn.CrossEntropyLoss(reduction="mean").to(processor)
#optimiser = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True)

# Best lr so far is 2e-2
optimiser = optim.Adam(model.parameters(), lr=8e-4)#, weight_decay=0.001)

In [None]:
# Initialising some variables for use in training
batches = 500 #float("inf")
time_diff = 0
no_improv = 0
min_loss = float("inf")
epochs = 30
print_stuff = False

# Main training loop
for epoch in range(epochs):
    for i, (batch, labels, valid_elems) in enumerate(base_generator):

        # Keeping track of stuff
        start_time = datetime.now()

        est_time_left = str(time_diff*(min(batches, dataset.__len__()/batch_size) - i) + (time_diff*min(batches, dataset.__len__()/batch_size)) * (epochs - (epoch+1))).split(".")[0]
        #est_time_left = str(time_diff*(min(batches, dataset.__len__()) - i)+time_diff*(epochs-(epoch+1))*min(batches, dataset.__len__()/batch_size)).split(".")[0]
        sys.stdout.write("\rEpoch: {0}. Batch: {1}. Min loss: {2:.5f}. Estimated time left: {3}. Best: {4} batches ago.".format(epoch+1, i+1, min_loss, est_time_left, no_improv))

        # Putting data on gpu
        batch = batch.to(processor)
        labels = labels.to(processor)
        valid_elems = valid_elems.to(processor)

        # Transposing from (batch x seq x feature_size) to (seq x batch x feature_size)
        batch = batch.transpose(0, 1)

        # Resetting gradients
        model.train()
        model.zero_grad()
    
        if i == 0 and print_stuff:
            print("\nInput:\t", batch.size())
            print("Labels:\t", labels.size())
            print("Valid:\t", valid_elems)

        labels = rnn.pack_padded_sequence(labels, valid_elems, enforce_sorted=False, batch_first=True)

        out, hidden = model(batch, valid_elems)
        print("\nBatch:", batch.grad_fn)
        print("Out:", out.grad_fn)
        print("Hidden:", hidden.grad_fn)
        
        out = rnn.pack_padded_sequence(out, valid_elems, enforce_sorted=False)
        if i == 0 and print_stuff:
            print("Output:\t", out.data.size())
            print("Hidden:\t", hidden.data.size())


        #print(out.data.size())
        #print(labels.data.size())
        loss = loss_function(out.data, labels.data)
        print(loss, "\n")
    
        '''
        # This bit is replaced by packing labels and the line above
        # Backpropping only through the non-padded parts
        loss = 0
        for j in range(out.size()[1]):
        #print(out.size())
        #print(torch.narrow(out, 1, j, 1).squeeze(1).size())
        
        narrowed_out = torch.narrow(torch.narrow(out, 1, j, 1).squeeze(1), 0, 0, valid_elems[j])
        narrowed_labels = torch.narrow(labels[j], 0, 0, valid_elems[j])

        #print_seq(narrowed_out[0].view(1,2000,23))
        
        print(narrowed_out.size())
        print(narrowed_labels.size())
        loss += loss_function(narrowed_out, narrowed_labels)
        '''
        #loss /= out.size()[1]

        loss_list.append(loss.item())
        loss.backward()
        optimiser.step()

        if loss.item() < min_loss:
            torch.save(model.state_dict(), "temp_best_model.pth")
            min_loss = loss.item()
            no_improv = 0
        else:
            no_improv += 1
        
    
    
        # For tracking progress
        end_time = datetime.now()
        time_diff = end_time - start_time
    
        # Breaking when it's run through the given number of batches
        if i+1 >= batches:
            break

torch.save(model.state_dict(), "temp_model.pth")

In [None]:
# Plotting the loss through the epochs
plt.plot(loss_list)
#plt.scatter(range(len(loss_list)), loss_list)
plt.title("Loss plotted through each batch")
plt.ylabel("Loss")
plt.xlabel("Batch number")
#plt.yscale("log")
#plt.xscale("log")
plt.savefig("loss_log.png")

In [None]:
# Predicting
accuracies = []
load_model = True
with torch.no_grad():
    # Loading saved model from file
    if load_model:
        model.load_state_dict(torch.load("best_lstm.pth"))
    model.eval()

    # Initalising data generator
    #dataset = Dataset(small_file, max_seq_len, acids=acids)
    #base_generator = data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=8)

    correct = 0
    
    # For loops is easiest way to get an element from
    # the generator even though we only loop once
    for i, (batch, labels, valid_elems) in enumerate(base_generator):
        test = batch
        batch = batch.to(processor)
        labels = labels.to(processor)

        # Transposing from (batch x seq x feature_size) to (seq x batch x feature_size)
        batch = batch.transpose(0,1)
        #labels = torch.transpose(labels, 0, 1)

        #batch = rnn.pack_padded_sequence(batch, valid_elems, enforce_sorted=False)

        out, hidden = model(batch, valid_elems)

        out = out.transpose(0,1)

        print(out.size())
        print(valid_elems.size())
        for j in range(batch_size):
            preds = torch.argmax(out[j], dim=1)[:valid_elems[j]]
            actual = labels[j][:valid_elems[j]]
            truths = [1 if pred == truth else 0 for pred, truth in zip(preds, actual)]
            correct += sum(truths)
        accuracy = correct/(torch.sum(valid_elems).item())
        break

print(labels.size())
print("Test Accuracy: {0:.3f}%".format(accuracy*100))

#print(torch.argmax(out[0], dim=1))
#print(torch.max(out[0,:10], dim=1))
#print(torch.argmax(out[0,:10], dim=1))
#print(out[0,:10])

print("\nPredictions")
print_seq(out[0].view(1,out.size()[1], out.size()[2]), valid_elems, acids)
#print_seq(out, valid_elems, acids)
print("\nInput")
#print(test.size())
print_test = ""
for elem in test[0][:valid_elems[0]]:
    print_test += acids[elem]
print(print_test)
#print_seq(test[0].view(1,test.size()[1], test.size()[2]), valid_elems, acids)

print_labels = ""
for elem in labels[0][:valid_elems[0]]:
    print_labels += acids[elem]

print("\nLabels")
print(print_labels)

#print_seq(out, valid_elems)


## T-SNE and Visualisation 

In [None]:
scope_file = "scope_data_40.fasta"
scope_dataset = Dataset(scope_file, max_seq_len, output_type="embed", acids=acids, get_prot_class=True)
scope_generator = data.DataLoader(scope_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
load_model = True

model_file = "best_lstm.pth"

with torch.no_grad():
    # Loading saved model from file
    if load_model:
        model.load_state_dict(torch.load(model_file))
    model.eval()

    full_hidden = None
    full_labels = []

    for i, (batch, labels, valid_elems, prot_label) in enumerate(scope_generator):
        batch = batch.to(processor)
        labels = labels.to(processor)

        # Transposing from (batch x seq x feature_size) to (seq x batch x feature_size)
        batch = batch.transpose(0,1)
        #labels = torch.transpose(labels, 0, 1)

        #batch = rnn.pack_padded_sequence(batch, valid_elems, enforce_sorted=False)

        out, hidden = model(batch, valid_elems)
        
        reduced_hidden = torch.mean(hidden, dim=0)

        #print(list(prot_label))
        full_labels = full_labels + list(prot_label)
        if full_hidden is not None:
            full_hidden = torch.cat((full_hidden, reduced_hidden), 0)
        else:
            full_hidden = reduced_hidden

        if i >= 63:#float("inf"):
            break

print("\nFull Size:", full_hidden.size())
print("Starting TSNE")
t_sne = TSNE(n_components=2, perplexity=15, learning_rate=200).fit_transform(full_hidden.cpu())


#print(full_labels)
fig, ax = plt.subplots(1, figsize=(12, 8))
for unique in np.unique(full_labels):
    mask = [elem==unique and unique != 'd' for elem in full_labels]
    unique_list = t_sne[mask]
    ax.scatter(unique_list[:,0], unique_list[:,1], label=unique, marker='.')


ax.legend()
fig.savefig("tsne_base_plot.png")