In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

import numpy as np
import matplotlib.pyplot as plt
import sys

from Bio import SeqIO
from datetime import datetime

from torch.utils import data
#from data_generator import data_generator
from data_generator import Dataset
from lstm import LSTM_model
from lstm import LSTMCell
from time import sleep

import gc

In [None]:

acids = "ACDEFGHIKLMNOPQRSTUVWY-"
large_file = "uniref50.fasta"
small_file = "100k_rows.fasta"
test_file = "test.fasta"

max_seq_len = 2000

batch_size = 32

# Use Cuda if available
use_cuda = torch.cuda.is_available() and True
print("Using GPU:", use_cuda)
processor = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True



In [None]:
class AE(nn.Module):
    def __init__(self,input_size, output_size, latent_dim):
        super(AE, self).__init__()
        ### Encoder layers
        self.fc_enc1 = nn.Linear(input_size, 256)
        self.fc_enc2 = nn.Linear(256, 128)
        self.fc_enc3 = nn.Linear(128, 64)
        self.fc_enc4 = nn.Linear(64, 32)
        self.fc_enc5 = nn.Linear(32, latent_dim) # Note we return 2*latent_dim
        
        ### Decoder layers
        self.fc_dec1 = nn.Linear(latent_dim, 32)
        self.fc_dec2 = nn.Linear(32,64)
        self.fc_dec3 = nn.Linear(64,128)
        self.fc_dec4 = nn.Linear(128,256)
        self.fc_dec5 = nn.Linear(256,output_size)

    def encode(self, x):
        ### Using F.relu() to call the
        ### rectified linear unit activation function

        z1 = F.relu(self.fc_enc1(x))
        z2 = F.relu(self.fc_enc2(z1))
        z3 = F.relu(self.fc_enc3(z2))
        z4 = F.relu(self.fc_enc4(z3))
        z5 = self.fc_enc5(z4)
        
        return z5
    
    def decode(self, z):
        xHat1 = F.relu(self.fc_dec1(z))
        xHat2 = F.relu(self.fc_dec2(xHat1))
        xHat3 = F.relu(self.fc_dec3(xHat2))
        xHat4 = F.relu(self.fc_dec4(xHat3))
        xHat5 = self.fc_dec5(xHat4)
        
        return torch.sigmoid(xHat5)

    def forward(self, x):
        ### Autoencoder returns the reconstruction 
        ### and latent representation
        z = self.encode(x)
        xHat = self.decode(z)
        return xHat,z 

In [None]:
dataset = Dataset(small_file, max_seq_len, acids=acids, int_version=True)
base_generator = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=16)

In [None]:
print("starting")

lr = 0.001
epochs = 10

AE_loss_list = []
AE_model = AE(2000,2000, 2).to(processor)
optimizer = optim.Adam(AE_model.parameters(), lr=lr)
loss_function = nn.BCELoss().to(processor)

for epoch in range(1, epochs + 1):
    AE_model.train()
    train_loss = 0
    for batch_idx, (batch, labels, valid_elems) in enumerate(base_generator):
        batch = batch.to(processor)
        labels = labels.to(processor)
        optimizer.zero_grad()
        
        xHat, z = AE_model(batch)
        loss = loss_function(xHat, batch)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 5 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(batch), len(base_generator.dataset),
                100. * batch_idx / len(base_generator),
                loss.item() / len(batch)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(base_generator.dataset)))
    AE_loss_list.append(train_loss / len(base_generator.dataset))
 
        

In [None]:
v plt.plot(list(range(len(AE_loss_list))),AE_loss_list)
plt.title("Loss of AE model")
plt.show()

# 