# TL;DR

1. In this lab scenario you will have a chance to compare performance of the classic RNN and LSTM on a toy example. 
2. This toy example will show that maintaining memory over even 20 steps is non-trivial. 
3. Finally, you will see how curriculum learning may allow to train a model on larger sequences.

# Problem definition

Here we consider a toy example, where the goal is to discriminate between two types of binary sequences:
* [Type 0] a sequence with exactly one zero (remaining entries are equal to one).
* [Type 1] a sequence full of ones,

We are especially interested in the performance of the trained models on discriminating between a sequence full of ones versus a sequence with leading zero followed by ones. Note that in this case the goal of the model is to output the first element of the sequence, as the label (sequence type) is fully determined by the first element of the sequence.

#Implementation

## Importing torch

Install `torch` and `torchvision`

In [1]:
!pip3 install torch torchvision



In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

torch.manual_seed(1)

<torch._C.Generator at 0x7f67984393d0>

## Understand dimensionality

Check the input and output specification [LSTM](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html) and [RNN](https://pytorch.org/docs/stable/generated/torch.nn.RNN.html). The following snippet shows how we can process
a sequence by LSTM and output a vector of size `hidden_dim` after reading
each token of the sequence. 

In [2]:
seq_len = 10
input_dim = 1
hidden_dim = 5
lstm = nn.LSTM(input_dim, hidden_dim)  # Input sequence contains elements - vectors of size 1

# create a random sequence
sequence = [torch.randn(input_dim) for _ in range(seq_len)]

# initialize the hidden state (including cell state)
hidden = (torch.zeros(1, 1, hidden_dim),
          torch.zeros(1, 1, hidden_dim))

for i, elem in enumerate(sequence):
  # we are processing only a single element of the sequence, and there
  # is only one sample (sequence) in the batch, the third one
  # corresponds to the fact that our sequence contains elemenents,
  # which can be treated as vectors of size 1
  out, hidden = lstm(elem.view(1, 1, input_dim), hidden)
  print(f'i={i} out={out.detach()}')
print(f'Final hidden state={hidden[0].detach()} cell state={hidden[1].detach()}')

i=0 out=tensor([[[-0.0675,  0.1179,  0.1081,  0.0414, -0.0341]]])
i=1 out=tensor([[[-0.1067,  0.1726,  0.1400,  0.0902, -0.0596]]])
i=2 out=tensor([[[-0.1148,  0.1885,  0.1956,  0.0974, -0.0840]]])
i=3 out=tensor([[[-0.1270,  0.2031,  0.1495,  0.1249, -0.0860]]])
i=4 out=tensor([[[-0.1281,  0.2019,  0.1810,  0.1475, -0.1027]]])
i=5 out=tensor([[[-0.1274,  0.2060,  0.0798,  0.1330, -0.0860]]])
i=6 out=tensor([[[-0.1318,  0.2039,  0.0997,  0.1772, -0.1011]]])
i=7 out=tensor([[[-0.1145,  0.2008, -0.0431,  0.1051, -0.0717]]])
i=8 out=tensor([[[-0.1289,  0.1989,  0.0515,  0.1944, -0.1030]]])
i=9 out=tensor([[[-0.1329,  0.1920,  0.0686,  0.1772, -0.0988]]])
Final hidden state=tensor([[[-0.1329,  0.1920,  0.0686,  0.1772, -0.0988]]]) cell state=tensor([[[-0.2590,  0.4080,  0.1307,  0.4329, -0.2895]]])


## To implement

Process the whole sequence all at once by calling `lstm` only once and check that the output is exactly the same as above (remember to initialize the hidden state the same way).

In [3]:
seq_len = 10
input_dim = 1
hidden_dim = 5
lstm = nn.LSTM(1, hidden_dim)  # Input sequence contains elements - vectors of size 1

# create a random sequence
sequence = torch.randn((seq_len, 1, input_dim))

# initialize the hidden state (including cell state)
hidden = (torch.zeros(1, 1, hidden_dim),
          torch.zeros(1, 1, hidden_dim))

out, (h, c) = lstm(sequence, hidden)

h.eq(out[-1]).all()

tensor(True)

## Training a model

Below we define a very simple model, which is a single layer of LSTM, where the output in each time step is processed by relu followed by a single fully connected layer, the output of which is a single number. We are going
to use the number generated after reading the last element of the sequence,
which will serve as the logit for our classification problem.

In [4]:
class Model(nn.Module):

    def __init__(self, rnn, hidden_dim):
        super(Model, self).__init__()
        self.hidden_dim = hidden_dim
        self.rnn = rnn(1, self.hidden_dim)
        self.hidden2label = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        out, _ = self.rnn(x)
        sequence_len = x.shape[0]
        logits = self.hidden2label(F.relu(out[-1].view(-1))) # assumption that only 1 batch
        return logits

Below is a training loop, where we only train on the two hardest examples.

In [23]:
def eval_on_hard_examples(model, hard_examples):
    with torch.no_grad():
        probs = []
        for sequence, _ in hard_examples:
            input = sequence.view(-1, 1, 1)
            logit = model(input)
            probs.append(torch.sigmoid(logit.detach()))
        print(f'Probs for hard examples={probs}')
        return probs[0] < 0.01 and probs[1] > 0.99


def train_model(rnn, hidden_dim, lr, num_steps=10000, train_examples=None, hard_examples=None, cl=False):
    # Pairs of (sequence, label)

    model = Model(rnn=rnn, hidden_dim=hidden_dim)
    loss_function = nn.BCEWithLogitsLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.99)

    for step in range(num_steps):  
        if step % 100 == 0:
            exceed = eval_on_hard_examples(model, hard_examples)
            if exceed:
              if cl:
                print(f'Next seqs have length {hard_examples[0][0].shape[0] + 1}')
                seqs = torch.ones((2, hard_examples[0][0].shape[0] + 1))
                seqs[0][0] = 0
                hard_examples = [(seqs[0], 0),
                                (seqs[1], 1)]
                print(hard_examples)
              else:
                break



        for sequence, label in train_examples:
            model.zero_grad()
            logit = model(sequence.view(-1, 1, 1))  
            
            loss = loss_function(logit.view(-1), torch.tensor([label], dtype=torch.float32))
            loss.backward()

            optimizer.step()   

In [21]:
seq_len = 5
seqs = torch.ones((2, seq_len))
seqs[0][0] = 0
hard_examples = [(seqs[0], 0),
                 (seqs[1], 1)]

train_model(rnn=nn.LSTM, hidden_dim=20, lr=0.01, num_steps=10000, train_examples=hard_examples, hard_examples=hard_examples)

Probs for hard examples=[tensor([0.4979]), tensor([0.4978])]
Probs for hard examples=[tensor([0.5054]), tensor([0.5067])]
Probs for hard examples=[tensor([0.4448]), tensor([0.5452])]
Probs for hard examples=[tensor([0.0002]), tensor([0.9998])]


## To implement
Note that for steps 2-4 you may need to change the value of `num_steps`.


1. Check for what values of `SEQUENCE_LEN` the model is able to discriminate betweeh the two hard examples (after training).

In [14]:
seq_len = 10
seqs = torch.ones((2, seq_len))
seqs[0][0] = 0
hard_examples = [(seqs[0], 0),
                 (seqs[1], 1)]

train_model(rnn=nn.LSTM, hidden_dim=20, lr=0.01, num_steps=10000, train_examples=hard_examples, hard_examples=hard_examples)

Probs for hard examples=[tensor([0.5312]), tensor([0.5313])]
Probs for hard examples=[tensor([0.5046]), tensor([0.5046])]
Probs for hard examples=[tensor([0.4960]), tensor([0.4960])]
Probs for hard examples=[tensor([0.5007]), tensor([0.5008])]
Probs for hard examples=[tensor([0.5010]), tensor([0.5018])]
Probs for hard examples=[tensor([0.0484]), tensor([0.9899])]
Probs for hard examples=[tensor([3.2362e-05]), tensor([0.9999])]


od 1-7 jest dość szybko,
dla 8-9 załapuje po chwili

2. Instead of training on `HARD_EXAMPLES` only, modify the training loop to train on sequences where zero may be in any position of the sequence (so any valid sequence of `Type 0`, not just the hardest one). After modifying the training loop check for what values of `SEQUENCE_LEN` you can train the model successfully.

In [15]:
# ale w sensie, że jeden taki ciąg zamiast najtrudniejszego czy wiele?
bs = 4
seq_len = 10
train_examples = torch.ones((bs, seq_len))
type0 = torch.randperm(bs)[:bs//2]
for i, j in zip(type0, torch.LongTensor(bs//2).random_(0, seq_len)):
  train_examples[i][j] = 0

train_examples = [(exp, int(i not in type0)) for i, exp in enumerate(train_examples)]
train_examples

[(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 1),
 (tensor([1., 0., 1., 1., 1., 1., 1., 1., 1., 1.]), 0),
 (tensor([1., 1., 1., 1., 1., 1., 0., 1., 1., 1.]), 0),
 (tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 1)]

In [16]:
seqs = torch.ones((2, seq_len))
seqs[0][0] = 0
hard_examples = [(seqs[0], 0),
                 (seqs[1], 1)]

train_model(rnn=nn.LSTM, hidden_dim=20, lr=0.01, num_steps=10000, train_examples=train_examples, hard_examples=hard_examples)

Probs for hard examples=[tensor([0.5193]), tensor([0.5192])]
Probs for hard examples=[tensor([0.4627]), tensor([0.4928])]
Probs for hard examples=[tensor([0.4726]), tensor([0.4727])]
Probs for hard examples=[tensor([0.4931]), tensor([0.5123])]
Probs for hard examples=[tensor([0.0015]), tensor([0.9975])]


3. Replace LSTM by a classic RNN and check for what values of `SEQUENCE_LEN` you can train the model successfully.

In [18]:
seq_len = 10
seqs = torch.ones((2, seq_len))
seqs[0][0] = 0
hard_examples = [(seqs[0], 0),
                 (seqs[1], 1)]

train_model(rnn=nn.RNN, hidden_dim=20, lr=0.01, num_steps=10000, train_examples=hard_examples, hard_examples=hard_examples)

Probs for hard examples=[tensor([0.5599]), tensor([0.5599])]
Probs for hard examples=[tensor([0.0034]), tensor([0.9972])]


4. Write a proper curricullum learning loop, where in a loop you consider longer and longer sequences, where expansion of the sequence length happens only after the model is trained successfully on the current length.

In [25]:
seq_len = 2
seqs = torch.ones((2, seq_len))
seqs[0][0] = 0
hard_examples = [(seqs[0], 0),
                 (seqs[1], 1)]

train_model(rnn=nn.LSTM, hidden_dim=20, lr=0.01, num_steps=10000, train_examples=hard_examples, hard_examples=hard_examples, cl=True)

Probs for hard examples=[tensor([0.4797]), tensor([0.4789])]
Probs for hard examples=[tensor([0.4049]), tensor([0.6410])]
Probs for hard examples=[tensor([8.4739e-05]), tensor([0.9995])]
Next seqs have length 3
[(tensor([0., 1., 1.]), 0), (tensor([1., 1., 1.]), 1)]
Probs for hard examples=[tensor([4.9995e-06]), tensor([0.9999])]
Next seqs have length 4
[(tensor([0., 1., 1., 1.]), 0), (tensor([1., 1., 1., 1.]), 1)]
Probs for hard examples=[tensor([2.7092e-06]), tensor([0.9999])]
Next seqs have length 5
[(tensor([0., 1., 1., 1., 1.]), 0), (tensor([1., 1., 1., 1., 1.]), 1)]
Probs for hard examples=[tensor([2.2392e-06]), tensor([0.9999])]
Next seqs have length 6
[(tensor([0., 1., 1., 1., 1., 1.]), 0), (tensor([1., 1., 1., 1., 1., 1.]), 1)]
Probs for hard examples=[tensor([2.0600e-06]), tensor([0.9999])]
Next seqs have length 7
[(tensor([0., 1., 1., 1., 1., 1., 1.]), 0), (tensor([1., 1., 1., 1., 1., 1., 1.]), 1)]
Probs for hard examples=[tensor([1.9739e-06]), tensor([1.0000])]
Next seqs hav