In [1]:
%load_ext autoreload
%autoreload 2

MODEL TRAINING

In [2]:
import time
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch.optim import Adam
from pathlib import Path
from torch.utils.data import TensorDataset, DataLoader
from struct2seq.rna_features import RNAFeatures, PositionalEncodings
from struct2seq.rna_struct2seq import RNAStruct2Seq
from struct2seq import noam_opt

In [3]:
def loss_nll(S, log_probs, mask):
    """ Negative log probabilities """
    criterion = torch.nn.NLLLoss(reduction='none')
    loss = criterion(
        log_probs.contiguous().view(-1, log_probs.size(-1)), S.contiguous().view(-1)
    ).view(S.size())
    loss_av = torch.sum(loss * mask) / torch.sum(mask)
    return loss, loss_av




def loss_smoothed(S, log_probs, mask, weight=0.1, vocab_size=6):
    """ Negative log probabilities """
    S_onehot = torch.nn.functional.one_hot(S, num_classes=vocab_size).float()

    # Label smoothing
    S_onehot = S_onehot + weight / float(S_onehot.size(-1))
    S_onehot = S_onehot / S_onehot.sum(-1, keepdim=True)

    loss = -(S_onehot * log_probs).sum(-1)
    loss_av = torch.sum(loss * mask) / torch.sum(mask)
    return loss, loss_av


In [4]:

basedir = Path('.').resolve()
processed_dir = basedir / 'data/rna/processed_for_ml'


dist_map = torch.tensor(np.load(processed_dir / 'distance_map.npy'), device='cpu')
X_train = torch.load(processed_dir / 'train.pt')
X_val = torch.load(processed_dir / 'val.pt')

train_dl = DataLoader(TensorDataset(X_train), batch_size=10)
val_dl = DataLoader(TensorDataset(X_val), batch_size=10, shuffle=False)

In [5]:
# Hyperparameters -- partially based on Ingraham 
vocab_size = 6
num_node_feats = 64
num_edge_feats = 64
hidden_dim = 64
num_encoder_layers = 1
num_decoder_layers = 3

smoothing_weight = 0.1



k_nbrs = 10
device = 'cuda'

model = RNAStruct2Seq(vocab_size, num_node_feats, num_edge_feats, dist_map, hidden_dim, num_encoder_layers, num_decoder_layers, k_nbrs)
model.to(device)

num_epochs = 2
optimizer = noam_opt.get_std_opt(model.parameters(), hidden_dim)
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.NLLLoss(reduction='none')

In [7]:
# Log files
log_folder = basedir / 'log' / 'test_rna'
log_folder.mkdir(exist_ok=True)

logfile = log_folder / 'log.txt'
with open(logfile, 'w') as f:
    f.write('Epoch\tTrain\tValidation\n')

start_train = time.time()
epoch_losses_train, epoch_losses_valid = [], []
epoch_checkpoints = []
total_step = 0
num_epochs=10
for e in range(num_epochs):
    # Training epoch
    model.train()
    train_sum, train_weights = 0., 0.
    for train_i, S in tqdm(enumerate(train_dl), total=len(train_dl)):
        S = S[0].to(device)
        
        mask = torch.ones_like(S)
        start_batch = time.time()

        optimizer.zero_grad()
        log_probs = model(S)
        _, loss_av_smoothed = loss_smoothed(S, log_probs, mask, weight=smoothing_weight)
        loss_av_smoothed.backward()
        optimizer.step()

        loss, loss_av = loss_nll(S, log_probs, mask)

        # Timing
        elapsed_batch = time.time() - start_batch
        elapsed_train = time.time() - start_train
        total_step += 1
        #print(total_step, elapsed_train, np.exp(loss_av.cpu().data.numpy()), np.exp(loss_av_smoothed.cpu().data.numpy()))


        # Accumulate true loss
        train_sum += torch.sum(loss * mask).cpu().data.numpy()
        train_weights += torch.sum(mask).cpu().data.numpy()
    print(f"Train Loss: {train_sum / train_weights:.3f}")

    with torch.no_grad():
        val_sum = 0.0 
        val_weights = 0.0 
        model.eval()
        for val_i, S in tqdm(enumerate(val_dl), total=len(val_dl)):
            S = S[0].to(device)
            mask = torch.ones_like(S)
            log_probs = model(S)
            loss, loss_av = loss_nll(S, log_probs, mask)
            val_sum += torch.sum(loss * mask).item()
            val_weights += torch.sum(mask).item()
    print(f"Val Loss: {val_sum / val_weights:.3f}")

 13%|█▎        | 337/2604 [00:36<04:03,  9.30it/s]


KeyboardInterrupt: 