# Review Classification

## Batching and Model Training

### Import Libraries

In [1]:
import pandas as pd # Loading data
import numpy as np
import warnings
from sklearn.model_selection import train_test_split # train test splits

warnings.filterwarnings('ignore')

### Data Loading and Processing

We will first do all the necessary pre-processing before starting to create batches and training the model. All the steps are explained in the notebook named `Text Cleaning.ipynb`

In [2]:
# Read dataset
data = pd.read_csv("Reviews.csv")
# Drop unnecesary columns and duplicates
new_data = data.drop_duplicates(subset=['UserId', 'ProfileName', 'Time', 'Text'])
# Get useful columns
useful_data = new_data[['Text', 'Score']]
# Calculate length of each sentence without tokenizer
useful_data['sudo_length'] = useful_data.Text.str.split().str.len()
# Filter examples by length
useful_data = useful_data[(useful_data.sudo_length > 20) & (useful_data.sudo_length < 100)]
# Remove length column
useful_data = useful_data.drop(['sudo_length'], axis = 1)
# print 5 rows
useful_data.head()

Unnamed: 0,Text,Score
0,I have bought several of the Vitality canned d...,5
1,Product arrived labeled as Jumbo Salted Peanut...,1
2,This is a confection that has been around a fe...,4
3,If you are looking for the secret ingredient i...,2
4,Great taffy at a great price. There was a wid...,5


#### Create Train and Test sets

In [3]:
train, test = train_test_split(useful_data, test_size = 0.2)
train.to_csv("./train_test_data/train.csv", index=False)
test.to_csv("./train_test_data/test.csv", index=False)

In [4]:
import torchtext
from torchtext.data import TabularDataset, Field, BucketIterator
import spacy

In [5]:
tok = spacy.load('en_core_web_sm')

In [6]:
def tokenize_en(sent):
    return [item.text for item in tok.tokenizer(sent)]

In [7]:
sent = "hello their, why don't u have a seat?"
tokenize_en(sent)

['hello', 'their', ',', 'why', 'do', "n't", 'u', 'have', 'a', 'seat', '?']

In [8]:
SENT_FIELD = Field(sequential=True, tokenize=tokenize_en)
LABEL_FIELD = Field(sequential=False, use_vocab=False, pad_token=None, unk_token=None)

data_fields = [
    ('Text', SENT_FIELD),
    ('Score', LABEL_FIELD)
]

In [9]:
train, val = TabularDataset.splits(
    path='./train_test_data',
    train='train.csv',
    validation = 'test.csv',
    format='csv',
    skip_header=True,
    fields=data_fields
)

In [10]:
SENT_FIELD.build_vocab(train)

In [11]:
BATCH_SIZE = 32
dev = 'cuda'

train_iter, val_iter = BucketIterator.splits(
    (train, val), 
    batch_sizes=(BATCH_SIZE, BATCH_SIZE), 
    sort_key=lambda x: len(x.Text), 
    shuffle=True, 
    sort_within_batch=True,
    repeat=False,
    device = dev
)

In [12]:
import torch
import torch.nn as nn
import torch.optim as opt
from sklearn.metrics import confusion_matrix

In [13]:
class ClassificationMetrics:
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.classes = list(range(num_classes))
        self.epsilon = 1e-12
        self.cmatrix = np.zeros((num_classes, num_classes), dtype = np.int64) + self.epsilon
        
        self.total_correct = 0
        self.total_examples = 0
        
    def update(self, pred, truth):
        pred = pred.cpu()
        truth = truth.cpu()
        
        _, idx = pred.topk(1)
        truth = truth.view(-1, 1)
        
        self.total_examples += len(truth)
        self.total_correct += sum(idx == truth).item()
        
        val = confusion_matrix(truth, idx, labels=self.classes)
        
        self.cmatrix = self.cmatrix + val
        
        
    def precision_score(self):
        scores = {}
        for i in range(self.num_classes):
            scores[i] = self.cmatrix[i, i] / (sum(self.cmatrix[:, i]) + self.epsilon)
        
        return scores
    
    def recall_score(self):
        scores = {}
        for i in range(self.num_classes):
            scores[i] = self.cmatrix[i, i] / (sum(self.cmatrix[i, :]) + self.epsilon)
        
        return scores
    
    def scores(self, return_type = 'f1'):
        pscores = self.precision_score()
        rscores = self.recall_score()
        scores = {}
        for i in range(self.num_classes):
            if(pscores[i] == 0 and rscores[i] == 0):
                scores[i] = 0
            else:
                scores[i] = 2 * ((pscores[i] * rscores[i]) / (pscores[i] + rscores[i])  + self.epsilon)
            
        if return_type == 'f1':
            return scores
        elif return_type == 'all':
            all_scores = list(zip(pscores.values(), rscores.values(), scores.values()))
            t = {}
            for i in range(self.num_classes):
                t[i] = all_scores[i]
                
            return t
        else:
            raise Exception("Invalid argument for return type")
            
    def accuracy_score(self):
        return self.total_correct / self.total_examples
    
    def reset(self):
        self.total_correct = 0
        self.total_examples = 0
        self.cmatrix = np.zeros((self.num_classes, self.num_classes))
            
    def print_report(self):
        all_scores = self.scores('all')
        print("{:^15}\t{:^15}\t{:^15}\t{:^15}".format("Class", "Precision", "Recall", "F1-score"))
        for c, values in all_scores.items():
            print("{:^15}\t{:^15.3f}\t{:^15.3f}\t{:^15.3f}".format(c, values[0], values[1], values[2]))
            
        print("Accuracy : {:.5f} %".format(self.accuracy_score()))

In [14]:
class BiDirectionalLstm(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_classes):
        super(BiDirectionalLstm, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.hidden_size = hidden_size
        self.cell = nn.LSTM(embedding_dim, hidden_size, bidirectional = True)
        self.linear = nn.Linear(hidden_size * 2, num_classes)
        self.soft = nn.Softmax(dim=1)
        
    def forward(self, x, hstate = None):
        if hstate is None:
            hstate = self.init_hidden(self.hidden_size, x.shape[-1])
            
        cell_out, _ = self.cell(self.embedding(x), hstate)
        
        temp = torch.cat([cell_out[-1, :, :self.hidden_size], cell_out[0, :, self.hidden_size:]], axis = -1)
        
        out = self.linear(temp)
        
        return self.soft(out)
            
    def init_hidden(self, hidden_size, bs):
        return (torch.zeros(2, bs, hidden_size, device=dev), torch.zeros(2, bs, hidden_size, device=dev))
    
    def load_embeddings(self, embeddings):
        self.embedding.weight.data.copy_(embeddings)

In [15]:
VOCAB_SIZE = len(SENT_FIELD.vocab)
EMBEDDING_DIM = 300
HIDDEN_SIZE = 128
NUM_CLASSES = 5

In [16]:
import pandas as pd
from sklearn.utils import class_weight

train = pd.read_csv("./train_test_data/train.csv")
weight_array = class_weight.compute_class_weight('balanced', sorted(train.Score.unique()), train.Score)
del(train)

In [17]:
net = BiDirectionalLstm(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_SIZE, NUM_CLASSES)
#net.load_embeddings(SENT_FIELD.vocab.vectors)
net = net.cuda()
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(weight_array).to(dev))

optimizer = opt.Adam(net.parameters(), lr = 0.001)

In [18]:
train_metrics = ClassificationMetrics(NUM_CLASSES)
val_metrics = ClassificationMetrics(NUM_CLASSES)

In [19]:
from tqdm import tqdm 
N_EPOCH = 10

for epoch in range(N_EPOCH):
    train_metrics.reset()
    losses = []
    net.train()
    for batch in tqdm(train_iter):
        optimizer.zero_grad()
        labels = batch.Score - 1
        pred = net(batch.Text)

        loss = criterion(pred, labels)
        loss.backward()
        losses.append(loss.item())
        optimizer.step()

        train_metrics.update(pred, labels)
    
    print("Training Run\nEpoch : {} Loss : {:.5f}".format(epoch + 1, sum(losses) / len(losses)))
    train_metrics.print_report()
    
    val_metrics.reset()
    val_losses = []
    net.eval()
    for batch in tqdm(val_iter):
        labels = batch.Score - 1
        pred = net(batch.Text)
        loss = criterion(pred, labels)
        val_losses.append(loss.item())
        val_metrics.update(pred, labels)
        
    print("Validation Run\nEpoch : {} Loss : {:.5f}".format(epoch + 1, sum(val_losses) / len(val_losses)))
    val_metrics.print_report()

100%|██████████████████████████████████████████████████████████████████████████████| 7011/7011 [07:02<00:00, 16.59it/s]
  0%|▎                                                                                | 6/1753 [00:00<00:30, 57.85it/s]

Training Run
Epoch : 1 Loss : 1.42975
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.405     	     0.600     	     0.483     
       1       	     0.193     	     0.249     	     0.217     
       2       	     0.221     	     0.365     	     0.275     
       3       	     0.247     	     0.343     	     0.287     
       4       	     0.880     	     0.675     	     0.764     
Accuracy : 0.58296 %


100%|██████████████████████████████████████████████████████████████████████████████| 1753/1753 [00:17<00:00, 97.46it/s]
  0%|                                                                                         | 0/7011 [00:00<?, ?it/s]

Validation Run
Epoch : 1 Loss : 1.39277
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.561     	     0.599     	     0.580     
       1       	     0.259     	     0.161     	     0.198     
       2       	     0.276     	     0.531     	     0.363     
       3       	     0.250     	     0.457     	     0.324     
       4       	     0.900     	     0.687     	     0.780     
Accuracy : 0.61243 %


100%|██████████████████████████████████████████████████████████████████████████████| 7011/7011 [06:40<00:00, 17.52it/s]
  1%|▍                                                                                | 9/1753 [00:00<00:20, 85.94it/s]

Training Run
Epoch : 2 Loss : 1.34704
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.555     	     0.682     	     0.612     
       1       	     0.281     	     0.394     	     0.328     
       2       	     0.327     	     0.444     	     0.377     
       3       	     0.300     	     0.457     	     0.362     
       4       	     0.910     	     0.727     	     0.808     
Accuracy : 0.65195 %


100%|█████████████████████████████████████████████████████████████████████████████| 1753/1753 [00:15<00:00, 115.71it/s]
  0%|                                                                                         | 0/7011 [00:00<?, ?it/s]

Validation Run
Epoch : 2 Loss : 1.37190
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.523     	     0.704     	     0.600     
       1       	     0.279     	     0.248     	     0.262     
       2       	     0.289     	     0.472     	     0.358     
       3       	     0.260     	     0.459     	     0.332     
       4       	     0.912     	     0.681     	     0.780     
Accuracy : 0.61818 %


100%|██████████████████████████████████████████████████████████████████████████████| 7011/7011 [06:37<00:00, 17.65it/s]
  0%|▏                                                                                | 4/1753 [00:00<00:46, 37.84it/s]

Training Run
Epoch : 3 Loss : 1.29656
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.628     	     0.714     	     0.668     
       1       	     0.348     	     0.489     	     0.407     
       2       	     0.377     	     0.520     	     0.437     
       3       	     0.339     	     0.511     	     0.408     
       4       	     0.922     	     0.748     	     0.826     
Accuracy : 0.68591 %


100%|█████████████████████████████████████████████████████████████████████████████| 1753/1753 [00:15<00:00, 113.73it/s]
  0%|                                                                                         | 0/7011 [00:00<?, ?it/s]

Validation Run
Epoch : 3 Loss : 1.36746
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.561     	     0.615     	     0.586     
       1       	     0.261     	     0.403     	     0.317     
       2       	     0.291     	     0.429     	     0.347     
       3       	     0.313     	     0.375     	     0.341     
       4       	     0.898     	     0.771     	     0.830     
Accuracy : 0.66373 %


100%|██████████████████████████████████████████████████████████████████████████████| 7011/7011 [06:37<00:00, 17.62it/s]
  0%|▎                                                                                | 8/1753 [00:00<00:23, 75.67it/s]

Training Run
Epoch : 4 Loss : 1.25596
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.670     	     0.748     	     0.707     
       1       	     0.409     	     0.571     	     0.477     
       2       	     0.427     	     0.577     	     0.491     
       3       	     0.374     	     0.553     	     0.446     
       4       	     0.928     	     0.764     	     0.838     
Accuracy : 0.71288 %


100%|█████████████████████████████████████████████████████████████████████████████| 1753/1753 [00:15<00:00, 113.73it/s]
  0%|                                                                                         | 0/7011 [00:00<?, ?it/s]

Validation Run
Epoch : 4 Loss : 1.36372
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.532     	     0.683     	     0.598     
       1       	     0.251     	     0.381     	     0.303     
       2       	     0.315     	     0.395     	     0.350     
       3       	     0.274     	     0.473     	     0.347     
       4       	     0.915     	     0.690     	     0.787     
Accuracy : 0.62566 %


100%|██████████████████████████████████████████████████████████████████████████████| 7011/7011 [06:36<00:00, 17.68it/s]
  1%|▍                                                                               | 10/1753 [00:00<00:17, 97.35it/s]

Training Run
Epoch : 5 Loss : 1.22121
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.702     	     0.775     	     0.737     
       1       	     0.466     	     0.627     	     0.534     
       2       	     0.466     	     0.626     	     0.534     
       3       	     0.408     	     0.590     	     0.482     
       4       	     0.934     	     0.781     	     0.851     
Accuracy : 0.73762 %


100%|█████████████████████████████████████████████████████████████████████████████| 1753/1753 [00:14<00:00, 121.33it/s]
  0%|                                                                                         | 0/7011 [00:00<?, ?it/s]

Validation Run
Epoch : 5 Loss : 1.36534
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.517     	     0.698     	     0.594     
       1       	     0.255     	     0.373     	     0.303     
       2       	     0.313     	     0.390     	     0.347     
       3       	     0.311     	     0.399     	     0.350     
       4       	     0.903     	     0.754     	     0.822     
Accuracy : 0.65917 %


100%|██████████████████████████████████████████████████████████████████████████████| 7011/7011 [06:35<00:00, 17.71it/s]
  1%|▍                                                                                | 9/1753 [00:00<00:19, 88.47it/s]

Training Run
Epoch : 6 Loss : 1.19644
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.730     	     0.792     	     0.760     
       1       	     0.502     	     0.669     	     0.574     
       2       	     0.494     	     0.663     	     0.566     
       3       	     0.429     	     0.617     	     0.506     
       4       	     0.938     	     0.790     	     0.858     
Accuracy : 0.75297 %


100%|█████████████████████████████████████████████████████████████████████████████| 1753/1753 [00:15<00:00, 115.72it/s]
  0%|                                                                                         | 0/7011 [00:00<?, ?it/s]

Validation Run
Epoch : 6 Loss : 1.37255
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.602     	     0.603     	     0.603     
       1       	     0.271     	     0.335     	     0.300     
       2       	     0.313     	     0.424     	     0.360     
       3       	     0.292     	     0.459     	     0.357     
       4       	     0.899     	     0.749     	     0.817     
Accuracy : 0.65521 %


100%|██████████████████████████████████████████████████████████████████████████████| 7011/7011 [06:37<00:00, 17.65it/s]
  0%|▎                                                                                | 8/1753 [00:00<00:22, 77.88it/s]

Training Run
Epoch : 7 Loss : 1.17758
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.742     	     0.801     	     0.771     
       1       	     0.528     	     0.693     	     0.600     
       2       	     0.521     	     0.690     	     0.593     
       3       	     0.453     	     0.639     	     0.530     
       4       	     0.941     	     0.801     	     0.866     
Accuracy : 0.76725 %


100%|█████████████████████████████████████████████████████████████████████████████| 1753/1753 [00:15<00:00, 115.42it/s]
  0%|                                                                                         | 0/7011 [00:00<?, ?it/s]

Validation Run
Epoch : 7 Loss : 1.36840
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.599     	     0.612     	     0.605     
       1       	     0.262     	     0.378     	     0.309     
       2       	     0.297     	     0.446     	     0.356     
       3       	     0.292     	     0.414     	     0.343     
       4       	     0.903     	     0.750     	     0.819     
Accuracy : 0.65453 %


100%|██████████████████████████████████████████████████████████████████████████████| 7011/7011 [06:36<00:00, 17.70it/s]
  1%|▍                                                                               | 10/1753 [00:00<00:17, 99.27it/s]

Training Run
Epoch : 8 Loss : 1.16141
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.758     	     0.813     	     0.784     
       1       	     0.556     	     0.718     	     0.627     
       2       	     0.540     	     0.714     	     0.615     
       3       	     0.469     	     0.654     	     0.546     
       4       	     0.944     	     0.809     	     0.871     
Accuracy : 0.77813 %


100%|█████████████████████████████████████████████████████████████████████████████| 1753/1753 [00:14<00:00, 116.88it/s]
  0%|                                                                                         | 0/7011 [00:00<?, ?it/s]

Validation Run
Epoch : 8 Loss : 1.37161
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.594     	     0.623     	     0.608     
       1       	     0.256     	     0.379     	     0.305     
       2       	     0.312     	     0.387     	     0.345     
       3       	     0.280     	     0.476     	     0.353     
       4       	     0.906     	     0.721     	     0.803     
Accuracy : 0.64028 %


100%|██████████████████████████████████████████████████████████████████████████████| 7011/7011 [06:35<00:00, 17.72it/s]
  1%|▍                                                                               | 10/1753 [00:00<00:17, 99.27it/s]

Training Run
Epoch : 9 Loss : 1.14963
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.764     	     0.825     	     0.793     
       1       	     0.578     	     0.732     	     0.646     
       2       	     0.556     	     0.726     	     0.630     
       3       	     0.485     	     0.672     	     0.563     
       4       	     0.946     	     0.815     	     0.875     
Accuracy : 0.78689 %


100%|█████████████████████████████████████████████████████████████████████████████| 1753/1753 [00:14<00:00, 118.56it/s]
  0%|                                                                                         | 0/7011 [00:00<?, ?it/s]

Validation Run
Epoch : 9 Loss : 1.37103
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.579     	     0.625     	     0.601     
       1       	     0.257     	     0.378     	     0.306     
       2       	     0.313     	     0.402     	     0.352     
       3       	     0.289     	     0.449     	     0.352     
       4       	     0.903     	     0.737     	     0.812     
Accuracy : 0.64875 %


100%|██████████████████████████████████████████████████████████████████████████████| 7011/7011 [06:35<00:00, 17.71it/s]
  2%|█▏                                                                             | 27/1753 [00:00<00:16, 104.75it/s]

Training Run
Epoch : 10 Loss : 1.13802
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.777     	     0.834     	     0.804     
       1       	     0.592     	     0.753     	     0.663     
       2       	     0.570     	     0.743     	     0.645     
       3       	     0.499     	     0.680     	     0.575     
       4       	     0.948     	     0.822     	     0.880     
Accuracy : 0.79547 %


100%|█████████████████████████████████████████████████████████████████████████████| 1753/1753 [00:14<00:00, 120.36it/s]

Validation Run
Epoch : 10 Loss : 1.37275
     Class     	   Precision   	    Recall     	   F1-score    
       0       	     0.577     	     0.667     	     0.619     
       1       	     0.263     	     0.341     	     0.297     
       2       	     0.303     	     0.396     	     0.344     
       3       	     0.298     	     0.413     	     0.346     
       4       	     0.898     	     0.762     	     0.824     
Accuracy : 0.66223 %





In [33]:
pos_sent = "This food was awesome!"

pos_rev = SENT_FIELD.process([SENT_FIELD.preprocess(pos_sent)])

net.eval()
pred = net(pos_rev.to(dev))

print("Review Rating as predicted : {}".format(pred.topk(1)[1].item() + 1))

Review Rating as predicted : 5


In [34]:
pos_sent = "I just hated that food. Not recommended at all."

pos_rev = SENT_FIELD.process([SENT_FIELD.preprocess(pos_sent)])

net.eval()
pred = net(pos_rev.to(dev))

print("Review Rating as predicted : {}".format(pred.topk(1)[1].item() + 1))

Review Rating as predicted : 1
