In [None]:
on_colab = False
if on_colab:
    from google.colab import drive
    !pip install tqdm
    !pip install Biopython

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn
#import torchvision.transforms.functional as TF
#from torchvision.transforms import ToTensor
#from torchvision.utils import save_image
import torch.optim as optim
import random
import re

from sklearn.manifold import TSNE

from sklearn.decomposition import PCA

from mpl_toolkits import mplot3d

import scipy.misc
from scipy import ndimage


import tqdm

from Bio import SeqIO

import os
import sys
import datetime

import math

from torch.utils import data

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from PIL import Image

In [None]:

class Dataset(data.Dataset):
    '''
    Checks whether a given sequence is legal

    Parameters
    ----------
    seq : String
        Raw unprocessed string from the fasta file

    Returns
    ----------
    Bool
    '''
    def __is_legal_seq__(self, seq):
        len_val = not (len(seq) > self.max_seq_len)
        cont_val = not(('X' in seq) or ('B' in seq) or ('Z' in seq) or ('J' in seq))
        return len_val and cont_val

    '''
    Generates a dictionary given a string with all the elements that will be in the dictionary

    Parameters
    ----------
    acids : String
        An "alphabet" to generate the dictionary from.
        The last char will be used as a padding value

    Returns
    ----------
    acid_dict : Dictionary of Tensors
        A dictionary with the same length as the input string.
        It is effectively a one-hot encoding of acids
    '''
    def __gen_acid_dict__(self, acids):
        acid_dict = {}
        int_acid_dict = {}
        int_to_acid_dict={}
        for i, elem in enumerate(acids):
            temp = torch.zeros(len(acids))
            temp[i] = 1
            acid_dict[elem] = temp
            int_acid_dict[temp] = i
            int_to_acid_dict[i] = temp
        return acid_dict, int_acid_dict, int_to_acid_dict

    '''
    Initialisation for the Dataset

    Parameters
    ----------
    filename : String
        Path to a fasta file with data.
    max_seq_len : Int
        An integer representing the longest sequences we want to take into account.
    acids : String
        An "alphabet" to generate the dictionary from.

    Variables
    ----------
    acid_dict : Dictionary of Tensors
        See __gen_acid_dict__
    data : List of Strings
        The entire input file loaded as strings
    '''
    def __init__(self, filename, max_seq_len, output_type="onehot", acids="ACDEFGHIKLMNPQRSTVWY-", get_prot_class=False):
        elem_list = []
        label_list = []
        self.acids = acids
        self.get_prot_class = get_prot_class
        self.output_type = output_type
        self.acid_dict, self.int_acid_dict, self.int_to_acid_dict = self.__gen_acid_dict__(acids)
        self.max_seq_len = max_seq_len
        # Loading the entire input file into memory
        prot_class_re = re.compile(r" (\w)\.\d+")
        for i, elem in enumerate(SeqIO.parse(filename, "fasta")):
            if self.__is_legal_seq__(elem.seq.upper()):
                elem_list.append(elem.seq.upper())
                if get_prot_class:
                    label_list.append(prot_class_re.search(elem.description).group(1))
        self.data = elem_list
        self.prot_labels = label_list

    '''
    Method to get the length of the dataset

    Returns
    ----------
    Int : Length of the entire dataset
    '''
    def __len__(self):
        return len(self.data)

    '''
    Preprocesses a sequence into something usable by an LSTM and outputs it

    Parameters
    ----------
    index : Int
        Index to take the data from

    Returns
    ----------
    tensor_seq : Tensor of size max_seq_len x len(acid_dict)
        The padded, preprocessed Tensor of one-hot encoded acids.
        If output_type="embed" then it will have size max_seq_len
    labels_seq : Tensor of size max_seq_len
        Contains the labels for each element in tensor_seq
        as the correct index of the one-hot encoding
    valid_elems : Int
        Integer value representing the length of the sequence before padding
    '''
    def __getitem__(self, index):
        seq = self.data[index]
        #print(seq)
        #print(self.acid_dict.keys())
        valid_elems = min(len(seq), self.max_seq_len)

        seq = str(seq).ljust(self.max_seq_len+1, self.acids[-1])
        temp_seq = [self.acid_dict[x] for x in seq]
        if self.output_type == "embed":
            tensor_seq = torch.argmax(torch.stack(temp_seq[:-1]), dim=1).long()
        else:
            tensor_seq = torch.stack(temp_seq[:-1], dim=0).float()#.view(self.max_seq_len, 1, -1)

        # Labels consisting of the index of correct class
        #                                               |
        #                                   CHANGE THIS V TO 1: WHEN FINISHED PREDICTING IDENTITY
        labels_seq = torch.argmax(torch.stack(temp_seq[1:]), dim=1).long()#.view(-1, 1)
        if self.get_prot_class:
            return tensor_seq, labels_seq, valid_elems, self.prot_labels[index]
        else:
            return tensor_seq, labels_seq, valid_elems


In [None]:
#print(len(gc.get_objects()))
acids = "ACDEFGHIKLMNOPQRSTUVWY-"

large_file = "uniref50.fasta"

if on_colab:
    small_file = "/content/gdrive/My Drive/proteinData/100k_rows.fasta"

    small_label_file1 = "/content/gdrive/My Drive/proteinData/astral-scopedom-seqres-gd-sel-gs-bib-40-2.07.fasta"
    big_label_file1 = "/content/gdrive/My Drive/proteinData/astral-scopedom-seqres-gd-sel-gs-bib-95-2.07.fasta"

    small_label_file2 = "/content/gdrive/My Drive/proteinData/scope_data_40.fasta"
    big_label_file2 = "/content/gdrive/My Drive/proteinData/scope_data_95.fasta"
else:
    large_file = "uniref50.fasta"
    small_file = "100k_rows.fasta"
    small_label_file2 = "scope_data_40.fasta"
    big_label_file2 = "scope_data_95.fasta"

test_file = "test.fasta"

max_seq_len = 500

# Good sizes: 16/700 or 32/400 on laptop
# 32/1500 on desktop

#hidden_dim = 200

#hidden_layers = 1

# Use Cuda if available
use_cuda = torch.cuda.is_available() and True
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

batch_size = 64
dataset = Dataset(small_file, max_seq_len, acids=acids)
base_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=16)

In [None]:
class CNN(nn.Module):

    def __init__(self, latent_space):
        super(CNN, self).__init__()
        self.latent_dim = latent_space

        self.embed = nn.Embedding(23, 30)

        #Encode Layer
        self.conv1 = nn.Conv1d(30, 20, 5, padding=1)#self.conv(30, 15, 5)
        self.conv2 = nn.Conv1d(20, 14, 5, padding=1)#self.conv(15, 8, 5)
        self.conv3 = nn.Conv1d(14, 8, 5, padding=1)#self.conv(8, 4, 5)

        self.conv4 = nn.Conv1d(8, 4, 5, padding=1)

        self.conv_mid = nn.Conv1d(4,4,5,padding=2)
    
        #Decode Layer
        self.conv5 = nn.Conv1d(4, 8, 5, padding=1)#self.conv(4, 8, 5)
        self.conv6 = nn.Conv1d(8, 14, 5, padding=1)#self.conv(8, 14, 5)
        self.conv7 = nn.Conv1d(14, 20, 5, padding=1)#self.conv(14, 20, 5)
        self.conv8 = nn.Conv1d(20, 23, 5, padding=2)#self.conv(20, 23, 5)

        #self.Max_pool = torch.nn.MaxPool1d(2,return_indices=True)
        self.Avg_pool = torch.nn.AvgPool1d(2)

        self.Latent_avg_pool =  nn.AdaptiveAvgPool1d(self.latent_dim)#nn.AdaptiveMaxPool1d(self.latent_dim)

        self.Up_sample_first = nn.Upsample(62, scale_factor=None, align_corners=None)
        self.Up_sample_mid = nn.Upsample(size=None, scale_factor=2, align_corners=None)
        self.Up_sample_last = nn.Upsample(size=500, scale_factor=None, align_corners=None)

        #self.UnPool = nn.MaxUnpool1d(2, stride=2)
    
    def initialize(self, input_data):
        init_x = self.embed(input_data)
        init_x = torch.transpose(init_x, 1, 2)
        #init_x = rnn.pack_padded_sequence(init_x, valid_elems, enforce_sorted=False, batch_first=True)
        return init_x

    def Encode(self,data):
        x = F.relu(self.conv1(data))
        x = self.Avg_pool(x)

        x = F.relu(self.conv2(x))
        x = self.Avg_pool(x)

        x = F.relu(self.conv3(x))
        x = self.Avg_pool(x)

        x = self.conv4(x)
        x = self.Latent_avg_pool(x)
        x = self.conv_mid(x)
        return x
  
    def Decode(self,x):

        x_con = self.Up_sample_first(x)
        x_con = F.relu(self.conv5(x_con))

        x_con = self.Up_sample_mid(x_con)
        x_con = F.relu(self.conv6(x_con))

        x_con = self.Up_sample_mid(x_con)
        x_con = F.relu(self.conv7(x_con))

        x_con = self.Up_sample_last(x_con)
        x_con = F.relu(self.conv8(x_con))

        return x_con

    def forward(self, data):
        init_data = self.initialize(data)
        x = self.Encode(init_data)
        x_con = self.Decode(x)
        #x_con = torch.sigmoid(x_con)
        return x_con, torch.flatten(x, start_dim=1)

    def save(self, filename):
        args_dict = {
            "latent_space": self.latent_dim,
        }
        torch.save({
            "state_dict": self.state_dict(),
            "args_dict": args_dict
        }, filename)

In [None]:
batch_size = 128
dataset = Dataset(small_file, max_seq_len, output_type="embed", acids=acids)
base_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=16)

model = CNN(50).to(processor)

In [None]:
lr = 1e-6
loss_func = nn.CrossEntropyLoss(reduction="mean").to(processor)



optimizer = optim.Adam(model.parameters(), lr=lr)#optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30)

loss_list = []
epochs = 100
time_diff=0

batches = float("inf")
min_loss = float("inf")
no_improv = 0

for epoch in range(epochs):
    
    running_loss = 0.0
    for batch_index, (batch, labels, valid_elems) in enumerate(base_generator):
        start_time = datetime.datetime.now()

        est_time_left = str(time_diff*(min(batches, dataset.__len__()/batch_size) - batch_index) + (time_diff*min(batches, dataset.__len__()/batch_size)) * (epochs - (epoch+1))).split(".")[0]
        
        sys.stdout.write("\rEpoch: {0}. Batch: {1}. Min loss: {2:.5f}. Estimated time left: {3}. Best: {4} batches ago.".format(epoch+1, batch_index+1, min_loss, est_time_left, no_improv))
        #optimizer.zero_grad()

        batch = batch.to(processor)
        
        #labels = torch.argmax(batch,dim=2)
        
        #batch = torch.transpose(batch, 1,2)
        outputs, x = model(batch)

        outputs = torch.transpose(outputs, 1, 2)

        outputs = rnn.pack_padded_sequence(outputs, valid_elems, enforce_sorted=False, batch_first=True)
        labels = rnn.pack_padded_sequence(batch, valid_elems, enforce_sorted=False, batch_first=True)

        #labels = rnn.pack_padded_sequence(labels, valid_elems, enforce_sorted=False, batch_first=True)

        loss = loss_func(outputs.data, labels.data)

        loss.backward()
        optimizer.step()
        
        #if batch_index % 500==0:
        #  print("epoch: {}. batch_index: {}. Loss = {}".format(epoch, batch_index, loss))
        
        if loss.item() < min_loss:
            no_improv = 0
            model.save("temp_best_cnn_model.pth")
            min_loss = loss.item()
        else:
            no_improv += 1
          
        loss_list.append(loss.item())
    #scheduler.step()
    end_time = datetime.datetime.now()
    time_diff = end_time - start_time


In [None]:
print(len(loss_list))
plt.plot(list(range(len(loss_list))), loss_list)
plt.show()

load_model = True

if load_model:
    loaded_params = torch.load("temp_best_cnn_model.pth")
    #model = CNN(**loaded_params["args_dict"]).to(processor)
    model = CNN(100).to(processor)
    model.load_state_dict(loaded_params["state_dict"])
    
batch_size1=1

dataset = Dataset(small_file, max_seq_len, output_type="embed", acids=acids)
base_generator = data.DataLoader(dataset, batch_size=batch_size1, shuffle=True, num_workers=16)

def print_seq(preds, valid, alphabet):
    for i, seq in enumerate(preds):
        print("Sequence {}".format(i))
        indexes = torch.argmax(seq[:valid[i]], dim=1)
        ret_val = [alphabet[x] for x in indexes]
        print("".join(ret_val))
        return ret_val

for batch_index, (batch, labels, valid_elems) in enumerate(base_generator):
      batch = batch.to(processor)

      pred, x = model(batch)

      pred = torch.squeeze(pred,dim=0)
      batch = torch.squeeze(batch,dim=
                            0)
      batch = torch.stack([dataset.int_to_acid_dict[int(elem)] for elem in batch.cpu()])

      batch = torch.transpose(batch, 0,1)

      pred = pred.view(1,23,500)
      batch = batch.view(1,23,500)

      batch = torch.transpose(batch, 1,2)
      pred = torch.transpose(pred, 1,2)

      break

print("\nInput")
inp_str = print_seq(batch[0].view(1,batch.size()[1], batch.size()[2]), valid_elems, acids)

print("\nPrediction")
pred_str = print_seq(pred[0].view(1,pred.size()[1], pred.size()[2]), valid_elems, acids)

acc = np.sum([1 if pred_elem == original_elem else 0 for pred_elem, original_elem in zip(inp_str,pred_str)])/len(pred_str)

print(acc)

In [None]:
def structure_to_int(structure_chars):
    count = 0 
    char_to_int = {}
    int_to_char = {}
    for label in structure_chars:
      if label not in char_to_int:
        char_to_int.update({label : count})
        int_to_char.update({count : label})
        count += 1
    return char_to_int, int_to_char

def plot_data(nr_plots, model_output, labels, colors, title, rev_structure_labels):
    fig, ax = plt.subplots(1, figsize=(12, 8))
    for unique in np.unique(labels):
      mask = [elem==unique and unique != 'd' for elem in labels]
      unique_list = model_output[mask]
      ax.scatter(unique_list[:,0], unique_list[:,1], label=rev_structure_labels[unique], marker='.')
    plt.title(title)
    plt.legend()
    plt.show()
    



In [None]:
data_points = []
structure_labels = []
batch_size1 = 1

structure_chars = Chars = "dcabgfe"

char_to_int_dict, rev_structure_labels = structure_to_int(structure_chars)


dataset = Dataset(small_label_file2, max_seq_len, acids=acids, get_prot_class=True)
base_generator = data.DataLoader(dataset, batch_size=batch_size1, shuffle=True, num_workers=16)

for batch_index, (batch, labels, valid_elems, protein_label) in enumerate(base_generator):
    if protein_label[0] == 'd':
      continue
    batch = torch.transpose(batch, 1,2).to(processor)
    output, x_out = model(batch)

    x_out = torch.squeeze(torch.squeeze(x_out, dim=0),dim=0)
    data_points.append((x_out.cpu()).detach().numpy())
    if protein_label[0] == 'd':
      print(protein_label[0])

    structure_labels.append(char_to_int_dict[protein_label[0]])



In [None]:
data_points = np.array(data_points)
structure_labels = np.array(structure_labels)

pca = PCA(n_components = 2)
PCA_data = pca.fit_transform(data_points)

t_sne_data = TSNE(n_components=2, perplexity=15, learning_rate=200).fit_transform(data_points)


#z = low_dim_points[:,2]

In [None]:
colors = ['Grey', 'Purple', 'Blue', 'Green', 'Orange', 'Red',
          'Yellow', 'Black']
        
#pca = PCA(n_components = 2)
#low_dim_points = pca.fit_transform(data_points)


data_points = np.array(data_points)
x = t_sne_data[:,0]
y = t_sne_data[:,1]

plt.scatter(x,y, marker= '.', s=0.5)
plt.show()

print(loss_func)
lossfun = "CrossEntropy"
optimizer = "Adagrad"

t_sne_title = "T-sne, using batch_size: {}, lr: {}, epochs: {}, loss_func: {}, optimizer: {}".format(batch_size, lr, epochs, lossfun,optimizer)
PCA_title = "PCA, using batch_size: {}, lr: {}, epochs: {}, loss_func: {}, optimizer: {}".format(batch_size, lr, epochs, lossfun,optimizer)
print(t_sne_title)
print(PCA_title)

plot_data(10000,t_sne_data,structure_labels, colors,t_sne_title,rev_structure_labels)
plot_data(10000,PCA_data,structure_labels, colors,PCA_title,rev_structure_labels)