In [1]:
from datasets import load_dataset
import numpy as np
dataset = load_dataset("rotten_tomatoes")
train_dataset = dataset['train'].to_pandas()
validation_dataset = dataset['validation'].to_pandas()
test_dataset = dataset['test'].to_pandas()
max_len=max(0,train_dataset["text"].apply(lambda x:len(x)).max())
max_len=max(max_len,validation_dataset["text"].apply(lambda x:len(x)).max())
max_len=max(max_len,test_dataset["text"].apply(lambda x:len(x)).max())
max_len+=5


In [2]:
import nltk

def prep_pretrained_embedding():
    def build_vocab(train_dataset):
        # Create set, unique words only
        vocab = set()
        train_dataset_pos = []
        
        # Loop thru each sentence in training dataset
        for sentence in train_dataset['text']:
            # Basic text processing
            
            # Case folding
            sentence = sentence.lower()
            
            # NLTK tokenizer does a good job at separating meaningful words + punctuations
            # Better than defining regex ourselves
            word_list = nltk.tokenize.word_tokenize(sentence)
            
            # # Further split words into separate words
            # # e.g., 'well-being' -> 'well', 'being'
            # # e.g., 'music/song' -> 'music', 'song'
            # split_word_list = []
            # for word in sentence_list:
            #     split_word_list.extend(word.replace('-', ' ').replace('/', ' ').split())
            
            # Dont remove all special characters, some are meaningful
            # Some words are surrounded by single/double quotes
            word_list = [word.strip("'\"") for word in word_list]
            
            # Add into set
            vocab.update(word_list)
            
            # Get pos tags
            # Also build POS tags
            pos_tags = nltk.pos_tag(word_list)
            train_dataset_pos.append(pos_tags)
            
        vocab.discard('')
        return vocab, train_dataset_pos

    vocab, train_dataset_pos = build_vocab(train_dataset)



    def load_glove_embeddings(path):
        glove_embeddings = {}
        with open(path, 'r', encoding='utf-8') as f:
            for line in f:
                values = line.split()
                word = values[0]
                vector = np.asarray(values[1:], dtype='float64')
                glove_embeddings[word] = vector
                
        return glove_embeddings

    glove_embeddings = load_glove_embeddings('glove.6B.50d.txt')
    vocab_word_to_index = {word: idx for idx, word in enumerate(vocab)}

    def create_embedding_matrix(word_to_index, glove_embeddings):
        # Initialize embedding matrix with zeros
        # 50d
        embedding_matrix = np.zeros((len(vocab)+2, 50), dtype='float64')
        
        # Loop thru each word in vocab
        for word, idx in word_to_index.items():
            # Check if word exists in glove embeddings
            if word in glove_embeddings:
                # Copy glove embedding to embedding matrix
                embedding_matrix[idx] = glove_embeddings[word]
                # If OOV, assign None first
                
        return embedding_matrix

    embedding_matrix = create_embedding_matrix(vocab_word_to_index, glove_embeddings)
    #handle <unk>
    embedding_matrix[-2]=[ 0.01513297,  0.2400952 , -0.13676383,  0.13166569, -0.28283166,
        0.10421129,  0.39747017,  0.07944959,  0.29670785,  0.05400998,
        0.48425894,  0.26516231, -0.48021244, -0.25129253, -0.24367068,
       -0.24188322,  0.47579495, -0.2097357 , -0.02568224, -0.31143999,
       -0.3196337 ,  0.44878632, -0.07379564,  0.32765833, -0.49052161,
       -0.33455611, -0.34772199, -0.05043562, -0.0898296 ,  0.04898804,
        0.4993778 ,  0.04359836,  0.40077601, -0.31343237,  0.24126281,
       -0.4907152 , -0.20372591, -0.32123346, -0.39554707,  0.37386547,
        0.44720326,  0.45492689, -0.16420979,  0.42844699,  0.15748723,
       -0.23547929, -0.33962153,  0.04243802, -0.03647524, -0.0042893 ]
    
    return vocab_word_to_index,embedding_matrix


In [3]:
import pickle

def prep_embedding(handle_oov=False):
    if handle_oov:
        with open('embedding_matrix.pkl', 'rb') as file:  
            embedding_matrix = pickle.load(file)
            embedding_matrix = np.concatenate((embedding_matrix, np.zeros((1, 50))), axis=0)
        with open('vocab_word_to_index.pkl', 'rb') as file:  
            vocab_word_to_index = pickle.load(file)
            del vocab_word_to_index['<UNK>']
    else:
        vocab_word_to_index,embedding_matrix= prep_pretrained_embedding()
        embedding_matrix[-1]=np.zeros(50)
    # print(embedding_matrix)
    # print(embedding_matrix.shape)
    # print(len(vocab_word_to_index))
    # print(vocab_word_to_index)
    return vocab_word_to_index,embedding_matrix

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset

device=torch.device('cuda')

class CustomedDataset(Dataset):
    def __init__(self,sentences,labels,vocab_word_to_index):
        self.features=torch.tensor([[vocab_word_to_index[word] if word in vocab_word_to_index else len(vocab_word_to_index) for word in sentence]+[len(vocab_word_to_index)+1]*(max_len-len(sentence)) for sentence in sentences]).to(device)
        self.labels=torch.tensor(labels).to(device)
    
    def __len__(self):
        return self.features.shape[0]
    
    def __getitem__(self,idx):
        return self.features[idx],self.labels[idx]

def prep_dataloader(train_dataset,validation_dataset,test_dataset,batch_size,vocab_word_to_index):
    train_dataloader=DataLoader(CustomedDataset(train_dataset["text"],train_dataset["label"],vocab_word_to_index),batch_size=batch_size,shuffle=True)
    validation_dataloader=DataLoader(CustomedDataset(validation_dataset["text"],validation_dataset["label"],vocab_word_to_index),batch_size=batch_size)
    test_dataloader=DataLoader(CustomedDataset(test_dataset["text"],test_dataset["label"],vocab_word_to_index),batch_size=batch_size)
    return train_dataloader,validation_dataloader,test_dataloader
    

In [5]:


class CNNTextClassifier(nn.Module):
    def __init__(self, embedding_matrix, n_filters, filter_sizes, output_dim, dropout):
        super().__init__()
        embedding_matrix=torch.tensor(embedding_matrix,dtype=torch.float32)
        # print(embedding_matrix)
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=False)
        # self.embedding = embedding_matrix
        # print(self.embedding.shape)
        self.convs = nn.ModuleList(
            [nn.Conv2d(1, n_filters, (fs, embedding_matrix.shape[1])) for fs in filter_sizes]
        )
        self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.softmax=nn.Softmax(-1)

    def forward(self, sentences):
        # text = [batch size, sent len]
        # embedded=[[self.embedding[idx] for idx in sentence] for sentence in sentences]
        embedded = self.embedding(sentences)  # embedded = [batch size, sent len, emb dim]
        embedded = embedded.unsqueeze(1)  # embedded = [batch size, 1, sent len, emb dim]
        # print(embedded)
        # print(embedded.shape)
        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]  # conv_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]
        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]  # pooled_n = [batch size, n_filters]
        cat = self.dropout(torch.cat(pooled, dim=1))  # cat = [batch size, n_filters * len(filter_sizes)]
        
        return self.softmax(self.fc(cat))


In [6]:


def train(model,optimizer,criterion,num_epoch,train_dataloader,validation_dataloader):
    from tqdm import tqdm
    model.to(device)
    for _ in range(num_epoch):
        acc_loss=0
        model.train()
        process_bar=tqdm(train_dataloader,desc=f"Epoch {_}/{num_epoch}",leave=True)
        for features,labels in process_bar:
            
            pred=model(features)
            
            optimizer.zero_grad()
            loss=criterion(pred,labels)
            loss.backward()
            optimizer.step()
            
            acc_loss+=loss.item()
            process_bar.set_postfix_str(f"Mean loss: {acc_loss/(process_bar.n+1)}")
        
        print("Train loss:",acc_loss/process_bar.n)
        
        acc_loss=0
        model.eval()
        with torch.no_grad():
            acc_loss=0
            process_bar=tqdm(validation_dataloader,desc="Validating",leave=True)
            for features,labels in process_bar:
                
                pred=model(features)
                
                loss=criterion(pred,labels)
                
                acc_loss+=loss.item()
                process_bar.set_postfix_str(f"Mean loss: {acc_loss/(process_bar.n+1)}")
                
            print("Validation loss:",acc_loss/process_bar.n)



In [7]:
def work_flow(model_type,handle_oov,params):
    vocab_word_to_index,embedding_matrix=prep_embedding(handle_oov)
    # print(embedding_matrix)
    train_dataloader,validation_dataloader,test_dataloader=prep_dataloader(train_dataset,validation_dataset,test_dataset,params["batch_size"],vocab_word_to_index)

    if model_type=="CNN":
        model = CNNTextClassifier(embedding_matrix, params["n_filters"], params["filter_sizes"], params["output_dim"], params["dropout"])

    criterion=nn.CrossEntropyLoss()
    optimizer=torch.optim.Adam(model.parameters(),lr=params["lr"])
    
    train(model,optimizer,criterion,params["num_epoch"],train_dataloader,validation_dataloader)
    
    model.eval()
    test_acc=0
    tot_samples=0
    with torch.no_grad():
        for features,labels in test_dataloader:
            pred_labels=model(features)
            # print(labels.shape,pred_labels.shape)
            test_acc+=(labels==pred_labels.argmax(dim=1)).sum().item()
            tot_samples+=labels.shape[0]
        print(f"Test acc is:{test_acc/tot_samples*100}%")

params={"batch_size":256,"n_filters":128,"filter_sizes":[1,2,3,5],"output_dim":2,"dropout":0.2,"lr":0.001,"num_epoch":30}
work_flow("CNN",True,params)

Epoch 0/30: 100%|██████████| 34/34 [00:01<00:00, 24.56it/s, Mean loss: 0.694391527596642] 


Train loss: 0.694391527596642


Validating: 100%|██████████| 5/5 [00:00<00:00, 84.03it/s, Mean loss: 3.3717047572135925]


Validation loss: 0.6743409514427186


Epoch 1/30: 100%|██████████| 34/34 [00:00<00:00, 39.48it/s, Mean loss: 0.6960725964921893]


Train loss: 0.6755998730659485


Validating: 100%|██████████| 5/5 [00:00<00:00, 88.17it/s, Mean loss: 3.402778387069702]


Validation loss: 0.6805556774139404


Epoch 2/30: 100%|██████████| 34/34 [00:00<00:00, 39.98it/s, Mean loss: 0.681260220932238] 


Train loss: 0.6612231556107017


Validating: 100%|██████████| 5/5 [00:00<00:00, 89.98it/s, Mean loss: 3.3044593334198]


Validation loss: 0.6608918666839599


Epoch 3/30: 100%|██████████| 34/34 [00:00<00:00, 39.78it/s, Mean loss: 0.7122497097138436]


Train loss: 0.6494041470920339


Validating: 100%|██████████| 5/5 [00:00<00:00, 85.53it/s, Mean loss: 3.3273062705993652]


Validation loss: 0.665461254119873


Epoch 4/30: 100%|██████████| 34/34 [00:00<00:00, 40.03it/s, Mean loss: 0.695120140429466] 


Train loss: 0.6337860103915719


Validating: 100%|██████████| 5/5 [00:00<00:00, 78.09it/s, Mean loss: 3.1803221702575684]


Validation loss: 0.6360644340515137


Epoch 5/30: 100%|██████████| 34/34 [00:00<00:00, 40.03it/s, Mean loss: 0.6458510900988723]


Train loss: 0.6268554698018467


Validating: 100%|██████████| 5/5 [00:00<00:00, 90.87it/s, Mean loss: 3.123402953147888]


Validation loss: 0.6246805906295776


Epoch 6/30: 100%|██████████| 34/34 [00:00<00:00, 39.61it/s, Mean loss: 0.6093016775215373]


Train loss: 0.6093016775215373


Validating: 100%|██████████| 5/5 [00:00<00:00, 87.43it/s, Mean loss: 3.1291255354881287]


Validation loss: 0.6258251070976257


Epoch 7/30: 100%|██████████| 34/34 [00:00<00:00, 39.82it/s, Mean loss: 0.6641861450287604]


Train loss: 0.6055814851732815


Validating: 100%|██████████| 5/5 [00:00<00:00, 88.08it/s, Mean loss: 3.1332091093063354]


Validation loss: 0.626641821861267


Epoch 8/30: 100%|██████████| 34/34 [00:00<00:00, 39.64it/s, Mean loss: 0.6084854602813721]


Train loss: 0.5905888290966258


Validating: 100%|██████████| 5/5 [00:00<00:00, 87.98it/s, Mean loss: 3.162788450717926]


Validation loss: 0.6325576901435852


Epoch 9/30: 100%|██████████| 34/34 [00:00<00:00, 39.70it/s, Mean loss: 0.6353477874109822]


Train loss: 0.5792876885217779


Validating: 100%|██████████| 5/5 [00:00<00:00, 80.07it/s, Mean loss: 3.0534550547599792]


Validation loss: 0.6106910109519958


Epoch 10/30: 100%|██████████| 34/34 [00:00<00:00, 39.11it/s, Mean loss: 0.5863286870898623]


Train loss: 0.5690837257048663


Validating: 100%|██████████| 5/5 [00:00<00:00, 79.59it/s, Mean loss: 3.1180153489112854]


Validation loss: 0.6236030697822571


Epoch 11/30: 100%|██████████| 34/34 [00:00<00:00, 39.90it/s, Mean loss: 0.5693994497551638]


Train loss: 0.5693994497551638


Validating: 100%|██████████| 5/5 [00:00<00:00, 82.85it/s, Mean loss: 3.0428932905197144]


Validation loss: 0.6085786581039428


Epoch 12/30: 100%|██████████| 34/34 [00:00<00:00, 39.34it/s, Mean loss: 0.5744770945924701]


Train loss: 0.5575807094573975


Validating: 100%|██████████| 5/5 [00:00<00:00, 82.93it/s, Mean loss: 3.3500396609306335]


Validation loss: 0.6700079321861268


Epoch 13/30: 100%|██████████| 34/34 [00:00<00:00, 39.81it/s, Mean loss: 0.5510136765592238]


Train loss: 0.5510136765592238


Validating: 100%|██████████| 5/5 [00:00<00:00, 86.30it/s, Mean loss: 3.010809063911438]


Validation loss: 0.6021618127822876


Epoch 14/30: 100%|██████████| 34/34 [00:00<00:00, 39.44it/s, Mean loss: 0.5612619121869405]


Train loss: 0.5447542088873246


Validating: 100%|██████████| 5/5 [00:00<00:00, 89.46it/s, Mean loss: 2.965977132320404]


Validation loss: 0.5931954264640809


Epoch 15/30: 100%|██████████| 34/34 [00:00<00:00, 39.68it/s, Mean loss: 0.5564098683270541]


Train loss: 0.5400448721997878


Validating: 100%|██████████| 5/5 [00:00<00:00, 87.61it/s, Mean loss: 2.9776124358177185]


Validation loss: 0.5955224871635437


Epoch 16/30: 100%|██████████| 34/34 [00:00<00:00, 38.88it/s, Mean loss: 0.5466898878415426]


Train loss: 0.5306107734932619


Validating: 100%|██████████| 5/5 [00:00<00:00, 81.38it/s, Mean loss: 3.069404900074005]


Validation loss: 0.6138809800148011


Epoch 17/30: 100%|██████████| 34/34 [00:00<00:00, 38.80it/s, Mean loss: 0.5441874672066082]


Train loss: 0.5281819534652373


Validating: 100%|██████████| 5/5 [00:00<00:00, 82.08it/s, Mean loss: 2.9362486600875854]


Validation loss: 0.5872497320175171


Epoch 18/30: 100%|██████████| 34/34 [00:00<00:00, 38.29it/s, Mean loss: 0.5432988567785784]


Train loss: 0.5273194786380319


Validating: 100%|██████████| 5/5 [00:00<00:00, 82.18it/s, Mean loss: 2.9878429770469666]


Validation loss: 0.5975685954093933


Epoch 19/30: 100%|██████████| 34/34 [00:00<00:00, 39.87it/s, Mean loss: 0.5142481563722386]


Train loss: 0.5142481563722386


Validating: 100%|██████████| 5/5 [00:00<00:00, 88.04it/s, Mean loss: 2.9849127531051636]


Validation loss: 0.5969825506210327


Epoch 20/30: 100%|██████████| 34/34 [00:00<00:00, 39.84it/s, Mean loss: 0.5093022006399491]


Train loss: 0.5093022006399491


Validating: 100%|██████████| 5/5 [00:00<00:00, 83.01it/s, Mean loss: 2.9300588965415955]


Validation loss: 0.586011779308319


Epoch 21/30: 100%|██████████| 34/34 [00:00<00:00, 39.22it/s, Mean loss: 0.5208939620942781]


Train loss: 0.5055735514444464


Validating: 100%|██████████| 5/5 [00:00<00:00, 83.37it/s, Mean loss: 2.89328470826149]


Validation loss: 0.578656941652298


Epoch 22/30: 100%|██████████| 34/34 [00:00<00:00, 39.45it/s, Mean loss: 0.509218825136914] 


Train loss: 0.509218825136914


Validating: 100%|██████████| 5/5 [00:00<00:00, 87.33it/s, Mean loss: 3.0786243081092834]


Validation loss: 0.6157248616218567


Epoch 23/30: 100%|██████████| 34/34 [00:00<00:00, 39.57it/s, Mean loss: 0.49543416237129884]


Train loss: 0.49543416237129884


Validating: 100%|██████████| 5/5 [00:00<00:00, 84.03it/s, Mean loss: 3.0199597477912903]


Validation loss: 0.6039919495582581


Epoch 24/30: 100%|██████████| 34/34 [00:00<00:00, 39.83it/s, Mean loss: 0.49751485884189606]


Train loss: 0.49751485884189606


Validating: 100%|██████████| 5/5 [00:00<00:00, 83.27it/s, Mean loss: 2.8901124596595764]


Validation loss: 0.5780224919319152


Epoch 25/30: 100%|██████████| 34/34 [00:00<00:00, 39.44it/s, Mean loss: 0.49627482891082764]


Train loss: 0.49627482891082764


Validating: 100%|██████████| 5/5 [00:00<00:00, 87.85it/s, Mean loss: 2.880316436290741]


Validation loss: 0.5760632872581481


Epoch 26/30: 100%|██████████| 34/34 [00:00<00:00, 39.41it/s, Mean loss: 0.5014495967012463]


Train loss: 0.48670107915120964


Validating: 100%|██████████| 5/5 [00:00<00:00, 77.28it/s, Mean loss: 2.9862104654312134]


Validation loss: 0.5972420930862427


Epoch 27/30: 100%|██████████| 34/34 [00:00<00:00, 39.53it/s, Mean loss: 0.5276506937319233] 


Train loss: 0.4810932795791065


Validating: 100%|██████████| 5/5 [00:00<00:00, 87.64it/s, Mean loss: 2.947641611099243]


Validation loss: 0.5895283222198486


Epoch 28/30: 100%|██████████| 34/34 [00:00<00:00, 34.84it/s, Mean loss: 0.4939006232854092] 


Train loss: 0.47937413436525006


Validating: 100%|██████████| 5/5 [00:00<00:00, 89.19it/s, Mean loss: 3.1967454850673676]


Validation loss: 0.6393490970134735


Epoch 29/30: 100%|██████████| 34/34 [00:00<00:00, 39.56it/s, Mean loss: 0.49930560408216534]


Train loss: 0.48462014513857227


Validating: 100%|██████████| 5/5 [00:00<00:00, 83.23it/s, Mean loss: 2.9429466128349304]


Validation loss: 0.588589322566986
Test acc is:73.26454033771107%
