In [1]:
from torchtext.datasets import IMDB
from torchtext.vocab import FastText
from torchtext import data, datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from IPython.core.debugger import set_trace
from tqdm import tqdm_notebook as tqdm

In [2]:
TEXT = data.Field(lower=True,include_lengths=False,batch_first=False)
LABEL = data.Field(sequential=False,)
train, test = datasets.IMDB.splits(TEXT,LABEL)
TEXT.build_vocab(train, vectors=FastText(language = 'en'), max_size = 60000)
LABEL.build_vocab(train)

In [3]:
batch_size = 32

In [4]:
train_iter, test_iter = data.BucketIterator.splits((train, test), device='cuda', batch_size=batch_size,shuffle=True)

In [13]:
x = next(iter(train_iter))
x.text.shape

torch.Size([1364, 32])

In [5]:
weight_matrix = TEXT.vocab.vectors
n_tokens = weight_matrix.size(0)

In [26]:
class ClassifierModel(nn.Module):
    
    def __init__(self, ntoken, ninp,
                 nhid, nlayers, bsz, noutputs,
                 dropout=0.5):
        super(ClassifierModel, self).__init__()
        self.nhid, self.nlayers, self.bsz = nhid, nlayers, bsz
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.output = nn.Linear(nhid,noutputs)

        self.init_weights()
        
    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.output.bias.data.fill_(0)
        self.output.weight.data.uniform_(-initrange, initrange)
 
    def forward(self, input):
        bsz = input.size()[1]
        if bsz != self.bsz:
            self.bsz = bsz
        emb = self.drop(self.encoder(input))
        self.hidden = self.init_hidden(self.bsz)
        output, self.hidden = self.rnn(emb, self.hidden)
        output = self.drop(output)[-1]
        output = self.output(output)      
        return output
 
    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        return (torch.tensor(weight.new(self.nlayers, bsz, self.nhid).zero_().cuda()),
                torch.tensor(weight.new(self.nlayers, bsz, self.nhid).zero_()).cuda())


weight_matrix = TEXT.vocab.vectors
model = ClassifierModel(weight_matrix.size(0),
weight_matrix.size(1), 200, 3,batch_size,2)
model = model.cuda()

In [27]:
model_state = model.state_dict()
pretrained_state = torch.load('lstm_8.pt')
pretrained_state = { k:v for k,v in pretrained_state.items() if k in model_state and v.size() == model_state[k].size() }
model_state.update(pretrained_state)
model.load_state_dict(model_state)

In [28]:
optimizer = optim.Adam(model.parameters(),lr=1e-3)
loss_function = nn.CrossEntropyLoss()
loss_function = loss_function.cuda()
def fit(epoch,model,data_loader,phase='training'):
    if phase == 'training':
        model.train()
    if phase == 'validation':
        model.eval()
    running_loss = 0.0
    running_correct = 0
    for batch_idx , batch in enumerate(data_loader):
        text , target = batch.text , batch.label-1
        
        if phase == 'training':
            optimizer.zero_grad()
        output = model(text)
        loss = loss_function(output,target)
        
        running_loss += loss.detach()
        preds = output.data.max(dim=1,keepdim=True)[1]
        running_correct += preds.eq(target.data.view_as(preds)).cpu().sum()
        if phase == 'training':
            
            loss.backward()
            optimizer.step()
    loss = running_loss/len(data_loader.dataset)
    accuracy = 100. * running_correct/len(data_loader.dataset)    
    print(f'{phase} loss is {loss:{5}.{2}} and {phase} accuracy is {running_correct}/{len(data_loader.dataset)}{accuracy:{10}.{4}}')
    return loss,accuracy
train_losses , train_accuracy = [],[]
val_losses , val_accuracy = [],[]

In [None]:
for epoch in tqdm(range(5)):

    epoch_loss, epoch_accuracy = fit(epoch,model,train_iter,phase='training')
    val_epoch_loss , val_epoch_accuracy = fit(epoch,model,test_iter,phase='validation')
    train_losses.append(epoch_loss)
    train_accuracy.append(epoch_accuracy)
    val_losses.append(val_epoch_loss)
    val_accuracy.append(val_epoch_accuracy)

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))