In [126]:
import numpy as np
import pandas as pd
from keras.preprocessing.sequence import pad_sequences
from tqdm import tqdm_notebook

In [2]:
import torch

In [105]:
with open("normalized_names.txt") as f:
    text = [line.lower().strip('\n ') for line in f]
    
strain_names = []
for t in text:
    if ',' in t:
        strain_names.extend(t.split(','))
    else:
        strain_names.append(t)
names = [["<BOS>"] + list(name) for name in strain_names]
targets = [list(name) + ["<EOS>"] for name in strain_names]

In [114]:
uniq_chars = {char for name in names for char in name}
uniq_chars.update({'<EOS>'})
char2idx = {char: i+1 for i, char in enumerate(uniq_chars)}
char2idx['<MASK>'] = 0

idx2char = {i: char for char, i in char2idx.items()}

In [107]:
train_names = names[:1000]
train_targets = targets[:1000]
test_names = names[1000:]
test_targets = targets[1000:]

In [209]:
train_names[:10]

[['<BOS>',
  'g',
  '1',
  '3',
  ' ',
  's',
  'u',
  'p',
  'e',
  'r',
  ' ',
  's',
  'i',
  'l',
  'v',
  'e',
  'r',
  ' ',
  'h',
  'a',
  'z',
  'e'],
 ['<BOS>', 'c', 'h', 'e', 'r', 'r', 'y', ' ', 'o', 'g', ' ', 'b', 'x'],
 ['<BOS>',
  'm',
  'e',
  'n',
  'd',
  'o',
  'c',
  'i',
  'n',
  'o',
  ' ',
  'p',
  'u',
  'r',
  'p',
  'l',
  'e'],
 ['<BOS>',
  'l',
  'o',
  'o',
  'm',
  'p',
  'a',
  "'",
  's',
  ' ',
  'c',
  'o',
  'l',
  'u',
  'm',
  'b',
  'i',
  'a',
  'n',
  ' ',
  'c',
  'h',
  'e',
  'm',
  'd',
  'a',
  'w',
  'g',
  ' ',
  'd',
  ' ',
  'f',
  '2'],
 ['<BOS>', 'k', 'e', 'n', "'", 's', ' ', 'k', 'u', 's', 'h'],
 ['<BOS>', 's', 'p', 'i', 'c', 'y', ' ', 'c', 'b', 'd'],
 ['<BOS>', 's', 'p', 'a', 'c', 'e', ' ', 'k', 'u', 's', 'h'],
 ['<BOS>', 'k', 'a', 'r', 'm', 'd', 'o', 'w', 'n'],
 ['<BOS>',
  'd',
  'o',
  'u',
  'b',
  'l',
  'e',
  ' ',
  's',
  't',
  'u',
  'f',
  'f',
  'e',
  'd'],
 ['<BOS>', 'f', 'i', 'r', 'e', ' ', 'o', 'g', ' ', 'k', 'u', 's', 

In [167]:
max_seq_length = len(max(train_names, key=len))
X_train = torch.zeros((len(train_names), max_seq_length)).long()
for i, strain in enumerate(train_names):
    for j in range(max_seq_length):
        if j < len(strain):
            X_train[i][j] = char2idx[strain[j]]
        else:
            X_train[i][j] = char2idx['<MASK>']

max_seq_length = len(max(train_targets, key=len))
y_train = torch.zeros((len(train_targets), max_seq_length)).long()
for i, strain in enumerate(train_targets):
    for j in range(max_seq_length):
        if j < len(strain):
            y_train[i][j] = char2idx[strain[j]]
        else:
            y_train[i][j] = char2idx['<MASK>']

In [168]:
max_seq_length = len(max(test_names, key=len))
X_val = torch.zeros((len(test_names), max_seq_length)).long()
for i, strain in enumerate(test_names):
    for j in range(max_seq_length):
        if j < len(strain):
            X_val[i][j] = char2idx[strain[j]]
        else:
            X_val[i][j] = char2idx['<MASK>']

max_seq_length = len(max(test_targets, key=len))
y_val = torch.zeros((len(test_targets), max_seq_length)).long()
for i, strain in enumerate(test_targets):
    for j in range(max_seq_length):
        if j < len(strain):
            y_val[i][j] = char2idx[strain[j]]
        else:
            y_val[i][j] = char2idx['<MASK>']

In [169]:
def batch_iter(X, y, batch_size=25):
    i = 0
    while i < X.shape[0]:
        xs = X[i:i+batch_size]
        ys = y[i:i+batch_size]
        i += batch_size
        yield xs, ys

In [195]:
def decode_samples(samples):
    samples = samples.argmax(2)
    all_strain_names = []
    for sample in range(samples.shape[0]):
        strain_name = ""
        for time_step in range(samples.shape[1]):
            char_idx = samples[sample, time_step].item()
            char = idx2char[char_idx]
            if char in ('<BOS>', '<EOS>', '<MASK>'):
                continue
            else:
                strain_name += char
        all_strain_names.append(strain_name)
    return all_strain_names

In [196]:
def normalize_sizes(y_pred, y_true):
    if len(y_pred.size()) == 3:
        y_pred = y_pred.contiguous().view(-1, y_pred.size(2))
    if len(y_true.size()) == 2:
        y_true = y_true.contiguous().view(-1)
    return y_pred, y_true

def compute_accuracy(y_pred, y_true, mask_index):
    y_pred, y_true = normalize_sizes(y_pred, y_true)
    _, y_pred_indices = y_pred.max(dim=1)
    correct_indices = torch.eq(y_pred_indices, y_true).float()
    valid_indices = torch.ne(y_true, mask_index).float()
    n_correct = (correct_indices * valid_indices).sum().item()
    n_valid = valid_indices.sum().item()

    return n_correct / n_valid * 100

def sequence_loss(y_pred, y_true, mask_index):
    y_pred, y_true = normalize_sizes(y_pred, y_true)
    return torch.nn.functional.cross_entropy(y_pred, y_true, ignore_index=mask_index)

In [197]:
def make_train_state(learning_rate, model_state_file):
    return {'stop_early': False,
            'early_stopping_step': 0,
            'early_stopping_best_val': 1e8,
            'learning_rate': learning_rate,
            'epoch_index': 0,
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': [],
            'test_loss': -1,
            'test_acc': -1,
            'model_filename': model_state_file}


def update_train_state(early_stopping_criteria, model, train_state):
    """Handle the training state updates"""
    if train_state['epoch_index'] == 0:
        torch.save(model.state_dict(), train_state['model_filename'])
        train_state['stop_early'] = False

    elif train_state['epoch_index'] >= 1:
        loss_tm1, loss_t = train_state['val_loss'][-2:]
         
        # If loss worsened
        if loss_t >= loss_tm1:
            # Update step
            train_state['early_stopping_step'] += 1
        # Loss decreased
        else:
            # Save the best model
            if loss_t < train_state['early_stopping_best_val']:
                torch.save(model.state_dict(), train_state['model_filename'])
                train_state['early_stopping_best_val'] = loss_t
                
            # Reset early stopping step
            train_state['early_stopping_step'] = 0

        # Stop early ?
        train_state['stop_early'] = \
            train_state['early_stopping_step'] >= early_stopping_criteria

    return train_state

In [230]:
class StrainNameModel(torch.nn.Module):
    def __init__(self, char2idx, embed_size=100, lstm_size=128):
        super().__init__()
        self.embed = torch.nn.Embedding(len(char2idx), embed_size, padding_idx=0)
        self.lstm = torch.nn.LSTM(embed_size, lstm_size)
        self.fc = torch.nn.Linear(lstm_size, len(char2idx))
    
    def forward(self, X):
        out, _ = self.lstm(self.embed(X))
        batch_size, seq_size, feat_size = out.shape
        
        out = out.contiguous().view(batch_size * seq_size, feat_size)
        out = self.fc(torch.nn.functional.dropout(out, p=0.75))
        
        # out = torch.nn.functional.softmax(out, dim=1)
        
        new_feat_size = out.shape[-1]
        out = out.view(batch_size, seq_size, new_feat_size)
        
        return out

In [238]:
mask_index = char2idx['<MASK>']
num_epochs = 1000
learning_rate = 0.00001
model_save_path = "model.pth"
batch_size = 4
early_stopping_criteria = 100

model = StrainNameModel(char2idx)

optimizer = torch.optim.Adam(model.parameters(), 
                             learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                       mode='min', factor=0.3,
                                                       patience=1)
train_state = make_train_state(learning_rate, model_save_path)
epoch_bar = tqdm_notebook(desc='training routine', 
                          total=num_epochs,
                          position=0)

train_bar = tqdm_notebook(desc='split=train',
                          total= 1 + (X_train.shape[0] // batch_size), 
                          position=1, 
                          leave=True)
val_bar = tqdm_notebook(desc='split=val',
                        total= 1 + (X_val.shape[0] // batch_size), 
                        position=1, 
                        leave=True)

try:
    for epoch_index in range(num_epochs):
        train_state['epoch_index'] = epoch_index

        # Iterate over training dataset

        # setup: batch generator, set loss and acc to 0, set train mode on
        running_loss = 0.0
        running_acc = 0.0
        model.train()
        
        for batch_index, (X, y) in enumerate(batch_iter(X_train, y_train, batch_size)):
            # the training routine is these 5 steps:

            # --------------------------------------    
            # step 1. zero the gradients
            optimizer.zero_grad()

            # step 2. compute the output
            y_pred = model(X)
            # step 3. compute the loss
            loss = sequence_loss(y_pred, y, mask_index)
            # step 4. use loss to produce gradients
            loss.backward()
            # step 5. use optimizer to take gradient step
            optimizer.step()
            
            # -----------------------------------------
            # compute the  running loss and running accuracy
            running_loss += (loss.item() - running_loss) / (batch_index + 1)
            acc_t = compute_accuracy(y_pred, y, mask_index)
            running_acc += (acc_t - running_acc) / (batch_index + 1)

            # update bar
            train_bar.set_postfix(loss=running_loss,
                                  acc=running_acc,
                                  epoch=epoch_index)
            train_bar.update()

        train_state['train_loss'].append(running_loss)
        train_state['train_acc'].append(running_acc)

        # Iterate over val dataset

        # setup: batch generator, set loss and acc to 0; set eval mode on
        running_loss = 0.
        running_acc = 0.
        model.eval()

        for batch_index, (X, y)  in enumerate(batch_iter(X_val, y_val, batch_size)):
            # compute the output
            y_pred = model(X)
            # step 3. compute the loss
            loss = sequence_loss(y_pred, y, mask_index)
            # compute the  running loss and running accuracy
            running_loss += (loss.item() - running_loss) / (batch_index + 1)
            acc_t = compute_accuracy(y_pred, y, mask_index)
            running_acc += (acc_t - running_acc) / (batch_index + 1)
            
            # Update bar
            val_bar.set_postfix(loss=running_loss, 
                                acc=running_acc, 
                                epoch=epoch_index)
            val_bar.update()

        train_state['val_loss'].append(running_loss)
        train_state['val_acc'].append(running_acc)

        train_state = update_train_state(early_stopping_criteria, 
                                         model, 
                                         train_state)

        scheduler.step(train_state['val_loss'][-1])

        if train_state['stop_early']:
            break
        
        # move model to cpu for sampling
        sampled_strains = random.sample(decode_samples(y_pred), 3)
        epoch_bar.set_postfix(sample1=sampled_strains[0], 
                              sample2=sampled_strains[1])
        
        train_bar.n = 0
        val_bar.n = 0
        epoch_bar.update()
        
except KeyboardInterrupt:
    print("Exiting loop")

HBox(children=(IntProgress(value=0, description='training routine', max=1000, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='split=train', max=251, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='split=val', max=75, style=ProgressStyle(description_width='in…

Exiting loop


In [208]:
[idx2char[i.item()] for i in y[0] if i != 0]

['t', 'r', 'i', 'p', 'l', 'e', ' ', 'x', '<EOS>']

In [227]:
decode_samples(y)

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

In [229]:
random.sample(decode_samples(y_pred), 10)

['s  eeea   sa  ',
 'sae a ase e eeeee',
 'seeeeee s  ',
 's  e aeaaeeeeeeeee eeeeee',
 's e     sees a ee',
 's   e s a  ee saea',
 's e e   e a',
 'sa     ees ee',
 'sse esa  as e',
 'see  eae e ']

RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.FloatTensor instead (while checking arguments for embedding)