In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn
import torchvision.transforms.functional as TF
from torchvision.transforms import ToTensor
from torchvision.utils import save_image
import torch.optim as optim
import random

from mpl_toolkits import mplot3d

import scipy.misc
from scipy import ndimage

!pip install tqdm
import tqdm

!pip install Biopython
from Bio import SeqIO

import os
import sys
import datetime

import math

from torch.utils import data

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from PIL import Image



from google.colab import drive

In [None]:
class Dataset(data.Dataset):
    # Checks whether a given sequence is legal
    def __is_legal_seq__(self, seq):
        len_val = not (len(seq) > self.max_seq_len)
        cont_val = not(('X' in seq) or ('B' in seq) or ('Z' in seq) or ('J' in seq))
        return len_val and cont_val

    # Generates a dictionary given a string with all the elements
    def __gen_acid_dict__(self, acids):
        acid_dict = {}
        int_acid_dict = {}
        for i, elem in enumerate(acids):
            temp = torch.zeros(len(acids))
            temp[i] = 1
            acid_dict[elem] = temp
            int_acid_dict[temp] = i
        return acid_dict, int_acid_dict

    def __init__(self, filename, max_seq_len, acids="ACDEFGHIKLMNPQRSTVWY-", int_version=False):
        elem_list = []
        self.acids = acids
        self.acid_dict, self.int_acid_dict = self.__gen_acid_dict__(acids)
        self.max_seq_len = max_seq_len
        self.int_version = int_version
        # Loading the entire input file into memory
        for i, elem in enumerate(SeqIO.parse(filename, "fasta")):
            if self.__is_legal_seq__(elem.seq):
                elem_list.append(elem.seq)
        self.data = elem_list

    def __len__(self):
        return len(self.data)

    def __prepare_seq__(self, seq):
      valid_elems = min(len(seq)+1, self.max_seq_len)

      seq = str(seq).ljust(self.max_seq_len+1, self.acids[-1])
      temp_seq = [self.acid_dict[x] for x in seq]
      tensor_seq = torch.stack(temp_seq[:-1], dim=0).float()#.view(self.max_seq_len, 1, -1)

      # Labels consisting of the index of correct class
      labels_seq = torch.argmax(torch.stack(temp_seq[1:]), dim=1).long()#.view(-1, 1)

      return tensor_seq, labels_seq, valid_elems

    def __getitem__(self, index):
        return self.__prepare_seq__(self.data[index])


In [None]:
#print(len(gc.get_objects()))
acids = "ACDEFGHIKLMNOPQRSTUVWY-"
large_file = "uniref50.fasta"
small_file = "/content/gdrive/My Drive/proteinData/100k_rows.fasta"
small_label_file = "/content/gdrive/My Drive/proteinData/astral-scopedom-seqres-gd-sel-gs-bib-40-2.07.fasta"
big_label_file = "/content/gdrive/My Drive/proteinData/astral-scopedom-seqres-gd-sel-gs-bib-95-2.07.fasta"
test_file = "test.fasta"

max_seq_len = 500

# Good sizes: 16/700 or 32/400 on laptop
# 32/1500 on desktop
batch_size = 32
#hidden_dim = 200

#hidden_layers = 1

# Use Cuda if available
use_cuda = torch.cuda.is_available() and True
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

dataset = Dataset(small_file, max_seq_len, acids=acids, int_version=False)
base_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=16)

In [None]:
class CNN(nn.Module):
  def __init__(self, latent_space):
    super(CNN, self).__init__()
    self.latent_dim = latent_space
    
    #Encode Layer
    self.conv1 = nn.Conv1d(23, 15, 3, padding=1)
    self.conv2 = nn.Conv1d(15, 8, 3, padding=1)
    self.conv3 = nn.Conv1d(8, 4, 3, padding=1)
    self.conv4 = nn.Conv1d(4, 1, 3, padding=1)

    #Decode layer
    #self. t_conv1 = nn.ConvTranspose1d(184,92,3, padding=1)
    #self. t_conv2 = nn.ConvTranspose1d(92,46,3, padding=1)
    #self. t_conv3 = nn.ConvTranspose1d(46,23,3, padding=1)

    self.conv5 = nn.Conv1d(1, 4, 3, padding=1)
    self.conv6 = nn.Conv1d(4, 8, 3, padding=1)
    self.conv7 = nn.Conv1d(8, 15, 3, padding=1)
    self.conv8 = nn.Conv1d(15, 23, 3, padding=1)

    self.Max_pool = torch.nn.MaxPool1d(2,return_indices=True)
    self.Avg_pool = torch.nn.AvgPool1d(2)

    self.Latent_avg_pool =  nn.AdaptiveAvgPool1d(self.latent_dim)

    self.Up_sample_first = nn.Upsample(size=16, scale_factor=None, mode='nearest', align_corners=None)
    self.Up_sample_mid = nn.Upsample(size=None, scale_factor=2, mode='nearest', align_corners=None)
    self.Up_sample_last = nn.Upsample(size=500, scale_factor=None, mode='nearest', align_corners=None)

    self.UnPool = nn.MaxUnpool1d(2, stride=2)
    

    

  def Encode(self,data):
    """
    x = F.relu(self.conv1(data))
    x, id1 = self.Max_pool(x)

    x = F.relu(self.conv2(x))
    x, id2 = self.Max_pool(x)

    x = F.relu(self.conv3(x))
    x, id3 = self.Max_pool(x)

    x = F.relu(self.conv3(x))
    x, id3 = self.Max_pool(x)
    """
    x = F.relu(self.conv1(data))
    x = self.Avg_pool(x)

    x = F.relu(self.conv2(x))
    x = self.Avg_pool(x)

    x = F.relu(self.conv3(x))
    x = self.Avg_pool(x)

    x = F.relu(self.conv4(x))
    x = self.Latent_avg_pool(x)

    return x#, id1, id2, id3
  
  def Decode(self,x):#, id1, id2, id3):
    """
    x_con = self.UnPool(x, id3, output_size=id2.size())
    x_con = F.relu(self.conv4(x_con))

    x_con = self.UnPool(x_con, id2)
    x_con = F.relu(self.conv5(x_con))

    x_con = self.UnPool(x_con, id1)
    x_con = F.relu(self.conv6(x_con))
    """
    #print("lol", x.size())
    x_con = F.relu(self.conv5(x))
    x_con = self.Up_sample_first(x_con)
    x_con = self.Up_sample_mid(x_con)
    #print("lol", x_con.size())

    x_con = F.relu(self.conv6(x_con))
    x_con = self.Up_sample_mid(x_con)
    #print(x_con.size())
    
    x_con = F.relu(self.conv7(x_con))
    x_con = self.Up_sample_mid(x_con)
    #print(x_con.size())
    
    #x_con = F.relu(self.conv8(x_con))
    x_con = self.conv8(x_con)
    x_con = self.Up_sample_last(x_con)
    #print(x_con.size())
    
    return x_con


  def forward(self, data):
    #print("starting Encoding")
    x = self.Encode(data)
    #print("starting Decoding")
    #print("id1: {}\nid2: {}\nid3: {}".format(id1.size(), id2.size(), id3.size()))
    x_con = self.Decode(x)#, id1, id2, id3)
    return x_con,x

test = CNN(3)
lel = torch.zeros(32,23,500).float()
juhu = test(lel)


In [None]:
test = CNN(2).cuda()

loss_func = nn.CrossEntropyLoss().cuda()#nn.NLLLoss()#nn.BCELoss().cuda()
optimizer = optim.Adam(test.parameters(), lr=0.00005)#optim.SGD(test.parameters(), lr=0.001, momentum=0.9)
loss_list = []
epochs = 10
time_diff=0

for epoch in range(epochs):
    start_time = datetime.datetime.now()

    sys.stdout.write("\rCurrently at epoch: " + str(epoch+1) + ". Estimated time remaining: {}\n".format(time_diff*(epochs - epoch)))
    running_loss = 0.0
    for batch_index, (batch, labels, valid_elems) in enumerate(base_generator):
        optimizer.zero_grad()

        batch = batch.cuda()
        labels = labels.cuda()
        labels = torch.argmax(batch,dim=2)
        
        batch = torch.transpose(batch, 1,2)
        
        outputs, x = test(batch.data)

        loss = loss_func(outputs, labels)

        loss.backward()
        optimizer.step()
        
        if batch_index % 500==0:
          print("epoch: {}. batch_index: {}. Loss = {}".format(epoch, batch_index, loss))
          
        loss_list.append(loss.item())
    end_time = datetime.datetime.now()
    time_diff = end_time - start_time

In [None]:
print(len(loss_list))
plt.plot(list(range(len(loss_list))), loss_list)
plt.show()

def print_seq(preds, valid, alphabet):
    for i, seq in enumerate(preds):
        print("Sequence {}".format(i))
        indexes = torch.argmax(seq[:valid[i]], dim=1)
        ret_val = [alphabet[x] for x in indexes]
        print("".join(ret_val))
        return ret_val

mean = 0
count = 0
for i, seq in enumerate(SeqIO.parse(small_file, "fasta")):
  if (len(seq.seq) > max_seq_len) or (('X' in seq.seq) or ('B' in seq.seq) or ('Z' in seq.seq) or ('J' in seq.seq)):
    continue
  original = str(seq.seq).ljust(max_seq_len, '-')
  seq_tensor = str(seq.seq[:-1]).ljust(max_seq_len, '-')
  label_tensor = str(seq.seq[1:]).ljust(max_seq_len, '-')
  print(seq_tensor)

  original = [dataset.acid_dict[seq_elem] for seq_elem in original]
  seq_tensor = [dataset.acid_dict[seq_elem] for seq_elem in seq_tensor]
  label_tensor = [dataset.acid_dict[seq_elem] for seq_elem in label_tensor]

  original = torch.stack(original).float()
  seq_tensor = torch.stack(seq_tensor).float()
  label_tensor = torch.stack(label_tensor).float()

  original = seq_tensor.view(1,max_seq_len,23)
  seq_tensor = seq_tensor.view(1,max_seq_len,23)
  label_tensor = seq_tensor.view(1,max_seq_len,23)

  seq_tensor = torch.transpose(seq_tensor,1,2).cuda()
  pred,x = test(seq_tensor)
  
  pred = torch.transpose(pred,1,2)
  seq_tensor = torch.transpose(seq_tensor,1,2)
  break

print("\nOriginal")
ori_str = print_seq(original[0].view(1,original.size()[1], original.size()[2]), valid_elems, acids)

print("\nInput")
inp_str = print_seq(seq_tensor[0].view(1,seq_tensor.size()[1], seq_tensor.size()[2]), valid_elems, acids)

print("\nPrediction")
pred_str = print_seq(pred[0].view(1,pred.size()[1], pred.size()[2]), valid_elems, acids)

print("\Label")
label_str = print_seq(label_tensor[0].view(1,label_tensor.size()[1], label_tensor.size()[2]), valid_elems, acids)

print
acc = 0

acc = np.sum([1 if pred_elem == original_elem else 0 for pred_elem, original_elem in zip(inp_str,pred_str)])/len(pred_str)

print(acc)

In [None]:
data_points = []
for i, seq in enumerate(SeqIO.parse(small_file, "fasta")):
  if (len(seq.seq) > max_seq_len) or (('X' in seq.seq) or ('B' in seq.seq) or ('Z' in seq.seq) or ('J' in seq.seq)):
    continue
  seq_tensor = str(seq.seq).ljust(max_seq_len, '-')
  seq_tensor = [dataset.acid_dict[seq_elem] for seq_elem in seq_tensor]
  seq_tensor = torch.stack(seq_tensor).float()
  seq_tensor = seq_tensor.view(1,max_seq_len,23)
  seq_tensor = torch.transpose(seq_tensor,1,2).cuda()
  _, low_dim = test(seq_tensor)
  low_dim = torch.squeeze(torch.squeeze(low_dim,0),0)
  data_points.append((low_dim.cpu()).detach().numpy())


In [None]:
data_points = np.array(data_points)
x = data_points[:,0]
y = data_points[:,1]

plt.scatter(x,y)
plt.show()
#z = low_dim_points[:,2]