In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
from nb_007a import *

# IMDB

## Fine-tuning the LM

Data has been prepared in csv files at the beginning 007a, we will use it know.

### Loading the data

In [None]:
PATH = Path('../data/aclImdb/')
CLAS_PATH = PATH/'clas'
LM_PATH = PATH/'lm'
MODEL_PATH = PATH/'models'
os.makedirs(CLAS_PATH, exist_ok=True)
os.makedirs(LM_PATH, exist_ok=True)
os.makedirs(MODEL_PATH, exist_ok=True)

In [None]:
tokenizer = Tokenizer(rules=rules, special_cases=[BOS, FLD, UNK, PAD])
train_ds, valid_ds = TextDataset.from_csv(LM_PATH, tokenizer)

In [None]:
bs,bptt = 50,70
train_dl = LanguageModelLoader(np.concatenate(train_ds.ids), bs, bptt)
valid_dl = LanguageModelLoader(np.concatenate(valid_ds.ids), bs, bptt)

In [None]:
data = DataBunch(train_dl, valid_dl)

### Adapt the pre-trained weights to the new vocabulary

Download the pretrained model and the corresponding itos dictionary here and put them in the MODEL_PATH folder.

In [None]:
def replace(itos, tok1, tok2):
    itos[itos.index(tok1)] = tok2
    return itos

def apply_new_flags():
    "Temporary function to change the old special tokens by the new ones"
    itos_wt = pickle.load(open(MODEL_PATH/'itos.pkl', 'rb'))
    olds = ['_unk_', '_pad_', 'xbos', 'xfld', 'u_n', 't_up', 'tk_rep', 'tk_wrep']
    news = [UNK, PAD, BOS, FLD, UNK, TOK_UP, TK_REP, TK_WREP]
    for tok1,tok2 in zip(olds, news):
        itos_wt = replace(itos_wt, tok1, tok2)
    pickle.dump(itos_wt, open(MODEL_PATH/'itos.pkl', 'wb'))

In [None]:
#apply_new_flags()

In [None]:
itos_wt = pickle.load(open(MODEL_PATH/'itos.pkl', 'rb'))
stoi_wt = {v:k for k,v in enumerate(itos_wt)}

In [None]:
def convert_weights(wgts, stoi_wgts, itos_new):
    dec_bias, enc_wgts = wgts['1.decoder.bias'], wgts['0.encoder.weight']
    bias_m, wgts_m = dec_bias.mean(0), enc_wgts.mean(0)
    new_w = enc_wgts.new_zeros((len(itos_new),enc_wgts.size(1))).zero_()
    new_b = dec_bias.new_zeros((len(itos_new),)).zero_()
    for i,w in enumerate(itos_new):
        r = stoi_wgts[w] if w in stoi_wgts else -1
        new_w[i] = enc_wgts[r] if r>=0 else wgts_m
        new_b[i] = dec_bias[r] if r>=0 else bias_m
    wgts['0.encoder.weight'] = new_w
    wgts['0.encoder_dp.emb.weight'] = new_w.clone()
    wgts['1.decoder.weight'] = new_w.clone()
    wgts['1.decoder.bias'] = new_b
    return wgts

In [None]:
wgts = torch.load(MODEL_PATH/'lstm1.pth', map_location=lambda storage, loc: storage)

In [None]:
wgts['1.decoder.bias'][:10]

In [None]:
itos_wt[:10]

In [None]:
wgts = convert_weights(wgts, stoi_wt, train_ds.vocab.itos)

In [None]:
wgts['1.decoder.bias'][:10]

In [None]:
train_ds.vocab.itos[:10]

## Define the model

In [None]:
vocab_size = len(train_ds.vocab.itos)
emb_sz,nh,nl = 400,1150,3
dps = np.array([0.25, 0.1, 0.2, 0.02, 0.15])*0.5

In [None]:
model = get_language_model(vocab_size, emb_sz, nh, nl, 0, input_p=dps[0], output_p=dps[1], weight_p=dps[2], 
                           embed_p=dps[3], hidden_p=dps[4])
model.load_state_dict(wgts)

Separation in different groups for discriminitative lr and gradual unfreezing.

In [None]:
groups = [nn.Sequential(rnn, dp) for rnn, dp in zip(model[0].rnns, model[0].hidden_dps)] 
groups.append(nn.Sequential(model[0].encoder, model[0].encoder_dp, model[1]))

In [None]:
learn = Learner(data, model)
learn.layer_groups = groups
learn.callbacks.append(RNNTrainer(learn, bptt, alpha=2, beta=1))
learn.metrics = [accuracy]
learn.freeze()

In [None]:
lr_find(learn)

In [None]:
learn.recorder.plot()

In [None]:
learn.fit_one_cycle(1, 1e-2, moms=(0.8,0.7), wd=1e-7)

In [None]:
learn.unfreeze()
learn.save('fit_head')

In [None]:
learn.load('fit_head')
learn.fit_one_cycle(10, 1e-3, moms=(0.8,0.7), wd=1e-7)

In [None]:
learn.save('fine_tuned')

In [None]:
learn.load('fine_tuned')

In [None]:
def save_encoder(learn, name):
    torch.save(learn.model[0].state_dict(), learn.path/f'{name}.pth')

In [None]:
save_encoder(learn, 'fine_tuned_enc')

## Classifier

We need to use the same itos than the language model.

In [None]:
shutil.copy(LM_PATH/'tmp'/'itos.pkl', CLAS_PATH/'tmp'/'itos.pkl')

In [None]:
vocab = Vocab(CLAS_PATH/'tmp')
tokenizer = Tokenizer(rules=rules, special_cases=[BOS, FLD, UNK, PAD])
train_ds, valid_ds = TextDataset.from_csv(CLAS_PATH, tokenizer, vocab=vocab)

In [None]:
train_ds.vocab.itos[:20], vocab.itos[:20]

In [None]:
train_ds.vocab.textify(train_ds.ids[1]), train_ds.labels[1]

In [None]:
train_ds.vocab.textify(train_ds.ids[5]), train_ds.labels[5]

In [None]:
train_ds.vocab.textify(valid_ds.ids[15]), valid_ds.labels[15]

In [None]:
from torch.utils.data import Sampler, BatchSampler

class SortSampler(Sampler):
    "Go through the text data by order of length"
    
    def __init__(self, data_source, key): self.data_source,self.key = data_source,key
    def __len__(self): return len(self.data_source)
    def __iter__(self):
        return iter(sorted(range(len(self.data_source)), key=self.key, reverse=True))


class SortishSampler(Sampler):
    "Go through the text data by order of length with a bit of randomness"
    
    def __init__(self, data_source, key, bs):
        self.data_source,self.key,self.bs = data_source,key,bs

    def __len__(self): return len(self.data_source)

    def __iter__(self):
        idxs = np.random.permutation(len(self.data_source))
        sz = self.bs*50
        ck_idx = [idxs[i:i+sz] for i in range(0, len(idxs), sz)]
        sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
        sz = self.bs
        ck_idx = [sort_idx[i:i+sz] for i in range(0, len(sort_idx), sz)]
        max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx])  # find the chunk with the largest key,
        ck_idx[0],ck_idx[max_ck] = ck_idx[max_ck],ck_idx[0]     # then make sure it goes first.
        sort_idx = np.concatenate(np.random.permutation(ck_idx[1:]))
        sort_idx = np.concatenate((ck_idx[0], sort_idx))
        return iter(sort_idx)

In [None]:
class TextClasDataLoader():
    "Regroups sequences in a batch."
    
    def __init__(self, text_ds:TextDataset, bs:int=64, sampler:Sampler=None, pad_idx=1, drop_last=False):
        self.ids,self.labels = text_ds.ids,text_ds.labels
        self.bs,self.pad_idx = bs,pad_idx
        if sampler is None: sampler = iter(range(len(self.ids)))
        self.batch_sampler = BatchSampler(sampler, bs, drop_last)
    
    def __len__(self): return len(self.batch_sampler)
    
    def __iter__(self):
        for samples in self.batch_sampler:
            max_len = max([len(self.ids[i])  for i in samples])
            res = torch.zeros(max_len, self.bs).long() + self.pad_idx
            labels = []
            for i,idx in enumerate(samples):
                seq = self.ids[idx]
                res[-len(seq):,i] = LongTensor(seq)
                labels.append(self.labels[idx])
            yield res, LongTensor(labels)

In [None]:
train_sampler = SortishSampler(train_ds.ids, key=lambda x: len(train_ds.ids[x]), bs=bs//2)
valid_sampler = SortSampler(valid_ds.ids, key=lambda x: len(valid_ds.ids[x]))
train_dl = TextClasDataLoader(train_ds, bs//2, train_sampler)
valid_dl = TextClasDataLoader(valid_ds, bs,    valid_sampler)
data = DataBunch(train_dl, valid_dl)

In [None]:
x,y = next(iter(valid_dl))

In [None]:
vocab.textify(x[:,2]), y[2]

In [None]:
vocab.textify(x[:,6]), y[6]

In [None]:
class MultiBatchRNNCore(RNNCore):
    def __init__(self, bptt, max_seq, *args, **kwargs):
        self.max_seq,self.bptt = max_seq,bptt
        super().__init__(*args, **kwargs)

    def concat(self, arrs):
        return [torch.cat([l[si] for l in arrs]) for si in range(len(arrs[0]))]

    def forward(self, input):
        sl,bs = input.size()
        self.reset()
        raw_outputs, outputs = [],[]
        for i in range(0, sl, self.bptt):
            r, o = super().forward(input[i: min(i+self.bptt, sl)])
            if i>(sl-self.max_seq):
                raw_outputs.append(r)
                outputs.append(o)
        return self.concat(raw_outputs), self.concat(outputs)

In [None]:
def bn_dp_lin(n_in, n_out, drop, relu=True): 
    layers = [nn.BatchNorm1d(n_in), nn.Dropout(drop), nn.Linear(n_in, n_out)]
    if relu: layers.append(nn.ReLU(inplace=True))
    return layers

In [None]:
class PoolingLinearClassifier(nn.Module):
    def __init__(self, layers, drops):
        super().__init__()
        lyrs = []
        for i in range(len(layers)-1):
            lyrs += bn_dp_lin(layers[i], layers[i + 1], drops[i], i!=len(layers)-2)
        self.layers = nn.Sequential(*lyrs)

    def pool(self, x, bs, is_max):
        f = F.adaptive_max_pool1d if is_max else F.adaptive_avg_pool1d
        return f(x.permute(1,2,0), (1,)).view(bs,-1)

    def forward(self, input):
        raw_outputs, outputs = input
        output = outputs[-1]
        sl,bs,_ = output.size()
        avgpool = self.pool(output, bs, False)
        mxpool = self.pool(output, bs, True)
        x = torch.cat([output[-1], mxpool, avgpool], 1)
        x = self.layers(x)
        return x, raw_outputs, outputs

In [None]:
def get_rnn_classifier(bptt, max_seq, n_class, vocab_sz, emb_sz, n_hid, n_layers, pad_token, layers, drops, 
                       bidir=False, qrnn=False, hidden_p=0.2, input_p=0.6, embed_p=0.1, weight_p=0.5):
    rnn_enc = MultiBatchRNNCore(bptt, max_seq, vocab_sz, emb_sz, n_hid, n_layers, pad_token=pad_token, bidir=bidir,
                      qrnn=qrnn, hidden_p=hidden_p, input_p=input_p, embed_p=embed_p, weight_p=weight_p)
    return SequentialRNN(rnn_enc, PoolingLinearClassifier(layers, drops))

In [None]:
vocab_size,n_class = len(train_ds.vocab.itos),2
emb_sz,nh,nl = 400,1150,3
dps = np.array([0.25, 0.1, 0.2, 0.02, 0.15])*0.7
#dps = np.array([0.4,0.5,0.05,0.3,0.4])*0.5

In [None]:
model = get_rnn_classifier(bptt, 20*70, n_class, vocab_size, emb_sz=emb_sz, n_hid=nh, n_layers=nl, pad_token=1,
          layers=[emb_sz*3, 50, n_class], drops=[dps[1], 0.1],
          input_p=dps[0], weight_p=dps[2], embed_p=dps[3], hidden_p=dps[4])

In [None]:
groups = [nn.Sequential(model[0].encoder, model[0].encoder_dp)]
groups += [nn.Sequential(rnn, dp) for rnn, dp in zip(model[0].rnns, model[0].hidden_dps)] 
groups.append(model[1])

In [None]:
model[1]

In [None]:
learn = Learner(data, model)
learn.layer_groups = groups
learn.callbacks.append(RNNTrainer(learn, bptt, alpha=2, beta=1, adjust=False))
learn.callback_fns.append(partial(GradientClipping, clip=0.12))
learn.metrics = [accuracy]
learn.freeze()

In [None]:
def load_encoder(learn, name):
    learn.model[0].load_state_dict(torch.load(learn.path/f'{name}.pth'))

In [None]:
load_encoder(learn, 'fine_tuned_enc')

In [None]:
learn.lr_find()

In [None]:
learn.recorder.plot()

In [None]:
learn.fit_one_cycle(1, 1e-2, moms=(0.8,0.7), wd=0)

In [None]:
learn.save('first')

In [None]:
lr=5e-3
lrm = 2.6
lrs = np.array([lr/(lrm**4), lr/(lrm**3), lr/(lrm**2), lr/lrm, lr])

In [None]:
learn.freeze_to(-2)
learn.fit_one_cycle(1, lrs, moms=(0.8,0.7), wd=0)

In [None]:
learn.save('second')

In [None]:
learn.unfreeze()
learn.fit_one_cycle(2, lrs, moms=(0.8,0.7), wd=0)