This notebook contains our LSTM model for the protein data set and relevant pre processing functions to be able to use our data sets with the model, furthermore the test loss and perplexity can be computed in the bottom of the notebook with the possibility of generating new sequences through sampling.

Authors: Mathias Dizon Olsson s174031, Mathias Sabroe Simonsen s164034

In [0]:
# Combine the domains into single test, training and validation splits
domain_dir = [['/domain_set/virus/test.txt',
               '/domain_set/virus/train.txt',
               '/domain_set/virus/valid.txt'],
              ['/domain_set/archaea/test.txt',
               '/domain_set/archaea/train.txt',
               '/domain_set/archaea/valid.txt'],
              ['/domain_set/bacteria/test.txt',
               '/domain_set/bacteria/train.txt',
               '/domain_set/bacteria/valid.txt'],
              ['/domain_set/eukaryotes/test.txt',
               '/domain_set/eukaryotes/train.txt',
               '/domain_set/eukaryotes/valid.txt'],
              ['/domain_set/completenoeuk/test.txt',
               '/domain_set/completenoeuk/train.txt',
               '/domain_set/completenoeuk/valid.txt']]

domains = ['virus',
           'archaea',
           'bacteria',
           'eukaryotes',
           'completenoeuk']

# Combining everything but the eukaryotes as described in the paper
read_files = [[domain_dir[0][0],domain_dir[1][0],domain_dir[2][0]],
              [domain_dir[0][1],domain_dir[1][1],domain_dir[2][1]],
              [domain_dir[0][2],domain_dir[1][2],domain_dir[2][2]]]

for i in range(len(read_files)):
    with open(domain_dir[4][i], "wb") as outfile:
        for f in read_files[i]:
            with open(f, "rb") as infile:
                outfile.write(infile.read())

In [0]:
# Ignore the specific warning in regards to module weights as explained in the discussions chapter of the paper
import warnings
warnings.simplefilter("ignore", UserWarning)

In [0]:
import numpy as np
import sys
import math

import torch

import time

import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

from torch.autograd import Variable

import torchtext
from torchtext import data
from torchtext import datasets

from torch.nn.utils import clip_grad_norm_

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.FloatTensor

In [0]:
from torch.nn import Parameter
from functools import wraps

# Modified WeightDrop function based on the function made by salesforce (https://github.com/salesforce/) found here -> https://github.com/ChengyueGongR/awd-lstm-lm/blob/master/weight_drop.py
# using the modifications made in the post -> https://github.com/salesforce/awd-lstm-lm/issues/79 by github user -> https://github.com/shirishr
class WeightDrop(torch.nn.Module):
    def __init__(self, module, weights, dropout=0, variational=False):
        super(WeightDrop, self).__init__()
        self.module = module
        self.weights = weights
        self.dropout = dropout
        self.variational = variational
        self._setup()

    def widget_demagnetizer_y2k_edition(*args, **kwargs):
        # We need to replace flatten_parameters with a nothing function
        # It must be a function rather than a lambda as otherwise pickling explodes
        # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION!
        # (╯°□°）╯︵ ┻━┻
        return

    def _setup(self):
        # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
        if issubclass(type(self.module), torch.nn.RNNBase):
            self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition

        for name_w in self.weights:
            print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
            w = getattr(self.module, name_w)
            del self.module._parameters[name_w]
            self.module.register_parameter(name_w + '_raw', Parameter(w.data))

    def _setweights(self):
        for name_w in self.weights:
            raw_w = getattr(self.module, name_w + '_raw')
            w = None
            if self.variational:
                mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
                if raw_w.is_cuda: mask = mask.cuda()
                mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
            # the modified code    
                w = torch.nn.Parameter(mask.expand_as(raw_w) * raw_w)
            else:
                w = torch.nn.Parameter(torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training))
            # the original code
            #    w = mask.expand_as(raw_w) * raw_w
            #else:
            #    w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
            setattr(self.module, name_w, w)

    def forward(self, *args):
        self._setweights()
        return self.module.forward(*args)

In [0]:
# Function to select which dataset should be used
def domain_select(domain, field, rootpath):
    
    splits = ['train',
              'valid',
              'test']
    split = [None]*3

    for i in range(3):
        path = rootpath + '/' + domain + '/' + splits[i] + '.txt'
        split[i] = torchtext.datasets.LanguageModelingDataset(path = path, 
                                                              text_field = field)

    return (split[0],split[1],split[2])

The desired domain can be chosen by changing the domain field to the available domains in the code comment.

In [0]:
# set up field, path and domain
TEXT = data.Field(lower=True, batch_first=False)
path = '/content/drive/My Drive/Deep Learning Project/domain_set'
domain = 'completenoeuk' # Available domains: virus, archaea, bacteria, eukaryotes (might crash due to size), completenoeuk (a combined set of all domains but eukaryotes)

# create torchtext datasets
train, valid, test = domain_select(domain = domain,
                                   field = TEXT,
                                   rootpath = path)

# build the vocabulary
TEXT.build_vocab(train)

# batch size and backpropagate-through-time length
batch_size = 20
bptt_len = 600

# create iterator for splits
train_iter, valid_iter, test_iter = data.BPTTIterator.splits((train, valid, test), 
                                                              batch_sizes = (batch_size, batch_size, 1), 
                                                              bptt_len = bptt_len, 
                                                              device = device)

In [0]:
# Display the vocabulary
for i in range(len(TEXT.vocab)):
    if i < len(TEXT.vocab)-1:
        print(TEXT.vocab.itos[i], end = ', ')
    else:
        print(TEXT.vocab.itos[i], end = '.') 

<unk>, <pad>, l, a, e, g, v, k, d, i, t, s, r, p, n, q, f, y, m, h, w, c, <eos>, x, b, u, z, o.

In [0]:
# Check amount of batches
validation_set = (iter(valid_iter))
training_set = (iter(train_iter))
test_set = (iter(test_iter))

validation_batch_amount = 0
while True:
    try:
        c = (next(validation_set).text)
    except StopIteration:
        print('Amount of validation text batches: ',validation_batch_amount)
        break  # Iterator exhausted: stop the loop
    else:
        validation_batch_amount = validation_batch_amount + 1

training_batch_amount = 0
while True:
    try:
        c = (next(training_set).text)
    except StopIteration:
        print('Amount of training batches: ',training_batch_amount)
        break  # Iterator exhausted: stop the loop
    else:
        training_batch_amount = training_batch_amount + 1
        
test_batch_amount = 0
while True:
    try:
        c = (next(test_set).text)
    except StopIteration:
        print('Amount of test batches: ',test_batch_amount)
        break  # Iterator exhausted: stop the loop
    else:
        test_batch_amount = test_batch_amount + 1      

Amount of validation text batches:  358
Amount of training batches:  2579
Amount of test batches:  19098


In [0]:
# Reset iterators
validation_set = (iter(valid_iter))
training_set = (iter(train_iter))
test_set = (iter(test_iter))

In [0]:
class Net(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, weights, dropout):
        super(Net, self).__init__()

        self.embed = nn.Embedding(num_embeddings = vocab_size, 
                                  embedding_dim = embed_size)
        
        self.lstm = nn.LSTM(input_size = embed_size, 
                            hidden_size = hidden_size, 
                            num_layers = num_layers, 
                            batch_first = False)

        self.weightdrop = WeightDrop(module = self.lstm, 
                                     weights = weights, 
                                     dropout = dropout)
        
        self.l_out = nn.Linear(in_features = hidden_size, 
                               out_features = vocab_size, 
                               bias = False)
        
    def forward(self, x, hc):
        # Embed word ids to vectors
        x = self.embed(x)
        
        # Weightdropped LSTM returns output and last hidden state
        out, (h, c) = self.weightdrop(x, hc)

        # Reshape output to (sequence_length*batch_size, hidden_size)
        out = out.reshape(out.shape[0]*out.shape[1], -1)

        # Decode hidden states of all time steps
        out = self.l_out(out)

        return out, (h, c)

In [0]:
# Hyper-parameters
num_epochs = 10
embed_size = 25
hidden_size = 128
num_layers = 3
learning_rate = 0.002
max_norm = 0.25
weights = ['weight_hh_l' + str(i) for i in range(num_layers)]
dropout = 0.9
batch_size = 20
switch = False
vocab_size = len(TEXT.vocab)

In [0]:
# Initialize a new network
net = Net(vocab_size, embed_size, hidden_size, num_layers, weights, dropout).to(device)
print(net)

Applying weight drop of 0.9 to weight_hh_l0
Applying weight drop of 0.9 to weight_hh_l1
Applying weight drop of 0.9 to weight_hh_l2
Net(
  (embed): Embedding(28, 25)
  (lstm): LSTM(25, 128, num_layers=3)
  (weightdrop): WeightDrop(
    (module): LSTM(25, 128, num_layers=3)
  )
  (l_out): Linear(in_features=128, out_features=28, bias=False)
)


In [0]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(params=net.parameters(), lr=learning_rate)

In [0]:
# Detach hidden and cell states of the LSTM
def detach(states):
    return [state.detach() for state in states]

# Weight initialization
def weight_init():
    return (torch.zeros(num_layers, batch_size, hidden_size).to(device),
            torch.zeros(num_layers, batch_size, hidden_size).to(device))

In [0]:
# Training loop
# Track loss
training_loss, validation_loss = [], []

NT_ASGD_loss = []

# For each epoch
for epoch in range(num_epochs):
    
    # Display epoch number
    if epoch == 0:
        print(f'Epoch {epoch+1} of {num_epochs}')
    else:
        print(f'\nEpoch {epoch+1} of {num_epochs}')
        
    # Track loss
    epoch_training_loss = 0
    epoch_validation_loss = 0
    
    # Validation
    net.eval()
    
    hc = weight_init()

    batch_num = 0
    print('Net validation')

    for batch in valid_iter:
        
        batch_num = batch_num + 1  
        
        text = (batch.text).to(device)
        target = (batch.target).to(device)
        
        # Forward pass
        hc = detach(hc)
        outputs, hc = net(text, hc)
        loss = criterion(outputs, target.reshape(-1))
        
        
        # Update loss
        epoch_validation_loss += loss.detach()

        # Loading bar
        sys.stdout.write('\r')
        p = (batch_num) / validation_batch_amount
        sys.stdout.write("[%-50s] %d%%" % ('='*int(50*p), 100*p))
        sys.stdout.flush()

    # Training
    net.train()    

    hc = weight_init()

    batch_num = 0
    print('\nNet training')
                  
    for batch in train_iter:
        batch_num = batch_num + 1

        text = (batch.text).to(device)
        target = (batch.target).to(device)

        # Forward pass
        hc = detach(hc)
        outputs, hc = net(text, hc)
        loss = criterion(outputs, target.reshape(-1))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(net.parameters(), max_norm)
        optimizer.step()
      
        # Update loss
        epoch_training_loss += loss.detach()
        
        # Loading bar
        sys.stdout.write('\r')
        p = (batch_num) / training_batch_amount
        sys.stdout.write("[%-50s] %d%%" % ('='*int(50*p), 100*p))
        sys.stdout.flush()
        

    # Save loss for plot
    training_loss.append(epoch_training_loss/(training_batch_amount))
    validation_loss.append(epoch_validation_loss/(validation_batch_amount))

    # Print loss every epoch
    print(f'\nTraining loss: {training_loss[-1]}, Validation loss: {validation_loss[-1]}')

    # NT-ASGD
    NT_ASGD_loss.append(validation_loss[-1].cpu())
    if len(NT_ASGD_loss) == 6:
        del NT_ASGD_loss[0]
        if switch == False:
            if np.isclose(NT_ASGD_loss, NT_ASGD_loss[4], rtol=0, atol = 5e-4).all() == True:
                print('\nOptimizer has changed to ASGD')
                switch = True
                optimizer = optim.ASGD(params=net.parameters(), lr=learning_rate)

Epoch 1 of 10
Net validation
Net training
Training loss: 3.313389778137207, Validation loss: 3.334681272506714

Epoch 2 of 10
Net validation
Net training
Training loss: 3.2634942531585693, Validation loss: 3.2883694171905518

Epoch 3 of 10
Net validation
Net training
Training loss: 3.1988396644592285, Validation loss: 3.2248568534851074

Epoch 4 of 10
Net validation
Net training
Training loss: 3.128234386444092, Validation loss: 3.1467931270599365

Epoch 5 of 10
Net validation
Net training
Training loss: 3.060789108276367, Validation loss: 3.0677640438079834

Epoch 6 of 10
Net validation
Net training
Training loss: 3.0058393478393555, Validation loss: 3.001575469970703

Epoch 7 of 10
Net validation
Net training
Training loss: 2.967348098754883, Validation loss: 2.9563148021698

Epoch 8 of 10
Net validation
Net training
Training loss: 2.9443562030792236, Validation loss: 2.93093204498291

Epoch 9 of 10
Net validation
Net training
Training loss: 2.9296019077301025, Validation loss: 2.916

In [0]:
# Compute the test loss and perplexity
complete_test_loss = []
test_loss = 0
batch_num = 0

for batch in test_iter:

    hc = (torch.zeros(num_layers, 1, hidden_size).to(device),
          torch.zeros(num_layers, 1, hidden_size).to(device))
    
    batch_num = batch_num + 1
    text = (batch.text).to(device)
    target = (batch.target).to(device)

    # Forward pass
    hc = detach(hc)
    outputs, hc = net(text, hc)
    loss = criterion(outputs, target.reshape(-1))

    # Update loss
    test_loss += loss.detach()

    # Loading bar
    sys.stdout.write('\r')
    p = (batch_num) / test_batch_amount
    sys.stdout.write("[%-50s] %d%%" % ('='*int(50*p), 100*p))
    sys.stdout.flush()

complete_test_loss = test_loss / test_batch_amount
print(f'\ntest loss: {complete_test_loss}, perplexity: {np.exp(complete_test_loss.cpu())}')

test_loss: 2.9294137954711914, perplexity: 18.716655731201172


In [0]:
# Sampling to a txt file

num_samples = 2000

with torch.no_grad():
    with open('completenoeuk_results.txt', 'w') as f:
        
        # Set intial hidden and cell states
        hc = (torch.zeros(num_layers, 1, hidden_size).to(device),
                 torch.zeros(num_layers, 1, hidden_size).to(device))

        # Starting word, in this case the amino acid M
        input = torch.Tensor([[18]]).to(device).long()

        for i in range(num_samples):

            # Forward pass
            output, hc = net(input, hc)

            # Sample word
            prob = output.exp()
            word_id = torch.multinomial(prob, num_samples=1).item()

            # Fill input with sampled word for the next time step
            input.fill_(word_id)

            # File write
            if i == 0:
                f.write(f'perplexity: {np.exp(complete_test_loss.cpu())}' + '\n')
                f.write('Sequence: ' + TEXT.vocab.itos[18] + ' ') 
            word = TEXT.vocab.itos[word_id]
            word = '\n' + 'Sequence: ' if word == '<eos>' else word + ' '
            f.write(word)

    print('Words have been sampled and saved to .txt file')

Words have been sampled and saved to .txt file


In [0]:
# Sampling to python output

num_samples = 2000

with torch.no_grad():
    # Set intial hidden and cell states
    hc = (torch.zeros(num_layers, 1, hidden_size).to(device),
              torch.zeros(num_layers, 1, hidden_size).to(device))

    # Starting word, in this case the amino acid M
    input = torch.Tensor([[18]]).to(device).long()

    for i in range(num_samples):

        # Forward pass
        output, hc = net(input, hc)

        # Sample word
        prob = output.exp()
        word_id = torch.multinomial(prob, num_samples=1).item()

        # Fill input with sampled word for the next time step
        input.fill_(word_id)

        # File write
        if i == 0:
            print('Sequence: ' + TEXT.vocab.itos[18], end = ' ') 
        word = TEXT.vocab.itos[word_id]
        word = '\n' + 'Sequence: ' if word == '<eos>' else word
        print(word, end = ' ')

Sequence: w a g q q g r i t d s l a i l q r l v m v r i e n y n e n k v a k l i k r t l v l d d d l g v f d s e i v t t e a p l a v f d a d v l i l n m s i w s h d l a v e s n d s a r h q f v i a r n h g t l s r g l r d l i v p a i p i e e l a k r i m v g t v g a k a a g r k l t d k m r d k a n g i i f e v l d t a a s r g v l a i i v a a v q l y a a a v w l i t v e k d a t l t c a h i r e e n e f d p a l l a y g g v n r e h a n c l l <pad> v a v i i p l g k t p r q n i n g e k k v m f v q d a f k t a a a p g n k k n s v a a a e t m g r e i a s l a y r y k a q g k a v y f n a d a i h e k l w d e k g k n g i a p i g r k k d s q y p g v a v a r f g l k p a s k r h e g i k q f l i n a v e l a p g l g l k e t a p h e l q r r t v k v q e k m e l s s t f t a g k p v l v k v k m g r g l e r q n m q v q l f m i w k v m a p a v v 
Sequence:  g v t v f g e h v m e d a g a n a d g k e e p e t a w v e h n y g i d a r s i s v r h a g a q k e k a e g n d a l f k e l v t c g a k v p <pad> s k s d l p 