In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from sentimentanalyser.utils.data  import Path, pad_collate, grandparent_splitter
from sentimentanalyser.utils.data  import parent_labeler, listify
from sentimentanalyser.data.text   import TextList, ItemList, SplitData
from sentimentanalyser.utils.files import pickle_dump, pickle_load

from sentimentanalyser.preprocessing.processor import TokenizerProcessor, NuemericalizeProcessor
from sentimentanalyser.preprocessing.processor import CategoryProcessor

In [None]:
from functools import partial

In [None]:
path_imdb = Path("/home/anukoolpurohit/Documents/AnukoolPurohit/Datasets/imdb")

In [None]:
proc_tok = TokenizerProcessor()
proc_num = NuemericalizeProcessor()
proc_cat = CategoryProcessor()

In [None]:
tl_imdb = TextList.from_files(path=path_imdb, folders=['train','test'])
sd_imdb = tl_imdb.split_by_func(partial(grandparent_splitter, valid_name='test'))
ll_imdb = sd_imdb.label_by_func(parent_labeler, proc_x=[proc_tok, proc_num], proc_y=proc_cat)

In [None]:
imdb_data = ll_imdb.clas_databunchify(64)

In [None]:
x1,y1 = next(iter(imdb_data.train_dl))

In [None]:
from sentimentanalyser.utils.callbacks import sched_cos, combine_scheds
from sentimentanalyser.callbacks.training import LR_Find, CudaCallback, GradientClipping
from sentimentanalyser.callbacks.progress import ProgressCallback
from sentimentanalyser.callbacks.scheduler import ParamScheduler
from sentimentanalyser.callbacks.stats import AvgStatsCallback
from sentimentanalyser.callbacks.recorder import Recorder
from sentimentanalyser.training.trainer import Trainer

In [None]:
from sentimentanalyser.utils.dev import print_dims
from sentimentanalyser.utils.training import get_embedding_vectors
from sentimentanalyser.utils.metrics import accuracy
from tqdm.auto import tqdm

In [None]:
from sentimentanalyser.models.rnn import AttnAWDModel
from sentimentanalyser.models.regularization import WeightDropout
from sentimentanalyser.data.core import ListContainer

In [None]:
import torch
import torchtext
from torch import nn
from torchtext import vocab
import matplotlib.pyplot as plt

In [None]:
path_cache = Path('/home/anukoolpurohit/Documents/AnukoolPurohit/Models/WordEmbeddings')

In [None]:
glove_eng = vocab.GloVe(cache=path_cache)

In [None]:
local_vocab = proc_num.vocab

In [None]:
def get_basic(Model, num_layers=2):
    model = Model(proc_num.vocab, glove_eng,num_layers=num_layers)
    loss_func = nn.CrossEntropyLoss()
    opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
    return model, loss_func, opt

In [None]:
sched = combine_scheds([0.3, 0.7], [sched_cos(1e-4, 1e-3), sched_cos(1e-3, 3e-5)])

In [None]:
cbfs = [partial(AvgStatsCallback, [accuracy]),
        partial(ParamScheduler,'lr', [sched]),
        partial(GradientClipping, clip=0.1),
        ProgressCallback,
        CudaCallback,
        Recorder
       ]

In [None]:
model, loss_func, opt = get_basic(AttnAWDModel)

In [None]:
class Hook():
    def __init__(self, m, f):
        self.hook = m.register_forward_hook(partial(f, self))
    
    def remove(self):
        self.hook.remove()
    
    def __del__(self):
        self.remove()

In [None]:
def append_stats(hook, mod, inp, outp):
    if not hasattr(hook, 'stats'):
        hook.stats = ([],[])
    if not hasattr(hook, 'layer_name'):
        hook.layer_name = mod.__class__.__name__
    means, stds = hook.stats
    if mod.training:
        if isinstance(mod, WeightDropout):
            means.append(outp[0].data.mean().item())
            stds.append(outp[0].data.std().item())
        else:
            means.append(outp.data.mean().item())
            stds.append(outp.data.std().item())

In [None]:
class Hooks(ListContainer):
    def __init__(self, ms, f):
        super().__init__([Hook(m, f) for m in ms])
    
    def __enter__(self, *args):
        return self
    
    def __exit__(self, *args):
        self.remove()
        return
    
    def __del__(self):
        self.remove()
        return
    
    def __delitem__(self, i):
        self[i].remove()
        super().__delitem__(i)
        return
    
    def remove(self):
        for h in self:
            h.remove()

In [None]:
trainer = Trainer(imdb_data, model, loss_func, opt,
                  cb_funcs=cbfs)

In [None]:
with Hooks(list(model.children())[:4], append_stats) as hooks:
    trainer.fit(2)
    names = []
    fig, (ax0, ax1) = plt.subplots(1,2, figsize=(10,4))
    for h in hooks:
        names.append(h.layer_name)
        ms, ss = h.stats
        ax0.plot(ms[:10])
        #ax0.title('mean first 10')
        ax1.plot(ss[:10])
        #ax1.title('std first 10')
    plt.legend(names);
    
    fig, (ax0, ax1) = plt.subplots(1,2, figsize=(10,4))
    for h in hooks:
        ms, ss = h.stats
        ax0.plot(ms)
        #ax0.title('mean')
        ax1.plot(ss)
        #ax1.title('std')
    plt.legend(names);