In [1]:
import torch
from torch import nn, optim
import RVQE

In [2]:
torch.set_num_threads(2)

Our goal is to create a RNN or LSTM with roughly 1965 parameters, and compare it in the dna long sequence task implemented within RVQE.

In [3]:
dataset_t = lambda length: RVQE.datasets.all_datasets["dna"](0, num_shards=0, batch_size=16, sentence_length=length)

In [4]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
def to_one_hot(labels, num_classes=2**3):
    return torch.eye(num_classes)[labels]

In [5]:
SEEDS = [9120, 2783, 2057, 6549, 3201, 7063, 5243, 3102, 5303, 5819, 3693, 4884, 2231, 5514, 8850, 6861, 3106, 2378, 8697, 1821, 9480, 8483, 1633, 9678, 6596, 4509, 8618, 9765, 6346, 2969];
LENGTHS = [5, 10, 20, 50, 100, 200, 500, 1000];

# RNN

In [6]:
class SimpleRNN(nn.Module):
    """
        This is a very simplistic RNN setup. We found a single layer performs
        much better than two layers with a smaller hidden size.
        Without doubt one can improve the performance of this model.
        Yet we didn't optimize the QRNN setup for the task at hand either.
    """
    def __init__(self, io_size=2**3, hidden_size=80):
        super().__init__()
        
        self.rnn = nn.RNN(input_size=io_size, hidden_size=hidden_size, num_layers=2, batch_first=True)
        self.lin = nn.Linear(hidden_size, io_size)
        
    def reset(self):
        self.lin.reset_parameters()
        for name, param in self.rnn.named_parameters():
            # give an orthogonal start
            if "weight_hh" in name:
                torch.nn.init.orthogonal_(param.data)
            elif "bias" in name:
                param.data.fill_(0)
            elif "weight_ih" in name:
                torch.nn.init.xavier_uniform_(param.data)
            else:
                raise Exception(f"cannot initialize {name}")
        
    @property
    def num_parameters(self):
        return count_parameters(self.rnn) + count_parameters(self.lin)
        
    def forward(self, sentence):
        rnn_out, _ = self.rnn(sentence)
        return self.lin(rnn_out)

In [7]:
SimpleRNN().num_parameters

20808

In [18]:
results = {}

In [None]:
for length in LENGTHS:
    
    dataset = dataset_t(length)
    print(f"creating RNN with {SimpleRNN().num_parameters} parameters")
    
    criterion = nn.CrossEntropyLoss()
    
    results[length] = results[length] if length in results else []
    
    for seed in SEEDS:
        if seed in [ s for s, _ in results[length] ]:
            continue
            
        torch.manual_seed(seed)
        model = SimpleRNN()
        model.reset()
        optimizer = optim.Adam(model.parameters(), lr=0.01)  # this has been found to converge fastest
        
        for step in range(1, 16 * 1000):  # cap amounts to the same number of samples seen as for qrnn
            sentence, target = dataset.next_batch(0, RVQE.data.TrainingStage.TRAIN)
            
            # transform sentence to one-hot as in the qrnn case
            sentence = to_one_hot(RVQE.data.targets_for_loss(sentence))            
            
            optimizer.zero_grad()
            out = model(sentence.float())
            
            # unlike the qrnn case, we use the entire output as loss
            # this gives the rnn an advantage!
            out = out.transpose(1, 2)
            target = RVQE.data.targets_for_loss(target)
            loss = criterion(out, target)
            
            loss.backward()
            optimizer.step()
            
            if loss < 0.0005:
                results[length].append([seed, step])
                print(f"length {length} converged after {step} steps.")
                break
            
            if step % 500 == 0:
                print(f"{step:06d} {loss:.2e}")
                
        else:
            print(f"length {length} did not converge after step steps.")
            results[length].append([seed, step])

creating RNN with 20808 parameters
creating RNN with 20808 parameters
creating RNN with 20808 parameters
creating RNN with 20808 parameters
000500 1.34e+00
001000 1.37e+00
001500 1.53e+00
002000 1.36e+00
002500 1.55e+00
003000 1.31e+00
003500 1.33e+00
004000 1.42e+00
004500 1.57e+00
005000 1.39e+00
005500 1.59e+00
006000 1.33e+00
006500 1.32e+00
007000 1.37e+00
007500 1.50e+00
008000 1.44e+00
008500 1.39e+00
009000 1.58e+00
009500 1.28e+00
010000 1.39e+00
010500 1.41e+00
011000 1.23e+00
011500 1.22e+00
012000 1.53e+00
012500 1.41e+00
013000 1.32e+00
013500 1.39e+00
014000 1.43e+00
014500 1.26e+00
015000 1.49e+00
015500 1.45e+00
length 50 did not converge after step steps.
000500 1.43e+00
001000 1.43e+00
001500 1.37e+00
002000 1.32e+00
002500 1.27e+00
003000 1.22e+00
003500 1.43e+00
004000 1.16e+00
004500 1.34e+00
005000 1.28e+00
005500 1.14e+00
006000 1.47e+00
006500 1.27e+00
007000 1.35e+00
007500 1.28e+00
008000 1.28e+00
008500 1.39e+00
009000 1.36e+00
009500 1.32e+00
010000 1.89e+00

In [27]:
import pandas as pd

In [43]:
pd.DataFrame([ [key, seed, step, .0] for key in results for seed, step in results[key] ], columns=["sentence_length", "seed", "hparams/epoch", "hparams/validate_best"], index=None).to_csv("~/long-rnn.csv")

In [30]:
results.items() 

dict_items([(5, [[9120, 48], [2783, 45], [2057, 43], [6549, 57], [3201, 53], [7063, 65], [5243, 44], [3102, 41], [5303, 53], [5819, 50], [3693, 47], [4884, 47], [2231, 49], [5514, 46], [8850, 58], [6861, 42], [3106, 40], [2378, 68], [8697, 44], [1821, 46], [9480, 47], [8483, 53], [1633, 53], [9678, 49], [6596, 43], [4509, 43], [8618, 46], [9765, 46], [6346, 44], [2969, 49]]), (10, [[9120, 386], [2783, 276], [2057, 285], [6549, 304], [3201, 387], [7063, 432], [5243, 216], [3102, 352], [5303, 298], [5819, 415], [3693, 262], [4884, 317], [2231, 386], [5514, 342], [8850, 436], [6861, 424], [3106, 294], [2378, 285], [8697, 331], [1821, 348], [9480, 299], [8483, 419], [1633, 374], [9678, 401], [6596, 412], [4509, 422], [8618, 385], [9765, 277], [6346, 602], [2969, 302]]), (20, [[9120, 15999], [2783, 15999], [2057, 15999], [6549, 15999], [3201, 15999], [7063, 15999], [5243, 15999], [3102, 15999], [5303, 15999], [5819, 15999], [3693, 15999], [4884, 15999], [2231, 15999], [5514, 15999], [8850, 