In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from sentimentanalyser.utils.data import (
    Path, pad_collate, grandparent_splitter, random_splitter)
from sentimentanalyser.utils.data import parent_labeler
from sentimentanalyser.data.text import TextList, SplitData
from sentimentanalyser.utils.files import pickle_dump, pickle_load



In [3]:
from sentimentanalyser.preprocessing.processor import TokenizerProcessor
from sentimentanalyser.preprocessing.processor import NuemericalizeProcessor
from sentimentanalyser.preprocessing.processor import CategoryProcessor
from sentimentanalyser.utils.data import read_wiki

In [4]:
from sentimentanalyser.utils.callbacks import sched_cos, combine_scheds
from sentimentanalyser.callbacks.training import (
    LRFind, CudaCallback, GradientClipping, RNNCustomTrainer)
from sentimentanalyser.callbacks.progress import ProgressCallback
from sentimentanalyser.callbacks.scheduler import ParamSchedulerCustom
from sentimentanalyser.callbacks.stats import AvgStatsCallback
from sentimentanalyser.callbacks.recorder import RecorderCustom
from sentimentanalyser.training.trainer import Trainer

In [5]:
from sentimentanalyser.optimizers import adam_opt
from sentimentanalyser.utils.metrics import accuracy
from sentimentanalyser.utils.callbacks import combine_scheds, sched_cos, cos_1cycle_anneal
from sentimentanalyser.utils.callbacks import create_phases

In [6]:
from sentimentanalyser.preprocessing.tokens import TOKENS

In [7]:
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from functools import partial

In [8]:
path_imdb = Path("/home/anukoolpurohit/Documents/AnukoolPurohit/Datasets/imdb")
path_wiki = Path("/home/anukoolpurohit/Documents/AnukoolPurohit/Datasets/wikitext-103")
path_cache = Path('/home/anukoolpurohit/Documents/AnukoolPurohit/Models/WordEmbeddings')

In [9]:
path_model = Path('/home/anukoolpurohit/Documents/AnukoolPurohit/Models')

In [10]:
# """Uncomment to download the pretrained awd_lstm model from fastai"""
# path = path_model
# !wget http://files.fast.ai/models/wt103_tiny.tgz -P {path} --no-check-certificate
# !tar xf {path}/wt103_tiny.tgz -C {path}

In [11]:
proc_tok = TokenizerProcessor()
proc_num = NuemericalizeProcessor()

In [12]:
# tl_imdb_lm = TextList.from_files(path=path_imdb, folders=['train','test', 'unsup'])
# sd_imdb_lm = tl_imdb_lm.split_by_func(partial(random_splitter, p=[0.9, 0.1]))
# ll_imdb_lm = sd_imdb_lm.label_by_func(lambda x:0, proc_x=[proc_tok, proc_num])

In [13]:
# pickle_dump(ll_imdb_lm, 'dumps/variable/ll_imdb_lm.pickle')

In [14]:
ll_imdb_lm = pickle_load('dumps/variable/ll_imdb_lm.pickle')

In [15]:
bs, bptt = 64, 70
imdb_lm_data = ll_imdb_lm.lm_databunchify(bs, bptt)

In [16]:
vocab = ll_imdb_lm.train.proc_x[1].vocab

In [17]:
tok_pad = vocab.index(TOKENS.PAD)

In [18]:
from sentimentanalyser.models.rnn import EncDecLanguageModel

In [19]:
def cross_entropy_flat(input, target):
    bs, sl = target.size()
    return F.cross_entropy(input.view(bs * sl, -1), target.view(bs * sl))

In [20]:
def accuracy_flat(input, target):
    bs, sl = target.size()
    return accuracy(input.view(bs * sl, -1), target.view(bs * sl))

In [21]:
def get_basic(Model, vocab, **kwargs):
    model = Model(len(vocab), **kwargs)
    loss_func = cross_entropy_flat
    opt = adam_opt()(model.parameters())
    return model, loss_func, opt

In [22]:
sched = combine_scheds([0.3, 0.7], [sched_cos(1e-4, 1e-3), sched_cos(1e-3, 3e-5)])

In [None]:
cbs = [partial(AvgStatsCallback, [accuracy_flat]),
       partial(ParamSchedulerCustom,'lr', [sched]),
       partial(GradientClipping, clip=0.1),
       ProgressCallback,
       CudaCallback,
       partial(RNNCustomTrainer, α=2., β=1.),
       RecorderCustom]

In [None]:
old_wgts = torch.load(path_model/'pretrained'/'pretrained.pth')
old_vocab = pickle_load(path_model/'pretrained'/'vocab.pkl')

In [None]:
model,_,_ = get_basic(EncDecLanguageModel, vocab, dropout=0.5)

In [None]:
for name, param in model.named_parameters():
    print(name)

encoder.embeddings.weight
encoder.rnn.rnns.0.weight_hh_l0_raw
encoder.rnn.rnns.0.module.weight_ih_l0
encoder.rnn.rnns.0.module.weight_hh_l0
encoder.rnn.rnns.0.module.bias_ih_l0
encoder.rnn.rnns.0.module.bias_hh_l0
encoder.rnn.rnns.1.weight_hh_l0_raw
encoder.rnn.rnns.1.module.weight_ih_l0
encoder.rnn.rnns.1.module.weight_hh_l0
encoder.rnn.rnns.1.module.bias_ih_l0
encoder.rnn.rnns.1.module.bias_hh_l0
decoder.decoder.bias


In [None]:
len(old_wgts.keys())

14

In [None]:
len(list(model.named_parameters()))

12

In [None]:
for key in old_wgts.keys():
    print(key)

0.emb.weight
0.emb_dp.emb.weight
0.rnns.0.weight_hh_l0_raw
0.rnns.0.module.weight_ih_l0
0.rnns.0.module.weight_hh_l0
0.rnns.0.module.bias_ih_l0
0.rnns.0.module.bias_hh_l0
0.rnns.1.weight_hh_l0_raw
0.rnns.1.module.weight_ih_l0
0.rnns.1.module.weight_hh_l0
0.rnns.1.module.bias_ih_l0
0.rnns.1.module.bias_hh_l0
1.decoder.weight
1.decoder.bias


In [None]:
def match_embeds(old_wgts, old_vocab, new_vocab):
    wgts = old_wgts['0.emb.weight']
    bias = old_wgts['1.decoder.bias']
    wgts_m,bias_m = wgts.mean(dim=0),bias.mean()
    new_wgts = wgts.new_zeros(len(new_vocab), wgts.size(1))
    new_bias = bias.new_zeros(len(new_vocab))
    otoi = {v:k for k,v in enumerate(old_vocab)}
    for i,w in enumerate(new_vocab): 
        if w in otoi:
            idx = otoi[w]
            new_wgts[i],new_bias[i] = wgts[idx],bias[idx]
        else: new_wgts[i],new_bias[i] = wgts_m,bias_m
    old_wgts['0.emb.weight']    = new_wgts
    old_wgts['0.emb_dp.emb.weight'] = new_wgts
    old_wgts['1.decoder.weight']    = new_wgts
    old_wgts['1.decoder.bias']      = new_bias
    return old_wgts

In [None]:
wgts = match_embeds(old_wgts, old_vocab, vocab)

In [None]:
model.state_dict

<bound method Module.state_dict of EncDecLanguageModel(
  (encoder): AWDLSTMEncoder(
    (embeddings): Embedding(60002, 300, padding_idx=1)
    (embeddings_dropout): EmbeddingsWithDropout(
      (embeddings): Embedding(60002, 300, padding_idx=1)
    )
    (rnn): AWDLSTM(
      (rnns): ModuleList(
        (0): WeightDropout(
          (module): LSTM(300, 300, batch_first=True)
        )
        (1): WeightDropout(
          (module): LSTM(300, 300, batch_first=True)
        )
      )
      (hidden_dropouts): ModuleList(
        (0): RNNDropout()
        (1): RNNDropout()
      )
    )
    (input_dropout): RNNDropout()
  )
  (decoder): LinearDecoder(
    (output_dp): RNNDropout()
    (decoder): Linear(in_features=300, out_features=60002, bias=True)
  )
)>

In [None]:
trainer = Trainer(imdb_lm_data, *get_basic(EncDecLanguageModel, vocab, dropout=0.5), cb_funcs=cbs)

In [None]:
trainer.fit()

epoch,train_loss,train_accuracy_flat,valid_loss,valid_accuracy_flat,time
0,5.766635,0.1589,4.784607,0.23582,15:35
