# Amazon Review Star Rating Prediction – RNN Model
Author: Aiden Devine  
Model: BiLSTM with GloVe embeddings (50d)  
Input: Tokenized review text  
Output: Predicted star rating (1–5)

In [None]:
import os, random, sys, copy, json
import torch, torch.nn as nn, numpy as np
from tqdm.notebook import tqdm
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from nltk.tokenize import word_tokenize
from datasets import load_dataset, concatenate_datasets, load_from_disk
from datetime import datetime
from sklearn.metrics import classification_report
from collections import Counter, defaultdict
import re
import matplotlib.pyplot as plt

### Load Dataset

In [None]:
reviews = load_from_disk("filetred_amazon_reviews")
print(len(reviews))
print(Counter(reviews["rating"]))
print(reviews[0], '\n')
print(reviews[1])
print(reviews.column_names)

### Load Glove Embeddings

In [None]:
glove_file = 'glove.6B.50d.txt' # modify to appropriate path for your file system

embeddings_dict = {}

with open(glove_file, 'r', encoding='utf8') as f:
    for i, line in enumerate(f):
        line = line.strip().split(' ')
        word = line[0]
        embed = np.asarray(line[1:], "float")

        embeddings_dict[word] = embed


print('Loaded {} words from glove'.format(len(embeddings_dict)))

low = -1.0 / 3
high = 1.0 / 3
embedding_matrix = np.random.uniform(low=low, high=high, size=(len(embeddings_dict)+1, 50))

word2id = {}
for i, word in enumerate(embeddings_dict.keys(), 1):

    word2id[word] = i                                
    embedding_matrix[i] = embeddings_dict[word]      

word2id['<pad>'] = 0

### Set up train and validation datasets

In [None]:
#modified from the HW_3 
class RNNAmazonReviewDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset=None, word2id=None, finalized_data=None, data_limit=None, max_length=128):
        """
        :param hf_dataset: A Hugging Face Dataset object (preloaded and filtered)
        :param word2id: The GloVe word2id dictionary
        :param finalized_data: Used to create validation set or to tokenize data
        :param data_limit: Max number of examples to use
        :param max_length: Max sequence length
        """
        self.max_length = max_length
        self.word2id = word2id

        examples, labels = [], []

        # Unified logic: load from finalized_data or hf_dataset
        data_source = finalized_data if finalized_data else hf_dataset
        limit = len(data_source) if data_limit is None else data_limit

        for i, example in enumerate(data_source):
            if i >= limit:
                break
            examples.append(example["text"])
            labels.append(int(example["rating"]) - 1)  # 1–5 stars → 0–4

        tokenized = self.tokenize(examples)
        self.data = [(ids, length, label) for (ids, length), label in zip(tokenized, labels)]
        random.seed(42)
        random.shuffle(self.data)

    def tokenize(self, examples):
        example_ids = []
        misses = 0
        total = 0
        for example in tqdm(examples):
            tokens = word_tokenize(example)
            ids = []
            for tok in tokens:
                if tok in self.word2id:
                    tok.lower() # change text to lowercase
                    ids.append(self.word2id[tok])
                else:
                    misses += 1
                    ids.append(self.word2id.get('unk', 0))
                total += 1
            
            if len(ids) == 0:
                continue
            
            if len(ids) >= self.max_length:
                ids = ids[:self.max_length]
                length = self.max_length
            else:
                length = len(ids)
                ids += [self.word2id['<pad>']] * (self.max_length - len(ids))

            example_ids.append((torch.tensor(ids), length))

        print(f'Missed {misses} out of {total} words -- {misses/total:.2%}')
        return example_ids

    def generate_validation_split(self, ratio=0.8):
        split_idx = int(ratio * len(self.data))
        val_split = self.data[split_idx:]
        self.data = self.data[:split_idx]
        return val_split

    def __getitem__(self, index):
        return self.data[index]  # returns (input_ids, length, label)

    def __len__(self):
        return len(self.data)


In [None]:
def stratified_split_full_usage(hf_dataset, seed=42):
    grouped_by_rating = defaultdict(list)

    # Group by rating
    for ex in hf_dataset:
        rating = int(ex['rating']) - 1  # Convert 1–5 → 0–4
        grouped_by_rating[rating].append(ex)

    # Create fixed stratified splits
    train_data, val_data, test_data = [], [], []

    for rating, examples in grouped_by_rating.items():
        if len(examples) < 20026:
            raise ValueError(f"Expected 20026 examples for rating {rating}, got {len(examples)}")

        random.seed(seed)
        random.shuffle(examples)

        n_train = int(0.80 * 20026)   # 16020 per class
        n_val   = int(0.19 * 20026)   # 3804 per class
        n_test  = 20026 - n_train - n_val  # 202 per class

        train_data.extend(examples[:n_train])
        val_data.extend(examples[n_train:n_train + n_val])
        test_data.extend(examples[n_train + n_val:])

    random.shuffle(train_data)
    random.shuffle(val_data)
    random.shuffle(test_data)

    return train_data, val_data, test_data

# split raw data
train_data, val_data, test_data = stratified_split_full_usage(reviews)

#tokenize each dataset
train_dataset = RNNAmazonReviewDataset(finalized_data=train_data, word2id=word2id)
valid_dataset = RNNAmazonReviewDataset(finalized_data=val_data, word2id=word2id)
test_dataset  = RNNAmazonReviewDataset(finalized_data=test_data, word2id=word2id)

print(f"Train: {len(train_dataset)}  Val: {len(valid_dataset)}  Test: {len(test_dataset)}")

print(valid_dataset[0])  # (input_ids, length, label)

In [None]:
# sanity check
input_ids, length, label = valid_dataset[1]
print("True length:", length)
print("Non-padded input:", input_ids[:length])
print("Label:", label)


In [None]:
class LSTMModel(nn.Module):

    def __init__(self, embedding_matrix, lstm_hidden_size=75, num_lstm_layers=2, bidirectional=True):

        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_matrix))
        self.lstm = nn.LSTM(input_size = embedding_matrix.shape[1],
                            hidden_size = lstm_hidden_size,
                            num_layers = num_lstm_layers,
                            bidirectional = bidirectional,
                            batch_first = True,
                            dropout = 0.3)
        
        self.num_directions = 2 if bidirectional else 1
        self.hidden_1 = nn.Linear(lstm_hidden_size * self.num_directions, lstm_hidden_size)
        self.hidden_2 = nn.Linear(lstm_hidden_size, 5) # final layer to 5 classes
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)
    def forward(self, input_batch, input_lengths):
                
        embedded_input = self.embedding(input_batch)
        
        packed_input = pack_padded_sequence(embedded_input, input_lengths, batch_first=True, enforce_sorted=False)

        packed_output, (hn, cn) = self.lstm(packed_input)
        
        hn_view = hn.view(self.lstm.num_layers, self.num_directions, input_batch.shape[0], self.lstm.hidden_size)
        
        hn_view_last_layer = hn_view[-1]                                                                                       
        
        if self.num_directions == 2:
            # bidirectional → concat forward and backward
            hn_cat = torch.cat([hn_view_last_layer[0], hn_view_last_layer[1]], dim=1)
        else:
            # unidirectional → just take last layer output
            hn_cat = hn_view_last_layer[-1]
                                         
        hid = self.relu(self.hidden_1(hn_cat))
        hid = self.dropout(hid)
        
        output = self.hidden_2(hid)
                
        return output

In [None]:
def evaluate_model_on_test(model, test_dataset, batch_size=256):
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for input_data, length, y in test_dataloader:
            logits = model(input_data, length)
            preds = torch.argmax(logits, dim=1)
            y_true.extend(y.tolist())
            y_pred.extend(preds.tolist())

    print("=== Test Set Evaluation ===")
    print(classification_report(y_true, y_pred, target_names=["1★", "2★", "3★", "4★", "5★"], zero_division=0))

In [None]:
def train_lstm_classification(model, train_dataset, valid_dataset, epochs=10, batch_size=256, learning_rate=.001, print_frequency=25):
    criteria = nn.CrossEntropyLoss() # changed from BCE
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)

    print('Total train batches: {}'.format(train_dataset.__len__() / batch_size))

    best_accuracy = 0.0
    best_model_sd = None
    
    all_batch_losses = []
    epoch_avg_losses = []
    epoch_accuracies = []
    epoch_val_losses = []
    
    # Uncomment to create directories for saved models and metrics
    # os.makedirs("saved_models", exist_ok=True)
    # os.makedirs("saved_metrics", exist_ok=True)
    
    for i in range(epochs):
        print('### Epoch: ' + str(i+1) + ' ###')
    
        model.train()

        total_loss = 0
        batch_count = 0

        for step, data in enumerate(tqdm(train_dataloader, desc=f"Training Epoch {i+1}")):

            x, x_lengths, y = data

            optimizer.zero_grad()

            model_output = model(x, x_lengths)

            loss = criteria(model_output, y)

            loss.backward()
            optimizer.step()
            
            all_batch_losses.append(loss.item())
            total_loss += loss.item()
            batch_count += 1

            if step % print_frequency == (print_frequency - 1):
                print(f"epoch: {i} batch: {step} loss: {loss.item():.4f}")
        
        avg_epoch_loss = total_loss / max(1, batch_count)
        epoch_avg_losses.append(avg_epoch_loss)
        
        print('Evaluating...')
        model.eval()
        total_val_loss = 0
        val_batch_count = 0

        with torch.no_grad():
            total_correct = 0
            total_examples = 0
            for input_data, length, y in valid_dataloader:
                logits = model(input_data, length)
                preds = torch.argmax(logits, dim=1)
                loss = criteria(logits, y)
                total_val_loss += loss.item()
                val_batch_count += 1

                total_correct += (preds == y).sum().item()
                total_examples += len(y)

            acc = total_correct / total_examples
            val_loss_avg = total_val_loss / max(1, val_batch_count)
            epoch_val_losses.append(val_loss_avg)
            print(f"Validation Accuracy: {acc:.4f} | Validation Loss: {val_loss_avg:.4f}")
            epoch_accuracies.append(acc)
            if acc > best_accuracy:
                best_model_sd = copy.deepcopy(model.state_dict())
                best_accuracy = acc
         
        # Uncomment to save model
        # epoch_filename = f"saved_models/model_epoch_{i+1}.pth"
        # torch.save(model.state_dict(), epoch_filename)
        # print(f"Saved model checkpoint to {epoch_filename}")
        
        metrics = {
            "all_batch_losses": all_batch_losses,
            "epoch_avg_losses": epoch_avg_losses,
            "epoch_val_losses": epoch_val_losses,
            "epoch_accuracies": epoch_accuracies
        }
        # Uncomment to save metrics
        # with open("saved_metrics/metrics_epoch_{:02d}.json".format(i+1), "w") as f:
        #    json.dump(metrics, f)

    return model.state_dict(), best_model_sd, all_batch_losses, epoch_avg_losses, epoch_val_losses, epoch_accuracies

In [None]:
model = LSTMModel(embedding_matrix, lstm_hidden_size=50, num_lstm_layers=2, bidirectional=True)

final_model_state, best_model_state, all_losses, epoch_losses, epoch_val_losses, accuracies = train_lstm_classification(
    model,
    train_dataset,
    valid_dataset,
    epochs=5,             
    batch_size=256,
    learning_rate=1e-3,
    print_frequency=25
)

print("\n=== Training Summary ===")
for i, (train_loss, val_loss, acc) in enumerate(zip(epoch_losses, epoch_val_losses, accuracies), 1):
    print(f"Epoch {i:>2}: "
          f"Train Loss = {train_loss:.4f} | "
          f"Val Loss = {val_loss:.4f} | "
          f"Val Accuracy = {acc:.4f}")

evaluate_model_on_test(model, test_dataset, batch_size=256)

In [None]:
# Uncomment to save final model state

#torch.save(final_model_state, 'final_amazon_review_model.pth')
#print('Final model state saved')

In [None]:
# If you decided to save metrics, uncomment this to pull them from json 
    # and plot loss, average loss, and validation accuracy

# with open("saved_metrics/metrics_epoch_05.json") as f:
#     metrics = json.load(f)

# all_losses = metrics["all_batch_losses"]
# epoch_losses = metrics["epoch_avg_losses"]
# val_losses = metrics["epoch_val_losses"]
# accuracies = metrics["epoch_accuracies"]

# # Plot all batch-level training losses
# plt.figure(figsize=(12, 5))
# plt.plot(all_losses, label="Batch Loss")
# plt.title("Training Loss per Batch")
# plt.xlabel("Batch #")
# plt.ylabel("Loss")
# plt.grid(True)
# plt.legend()
# plt.show()

# # Plot average loss per epoch
# plt.figure(figsize=(8, 5))
# plt.plot(epoch_losses, marker='o', label="Avg Train Loss per Epoch", color='orange')
# plt.plot(val_losses, marker='x', label="Val Loss per Epoch", color='red')
# plt.title("Average Loss per Epoch")
# plt.xlabel("Epoch")
# plt.ylabel("Loss")
# plt.grid(True)
# plt.legend()
# plt.show()

# # Plot validation accuracy per epoch
# plt.figure(figsize=(8, 5))
# plt.plot(accuracies, marker='o', label="Validation Accuracy", color='green')
# plt.title("Validation Accuracy per Epoch")
# plt.xlabel("Epoch")
# plt.ylabel("Accuracy")
# plt.grid(True)
# plt.legend()
# plt.show()

In [None]:
# If you decided to save models as you trained, uncomment this to load weights from .pth and predict on test dataset
    # Make sure the hidden size of 'model' matches the hidden size of your original saved model
    # Also make sure the names match

# # === Define your test dataloader ===
# test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=256)

# # === Define your model parameters (must match training!) ===
# def load_model(epoch):
#     model = LSTMModel(embedding_matrix, lstm_hidden_size=75, num_lstm_layers=2, bidirectional=True)
#     model_path = f"saved_models/model_epoch_{epoch}.pth"
#     model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
#     model.eval()
#     return model

# # === Evaluate all saved models ===
# for epoch in range(1, 6):
#     print(f"\n=== Evaluating model from epoch {epoch} ===")
#     model = load_model(epoch)

#     y_true, y_pred = [], []

#     with torch.no_grad():
#         for input_data, length, y in test_dataloader:
#             logits = model(input_data, length)
#             preds = torch.argmax(logits, dim=1)
#             y_true.extend(y.tolist())
#             y_pred.extend(preds.tolist())

#     print(classification_report(y_true, y_pred, target_names=["1★", "2★", "3★", "4★", "5★"], zero_division=0))


     