# Language model using a GRU implemented from scratch

### Setup

In [1]:
#hide
! [ -e /content ] && pip install -Uqq fastbook
import fastbook
fastbook.setup_book()

In [2]:
#hide
from fastbook import *
import random
import numpy as np
import torch

In [3]:
from fastai.text.all import *
path = untar_data(URLs.HUMAN_NUMBERS)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

#hide
Path.BASE_PATH = path
path.ls()

(#2) [Path('valid.txt'),Path('train.txt')]

In [4]:
lines = L()
with open(path/'train.txt') as f: lines += L(*f.readlines())
with open(path/'valid.txt') as f: lines += L(*f.readlines())
lines

(#9998) ['one \n','two \n','three \n','four \n','five \n','six \n','seven \n','eight \n','nine \n','ten \n'...]

In [5]:
text = ' . '.join([l.strip() for l in lines])
text[:100]

'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'

In [6]:
tokens = text.split(' ')
tokens[:10]

['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']

In [7]:
vocab = L(*tokens).unique()
vocab

(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]

In [8]:
word2idx = {w:i for i,w in enumerate(vocab)}
nums = L(word2idx[i] for i in tokens)
nums

(#63095) [0,1,2,1,3,1,4,1,5,1...]

The block below will create a training and validation dataset.
Each dataset will have x has a sequence of consecutive words of length 16 and same for y but shifted by one 

In [9]:
def group_chunks(ds, bs):
    m = len(ds) // bs
    new_ds = L()
    for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))
    return new_ds

sl = 16
bs = 64
seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))
             for i in range(0,len(nums)-sl-1,sl))
cut = int(len(seqs) * 0.8)
dls = DataLoaders.from_dsets(
    group_chunks(seqs[:cut], bs), 
    group_chunks(seqs[cut:], bs), 
    bs=bs, drop_last=True, shuffle=False)

In [10]:
x, y = dls.train.one_batch()

In [11]:
(x.shape, y.shape)

(torch.Size([64, 16]), torch.Size([64, 16]))

In [12]:
print("Example of x: " + " ".join([vocab[s] for s in x[0]]))
print("Example of y: " + " ".join([vocab[s] for s in y[0]]))

Example of x: one . two . three . four . five . six . seven . eight .
Example of y: . two . three . four . five . six . seven . eight . nine


## Language model with custom GRU

### Single layer GRU

In [13]:
class GRUCell(Module):
    def __init__(self, n_input, n_hidden):
        self.reset_gate = nn.Linear(n_input + n_hidden, n_hidden)
        self.update_gate = nn.Linear(n_input + n_hidden, n_hidden)
        self.candidate_gate = nn.Linear(n_input + n_hidden, n_hidden)
        
    def forward(self, x, h):
        x_h = torch.cat((x, h), dim=1)
        reset = torch.sigmoid(self.reset_gate(x_h))
        update = torch.sigmoid(self.update_gate(x_h))
        x_reset_h = torch.cat((x, h*reset), dim=1)
        candidate = torch.tanh(self.candidate_gate(x_reset_h))
        return (1 - update)*h + update*candidate

class GRU(Module):
    def __init__(self, n_input, n_hidden):
        self.gru_cell = GRUCell(n_input, n_hidden)

    def forward(self, x, h):
        outs = []
        next_h = h
        for i in range(x.shape[1]):
            next_h = self.gru_cell(x[:, i, :], next_h)
            outs.append(next_h)
        return torch.stack(outs, dim=1), next_h
        
class GRUWrapper(Module):
    def __init__(self, vocab_sz, n_hidden ,p):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = GRU(n_hidden, n_hidden)
        self.drop = nn.Dropout(p)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h_o.weight = self.i_h.weight
        self.h = torch.zeros(bs, n_hidden)
        
    def forward(self, x):
        raw_out, next_h = self.rnn(self.i_h(x), self.h)
        self.h = next_h.detach()
        out = self.drop(raw_out)
        return self.h_o(out),raw_out,out
    
    def reset(self): 
        self.h = self.h.zero_()
        
learn = TextLearner(dls, GRUWrapper(len(vocab), 64, 0.4),
                    loss_func=CrossEntropyLossFlat(), metrics=accuracy)

learn.fit_one_cycle(15, 1e-2, wd=0.1)

epoch,train_loss,valid_loss,accuracy,time
0,2.905448,2.114603,0.328451,00:02
1,2.064665,1.770513,0.454753,00:03
2,1.475262,1.23237,0.590495,00:06
3,0.930198,0.774314,0.755046,00:01
4,0.558953,0.74125,0.776449,00:02
5,0.354457,0.665194,0.800944,00:01
6,0.232769,0.460846,0.872559,00:11
7,0.167412,0.487501,0.857503,00:01
8,0.125297,0.409847,0.892497,00:04
9,0.098681,0.433908,0.880371,00:01


In [14]:
### Multi layer GRU

In [24]:
class GRUCell(Module):
    def __init__(self, n_input, n_hidden):
        self.reset_gate = nn.Linear(n_input + n_hidden, n_hidden)
        self.update_gate = nn.Linear(n_input + n_hidden, n_hidden)
        self.candidate_gate = nn.Linear(n_input + n_hidden, n_hidden)
        
    def forward(self, x, h):
        x_h = torch.cat((x, h), dim=1)
        reset = torch.sigmoid(self.reset_gate(x_h))
        update = torch.sigmoid(self.update_gate(x_h))
        x_reset_h = torch.cat((x, h*reset), dim=1)
        candidate = torch.tanh(self.candidate_gate(x_reset_h))
        return (1 - update)*h + update*candidate

class GRU(Module):
    def __init__(self, n_input, n_hidden, n_layers):
        self.layers = []
        self.layers = [GRUCell(n_input, n_hidden) for _ in range(n_layers)]

    def forward(self, x, h):
        outs = []
        current_h = h
        for i in range(x.shape[1]):
            next_h = []
            next_input = x[:, i, :]
            for j, layer in enumerate(self.layers):
                next_input = layer(next_input, current_h[j])
                next_h.append(next_input)
            current_h = next_h
            outs.append(next_input)
        return torch.stack(outs, dim=1), torch.stack(current_h)
        
class GRUWrapper(Module):
    def __init__(self, vocab_sz, n_hidden, n_layers, p):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = GRU(n_hidden, n_hidden, n_layers)
        self.drop = nn.Dropout(p)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h = torch.zeros(n_layers, bs, n_hidden)
        
    def forward(self, x):
        raw_out, h = self.rnn(self.i_h(x), self.h)
        self.h = h.detach()
        out = self.drop(raw_out)
        return self.h_o(out),raw_out,out
    
    def reset(self): 
        self.h = self.h.zero_()
        
learn = TextLearner(dls, GRUWrapper(len(vocab), 64, 2, 0.5),
                    loss_func=CrossEntropyLossFlat(), metrics=accuracy)

learn.fit_one_cycle(15, 1e-1, wd=0.1)

epoch,train_loss,valid_loss,accuracy,time
0,1.810028,1.70142,0.469482,00:18
1,1.467555,1.806924,0.488688,00:17
2,1.370051,1.748739,0.497721,00:02
3,1.431406,1.743019,0.546224,00:05
4,1.358479,1.631945,0.539225,00:06
5,1.335646,1.681896,0.54012,00:13
6,1.233187,1.404806,0.592773,00:05
7,1.041691,1.371786,0.63444,00:02
8,0.844506,1.169821,0.719727,00:02
9,0.612875,0.753473,0.737142,00:08
