# Sentiment Analysis on IMDB Rating
Implemented using RNN

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext
import numpy as np
import matplotlib.pyplot as plt
import datasets
import re
from collections import Counter, OrderedDict

## Preprocessing

In [10]:
#Load test and train data
train_data, test_data = datasets.load_dataset('imdb', split=['train','test'])

#Split test data into train (20k) and validate (5k)
from torch.utils.data.dataset import random_split
torch.manual_seed(1)
train_data, valid_data = random_split(list(train_data),[20000,5000])

In [23]:
def tokenizer(text):
    # Remove HTML tags
    text = re.sub('<[^>]*>', '', text)
    # Extract emoticons
    emoticons = re.findall(r'(?::|;|=)(?:-)?(?:\)|\(|D|P)', text.lower())
    # Eliminate excessive whitespace and convert text to lowercase
    text = re.sub(r'[\W]+', ' ', text.lower())
    # Append emoticons at the end, removing the "nose" for standardization
    text = text + ' ' + ' '.join(emoticons).replace('-', '')
    #Split by white space
    tokenized = text.split()
    return tokenized

In [29]:
#How many unique tokens are in the text corpus?
token_counts = Counter()
for review in train_data:
    text = review['text']
    tokens = tokenizer(text)
    token_counts.update(tokens)
print('number of tokens', len(token_counts))

number of tokens 69006


In [47]:
#Map each token to a unique integer. In reverse frequency order. 0 and 1 placeholders
#Sort counter in reverse frequency order
sorted_dict = sorted(
    token_counts.items(), key=lambda x:x[1], reverse=True
)
ordered_dict = OrderedDict(sorted_dict)

#Word_encode contains word:encoding pairs
word_encode = {}
counter = 2
for word, freq in ordered_dict.items():
    word_encode[word] = counter
    counter += 1

#0 reserverd for padding. 1 reserved for unknown words
word_encode['<pad>'] = 0
word_encode['<unk>'] = 1

#Demonstrate encoding scheme works
def encode(text):
    encoding = []
    tokens = tokenizer(text)
    for token in tokens:
        encoding.append(word_encode.get(token,1))
    return encoding

#Testing
print(encode("Roses are red"))
print(encode("roSes ARE reD :)"))

[11558, 26, 736]
[11558, 26, 736, 2152]


In [81]:
def build_dataloader(batch):
    label_list, text_list, lengths = [], [], []
    for review in batch:
        text = review['text']
        label = review['label']
        label_list.append(label)
        processed_text = torch.tensor(encode(text), dtype=torch.int64)
        text_list.append(processed_text)
        lengths.append(processed_text.size(0))
    label_list = torch.tensor(label_list)
    lengths = torch.tensor(lengths)
    #Ensure all sequence in minibatch have same length to store efficiently as tensor
    padded_text_list = nn.utils.rnn.pad_sequence(text_list, batch_first=True)
    return padded_text_list, label_list, lengths

In [82]:
#Load a small sample with batchsize of 4
from torch.utils.data import DataLoader
dataloader = DataLoader(train_data,batch_size=4,shuffle=False, collate_fn=build_dataloader)
text_batch, label_batch, length_batch = next(iter(dataloader))

In [83]:
text_batch

tensor([[   35,  1739,     7,   449,   721,     6,   301,     4,   787,     9,
             4,    18,    44,     2,  1705,  2460,   186,    25,     7,    24,
           100,  1874,  1739,    25,     7, 34414,  3568,  1103,  7517,   787,
             5,     2,  4991, 12401,    36,     7,   148,   111,   939,     6,
         11598,     2,   172,   135,    62,    25,  3199,  1602,     3,   928,
          1500,     9,     6,  4601,     2,   155,    36,    14,   274,     4,
         42944,     9,  4991,     3,    14, 10296,    34,  3568,     8,    51,
           148,    30,     2,    58,    16,    11,  1893,   125,     6,   420,
          1214,    27, 14542,   940,    11,     7,    29,   951,    18,    17,
         15994,   459,    34,  2480, 15211,  3713,     2,   840,  3200,     9,
          3568,    13,   107,     9,   175,    94,    25,    51, 10297,  1796,
            27,   712,    16,     2,   220,    17,     4,    54,   722,   238,
           395,     2,   787,    32,    27,  5236,  