### Language models for rna sequences

In [7]:
#!g1.1
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F

In [11]:
#!g1.1
EOS = '\n'
BOS = ' '
UNK = '_'

token_to_id = {BOS: 0, EOS: 1, UNK: 2, 'a': 10, 'c': 11, 'g': 12, 'u': 13}
id_to_token = {v: k for k, v in token2id.items()}

In [16]:
#!g1.1
def to_matrix(lines, max_len=None, pad=token_to_id[EOS], unk=token_to_id[UNK], dtype=np.int64):
    """
    Casts a list of lines into torch-digestable matrix
    """
    max_len = max_len or max(map(len, lines))
    n_lines = len(lines)
    matrix = np.full([n_lines, max(map(len, lines))], pad, dtype=dtype)
    
    for i, line in enumerate(lines):
        line_tokenized = [token_to_id.get(c, unk) for c in line[:max_len]]
        matrix[i, :len(line_tokenized)] = line_tokenized
        
    return matrix

In [17]:
#!g1.1
seqs = ['acut', 'acgacucuug']
to_matrix(seqs)

array([[10, 11, 13,  2,  1,  1,  1,  1,  1,  1],
       [10, 11, 12, 10, 11, 13, 11, 13, 13, 12]])

In [21]:
#!g1.1
def compute_mask(matrix, eos=token_to_id[EOS]):
    """
    compute a boolean mask that equals "1" until first EOS (including that EOS) 
    """
    return F.pad(torch.cumsum(matrix == eos, dim=-1)[..., :-1] < 1, pad=(1, 0, 0, 0), value=True)

In [20]:
#!g1.1
compute_mask(torch.tensor(to_matrix(seqs)))

tensor([[ True,  True,  True,  True,  True, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])

In [None]:
#!g1.1
def compute_loss(model, matrix):
    """
    :param model: language model that can compute next token logits given token indices
    :param matrix: int32 matrix of tokens, shape: [batch_size, length]; padded with eos_ix
    :returns: scalar loss function, mean crossentropy over non-eos tokens
    """
    
    matrix = torch.as_tensor(matrix, dtype=torch.int64).to(model.device)

    logits = model(matrix[:, :-1])
    reference_answers = matrix[:, 1:].detach()
    
    batch_size = matrix.shape[0]

    loss = nn.CrossEntropyLoss(reduction='none')
    mask = compute_mask(reference_answers).to(torch.int32).detach()
    
    out = torch.sum(loss(logits.permute(0, 2, 1), reference_answers) * mask) / mask.sum()

    return out 


def score_lines(model, lines, batch_size):
    """
    computes average loss over the entire dataset
    """
    loss_num = 0.
    loss_len = 0.
    
    model.eval()
    with torch.no_grad():
        for i in range(0, len(lines), batch_size):
            batch = to_matrix(lines[i: i + batch_size])
            loss_num += compute_loss(model, batch).item() * len(batch)
            loss_len += len(batch)
    return loss_num / loss_len

In [24]:
#!g1.1
import time
from IPython.display import clear_output
from tqdm import tqdm
import torch.nn.functional as F

def train_model(model, optimizer, val_lines, batch_size, scheduler=None):
    train_loss = []
    model.train(True) 
    for i in range(0, len(val_lines), batch_size):
        optimizer.zero_grad()
        
        batch = to_matrix(val_lines[i: i + batch_size])
        loss = compute_loss(model, batch)
        
        loss.backward()
        optimizer.step()

        train_loss.append(loss.item())

    scheduler.step()
    return np.mean(train_loss)

def train(model, optimizer, train_lines, val_lines, num_epochs, batch_size, scheduler=None):
    train_history = []
    test_history = []

    start_time = time.time()
    for epoch in range(num_epochs):
        
        train_loss = train_model(model, optimizer, train_lines, batch_size, scheduler)
        train_history.append((epoch + 1, train_loss))

        test_history.append((epoch + 1, score_lines(model, val_lines, batch_size)))

        clear_output(True)
        plt.figure(figsize=(16, 8))
        plt.subplot(1, 2, 1)
        plt.plot(*zip(*train_history), color='blue', label='train_loss')
        plt.plot(*zip(*test_history), color='red', label='dev_loss')
        plt.legend(); plt.grid()
        plt.show()
        
        # Then we print the results for this epoch:
        print("Epoch {} of {} took {:.3f}min".format(epoch + 1, num_epochs, (time.time() - start_time) / 60))
        print("  training loss (in-iteration): {}".format(train_history[-5:]))
        print("  validation loss (in-iteration): {}".format(test_history[-5:]))
        print("  validation accuracy: \t\t\t{:.3f}".format(accuracy))

In [28]:
#!g1.1
class FixedWindowLanguageModel(nn.Module):
    def __init__(self, n_tokens=len(token_to_id), emb_size=16, hid_size=64, window_size=5):
        
        super().__init__()
        
        self.window_size = window_size
        self.emb = nn.Embedding(n_tokens, emb_size)
        self.pad = nn.ZeroPad2d((self.window_size - 1, 0, 0, 0))
        self.conv = nn.Conv1d(emb_size, hid_size, kernel_size=self.window_size)
        self.fc = nn.Linear(hid_size, n_tokens)
            
    def __call__(self, input_ix):

        out = self.emb(input_ix).permute((0, 2, 1))
        out = self.pad(out)
        out = self.conv(out).permute((0, 2, 1))
        
        return self.fc(out) # [batch_size, sequence_length, n_tokens]
        

In [25]:
#!g1.1
class RNNLanguageModel(nn.Module):
    def __init__(self, n_tokens=n_tokens, emb_size=16, hid_size=256, dropout=0.2):
        super().__init__()
        
        self.emb = nn.Embedding(n_tokens, emb_size)
        self.lstm = nn.LSTM(emb_size, hid_size, batch_first=True)
        self.fc = nn.Linear(hid_size, n_tokens)
        
        #END OF YOUR CODE
    
    def __call__(self, input_ix):
        
        self.lstm.flatten_parameters()
        
        x = self.emb(input_ix)
        states, _ = self.lstm(x)

        out = self.fc(states)
        
        return out #[batch_size, sequence_length, n_tokens]
        

In [None]:
#!g1.1
