## Sentiment analysis on IMDB dataset

http://nlpprogress.com/english/sentiment_analysis.html

Top: [XLNet (Yang et al., 2019)](https://arxiv.org/pdf/1906.08237.pdf), accuracy: 96.21

But can we get near that with a "Simple but tough to beat .." encoder?

In [1]:
import itertools
import os

from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report

from tqdm.notebook import tqdm

from encoder import build_from_fasttext_bin
from nn import train_w2v, train_nn, fasttext, load_model
from utils import read_imdb, preprocess_sentence

  return f(*args, **kwds)


Download [IMDB dataset](https://ai.stanford.edu/~ang/papers/acl11-WordVectorsSentimentAnalysis.pdf)

In [2]:
X_train, y_train = read_imdb(subset='train')
X_test, y_test = read_imdb(subset='test')

label_encoder = LabelEncoder()
y_train = label_encoder.fit_transform(y_train)
y_test = label_encoder.transform(y_test)

Build a corpus for word2vec training and pre-process with `textacy` lib:
  - normalize unicode charset.
  - deaccent (rèsume -> resume)
  - unpack contractions (he's --> he is).
  - remove emojis, hashtags, URLs, emails, etc
  - remove punctuation marks
  - strip whitespace
  - lowercase

train word2vec skipgram model as follows;
  - dim = 200
  - lr = relatively low.
  - epochs = 15 (but should probably be ~ 25).
  - ws = 5 (but should probably be ~ 7).
  - sub-word information (minn = 3, maxn = 6).
  
alternatively, we can use a [pre-built model](https://fasttext.cc/docs/en/pretrained-vectors.html).

In [3]:
W2V_PREBUILT_MODEL = 'cc.en.300.bin'
W2V_MODEL = 'model.bin' # W2V_PREBUILT_MODEL

if W2V_MODEL == W2V_PREBUILT_MODEL:
    ! [[ ! -f {W2V_PREBUILT_MODEL} ]] && wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/{W2V_PREBUILT_MODEL}.gz
    ! [[ ! -f {W2V_PREBUILT_MODEL} ]] && gzip -d {W2V_PREBUILT_MODEL}.gz
    ! ls -lh {W2V_PREBUILT_MODEL}

if not os.path.isfile(W2V_MODEL):
    # build w2v corpus
    corpus = []
    raw_sentences, _ = read_imdb(subset=None, with_label=False)
    for raw_sentence in tqdm(raw_sentences):
        sent = preprocess_sentence(raw_sentence)
        corpus.append(sent)

    # train word2vec
    model = train_w2v(corpus,
                      model='skipgram',
                      dim=200,
                      min_count=20,
                      lr=0.015,
                      epoch=20,
                      ws=7,
                      minn=3,
                      maxn=6)
    # save model
    model.save_model(W2V_MODEL)

else: # load prebuilt model
    model = fasttext.load_model(W2V_MODEL)




word2vec ---> "Simple But Tough to Beat .." encoder

In [4]:
sentence_encoder = build_from_fasttext_bin(model, preprocessor=preprocess_sentence, weighted=True)

del model # free some memory !

In [6]:
X_train = sentence_encoder.fit_transform(X_train)
print('X_train.shape = ', X_train.shape)

X_train.shape =  (25000, 200)


In [7]:
X_test = sentence_encoder.transform(X_test)
print('X_test.shape = ', X_test.shape)

X_test.shape =  (25000, 200)


Now we can train a binary classification net:
  - 1 hidden layer (128).
  - dropout ~ [0.2 - 0.5].
  - binary logloss.

In [8]:
MODEL_PT = 'model.h5'

model = train_nn(
    X_train,
    y_train,
    hidden_layers=(128,),
    activation='relu',
    dropout=0.4,
    epochs=20,
    batch_size=32,
    validation_split=None,
    validation_data=(X_test, y_test),
    patience=4,
    shuffle=True,
    optimizer='adam',
    pt=MODEL_PT,
)

Train on 25000 samples, validate on 25000 samples
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20


89% accuracy with a pretty simple encoder ! that's nice !

In [9]:
model = load_model(MODEL_PT)
preds = model.predict_classes(X_test, batch_size=32)
preds = preds.reshape(preds.shape[0])

report = classification_report(y_test, preds, target_names=label_encoder.classes_)
print(report)

              precision    recall  f1-score   support

         neg       0.89      0.90      0.89     12500
         pos       0.89      0.89      0.89     12500

    accuracy                           0.89     25000
   macro avg       0.89      0.89      0.89     25000
weighted avg       0.89      0.89      0.89     25000

