# Predicting the sentiment of IMDb movie reviews


In [50]:
import re
import torch
from torchtext.datasets import IMDB
from torch import nn
from torch.utils.data.dataset import random_split
from collections import Counter, OrderedDict
train_dataset = IMDB(split='train')
test_dataset = IMDB(split='test')

# We need to apply several preprocessing steps:
# 1. Split the training dataset into separate training and validation partitions.
# 2. Identify the unique words in the training dataset
# 3. Map each unique word to a unique integer and encode the review text into encoded integers (an index of each unique word)
# 4. Divide the dataset into mini-batches as input to the model

torch.manual_seed(1) # Set the random seed for reproducibility purposes
train_dataset,valid_dataset = random_split(list(train_dataset), [20000,5000])



# Tokenizer
def tokenizer(text):
    text = re.sub('<[^>]*>', '', text)
    emoticons = re.findall('(?::|;|=)(?:-)?(?:\)|\(|D|P)', text.lower())
    text = re.sub('[\W]+', ' ', text.lower()) + ' '.join(emoticons).replace('-', '')
    tokenized = text.split()
    return tokenized

token_counts=Counter()

for label, line in train_dataset:
    tokens = tokenizer(line)
    token_counts.update(tokens)

print("vocab size: ", len(token_counts))

from torchtext.vocab import vocab 
sorted_by_freq_tuples = sorted(token_counts.items(), key=lambda x:x[1],reverse= True)
ordered_dict = OrderedDict(sorted_by_freq_tuples)
vocab = vocab(ordered_dict)
vocab.insert_token("<pad>",0)
vocab.insert_token("<unk>",1)
vocab.set_default_index(1)

print([vocab[token] for token in ["this", "and", "an"]])

text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]
label_pipeline = lambda x: 1 if x == "pos" else 0
def collate_batch(batch):
    label_list, text_list, lengths = [], [], []
    for _label, _text in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        lengths.append(processed_text.size(0))
    label_list = torch.tensor(label_list)
    lengths = torch.tensor(lengths)
    padded_text_list = nn.utils.rnn.pad_sequence(
        text_list, batch_first=True)
    return padded_text_list, label_list, lengths

from torch.utils.data import DataLoader
dataloader = DataLoader(train_dataset, batch_size=4, shuffle=False, collate_fn=collate_batch)


text_batch, label_batch, length_batch = next(iter(dataloader))

print(text_batch)
batch_size=32
train_dl= DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_batch)
valid_dl= DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=collate_batch)
test_dl= DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_batch)
embedding=nn.Embedding(num_embeddings=10, embedding_dim=3, padding_idx=0)
text_encoded_inputs= torch.LongTensor([[1,2,4,5],[4,3,2,0]])
print ("text_encoded_ inputs: ", embedding(text_encoded_inputs))
class RNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, rnn_hidden_size, fc_hidden_size):
        super().__init__()
        self.embedding=nn.Embedding(vocab_size, embed_dim,)   
        self.rnn= nn.LSTM(embed_dim,rnn_hidden_size, batch_first=True)
        self.fc1= nn.Linear(rnn_hidden_size, fc_hidden_size)
        self.relu= nn.ReLU()
        self.fc2= nn.Linear(fc_hidden_size,1)
        self.sigmoid= nn.Sigmoid()
        
    def forward(self, text, lengths):
        out = self.embedding(text)
        out = nn.utils.rnn.pack_padded_sequence(
            out,lengths.cpu().numpy(),
            enforce_sorted=False, batch_first=True)
        out , (hidden, cell)= self.rnn(out)
        out = hidden[-1,:,:]
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        return out
        
vocab_size = len(vocab)
embed_dim = 20
rnn_hidden_size = 64
fc_hidden_size = 64
torch.manual_seed(1)    
model = RNN(vocab_size, embed_dim, rnn_hidden_size, fc_hidden_size)
model 
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
def train(dataloader):
    model.train()
    total_acc, total_loss = 0, 0
    for text_batch, label_batch, lengths in dataloader:
        optimizer.zero_grad()
        pred = model(text_batch, lengths)[:,0]
        print("pred: ", pred)
        loss= loss_fn(pred,label_batch)
        loss.backward()
        optimizer.step()
        total_acc += (
        (pred >= 0.5).float() == label_batch
        ).float().sum().item()
        total_loss += loss.item()*label_batch.size(0)
def evaluate(dataloader):
    model.eval()
    total_acc, total_loss = 0, 0
    with torch.no_grad():
        for text_batch, label_batch, lengths in dataloader:
            pred = model(text_batch, lengths)[:,0]
                  # Convert label_batch to float32
            label_batch = label_batch.float()
            pred = pred.float()
            loss = loss_fn(pred, label_batch)
            total_acc += (
                (pred >= 0.5).float() == label_batch
            ).float().sum().item()
            total_loss += loss.item()*label_batch.size(0)
    return total_acc/len(dataloader.dataset), total_loss/len(dataloader.dataset)
num_epochs = 10
torch.manual_seed(1)
for epoch in range(num_epochs):
    acc_train, loss_train =train(train_dl)
    acc_valid, loss_calid =evaluate(valid_dl)
    print(f"Epoch: {epoch}, Train Loss: {loss_train:.4f}, Train Acc: {acc_train:.4f}, Valid Loss: {loss_valid:.4f}, Valid Acc: {acc_valid:.4f}")

acc_test,_ = evaluate(test_dl)
print(f"Test Acc: {acc_test:.4f}")

vocab size:  69023
[11, 3, 35]
tensor([[   35,  1739,     7,   449,   721,     6,   301,     4,   787,     9,
             4,    18,    44,     2,  1705,  2460,   186,    25,     7,    24,
           100,  1874,  1739,    25,     7, 34415,  3568,  1103,  7517,   787,
             5,     2,  4991, 12401,    36,     7,   148,   111,   939,     6,
         11598,     2,   172,   135,    62,    25,  3199,  1602,     3,   928,
          1500,     9,     6,  4601,     2,   155,    36,    14,   274,     4,
         42945,     9,  4991,     3,    14, 10296,    34,  3568,     8,    51,
           148,    30,     2,    58,    16,    11,  1893,   125,     6,   420,
          1214,    27, 14542,   940,    11,     7,    29,   951,    18,    17,
         15994,   459,    34,  2480, 15211,  3713,     2,   840,  3200,     9,
          3568,    13,   107,     9,   175,    94,    25,    51, 10297,  1796,
            27,   712,    16,     2,   220,    17,     4,    54,   722,   238,
           395,     2

RuntimeError: Found dtype Long but expected Float