# Stateful recurrent models

In the previous notebook *Character-based language model from scratch* we arrived at a point where we had a model that used characters 0 to 7, for example, to predict characters 1 to 8.

However, the problem is that when predicting the next set of characters, the hidden state of the rnn is reset to zero. Let's improve on this by preserving the hidden state (while detaching it from its history).

Difference to prev notebook: In the previous model the first bs=512 consecutive sequences of bptt=8 characters were part of the first minibatch, etc... When preserving the hidden state, you don't want this. Instead you want to split the entire text into bs=512 equal sized sequences.

In [1]:
from torchtext import vocab, data

from fastai.nlp import *
from fastai.lm_rnn import *

In [2]:
PATH = 'data/nietzsche/'

TRN_PATH = 'trn/'
VAL_PATH = 'val/'
TRN = f'{PATH}{TRN_PATH}'
VAL = f'{PATH}{VAL_PATH}'

In [3]:
# Field is a set of instructions on how to process the text
# list('abc') gives ['a', 'b', 'c'], so the tokens
# are the unique characters

TEXT = data.Field(lower=True, tokenize=list) 
bs=64
bptt=8
n_emb=42
n_hidden=256

In [4]:
FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)

In [5]:
modeldata = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=3)

In [6]:
len(modeldata.trn_dl), modeldata.nt, len(modeldata.trn_ds), len(modeldata.trn_ds[0].text)  
# number of batches, nt = number of unique tokens

(947, 55, 1, 485749)

## RNN

In [29]:
class CharSeqStatefulRnn(nn.Module):
    def __init__(self, voc_size, n_emb, bs, nh):
        super().__init__()
        self.voc_size = voc_size
        self.nh = nh
        self.e = nn.Embedding(self.voc_size, n_emb)
        self.rnn = nn.RNN(n_emb, nh)
        self.l_out = nn.Linear(nh, voc_size)
        self.init_hidden_state(bs)

    def forward(self, cs):
        bs = cs[0].size(0)
        # The last minibatch might be smaller
        # We need to account for this:
        if self.h.size(1) != bs:
            self.init_hidden_state(bs)
        outp, h = self.rnn(self.e(cs), self.h)
        
        # Wraps h in new Variables, to detach it 
        # from its history of operations.
        # Backprop will stop here
        self.h = repackage_var(h)
        
        # We need to flatten the output (view)
        # because loss functions in pytorch
        # currently don't support rank 3 tensors
        # (bs, bptt, n_voc)
        # target is flattened automatically
        return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.voc_size)
    
    
    def init_hidden_state(self, bs):
        self.h = V(torch.zeros(1, bs, self.nh))
        

In [30]:
model = CharSeqStatefulRnn(modeldata.nt, n_emb, 512, n_hidden).cuda()

In [31]:
opt = optim.Adam(model.parameters(), 1e-3)

In [32]:
fit(model, modeldata, 4, opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      1.869997   1.858461  
    1      1.689194   1.703229                               
    2      1.59974    1.638832                               
    3      1.55035    1.597093                               



[array([1.59709])]

In [33]:
opt = optim.Adam(model.parameters(), 1e-4)

In [34]:
fit(model, modeldata, 4, opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      1.47539    1.557163  
    1      1.473396   1.551019                               
    2      1.47469    1.547668                               
    3      1.464748   1.543166                               



[array([1.54317])]

## Let's look deeper into nn.RNN

From pytorch source code:

```
def RNNCell(input, hidden, w_ih, w_hh, b_ih, b_hh):
   return F.tanh(F.linear(input, w_ih, b_ih) +
          F.linear(hidden, w_hh, b_hh))
```

Interestingly pytorch does not concat but simply add.

*Tanh* is often used because it appears to be better at avoiding exploding gradients than *relu*.

Obviously do not do it like this by hand unless you want to use a new type of cell or other new concept that does not exist in pytorch yet.

In [37]:
class CharSeqStatefulRnn2(nn.Module):
    def __init__(self, voc_size, n_emb, bs, nh):
        super().__init__()
        self.voc_size = voc_size
        self.nh = nh
        self.e = nn.Embedding(self.voc_size, n_emb)
        self.rnn = nn.RNNCell(n_emb, nh)
        self.l_out = nn.Linear(nh, voc_size)
        self.init_hidden_state(bs)

    def forward(self, cs):
        bs = cs[0].size(0)
        if self.h.size(1) != bs:
            self.init_hidden_state(bs)
        
        # To append the outputs
        outp = []
        o = self.h
        for c in cs:
            o = self.rnn(self.e(c), o)
            outp.append(o)
            
        outp = self.l_out(torch.stack(outp))
        self.h = repackage_var(o)
        
        return F.log_softmax(outp, dim=-1).view(-1, self.voc_size)
    
    def init_hidden_state(self, bs):
        self.h = V(torch.zeros(1, bs, self.nh))
        

In [40]:
model = CharSeqStatefulRnn2(modeldata.nt, n_emb, 512, n_hidden).cuda()
opt = optim.Adam(model.parameters(), 1e-3)

In [41]:
fit(model, modeldata, 4, opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      1.8718     1.85822   
    1      1.687939   1.69713                                
    2      1.60298    1.628003                               
    3      1.549262   1.599414                               



[array([1.59941])]

**Nobody uses this rnn cell in practice. Gradient vanishing/exploding is still a problem despite the tanh so you need small learning rates and low values for bptt. Use GRU or LSTM instead.**

## GRU (Gated recurrent unit)

From the pytorch source:

```
def GRUCell(input, hidden, w_ih, w_hh, b_ih, b_hh):
    gi = F.linear(input, w_ih, b_ih)
    gh = F.linear(hidden, w_hh, b_hh)
    i_r, i_i, i_n = gi.chunk(3, 1)
    h_r, h_i, h_n = gh.chunk(3, 1)

    resetgate = F.sigmoid(i_r + h_r)
    inputgate = F.sigmoid(i_i + h_i)
    newgate = F.tanh(i_n + resetgate * h_n)
    return newgate + inputgate * (hidden - newgate)
```

To solve the vanishing gradient problem, GRU uses the so callec **update gate** and **reset gate**.

**Update gate:**

The update gate determines how much of the previous information needs to be passed to the future.

$z_t = \sigma\left(W_z\cdot[h_{t-1}, x_t]\right)$

**Reset Gate:**

Determines how much of the past information to forget. The network could for example learn to forget almost everything once it finds a "."

$r_t = \sigma\left(W_r\cdot[h_{t-1}, x_t]\right)$


**Use the reset gate to calculate the current memory content:**

$\tilde h_t=\tanh\left(W\cdot[r_t\ast h_{t-1}, xt]\right)$

**Calculate the new hidden state as interpolation of old hidden state and current memory content using the update gate:**

$h_t=(1-z_t)\ast h_{t-1} + z_t\ast\tilde h_t$

In [42]:
class CharSeqStatefulGRU(nn.Module):
    def __init__(self, voc_size, n_emb, bs, nh):
        super().__init__()
        self.voc_size = voc_size
        self.nh = nh
        self.e = nn.Embedding(self.voc_size, n_emb)
        self.rnn = nn.GRU(n_emb, nh)
        self.l_out = nn.Linear(nh, voc_size)
        self.init_hidden_state(bs)

    def forward(self, cs):
        bs = cs[0].size(0)
        # The last minibatch might be smaller
        # We need to account for this:
        if self.h.size(1) != bs:
            self.init_hidden_state(bs)
        
        outp, h = self.rnn(self.e(cs), self.h)
        
        # Wraps h in new Variables, to detach it 
        # from its history of operations.
        # Backprop will stop here
        self.h = repackage_var(h)
        
        # We need to flatten the output (view)
        # because loss functions in pytorch
        # currently don't support rank 3 tensors
        # (bs, bptt, n_voc)
        # target is flattened automatically
        return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.voc_size)
    
    
    def init_hidden_state(self, bs):
        self.h = V(torch.zeros(1, bs, self.nh))
        

In [43]:
model = CharSeqStatefulGRU(modeldata.nt, n_emb, 512, n_hidden).cuda()

In [44]:
opt = optim.Adam(model.parameters(), 1e-3)

In [45]:
fit(model, modeldata, 6, opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=6), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      1.743342   1.736539  
    1      1.562659   1.581666                               
    2      1.477294   1.526515                               
    3      1.42534    1.490105                               
    4      1.386104   1.468811                               
    5      1.356454   1.464969                               



[array([1.46497])]

In [46]:
opt = optim.Adam(model.parameters(), 1e-4)

In [47]:
fit(model, modeldata, 3, opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      1.271521   1.429878  
    1      1.27452    1.426042                               
    2      1.267621   1.425261                               



[array([1.42526])]

## LSTM

LSTM has an additional *cell state* compared to GRU. You therefore need to initialize an empty *tuple* of hidden states..

In [7]:
from fastai import sgdr

n_hidden = 512

In [24]:
class CharSeqStatefulLSTM(nn.Module):
    def __init__(self, voc_size, n_emb, bs, nh, nl):
        super().__init__()
        self.voc_size = voc_size
        self.nl = nl
        self.nh = nh
        self.e = nn.Embedding(self.voc_size, n_emb)
        self.rnn = nn.LSTM(n_emb, nh, nl, dropout = 0.5)
        self.l_out = nn.Linear(nh, voc_size)
        self.init_hidden_state(bs)

    def forward(self, cs):
        bs = cs[0].size(0)
        if self.h[0].size(1) != bs:
            self.init_hidden_state(bs)
        
        outp, h = self.rnn(self.e(cs), self.h)
        self.h = repackage_var(h)
        
        return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.voc_size)
    
    def init_hidden_state(self, bs):
        self.h = (V(torch.zeros(self.nl, bs, self.nh)),
                  V(torch.zeros(self.nl, bs, self.nh)))

In [25]:
model = CharSeqStatefulLSTM(modeldata.nt, n_emb, 512, n_hidden, 2).cuda()

In [26]:
lo = LayerOptimizer(optim.Adam, model, lrs=1e-2, wds=1e-5)

In [27]:
os.makedirs(f'{PATH}models', exist_ok=True)

In [29]:
fit(model, modeldata, 2, lo.opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      1.636604   1.592825  
    1      1.610442   1.5788                                 



[array([1.5788])]

In [30]:
on_end = lambda sched, cycle: save_model(model, f'{PATH}models/cyc_{cycle}')

cb = [CosAnneal(lo, len(modeldata.trn_dl), cycle_mult=2, on_cycle_end=on_end)]

fit(model, modeldata, 2**4 - 1, lo.opt, F.nll_loss, callbacks=cb)

HBox(children=(IntProgress(value=0, description='Epoch', max=15), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      1.464206   1.454684  
    1      1.51953    1.49333                                
    2      1.411895   1.416746                               
    3      1.550094   1.503142                               
    4      1.486141   1.469236                               
    5      1.405634   1.411643                               
    6      1.350604   1.379368                               
    7      1.534413   1.517684                               
    8      1.505779   1.487957                               
    9      1.477565   1.462818                               
    10     1.447093   1.445088                               
    11     1.411204   1.415879                               
    12     1.368016   1.382721                               
    13     1.332156   1.363008                               
    14     1.305085   1.350068                               



[array([1.35007])]

In [31]:
on_end = lambda sched, cycle: save_model(model, f'{PATH}models/cyc_{cycle}')

cb = [CosAnneal(lo, len(modeldata.trn_dl), cycle_mult=2, on_cycle_end=on_end)]

fit(model, modeldata, 2**6 - 1, lo.opt, F.nll_loss, callbacks=cb)

HBox(children=(IntProgress(value=0, description='Epoch', max=63), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      1.292349   1.347956  
    1      1.2951     1.347379                               
    2      1.29271    1.345982                               
    3      1.294296   1.344165                               
    4      1.289099   1.341886                               
    5      1.281891   1.340308                               
    6      1.283003   1.339759                               
    7      1.28202    1.340669                               
    8      1.2719     1.338476                               
    9      1.265462   1.337162                               
    10     1.263195   1.335769                               
    11     1.260501   1.334463                               
    12     1.253951   1.33328                                
    13     1.249763   1.332153                               
    14     1.251737   1.331728                               
    15     1.259258   1.334247       

[array([1.36642])]

## Test

In [32]:
def get_next(inp):
    idxs = TEXT.numericalize(inp)
    p = model(VV(idxs.transpose(0, 1)))
    r = torch.multinomial(p[-1].exp(), 1)
    return TEXT.vocab.itos[to_np(r)[0]]

In [33]:
def get_next_n(inp, n):
    result = inp
    for i in range(n):
        char = get_next(inp)
        result += char
        inp = inp[1:] + char
        
    return result

In [34]:
print(get_next_n('for those', 400))

for those exclusion of the root of through their anth--life as yet grow theleops, again, difficulty, divined claal man, weel viced agrown,diffule, trained, and afwords of history of this all godand depth, to overlooks or to other. for this hand. how possiblity! so that one must necessarily responsible, sequently fredom!" or oven our culture to be expediency, instinct, rationary evidence, again philosophy--


**Ok, Nietzsche does sound a little better :)**