In [1]:
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from statistics import mean
from torch.utils.data import DataLoader
from tqdm import tqdm
import math
import itertools
import joblib
from matplotlib import pyplot as plot

In [2]:
import gc


torch.cuda.empty_cache()
gc.collect()

16

In [3]:
#loading file
amino_acid_df = pd.read_excel("data/AminoAcid.xlsx", header=None)
amino_acid_df.columns = ['protein', 'sequence']

  warn("Workbook contains no default style, apply openpyxl's default")


In [4]:
gpu = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu

device(type='cuda')

In [5]:
amino_acid_df.head()

Unnamed: 0,protein,sequence
0,1EP9_1,VQLKGRDLLTLKNFTGEEIKYMLWLSADLKFRIKQKGEYLPLLQGK...
1,1BH9_1,LFSKELRCMMYGFGDDQNPYTESVDILEDLVIEFITEMTHKAMSI
2,1G96_1,VGGPMDASVEEEGVRRALDFAVGEYNKASNDMYHSRALQVVRARKQ...
3,1CYV_1,MIPGGLSEAKPATPEIQEIVDKVKPQLEEKTNETYGKLEAVQYKTQ...
4,1KE5_1,MENFQKVEKIGEGTYGVVYKARNKLTGEVVALKKIR


In [6]:
#spliting the sequences in letters
amino_acid_df['sequence'] = amino_acid_df['sequence'].apply(lambda seq: list(seq))

In [7]:
amino_acid_df.head()

Unnamed: 0,protein,sequence
0,1EP9_1,"[V, Q, L, K, G, R, D, L, L, T, L, K, N, F, T, ..."
1,1BH9_1,"[L, F, S, K, E, L, R, C, M, M, Y, G, F, G, D, ..."
2,1G96_1,"[V, G, G, P, M, D, A, S, V, E, E, E, G, V, R, ..."
3,1CYV_1,"[M, I, P, G, G, L, S, E, A, K, P, A, T, P, E, ..."
4,1KE5_1,"[M, E, N, F, Q, K, V, E, K, I, G, E, G, T, Y, ..."


One-hot encoding of the amino-acid sequence.

In [8]:
unique_amino_acid = list(set([item for sublist in amino_acid_df['sequence'] for item in sublist]))
unique_amino_acid.sort()

In [9]:
max_lenght_amino_acid = len(max(amino_acid_df['sequence'], key=len))

In [10]:
# Function to one-hot encode a sequence of amino acid.
# The output is a matrix of max_lenght_amino_acid (1965) x unique_amino_acid (21). Sequence shorter than max_lenght_amino_acid are filled with 0. 
def one_hot_encode(seq):
    matrix = np.zeros((max_lenght_amino_acid, len(unique_amino_acid)))
    for idx, elem in enumerate(seq):
        matrix[idx][unique_amino_acid.index(elem)] = 1
    return matrix.astype(np.float32)


In [11]:
amino_acid_df['seq_one_hot'] = amino_acid_df['sequence'].apply(lambda seq: one_hot_encode(seq))


Autoencoder

In [12]:
class AE_lin(nn.Module):
    def __init__(self, input_shape, layers):
        super().__init__()
        self.layers = [input_shape, *layers]
        self.input_shape = input_shape
        # Encoder
        encoder_modules = nn.ModuleList()
        for idx in range(len(self.layers)-1):
            # print('encoder: ',self.layers[idx], self.layers[idx+1])

            encoder_modules.append(
                nn.Sequential(
                    nn.Linear(self.layers[idx], self.layers[idx+1]),
                    nn.ReLU(),
                )
            )
        self.encoder = nn.Sequential(*encoder_modules)

        # Decoder
        decoder_modules = nn.ModuleList()
        for idx in range(len(self.layers)-1):
            # print('decoder: ',self.layers[-idx - 1], self.layers[-idx-2])
            decoder_modules.append(
                nn.Sequential(
                    nn.Linear(self.layers[-idx-1], self.layers[-idx-2]),
                    nn.ReLU(),
                )
            )
        self.decoder = nn.Sequential(*decoder_modules)


    def forward(self, input):
        # input = input.reshape(1,-1)
        input = self.encoder(input)
        input = self.decoder(input)
        return input


In [13]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid acitvation
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

In [14]:
def train(autoencoder:AE_lin, data:pd.DataFrame, optimizer:optim, epochs:int = 10, batch_size:int = 64, loss_fn =DiceLoss()):
    losses = []
    dataloader = DataLoader(data,batch_size=batch_size, shuffle=True)
    for _ in range(epochs):
        epoch_loss = []
        for _, batch in tqdm(enumerate(dataloader), total=np.ceil(len(data)/batch_size)):
            autoencoder.zero_grad()
            batch = batch.to(gpu)
            output = autoencoder(batch)
            loss = loss_fn(output, batch).sum()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss.append(loss.item())
            # print(epoch_loss)
        losses.append(np.average(epoch_loss))
        print(losses)
    return losses
batch_size = 64
    
        


In [20]:
batch_size=64


torch.cuda.empty_cache()
gc.collect()
# [2000,200], [2000, 500, 200],
for layers, (opt,name), (loss,loss_name) in itertools.product([[5000, 500], [10000, 1000]],
                                                       [(optim.Adam, 'Adam'), (optim.Adagrad, 'Adagrad')],
                                                       [(nn.MSELoss, "MSE"), (DiceLoss, 'DiceLoss'), (nn.L1Loss,'L1')]):
    print(layers, name, loss_name)
    ae = AE_lin(len(unique_amino_acid), layers).to(gpu)
    optimizer = opt(ae.parameters(), lr=0.01)
    loss_fn = loss()
    test = train(ae, amino_acid_df['seq_one_hot'], optimizer,epochs=15, batch_size=batch_size, loss_fn=loss_fn)
    joblib.dump({'autoencoder': ae,
                 "optimizer": optimizer,
                 "loss": test,
                 "layers":layers,
                 "loss_fn": loss
                 }, "models/{}_{}_{}.joblib".format(layers,name,loss_name))

[5000, 500] Adam MSE


100%|██████████| 680/680.0 [09:08<00:00,  1.24it/s]


[0.08587559185381156]


100%|██████████| 680/680.0 [09:10<00:00,  1.23it/s]


[0.08587559185381156, 0.004306091548984542]


100%|██████████| 680/680.0 [09:11<00:00,  1.23it/s]


[0.08587559185381156, 0.004306091548984542, 0.004305369485690094]


100%|██████████| 680/680.0 [09:11<00:00,  1.23it/s]


[0.08587559185381156, 0.004306091548984542, 0.004305369485690094, 0.004306152497541488]


100%|██████████| 680/680.0 [09:12<00:00,  1.23it/s]


[0.08587559185381156, 0.004306091548984542, 0.004305369485690094, 0.004306152497541488, 0.004305488946984577]


100%|██████████| 680/680.0 [09:12<00:00,  1.23it/s]


[0.08587559185381156, 0.004306091548984542, 0.004305369485690094, 0.004306152497541488, 0.004305488946984577, 0.004305572241804946]


100%|██████████| 680/680.0 [09:11<00:00,  1.23it/s]


[0.08587559185381156, 0.004306091548984542, 0.004305369485690094, 0.004306152497541488, 0.004305488946984577, 0.004305572241804946, 0.00430572909912478]


100%|██████████| 680/680.0 [08:46<00:00,  1.29it/s]


[0.08587559185381156, 0.004306091548984542, 0.004305369485690094, 0.004306152497541488, 0.004305488946984577, 0.004305572241804946, 0.00430572909912478, 0.0043057851757005074]


100%|██████████| 680/680.0 [09:06<00:00,  1.24it/s]


[0.08587559185381156, 0.004306091548984542, 0.004305369485690094, 0.004306152497541488, 0.004305488946984577, 0.004305572241804946, 0.00430572909912478, 0.0043057851757005074, 0.004305469439200619]


100%|██████████| 680/680.0 [09:06<00:00,  1.24it/s]


[0.08587559185381156, 0.004306091548984542, 0.004305369485690094, 0.004306152497541488, 0.004305488946984577, 0.004305572241804946, 0.00430572909912478, 0.0043057851757005074, 0.004305469439200619, 0.0043055267357404396]


100%|██████████| 680/680.0 [09:05<00:00,  1.25it/s]


[0.08587559185381156, 0.004306091548984542, 0.004305369485690094, 0.004306152497541488, 0.004305488946984577, 0.004305572241804946, 0.00430572909912478, 0.0043057851757005074, 0.004305469439200619, 0.0043055267357404396, 0.0043058071083471395]


100%|██████████| 680/680.0 [09:08<00:00,  1.24it/s]


[0.08587559185381156, 0.004306091548984542, 0.004305369485690094, 0.004306152497541488, 0.004305488946984577, 0.004305572241804946, 0.00430572909912478, 0.0043057851757005074, 0.004305469439200619, 0.0043055267357404396, 0.0043058071083471395, 0.004305596630130072]


100%|██████████| 680/680.0 [09:08<00:00,  1.24it/s]


[0.08587559185381156, 0.004306091548984542, 0.004305369485690094, 0.004306152497541488, 0.004305488946984577, 0.004305572241804946, 0.00430572909912478, 0.0043057851757005074, 0.004305469439200619, 0.0043055267357404396, 0.0043058071083471395, 0.004305596630130072, 0.004305633597130723]


100%|██████████| 680/680.0 [09:07<00:00,  1.24it/s]


[0.08587559185381156, 0.004306091548984542, 0.004305369485690094, 0.004306152497541488, 0.004305488946984577, 0.004305572241804946, 0.00430572909912478, 0.0043057851757005074, 0.004305469439200619, 0.0043055267357404396, 0.0043058071083471395, 0.004305596630130072, 0.004305633597130723, 0.004305892044281149]


100%|██████████| 680/680.0 [08:40<00:00,  1.31it/s]


[0.08587559185381156, 0.004306091548984542, 0.004305369485690094, 0.004306152497541488, 0.004305488946984577, 0.004305572241804946, 0.00430572909912478, 0.0043057851757005074, 0.004305469439200619, 0.0043055267357404396, 0.0043058071083471395, 0.004305596630130072, 0.004305633597130723, 0.004305892044281149, 0.004305332906283986]
[5000, 500] Adam DiceLoss


 48%|████▊     | 323/680.0 [04:35<05:04,  1.17it/s]


KeyboardInterrupt: 