In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision.transforms import ToTensor
from torchvision.utils import save_image
import torch.optim as optim
import random

import scipy.misc
from scipy import ndimage

!pip install Biopython
from Bio import SeqIO

import os
import sys
import datetime

import math

from torch.utils import data

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from PIL import Image

from google.colab import drive

In [None]:
class Dataset(data.Dataset):
    # Checks whether a given sequence is legal
    def __is_legal_seq__(self, seq):
        len_val = not (len(seq) > self.max_seq_len)
        cont_val = not(('X' in seq) or ('B' in seq) or ('Z' in seq) or ('J' in seq))
        return len_val and cont_val

    # Generates a dictionary given a string with all the elements
    def __gen_acid_dict__(self, acids):
        acid_dict = {}
        int_acid_dict = {}
        for i, elem in enumerate(acids):
            temp = torch.zeros(len(acids))
            temp[i] = 1
            acid_dict[elem] = temp
            int_acid_dict[temp] = i
        return acid_dict, int_acid_dict

    def __init__(self, filename, max_seq_len, acids="ACDEFGHIKLMNPQRSTVWY-", int_version=False):
        elem_list = []
        self.acids = acids
        self.acid_dict, self.int_acid_dict = self.__gen_acid_dict__(acids)
        self.max_seq_len = max_seq_len
        self.int_version = int_version
        # Loading the entire input file into memory
        for i, elem in enumerate(SeqIO.parse(filename, "fasta")):
            if self.__is_legal_seq__(elem.seq):
                elem_list.append(elem.seq)
        self.data = elem_list

    def __len__(self):
        return len(self.data)

    def __prepare_seq__(self, seq):
        valid_elems = min(len(seq)+1, self.max_seq_len)
        seq = str(seq).ljust(self.max_seq_len+1, '-')
        
        temp_seq = [self.acid_dict[x] for x in seq]

        tensor_seq = torch.stack(temp_seq[:-1]).float()
    
        labels_seq = torch.stack(temp_seq[1:]).float()
        
        if self.int_version:
            temp_seq = [self.int_acid_dict[x]/len(acids) for x in temp_seq]
            tensor_seq = torch.tensor(temp_seq[:-1]).float()
            labels_seq = torch.tensor(temp_seq[1:])

        return tensor_seq, labels_seq, valid_elems

    def __getitem__(self, index):
        return self.__prepare_seq__(self.data[index])

    

In [None]:
#print(len(gc.get_objects()))
acids = "ACDEFGHIKLMNOPQRSTUVWY-"
large_file = "uniref50.fasta"
small_file = "/content/gdrive/My Drive/proteinData/100k_rows.fasta"
test_file = "test.fasta"

max_seq_len = 2000

# Good sizes: 16/700 or 32/400 on laptop
# 32/1500 on desktop
batch_size = 32
#hidden_dim = 200

#hidden_layers = 1

# Use Cuda if available
use_cuda = torch.cuda.is_available() and True
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

dataset = Dataset(small_file, max_seq_len, acids=acids)
base_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=16)

In [None]:
class CNN(nn.Module):

    def conv(self,in_channels,out_channels, kernelSize):
        return nn.Sequential(
            nn.Conv2d(in_channels,out_channels, kernelSize, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels,out_channels, kernelSize, padding=1),
            nn.ReLU())

    def __init__(self, channel_size, kernel_size=3):
        super(CNN, self).__init__()
        self.in_channels = channel_size
        self.kernel_size = kernel_size

        self.pool0 = torch.nn.MaxPool2d(kernel_size=2)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
        #self.pool2 = torch.nn.MaxPool2d(kernel_size=2)

        self.dropout_02 = nn.Dropout2d(p=0.2)
        self.dropout_04 = nn.Dropout2d(p=0.4)

        self.conv0 = self.conv(channel_size, 16, kernel_size)
        self.conv1 = self.conv(16, 32, kernel_size)
        self.conv2 = self.conv(32, 64, kernel_size)
        #self.conv3 = self.conv(64,128, kernel_size)

        #self.conv4 = self.conv(128, 64, kernel_size)
        self.conv3 = self.conv(64, 32, kernel_size)
        self.conv4 = self.conv(32,16, kernel_size)

        self.final_conv = nn.Conv2d(16,1, 1, padding=0)

        self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.up_conv0 = self.conv(64,32,kernel_size)
        self.up_conv1 = self.conv(32,16, kernel_size)
        

    def forward(self, data):
        # first layer
        layer0 = self.conv0(data)
        print("layer0: ", layer0.size())

        layer1 = self.pool0(layer0)
        layer1 = self.conv1(layer1)

        print("layer1: ", layer1.size())
        #Second and last layer
        layer2 = self.pool1(layer1)
        layer2 = self.conv2(layer2)
        print("layer2: ", layer2.size())

        #first uplayer
        up_layer2 = self.up_sample(layer2)
        
        up_layer2 = self.up_conv0(up_layer2)
        print("up_layer2: ", up_layer2.size())
        up_layer2 = torch.cat([up_layer2,layer1], dim=1)
        print("up_layer2_cat: ", up_layer2.size())
        up_layer2 = self.conv3(up_layer2)
        print("up_layer2_cat: ", up_layer2.size())

        #second and last uplayer
        up_layer1 = self.up_sample(up_layer2)
        up_layer1 = self.up_conv1(up_layer1)
        up_layer1 = torch.cat([up_layer1, layer0], dim=1)
        up_layer1 = self.conv4(up_layer1)

        final_conv = self.final_conv(up_layer1)

        final_conv = torch.sigmoid(final_conv)

        return final_conv

In [None]:
test = CNN(1).cuda()

loss_func = nn.MSELoss().cuda()
optimizer = optim.Adamax(test.parameters(), lr=0.1)
loss_list = []
epochs = 50
time_diff=0

for epoch in range(epochs):
    start_time = datetime.datetime.now()

    sys.stdout.write("\rCurrently at epoch: " + str(epoch+1) + ". Estimated time remaining: {}".format(time_diff*(epochs - epoch)))
    running_loss = 0.0
    for batch_index, (batch, labels, valid_elems) in enumerate(base_generator):
        optimizer.zero_grad()

        batch = (batch.view(batch_size, 1, max_seq_len,len(acids))).cuda()
        labels = (labels.view(batch_size,1,max_seq_len,len(acids))).cuda()
        batch = F.pad(batch, (1,0,0,0))
        labels = F.pad(labels, (1,0,0,0))
        print(batch.size(),"\n")
        

        outputs = test(batch)

        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()

        loss_list.append(loss.item())
    end_time = datetime.datetime.now()
    time_diff = end_time - start_time