In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from fastai.learner import *

import torchtext
from torchtext import vocab, data
from torchtext.datasets import language_modeling

from fastai.rnn_reg import *
from fastai.rnn_train import *
from fastai.nlp import *
from fastai.lm_rnn import *

import dill as pickle
import spacy

In [3]:
with open('data/aclImdb/models/TEXT.pkl','rb') as pklf:
    TEXT = pickle.load(pklf)

In [4]:
spacy_tok = spacy.load('en')

In [5]:
bs=32; bptt=20

In [6]:
PATH='data/'

TRN_PATH = 'kk_train/'
VAL_PATH = 'kk_valid/'
TRN = f'{PATH}{TRN_PATH}'
VAL = f'{PATH}{VAL_PATH}'

%ls {PATH}

[0m[01;34maclImdb[0m/                    [01;34mkk_train[0m/        [01;34mmodels[0m/      [01;34mwiki_valid[0m/
cleaned-extra-kk-jokes.txt  [01;31mkk_train.tar.gz[0m  [01;34mtmp[0m/
extra_jokes.p               [01;34mkk_valid[0m/        wiki_en.txt
jokes_ds.p                  [01;31mkk_valid.tar.gz[0m  [01;34mwiki_train[0m/


In [7]:
FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)
md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)

In [9]:
print(
f'batches: {len(md.trn_dl)}\nunique tokens: {md.nt}\ntokens in training set: {len(md.trn_ds)}\nsentences: {len(md.trn_ds[0].text)}'
)

batches: 27
unique tokens: 37392
tokens in training set: 1
sentences: 17922


### Train

In [10]:
em_sz = 200  # size of each embedding vector
n_layers_act = 500     # number of hidden activations per layer
n_layers = 3      # number of layers

In [11]:
opt_fn = partial(optim.Adam, betas=(0.7, 0.99))

In [12]:
learner = md.get_model(opt_fn, em_sz, n_layers_act, n_layers,
               dropouti=0.05, dropout=0.05, wdrop=0.1, dropoute=0.02, dropouth=0.05)
learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
learner.clip=0.2

In [13]:
learner.load_encoder('adam3_20_enc')

In [14]:
learner.fit(3e-4, 4, wds=1e-6, cycle_len=1, cycle_mult=2)

epoch      trn_loss   val_loss                            
    0      4.918121   4.239881  
    1      4.127887   3.433996                            
    2      3.708695   3.258515                            
    3      3.441216   3.19427                             
    4      3.115167   2.730807                            
    5      2.886957   2.682442                            
    6      2.733479   2.647665                            
    7      2.662317   2.555713                            
    8      2.540474   2.483023                            
    9      2.430573   2.467169                            
    10     2.328552   2.533455                            
    11     2.235601   2.4704                              
    12     2.171336   2.445823                            
    13     2.134262   2.482945                            
    14     2.129152   2.493166                            



[2.4931664]

In [15]:
learner.save_encoder('kkadam_2')

In [16]:
learner.fit(3e-3, 1, wds=1e-6, cycle_len=10)

epoch      trn_loss   val_loss                            
    0      2.204255   2.411452  
    1      2.061987   2.412294                            
    2      1.844674   2.525815                            
    3      1.613543   2.506678                            
    4      1.390922   2.538958                            
    5      1.234058   2.600569                            
    6      1.096162   2.602224                            
    7      1.023965   2.561308                             
    8      0.933281   2.667499                             
    9      0.872674   2.603978                             



[2.6039782]

In [17]:
learner.save_encoder('kkadam2_2')

In [18]:
learner.fit(3e-3, 3, wds=1e-6, cycle_len=1, cycle_mult=3)

epoch      trn_loss   val_loss                             
    0      0.945318   2.673752  
    1      0.909592   2.714366                             
    2      0.849857   2.710828                             
    3      0.825967   2.633471                             
    4      0.829777   2.805096                             
    5      0.807564   2.791475                             
    6      0.750853   2.805592                             
    7      0.714836   2.780608                             
    8      0.661802   2.779977                             
    9      0.607998   2.799284                             
    10     0.572157   2.843715                             
    11     0.535977   2.833151                             
    12     0.514901   2.769182                             



[2.769182]

In [19]:
learner.save_encoder('kkadam3_2')

### Test

In [20]:
m=learner.model
ss="""knock knock who's there"""
ss=''.join(w.lower() for w in ss)
s = [TEXT.tokenize(ss)]
t=TEXT.numericalize(s)
' '.join(s[0])

"knock knock who 's there"

In [21]:
# Set batch size to 1
m[0].bs=1
# Turn off dropout
m.eval()
# Reset hidden state
m.reset()
# Get predictions from model
res,*_ = m(t)
# Put the batch size back to what it was
m[0].bs=bs

In [22]:
nexts = torch.topk(res[-1], 10)[1]
[TEXT.vocab.itos[o] for o in to_np(nexts)]

['!', '?', '.', '<eos>', 'who', ';', 'there', 'and', '...', ',']

In [23]:
import random

In [28]:
tk=8
pct=.75
print(ss,"\n")
for i in range(100):
    if random.random()>pct:
        n=res[-1].topk(tk)[1]
        n = n[random.randint(0,tk-1)] if n.data[0]==0 else n[random.randint(1,tk-1)]
    else:
        n=res[-1].topk(2)[1]
        n = n[1] if n.data[0]==0 else n[0]
    print(TEXT.vocab.itos[n.data[0]], end=' ')
    res,*_ = m(n[0].unsqueeze(0))
print('...')

knock knock who's there 

to open this door ? ? <eos> do you want to meet another knock - knock joke . <eos> do you have an moustache or an answer . <eos> knock , let 's be there ? butter butter who ? butter be quick , you have to go through the window . . <eos> knock , knock ! who ’s there ? i m a . i m the only one left ! ! <eos> knock , knock . . who who ? you really want to hear another another knock - knock joke ? please open up . who ...
