# 12. A Language Model from Scratch 

In [2]:
from fastai.text.all import * 

  from .autonotebook import tqdm as notebook_tqdm


## The Data

In [3]:
print(URLs.HUMAN_NUMBERS)
path = untar_data(URLs.HUMAN_NUMBERS); path.ls()

https://s3.amazonaws.com/fast-ai-sample/human_numbers.tgz


(#2) [Path('/home/mchristos/.fastai/data/human_numbers/valid.txt'),Path('/home/mchristos/.fastai/data/human_numbers/train.txt')]

In [4]:
lines = L()
with open(path/'train.txt') as f: lines += L(*f.readlines())
with open(path/'valid.txt') as f: lines += L(*f.readlines())
lines

(#9998) ['one \n','two \n','three \n','four \n','five \n','six \n','seven \n','eight \n','nine \n','ten \n'...]

In [5]:
text = ' . '.join([l.strip() for l in lines]); 
text[:500], text[-500:]

('one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fourteen . fifteen . sixteen . seventeen . eighteen . nineteen . twenty . twenty one . twenty two . twenty three . twenty four . twenty five . twenty six . twenty seven . twenty eight . twenty nine . thirty . thirty one . thirty two . thirty three . thirty four . thirty five . thirty six . thirty seven . thirty eight . thirty nine . forty . forty one . forty two . forty three . forty four . forty fi',
 'eighty seven . nine thousand nine hundred eighty eight . nine thousand nine hundred eighty nine . nine thousand nine hundred ninety . nine thousand nine hundred ninety one . nine thousand nine hundred ninety two . nine thousand nine hundred ninety three . nine thousand nine hundred ninety four . nine thousand nine hundred ninety five . nine thousand nine hundred ninety six . nine thousand nine hundred ninety seven . nine thousand nine hundred ninety eight . nine thousand nine hundred nine

In [6]:
tokens = text.split(' ')
tokens[:10], L(tokens[-10:])

(['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.'],
 (#10) ['hundred','ninety','eight','.','nine','thousand','nine','hundred','ninety','nine'])

In [7]:
vocab = L(*tokens).unique()
print(vocab)

['one', '.', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen', 'twenty', 'thirty', 'forty', 'fifty', 'sixty', 'seventy', 'eighty', 'ninety', 'hundred', 'thousand']


the whole vocab is pretty short! just 30 elements 

In [8]:
# convert our dataset of tokens into numbers 
word2idx = {w: i for i,w in enumerate(vocab)}
nums = L(word2idx[t] for t in tokens)
nums 

(#63095) [0,1,2,1,3,1,4,1,5,1...]

## Our First Language Model from Scratch 

In [9]:
L((tokens[i:i+3],tokens[i+3]) for i in range(0,len(tokens)-4,3))

(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]

In [22]:
seqs = L((tensor(nums[i:i+3]),nums[i+3]) for i in range(0,len(nums)-4,3)); seqs

(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10,  1, 11]), 1),(tensor([ 1, 12,  1]), 13),(tensor([13,  1, 14]), 1),(tensor([ 1, 15,  1]), 16)...]

In [23]:
bs = 64 
cut = int(len(seqs)*0.8)
train, valid = seqs[:cut], seqs[cut:]
train, valid

((#16824) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10,  1, 11]), 1),(tensor([ 1, 12,  1]), 13),(tensor([13,  1, 14]), 1),(tensor([ 1, 15,  1]), 16)...],
 (#4207) [(tensor([ 1,  8, 29]), 26),(tensor([26,  5,  1]), 8),(tensor([ 8, 29, 26]), 6),(tensor([6, 1, 8]), 29),(tensor([29, 26,  7]), 1),(tensor([ 1,  8, 29]), 26),(tensor([26,  8,  1]), 8),(tensor([ 8, 29, 26]), 9),(tensor([9, 1, 8]), 29),(tensor([29, 27,  1]), 8)...])

In [25]:
dls = DataLoaders.from_dsets(train,valid,bs=bs,shuffle=False)

## Our Language Model in PyTorch 

In [26]:
class LMModel1(Module):
    def __init__(self, vocab_sz, n_hidden):
        # input -> hidden 
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        # hidden -> hidden
        self.h_h = nn.Linear(n_hidden, n_hidden)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
    
    def forward(self, x):
        h = F.relu(self.h_h(self.i_h(x[:,0])))
        h = h + self.i_h(x[:,1])
        h = h + F.relu(self.h_h(h))
        h = h + self.i_h(x[:,2])
        h = h + F.relu(self.h_h(h))
        return self.h_o(h)

In [28]:
dls.train_ds[0]

(tensor([0, 1, 2]), 1)

In [32]:
learn = Learner(
    dls, 
    LMModel1(vocab_sz=len(vocab), n_hidden=64),
    loss_func=F.cross_entropy, 
    metrics=accuracy
)
learn.fit_one_cycle(4,lr_max=1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,1.971377,2.157408,0.413121,00:02
1,1.494654,1.922625,0.45662,00:02
2,1.466575,1.716544,0.495603,00:02
3,1.433506,1.701716,0.500832,00:02


In [39]:
n,counts = 0,torch.zeros(len(vocab))
# x and y are batches! 
for x,y in dls.valid:
    n += y.shape[0]
    for i in y:
        counts[i] += 1
counts

tensor([106., 637., 159., 107., 106., 159., 108., 106., 464., 442.,   6.,   7.,
          6.,   6.,   7.,   6.,   6.,   7.,   6.,   6.,  64.,  63.,  63.,  64.,
         63.,  63.,  66.,  66., 600., 638.])

In [44]:
# the most common token is 'thousand' 
i_max = torch.argmax(counts)
vocab[i_max], counts[i_max]/n

('thousand', tensor(0.1517))

In [43]:
class LMModel2(Module):
    def __init__(self, vocab_sz, n_hidden):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.h_h = nn.Linear(n_hidden, n_hidden)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
    
    def forward(self, x):
        h = 0 
        for i in range(3):
            h = h + self.i_h(x[:,i])
            h = F.relu(self.h_h(h))
        return self.h_o(h)

In [45]:
learn = Learner(dls, LMModel2(len(vocab), n_hidden=64), loss_func=F.cross_entropy,metrics=accuracy)
learn.fit_one_cycle(4,1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,1.873248,1.960821,0.464939,00:02
1,1.390725,1.915783,0.446399,00:02
2,1.421181,1.725936,0.478488,00:02
3,1.376492,1.814246,0.411457,00:02
