In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from functools import partial
from sentimentanalyser.data.text import TextList, ItemList, DataBunch, SplitData
from sentimentanalyser.utils.data import Path, listify, random_splitter, compose, parallel, pad_collate, parent_labeler, read_wiki, grandparent_splitter
from sentimentanalyser.data.samplers import SortishSampler, SortSampler
from sentimentanalyser.utils.preprocessing import *
from sentimentanalyser.utils.files import pickle_dump, pickle_load
from sentimentanalyser.preprocessing.processor import TokenizerProcessor, NuemericalizeProcessor, CategoryProcessor



In [3]:
path_imdb = Path("/home/anukoolpurohit/Documents/AnukoolPurohit/Datasets/imdb")

In [4]:
proc_tok = TokenizerProcessor()
proc_num = NuemericalizeProcessor()
proc_cat = CategoryProcessor()

In [6]:
il_imdb = TextList.from_files(path_imdb, folders=['train', 'test'])
sd_imdb = il_imdb.split_by_func(partial(grandparent_splitter, valid_name='test'))
ll_imdb = sd_imdb.label_by_func(parent_labeler, proc_x = [proc_tok, proc_num], proc_y=proc_cat)

HBox(children=(IntProgress(value=0, max=13), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [7]:
pickle_dump(ll_imdb, 'dumps/variable/ll_imdb.pickle')

In [8]:
ll_imdb = pickle_load('dumps/variable/ll_imdb.pickle')

In [9]:
imdb_data = ll_imdb.clas_databunchify(64)

In [130]:
import torch
from torch import nn
import torch.nn.functional as F
import time
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [12]:
def accuracy(preds, y):
    preds   = torch.argmax(preds, dim=1)
    # y       = torch.argmax(y, dim =1)
    correct = (preds == y).float()
    acc     = correct.sum() / len(correct)
    return acc

In [13]:
loss_func = nn.CrossEntropyLoss()

# Training

## Callbacks

In [121]:
from sentimentanalyser.callbacks.callback import Callback

### TrainEvalCallback

In [120]:
class TrainEvalCallback(Callback):
    def begin_epoch(self):
        self.model.train()
        self.trainer.in_train = True
        return
    def begin_validate(self):
        self.model.eval()
        self.trainer.in_train = False

### LoggerCallback

In [131]:
class Logger(Callback):
    def begin_fit(self):
        self.train_losses = []
        self.valid_losses = []
        self.lrs          = [[] for _ in self.opt.param_groups()]
        return
    
    def after_batch(self):
        if not self.in_train:
            self.valid_losses.append(self.loss.detach().cpu())
            return
        
        self.train_losses.append(self.loss.detach().cpu())
        for pg, lr in zip(self.opt.param_groups, self.lrs):
            lr.append(pg['lr'])
        return
    
    def plot(self, skip_last=0, pgid=-1):
        losses = [loss.item() for loss in self.train_losses]
        lrs    = self.lrs[pgid]
        n      = len(losses) - skip_last
        
        plt.xscale('log')
        plt.plot(lrs[:n], losses[:n])

## Exceptions

In [111]:
class CancelTrainException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass

### Utils funcs

## Trainer Class

In [122]:
class Trainer():
    def __init__(self, data, model, loss_func, opt, cbs):
        self.data, self.model    = data, model
        self.loss_func, self.opt = loss_func, opt
        self.cbs = cbs
    
    def add_cbs(self, cbs):
        [self.add_cb(cb) for cb in cbs]
    
    def one_batch(self, xb, yb):
        try:
            self.xb, self.yb = xb.cuda(), yb.cuda()

            self('begin_batch')
            self.preds  = self.model(self.xb)
            self('after_pred')
            self.loss   = self.loss_func(self.preds, self.yb)
            self('after_loss')
            self.losses.append(self.loss.detach().item())
            self.acc   += accuracy(self.preds, self.yb).item() 

            if not self.in_train: return

            self.loss.backward()
            self('after_backward')
            self.opt.step()
            self('after_step')
            self.opt.zero_grad()
        except CancelBatchException:
            self('after_cancel_batch')
        finally:
            self('after_batch')
        return
    
    def all_batches(self, dl):
        self.iters = len(dl)
        try:
            for xb, yb in tqdm(dl, leave=False):
                self.one_batch(xb, yb)
        except CancelEpochException:
            self('after_cancel_epoch')
        return 
    
    def fit(self, epochs=3):
        try:
            for cb in self.cbs:
                cb.set_trainer(self)
            self('begin_fit')
            self.model = self.model.cuda()
            for epoch in tqdm(range(epochs)):
                self.start_time = time.time()
                self.epoch = epoch

                # Training Loop
                if not self('begin_epoch'):
                    self.train_loss, self.train_acc = self.all_batches(self.data.train_dl)

                # Validation Loop
                with torch.no_grad():
                    if not self('begin_validate'):
                        self.valid_loss, self.valid_acc = self.all_batches(self.data.valid_dl)

                elapsed_time = time.strftime('%M:%S', time.gmtime(time.time() - self.start_time))
                print(f'Epoch: {self.epoch} | Time [min:sec]: {elapsed_time}')
                print(f'Train Loss: {sum(self.train_loss)/len(self.train_loss):.4f} | Train acc: {self.train_acc/len(self.train_loss)*100:.2f}%')
                print(f'Valid Loss: {sum(self.valid_loss)/len(self.valid_loss):.4f} | Valid acc: {self.valid_acc/len(self.valid_loss)*100:.2f}%')
        except CancelTrainException:
            self('after_cancel_train')
        finally:
            self('after_fit')
    def __call__(self, cb_name):
        res = True
        for cb in sorted(self.cbs, key=lambda x: x._order):
            res = cb(cb_name) and res
        return res

# Model

In [123]:
def get_lengths(x):
    return x.size(1) - (x == 1).sum(1)

In [124]:
class Model0(nn.Module):
    def __init__(self, vocab_size=proc_num.vocab_size, num_layers=2,
                 hidden_size=50, output_size=2, bidirectional=True,
                 padding_idx=1, bs=64):
        super().__init__()
        
        self.vocab_size, self.hidden_size, self.output_size   = vocab_size, hidden_size, output_size
        self.num_layers, self.bidirectional, self.padding_idx = num_layers, bidirectional, padding_idx
        self.bs = bs
        
        self.bidir          = 2 if bidirectional is True else 1
        
        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size, padding_idx=self.padding_idx)
        
        self.dropout   = nn.Dropout()
        
        self.rnn       = nn.LSTM(self.hidden_size, self.hidden_size,
                                 num_layers=self.num_layers,
                                 batch_first=True,
                                 bidirectional=self.bidirectional)
        
        
        self.bn        = nn.BatchNorm1d(self.hidden_size * self.bidir * self.num_layers)
        self.fc        = nn.Linear(self.hidden_size * self.bidir * self.num_layers, self.output_size)
        
        self.softmax   = nn.Softmax(dim=1)
        return
    
    def forward(self, texts):
        text_lengths = get_lengths(texts)
        embeded = self.dropout(self.embedding(texts))
        packed_embed = nn.utils.rnn.pack_padded_sequence(embeded, text_lengths, batch_first=True)
        
        packed_output, (hidden, cell) = self.rnn(packed_embed)
        
        hidden = self.dropout(torch.cat([h for h in hidden], dim=1))
        linear = self.fc(hidden)
        return linear

In [125]:
model = Model0(num_layers=4)

In [126]:
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

In [127]:
trainer = Trainer(imdb_data, model, loss_func, opt, cbs=[TrainEvalCallback()])

In [128]:
trainer.fit(epochs=3)

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

AttributeError: 'Trainer' object has no attribute 'losses'