# NLP - text classification on top of language model for the IMDb dataset

**Language model: Given a few words of a sentence, build a model that allows you to predict what the next word is going to be.**

Task: predict whether a review on IMDb is positive or negative.

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [30]:
from fastai.learner import *

# Pytorch's NLP library:
import torchtext
from torchtext.datasets import language_modeling
from torchtext import vocab, data

from fastai.rnn_reg import *
from fastai.rnn_train import *
from fastai.nlp import *
from fastai.lm_rnn import *

import dill as pickle # Imported dill as pickle because pickle can't handle `TEXT.vocab` correctly!

## Plan: pretrain a language model to predict the next word in a sequence of words and then fine tune the language model to classify sentiment

## Dataset: IMDb

In [12]:
PATH = 'data/aclImdb/'

TRN_PATH = 'train/all/'
VAL_PATH = 'test/all/'
TRN = f'{PATH}{TRN_PATH}'
VAL = f'{PATH}{VAL_PATH}'

%ls {PATH}

README  imdb.vocab  imdbEr.txt  [0m[01;34mtest[0m/  [01;34mtrain[0m/


Training folder

In [13]:
trn_files = !ls {TRN}
trn_files[:5]

['0_0.txt', '0_3.txt', '0_9.txt', '10000_0.txt', '10000_4.txt']

Let's look at an example

In [27]:
review = !cat {TRN}{trn_files[2]}
review

['Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as "Teachers". My 35 years in the teaching profession lead me to believe that Bromwell High\'s satire is much closer to reality than is "Teachers". The scramble to survive financially, the insightful students who can see right through their pathetic teachers\' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line: INSPECTOR: I\'m here to sack one of your teachers. STUDENT: Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn\'t!']

How mandy words are in the dataset?

In [16]:
!find {TRN} -name '*.txt' | xargs cat | wc -w

17486270


In [17]:
!find {VAL} -name '*.txt' | xargs cat | wc -w

5686609


## Tokenization

In [21]:
spacy_tok = spacy.load('en')

In [28]:
' '.join([word.string.strip() for word in spacy_tok(review[0])])

'Bromwell High is a cartoon comedy . It ran at the same time as some other programs about school life , such as " Teachers " . My 35 years in the teaching profession lead me to believe that Bromwell High \'s satire is much closer to reality than is " Teachers " . The scramble to survive financially , the insightful students who can see right through their pathetic teachers \' pomp , the pettiness of the whole situation , all remind me of the schools I knew and their students . When I saw the episode in which a student repeatedly tried to burn down the school , I immediately recalled ......... at .......... High . A classic line : INSPECTOR : I \'m here to sack one of your teachers . STUDENT : Welcome to Bromwell High . I expect that many adults of my age think that Bromwell High is far fetched . What a pity that it is n\'t !'

Notice for example how 'isn\'t' became 'is n\'t' during tokenization!

## Torchtext *field* definition describes how to preprocess a chunk of text
1. convert to lowercase
1. tokenize using spacy

In [37]:
TEXT = data.Field(lower=True, tokenize='spacy')  # Part of torchtext

1. **bs** batch size, how many sequences of words are considered at the same time? Entire dataset will be split into *bs* batches.
1. **bptt** (*backprop through time*) How long is each sequence of words in a minibatch? Number of layers the model will 'backprop' through.
1. **min_freq** When converting the words to integers, call every word *unknown* that appears less than *min_freq* times.

In [33]:
bs = 64
bptt = 70

In [34]:
FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)

In [38]:
md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)

The `ModelData` object fills the `TEXT` object with the crucial attribute `Text.vocab`. Stores which words/tokens have been found in text and the respective mapping to a unique integer value.

In [40]:
os.makedirs(f'{PATH}models', exist_ok=True)

In [41]:
pickle.dump(TEXT, open(f'{PATH}models/TEXT.pk1', 'wb'))

In [70]:
# number of batches, number of unique tokens, number of words
len(md.trn_dl), md.nt, len(md.trn_ds[0].text)

(4583, 37392, 20540756)

Let's take a look at the mapping from tokens to ints

In [53]:
# int-to-string, sorted according to frequency apart from first two
TEXT.vocab.itos[:10]

['<unk>', '<pad>', 'the', ',', '.', 'and', 'a', 'of', 'to', 'is']

In [54]:
TEXT.vocab.stoi['<unk>']

0

In [59]:
md.trn_ds[0].text[:10]

['"', 'planet', 'of', 'the', 'apes', '"', 'an', 'awesome', ',', 'classic']

In [61]:
TEXT.numericalize([md.trn_ds[0].text[:10]])

Variable containing:
   15
 1359
    7
    2
 5355
   15
   44
 1186
    3
  387
[torch.cuda.LongTensor of size 10x1 (GPU 0)]

In [62]:
TEXT.vocab.itos[1359]

'planet'

### Let's take a look at the *data loader* of the *LanguageModelData* object
Creates batches with bs=64 columns and bptt~80 rows. Note, that the second tensor the following command returns is shifted by 1. Reason: we aim to predict the next word. Each column contains many sentences

In [69]:
next(iter(md.trn_dl))

(Variable containing:
     15      3    229  ...      74     13    205
   1359     24      9  ...    1866      9     62
      7     12      6  ...       4     75      5
         ...            ⋱           ...         
     18   1512   1279  ...       8    136   4779
   2494     10      3  ...     120     23   1413
     13      2    589  ...      66     12     42
 [torch.cuda.LongTensor of size 68x64 (GPU 0)], Variable containing:
   1359
     24
      9
   ⋮   
    941
     35
  15906
 [torch.cuda.LongTensor of size 4352 (GPU 0)])

## Embedding

Each token gets its own embedding vector giving us an embedding matrix of 37392 x 200

In [72]:
em_sz = 200  # size of each embedding vector. Rule of thumb: between 50 and 600
nh = 500     # number of hidden activations per layer
nl = 3       # number of layers

## Optimizer and regularization

In [73]:
opt_fn = partial(optim.Adam, betas=(0.7, 0.99))

Researchs have found that the default amount of momentum for Adam (0.9) is too high for these models.

Regularization through dropout according to [Merity et al. 2017](https://arxiv.org/pdf/1708.02182.pdf). Increase the dropout values when overfitting, decrease when underfitting...

In [74]:
learner = md.get_model(opt_fn=opt_fn, emb_sz=em_sz, n_hid=nh, n_layers=nl,
                      dropouti=0.05, dropout=0.05, wdrop=0.1, dropoute=0.02, dropouth=0.05)

Regularization:

In [75]:
learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)

Gradient clipping:

In [76]:
learner.clip=0.3

## Training

In [77]:
learner.fit(lrs=3e-3, n_cycle=4, cycle_len=1, cycle_mult=2)

HBox(children=(IntProgress(value=0, description='Epoch', max=15), HTML(value='')))

epoch      trn_loss   val_loss                                
    0      4.836045   4.717827  
    1      4.637386   4.502882                                
    2      4.523966   4.424905                                
    3      4.565808   4.427167                                
    4      4.46659    4.358337                                
    5      4.41535    4.309937                                
    6      4.352574   4.297426                                
    7      4.488351   4.365648                                
    8      4.450402   4.339508                                
    9      4.395329   4.309739                                
    10     4.371314   4.284916                                
    11     4.3584     4.261108                                
    12     4.319931   4.243292                                
    13     4.297283   4.234079                                
    14     4.265207   4.231984                                



[array([4.23198])]

Save the encoder of the encoder-decoder model

In [78]:
learner.save_encoder('enc1')

In [79]:
learner.load_encoder('enc1')

In [82]:
learner.fit(lrs=3e-3, n_cycle=1, wds=1e-6, cycle_len=10)
# wds weight decay parameter (L2 reg)

HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))

epoch      trn_loss   val_loss                                
    0      4.510499   4.377362  
    1      4.500967   4.384203                                
    2      4.48165    4.361681                                
    3      4.453777   4.339405                                
    4      4.422646   4.316233                                
    5      4.394123   4.293792                                
    6      4.366147   4.272203                                
    7      4.313716   4.256799                                
    8      4.31329    4.249368                                
    9      4.303805   4.247358                                



[array([4.24736])]

**At this point, we could train this further (we are not overfitting) if we actually planned to use the network. However, it takes quite some time..**

In [83]:
learner.save_encoder('enc2')

In [84]:
learner.load_encoder('enc2')

Accuracy of language models is often measured using the metric *perplexity* (`exp()` of the loss).

In [85]:
math.exp(4.24736)

69.9205781465884

In [86]:
pickle.dump(TEXT, open(f'{PATH}models/TEXT.pk1', 'wb'))

## Test the language model

In [87]:
m = learner.model

In [159]:
string = """. So, it wasn't quite was I was expecting, but I really liked it anyway! The best"""
string_tok = [TEXT.preprocess(string)]
' '.join(string_tok[0])

". so , it was n't quite was i was expecting , but i really liked it anyway ! the best"

In [160]:
string_num = TEXT.numericalize(string_tok)

In [161]:
# Set the batch size of the model to 1
m[0].bs = 1
# Turn off dropout for evaluation
m.eval()
# Reset the hidden state
m.reset()
# Get prediction from model
res, *_ = m(string_num)
# Reset the batchsize 
m[0].bs = bs

In [162]:
next_words = torch.topk(res[-1], 10)[1]
next_words
[TEXT.vocab.itos[o] for o in next_words.data.cpu().numpy()]

['part',
 'thing',
 'scene',
 'way',
 'parts',
 'of',
 'line',
 'aspect',
 'performance',
 'scenes']

**All of these could be the next word of our test sentence!!**

Let's let our model write some more...

In [163]:
print(string, "\n")
for i in range(50):
    n=res[-1].topk(2)[1]
    n = n[1] if n.data[0]==0 else n[0]
    print(TEXT.vocab.itos[n.data[0]], end = ' ')
    res, *_ = m(n[0].unsqueeze(0))  # Adds a dimension (for minibatch assumably)

. So, it wasn't quite was I was expecting, but I really liked it anyway! The best 

part of the movie . the movie is a bit of a mess , but it 's not a bad movie . it 's a very good movie , but it 's not a bad movie . it 's a good movie , but it 's not a good movie 

Ok, ahaha, the language model could use some more training probably but right now I don't want to invest the time..

## Sentiment analysis
### Fine tune the language model to do classification

Of course we need the same vocab

In [164]:
TEXT = pickle.load(open(f'{PATH}models/TEXT.pk1', 'rb'))

In [165]:
# sequential: tells torchtext to not tokenize as 
# we want to store only 'pos' or 'neg' 
IMDB_LABEL = data.Field(sequential=False)

Here, we use a dataset built into torchtext. Refer to the arxiv.ipynb to see how a torchtext dataset is defined..

In [167]:
splits = torchtext.datasets.IMDB.splits(text_field=TEXT, label_field=IMDB_LABEL, root='data/')

downloading aclImdb_v1.tar.gz


In [180]:
test = splits[0].examples[4]

In [181]:
test.label, ' '.join(test.text[:50])

('pos',
 'when you come across a gem of a movie like this , you realize why the \' 80s were the greatest decade to live thru . the rock music ruled , & so did movies ... especially horror movies . filmmakers knew how to entertain us , & " trick')

Create a fastai ModelData object from torchtext splits

In [182]:
md2 = TextData.from_splits(PATH, splits=splits, bs=bs)

In [183]:
m3 = md2.get_model(opt_fn, max_sl=1500, bptt=bptt, emb_sz=em_sz, n_hid=nh, n_layers=nl, dropout=0.1, dropouti=0.4, wdrop=0.5, dropoute=0.05, dropouth=0.3)

In [184]:
m3.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)

Load the pretrained *encoder*. We use differential learning rates as we are fine-tuning a model.

In [185]:
m3.load_encoder(f'enc2')

In [186]:
m3.clip = 25.
lrs = np.array([1e-4, 1e-4, 1e-4, 1e-3, 1e-2])

In [187]:
m3.freeze_to(-1)
m3.fit(lrs/2, 1, metrics=[accuracy])

HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.661577   0.487363   0.761779  



[array([0.48736]), 0.7617787532059408]

In [188]:
m3.unfreeze()

In [189]:
m3.fit(lrs, 1, metrics=[accuracy], cycle_len=1)

HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.464077   0.315161   0.876894  



[array([0.31516]), 0.8768944070587855]

In [191]:
m3.fit(lrs, 7, metrics=[accuracy], cycle_len=2, cycle_save_name='imdb')

HBox(children=(IntProgress(value=0, description='Epoch', max=14), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.425103   0.260042   0.897343  
    1      0.391833   0.274921   0.89697                     
    2      0.368948   0.286929   0.892561                    
    3      0.357182   0.261986   0.901059                    
    4      0.341577   0.265945   0.901818                    
    5      0.350868   0.253552   0.905013                    
    6      0.333761   0.255058   0.90734                     
    7      0.327202   0.255533   0.908164                    
    8      0.31469    0.287294   0.903437                    
    9      0.296798   0.27291    0.906724                    
    10     0.306673   0.27883    0.913195                    
    11     0.283881   0.284334   0.907709                    
    12     0.307085   0.252115   0.915737                    
    13     0.276472   0.271925   0.91426                     



[array([0.27192]), 0.9142602107326157]

### Accuracy

In [224]:
preds, targs = m3.predict_with_targs()

In [226]:
preds = np.argmax(preds, axis=1)

In [227]:
acc = (preds==targs).mean()

In [228]:
acc

0.91888