In [1]:
import torch, random, os, string
from torchtext import data , datasets
from collections import defaultdict, Counter
import numpy as np

os.environ['GENSIM_DATA_DIR'] = os.path.join(os.getcwd(), 'gensim-data')

import gensim.downloader as api
from gensim.models.fasttext import load_facebook_model

from spacy.lang.en.stop_words import STOP_WORDS
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score
from sklearn.preprocessing import label_binarize
import time

In [2]:
### Part 0: Dataset Preparation

# For tokenization
TEXT = data.Field ( tokenize = 'spacy', tokenizer_language = 'en_core_web_sm', include_lengths = True )

# For multi - class classification labels
LABEL = data.LabelField ()

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Load the TREC dataset
# Train / Validation / Test split
train_data, test_data = datasets.TREC.splits( TEXT, LABEL, fine_grained = False )

train_data, validation_data = train_data.split(
    split_ratio=0.8,
    stratified=True,
    strata_field='label',
    random_state=random.seed(42)
)
print(vars(train_data.examples[0]))


# Count how many samples per label in the train set
label_counts = Counter([ex.label for ex in train_data.examples])
total_examples = len(train_data)

print("\nLabel distribution in training set:")
for label, count in sorted(label_counts.items()):
    percentage = (count / total_examples) * 100
    print(f"- {label}: {count} samples ({percentage:.2f}%)")

# Optional sanity check: total percentages should sum â‰ˆ 100%
total_percentage = sum((count / total_examples) * 100 for count in label_counts.values())
print(f"Total samples: {total_examples}, Sum of percentages: {total_percentage:.2f}%")


{'text': ['How', 'tall', 'is', 'Prince', 'Charles', '?'], 'label': 'NUM'}

Label distribution in training set:
- ABBR: 69 samples (1.58%)
- DESC: 930 samples (21.32%)
- ENTY: 1000 samples (22.93%)
- HUM: 978 samples (22.42%)
- LOC: 668 samples (15.31%)
- NUM: 717 samples (16.44%)
Total samples: 4362, Sum of percentages: 100.00%


# Part 1: Prepare Word Embeddings

In [3]:
#### a) Size of Vocabulary formed from training data according to tokenization method
# Vocabulary size (includes specials like <unk>, <pad>)
TEXT.build_vocab(train_data, min_freq=1)
vocab_size = len(TEXT.vocab)
print("Vocabulary Size (with specials):", vocab_size)

vocab_wo_specials = len([w for w in TEXT.vocab.stoi if w not in {TEXT.unk_token, TEXT.pad_token}])
print("Vocabulary size (no specials):", vocab_wo_specials)

Vocabulary Size (with specials): 8118
Vocabulary size (no specials): 8116


In [6]:
#### b) How many OOV words exist in your training data?
####    What is the number of OOV words for each topic category?
w2v = api.load("word2vec-google-news-300")
w2v_vocab = w2v.key_to_index

# Get training vocab tokens (types), excluding specials
specials = {TEXT.unk_token, TEXT.pad_token}
train_vocab_types = [w for w in TEXT.vocab.stoi.keys() if w not in specials]

# Overall OOV types in training vocab
oov_types_overall = {w for w in train_vocab_types if w not in w2v_vocab}
print("Number of OOV word types (overall):", len(oov_types_overall))

# OOV types per label (unique types per category across its sentences)
label_to_oov_types = defaultdict(set)
label_to_total_types = defaultdict(set)

for ex in train_data.examples:
    label = ex.label
    # Count by unique types per sentence to avoid overcounting repeats
    for w in set(ex.text):
        label_to_total_types[label].add(w)
        if w not in specials and w not in w2v_vocab:
            label_to_oov_types[label].add(w)

print("\nOOV word types per topic label:")
for label in sorted(label_to_total_types.keys()):
    num_oov = len(label_to_oov_types[label])
    num_types = len(label_to_total_types[label])
    rate = (num_oov / num_types) if num_types > 0 else 0.0
    print(f"- {label}: {num_oov} OOV types (out of {num_types}, rate={rate:.2%})")

Number of OOV word types (overall): 420

OOV word types per topic label:
- ABBR: 20 OOV types (out of 148, rate=13.51%)
- DESC: 118 OOV types (out of 2278, rate=5.18%)
- ENTY: 157 OOV types (out of 2979, rate=5.27%)
- HUM: 141 OOV types (out of 3055, rate=4.62%)
- LOC: 79 OOV types (out of 1762, rate=4.48%)
- NUM: 76 OOV types (out of 1880, rate=4.04%)


In [4]:
# #### c) OOV mitigation strategy (No transformer-based language models allowed)
# Implement your solution in your source code. Show the corresponding code snippet.
# 1. Fast Text Model Implementatation
# Load FastText with subword info (pretrained on Wikipedia)
# First download is large; cached afterwards

# 2. Modelling Unknown (<UNK>) token approach
# Make the <unk> vector informative and trainable by initializing it
# as the mean of available pretrained vectors.

# Loading fasttext model
fatter_fasttext_bin = load_facebook_model('crawl-300d-2M-subword/crawl-300d-2M-subword.bin')
embedding_dim = fatter_fasttext_bin.wv.vector_size

# Build embedding matrix aligned to TEXT.vocab
num_tokens = len(TEXT.vocab)
emb_matrix = np.zeros((num_tokens, embedding_dim), dtype=np.float32)

# torchtext 0.4.0: TEXT.vocab.itos is index->token, stoi is token->index
pad_tok = TEXT.pad_token
unk_tok = TEXT.unk_token

# Getting index of <unk> in vocab
unk_index = TEXT.vocab.stoi[TEXT.unk_token]
known_vecs = []

for idx, token in enumerate(TEXT.vocab.itos):
    # Skip specials here; we will set them explicitly below
    if token in {pad_tok, unk_tok}:
        continue

    vec = fatter_fasttext_bin.wv[token]
    emb_matrix[idx] = vec
    known_vecs.append(vec)

if len(known_vecs) > 0:
    unk_mean = torch.tensor(np.mean(known_vecs, axis=0), dtype=torch.float32)
else:
    unk_mean = torch.empty(embedding_dim).uniform_(-0.05, 0.05)
with torch.no_grad():
    emb_matrix[unk_index] = unk_mean

# Create Embedding layer initialized with FastText
fatter_embedding = torch.nn.Embedding(num_tokens, embedding_dim, padding_idx=TEXT.vocab.stoi[TEXT.pad_token])
fatter_embedding.weight.data.copy_(torch.from_numpy(emb_matrix))

torch.save(fatter_embedding, 'embedding_weights_fatter_fasttext.pt')

In [7]:
#### d) Select the 20 most frequent words from each topic category in the training set (removing
# stopwords if necessary). Retrieve their pretrained embeddings (from Word2Vec or GloVe).
# Project these embeddings into 2D space (using e.g., t-SNE or Principal Component Analysis).
# Plot the points in a scatter plot, color-coded by their topic category. Attach your plot here.
# Analyze your findings.

# Build per-label token frequency (lowercased, stopwords/punct filtered)
label_to_counter = defaultdict(Counter)
valid_chars = set(string.ascii_letters)

def is_valid_token(tok: str) -> bool:
    t = tok.strip("'\"")
    if len(t) == 0:
        return False
    # Keep purely alphabetic tokens to avoid punctuation/numbers
    return t.isalpha()

for ex in train_data.examples:
    label = ex.label
    for tok in ex.text:
        tok_l = tok.lower()
        if tok_l in STOP_WORDS:
            continue
        if not is_valid_token(tok_l):
            continue
        label_to_counter[label][tok_l] += 1

# Select top 20 per label that exist in Word2Vec
topk = 20
label_to_top_tokens = {}
for label, ctr in label_to_counter.items():
    selected = []
    for tok, _ in ctr.most_common():
        if tok in w2v.key_to_index:
            selected.append(tok)
        if len(selected) >= topk:
            break
    label_to_top_tokens[label] = selected

# Collect embeddings and labels
points = []
point_labels = []
point_words = []
for label, toks in label_to_top_tokens.items():
    for tok in toks:
        vec = w2v.get_vector(tok)
        points.append(vec)
        point_labels.append(label)
        point_words.append(tok)

if len(points) > 0:
    X = np.vstack(points)

    # 2D projections
    tsne_2d = TSNE(n_components=2, random_state=42, init="pca", perplexity=30).fit_transform(X)
    pca_2d = PCA(n_components=2, random_state=42).fit_transform(X)

    # Assign colors per label
    unique_labels = sorted(set(point_labels))
    color_map = {lab: plt.cm.tab10(i % 10) for i, lab in enumerate(unique_labels)}

    def plot_scatter(Y2, title: str, fname: str):
        plt.figure(figsize=(10, 8))
        for lab in unique_labels:
            idxs = [i for i, l in enumerate(point_labels) if l == lab]
            plt.scatter(Y2[idxs, 0], Y2[idxs, 1], c=[color_map[lab]], label=lab, alpha=0.8, s=40)
            # Light word annotations (optional; can clutter)
            for i in idxs:
                plt.annotate(point_words[i], (Y2[i, 0], Y2[i, 1]), fontsize=7, alpha=0.7)
        plt.legend(title="TREC label")
        plt.title(title)
        plt.tight_layout()
        plt.savefig(fname, dpi=200)
        plt.close()

    plot_scatter(tsne_2d, "Top-20 per TREC label (Word2Vec) - t-SNE", "trec_top20_tsne.png")
    plot_scatter(pca_2d, "Top-20 per TREC label (Word2Vec) - PCA", "trec_top20_pca.png")

    print("Saved plots: trec_top20_tsne.png, trec_top20_pca.png")
    for lab in unique_labels:
        print(f"{lab}: {label_to_top_tokens[lab]}")
else:
    print("No points collected for visualization. Check filtering or embedding availability.")

Saved plots: trec_top20_tsne.png, trec_top20_pca.png
ABBR: ['stand', 'abbreviation', 'mean', 'computer', 'cnn', 'letters', 'national', 'bureau', 'investigation', 'acronym', 'form', 'cpr', 'reading', 'classified', 'ads', 'dsl', 'scsi', 'washington', 'shield', 'psi']
DESC: ['mean', 'origin', 'difference', 'word', 'find', 'come', 'work', 'people', 'term', 'meaning', 'causes', 'like', 't', 'school', 'definition', 'called', 'time', 'happened', 'famous', 'computer']
ENTY: ['fear', 'called', 'kind', 'world', 'best', 'film', 'color', 'war', 'movie', 'novel', 'book', 'word', 'animal', 'drink', 'term', 'english', 'sport', 'known', 'play', 'use']
HUM: ['president', 'company', 'wrote', 'world', 'famous', 'invented', 'won', 'character', 'movie', 'team', 'baseball', 'new', 'tv', 'portrayed', 'known', 'american', 'actor', 'star', 'king', 'played']
LOC: ['country', 'city', 'state', 'world', 'largest', 'find', 'countries', 'located', 'river', 'highest', 'live', 'capital', 'airport', 'mountain', 'island

In [81]:
fatter_embedding

Embedding(8166, 300, padding_idx=1)

In [None]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

# Build vocabulary for labels
LABEL.build_vocab(train_data)
num_classes = len(LABEL.vocab)
print(f"\nNumber of classes: {num_classes}")
print(f"Classes: {LABEL.vocab.itos}")

# Create iterators for batching
def create_iterators(train_data, validation_data, test_data, batch_size):
    train_iterator = data.BucketIterator(
        train_data,
        batch_size=batch_size,
        sort_key=lambda x: len(x.text),
        sort_within_batch=True,
        device=device
    )
    
    val_iterator = data.BucketIterator(
        validation_data,
        batch_size=batch_size,
        sort_key=lambda x: len(x.text),
        sort_within_batch=True,
        device=device
    )
    
    test_iterator = data.BucketIterator(
        test_data,
        batch_size=batch_size,
        sort_key=lambda x: len(x.text),
        sort_within_batch=True,
        device=device
    )
    
    return train_iterator, val_iterator, test_iterator


class RNN_Classifier(nn.Module):
    """
    Simple RNN for topic classification with multiple aggregation strategies
    """
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, 
                 n_layers=1, bidirectional=False, dropout=0.5, 
                 padding_idx=0, pretrained_embeddings=None,
                 aggregation='last'):
        super(RNN_Classifier, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.bidirectional = bidirectional
        self.aggregation = aggregation  # 'last', 'mean', 'max', 'attention'
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        
        # Initialize with pretrained embeddings
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(pretrained_embeddings)
        
        # Make embeddings learnable (updated during training)
        self.embedding.weight.requires_grad = True
        
        # RNN layer
        self.rnn = nn.RNN(
            embedding_dim,
            hidden_dim,
            num_layers=n_layers,
            bidirectional=bidirectional,
            batch_first=True,
            dropout=dropout if n_layers > 1 else 0
        )
        
        # Dropout layer
        self.dropout = nn.Dropout(dropout)
        
        # Attention mechanism for aggregation
        if aggregation == 'attention':
            rnn_output_dim = hidden_dim * 2 if bidirectional else hidden_dim
            self.attention = nn.Linear(rnn_output_dim, 1)
        
        # Fully connected output layer
        rnn_output_dim = hidden_dim * 2 if bidirectional else hidden_dim
        self.fc = nn.Linear(rnn_output_dim, output_dim)
        
    def forward(self, text, text_lengths):
        # text: [batch_size, seq_len]
        # text_lengths: [batch_size]
        
        # Embed the input
        embedded = self.dropout(self.embedding(text))
        # embedded: [batch_size, seq_len, embedding_dim]
        
        # Pack the padded sequences
        packed_embedded = nn.utils.rnn.pack_padded_sequence(
            embedded, text_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        
        # Pass through RNN
        packed_output, hidden = self.rnn(packed_embedded)
        # packed_output: packed sequence of [batch_size, seq_len, hidden_dim * num_directions]
        # hidden: [n_layers * num_directions, batch_size, hidden_dim]
        
        # Unpack the sequences
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        # output: [batch_size, seq_len, hidden_dim * num_directions]
        
        # Aggregate word representations to sentence representation
        if self.aggregation == 'last':
            # Use the last hidden state
            if self.bidirectional:
                # Concatenate last states from forward and backward
                hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
            else:
                hidden = hidden[-1,:,:]
            sentence_repr = hidden
            
        elif self.aggregation == 'mean':
            # Mean pooling over all outputs (ignoring padding)
            # Create mask for padding
            batch_size, seq_len, hidden_size = output.size()
            mask = torch.arange(seq_len, device=device).unsqueeze(0) < text_lengths.unsqueeze(1)
            mask = mask.unsqueeze(2).float()  # [batch_size, seq_len, 1]
            
            # Apply mask and compute mean
            masked_output = output * mask
            sum_output = masked_output.sum(dim=1)
            sentence_repr = sum_output / text_lengths.unsqueeze(1).float()
            
        elif self.aggregation == 'max':
            # Max pooling over all outputs
            sentence_repr, _ = torch.max(output, dim=1)
            
        elif self.aggregation == 'attention':
            # Attention mechanism
            # Compute attention scores
            attn_scores = self.attention(output).squeeze(2)  # [batch_size, seq_len]
            
            # Mask padding positions
            mask = torch.arange(output.size(1), device=device).unsqueeze(0) < text_lengths.unsqueeze(1)
            attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
            
            # Apply softmax
            attn_weights = torch.softmax(attn_scores, dim=1).unsqueeze(1)  # [batch_size, 1, seq_len]
            
            # Weighted sum
            sentence_repr = torch.bmm(attn_weights, output).squeeze(1)  # [batch_size, hidden_dim * num_directions]
        
        # Apply dropout
        sentence_repr = self.dropout(sentence_repr)
        
        # Pass through fully connected layer
        output = self.fc(sentence_repr)
        
        return output


def count_parameters(model):
    """Count trainable parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def train_epoch(model, iterator, optimizer, criterion, device, l1_lambda=0.0, l2_lambda=0.0):
    """Train for one epoch"""
    model.train()
    epoch_loss = 0
    all_preds = []
    all_labels = []
    
    for batch in iterator:
        text, text_lengths = batch.text
        labels = batch.label
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        predictions = model(text, text_lengths)
        
        # Calculate loss
        loss = criterion(predictions, labels)
        
        # Add L1 regularization
        if l1_lambda > 0:
            l1_norm = sum(p.abs().sum() for p in model.parameters())
            loss = loss + l1_lambda * l1_norm
        
        # Add L2 regularization (can also use weight_decay in optimizer)
        if l2_lambda > 0:
            l2_norm = sum(p.pow(2).sum() for p in model.parameters())
            loss = loss + l2_lambda * l2_norm
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        # Update weights
        optimizer.step()
        
        epoch_loss += loss.item()
        
        # Store predictions and labels for metrics
        preds = torch.argmax(predictions, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    return epoch_loss / len(iterator), accuracy, f1


def evaluate(model, iterator, criterion, device, return_predictions=False):
    """Evaluate the model"""
    model.eval()
    epoch_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch in iterator:
            text, text_lengths = batch.text
            labels = batch.label
            
            # Forward pass
            predictions = model(text, text_lengths)
            
            # Calculate loss
            loss = criterion(predictions, labels)
            epoch_loss += loss.item()
            
            # Store predictions and labels
            probs = torch.softmax(predictions, dim=1)
            preds = torch.argmax(predictions, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    # Calculate AUC-ROC (one-vs-rest for multiclass)
    try:
        all_probs_array = np.array(all_probs)
        all_labels_bin = label_binarize(all_labels, classes=range(num_classes))
        auc_roc = roc_auc_score(all_labels_bin, all_probs_array, average='weighted', multi_class='ovr')
    except:
        auc_roc = 0.0
    
    if return_predictions:
        return epoch_loss / len(iterator), accuracy, f1, auc_roc, all_preds, all_labels
    
    return epoch_loss / len(iterator), accuracy, f1, auc_roc


def train_model(model, train_iterator, val_iterator, optimizer, criterion, 
                n_epochs, device, patience=5, l1_lambda=0.0, l2_lambda=0.0,
                save_path='best_model.pt'):
    """
    Train the model with early stopping
    """
    best_val_acc = 0
    patience_counter = 0
    
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    
    print(f"\nStarting training for {n_epochs} epochs...")
    print(f"Device: {device}")
    print(f"Model parameters: {count_parameters(model):,}")
    
    for epoch in range(n_epochs):
        start_time = time.time()
        
        # Train
        train_loss, train_acc, train_f1 = train_epoch(
            model, train_iterator, optimizer, criterion, device, l1_lambda, l2_lambda
        )
        
        # Evaluate on validation set
        val_loss, val_acc, val_f1, val_auc = evaluate(model, val_iterator, criterion, device)
        
        end_time = time.time()
        epoch_mins, epoch_secs = divmod(end_time - start_time, 60)
        
        # Store metrics
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        print(f'Epoch: {epoch+1:02}/{n_epochs} | Time: {int(epoch_mins)}m {int(epoch_secs)}s')
        print(f'\tTrain Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | Train F1: {train_f1:.4f}')
        print(f'\tVal Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}% | Val F1: {val_f1:.4f} | Val AUC: {val_auc:.4f}')
        
        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), save_path)
            print(f'\t>>> New best model saved with Val Acc: {val_acc*100:.2f}%')
        else:
            patience_counter += 1
            print(f'\t>>> No improvement. Patience: {patience_counter}/{patience}')
            
            if patience_counter >= patience:
                print(f'\nEarly stopping triggered after epoch {epoch+1}')
                break
    
    return {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'val_losses': val_losses,
        'val_accs': val_accs,
        'best_val_acc': best_val_acc
    }


def evaluate_per_topic(model, iterator, device):
    """Evaluate model performance per topic category"""
    model.eval()
    
    topic_correct = defaultdict(int)
    topic_total = defaultdict(int)
    
    with torch.no_grad():
        for batch in iterator:
            text, text_lengths = batch.text
            labels = batch.label
            
            # Forward pass
            predictions = model(text, text_lengths)
            preds = torch.argmax(predictions, dim=1)
            
            # Count per topic
            for pred, label in zip(preds.cpu().numpy(), labels.cpu().numpy()):
                topic_name = LABEL.vocab.itos[label]
                topic_total[topic_name] += 1
                if pred == label:
                    topic_correct[topic_name] += 1
    
    # Calculate accuracy per topic
    topic_accuracies = {}
    for topic in sorted(topic_total.keys()):
        acc = topic_correct[topic] / topic_total[topic] if topic_total[topic] > 0 else 0
        topic_accuracies[topic] = acc
        print(f'{topic}: {topic_correct[topic]}/{topic_total[topic]} = {acc*100:.2f}%')
    
    return topic_accuracies


def plot_training_curves(history, save_prefix='rnn'):
    """Plot training and validation curves"""
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss curve
    axes[0].plot(history['train_losses'], label='Train Loss', marker='o')
    axes[0].plot(history['val_losses'], label='Val Loss', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy curve
    axes[1].plot([acc*100 for acc in history['train_accs']], label='Train Acc', marker='o')
    axes[1].plot([acc*100 for acc in history['val_accs']], label='Val Acc', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'{save_prefix}_training_curves.png', dpi=200)
    plt.close()
    print(f'Saved training curves to {save_prefix}_training_curves.png')


Number of classes: 6
Classes: ['ENTY', 'HUM', 'DESC', 'NUM', 'LOC', 'ABBR']


In [87]:
print("\n" + "="*80)
print("PART 2: RNN MODEL TRAINING")
print("="*80)

# Get pretrained embeddings from Part 1
pretrained_embeddings = fatter_embedding.weight.data.clone()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters for baseline
BATCH_SIZE = 64
HIDDEN_DIM = 256
N_LAYERS = 1
DROPOUT = 0.5
N_EPOCHS = 50
LEARNING_RATE = 0.001
PATIENCE = 10

# Create data iterators
train_iterator, val_iterator, test_iterator = create_iterators(
    train_data, validation_data, test_data, BATCH_SIZE
)

# Initialize baseline model
baseline_model = RNN_Classifier(
    vocab_size=len(TEXT.vocab),
    embedding_dim=embedding_dim,
    hidden_dim=HIDDEN_DIM,
    output_dim=num_classes,
    n_layers=N_LAYERS,
    bidirectional=False,
    dropout=DROPOUT,
    padding_idx=TEXT.vocab.stoi[TEXT.pad_token],
    pretrained_embeddings=pretrained_embeddings,
    aggregation='last'
).to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(baseline_model.parameters(), lr=LEARNING_RATE)

print(f"\n>>> Training Baseline RNN Model")
print(f"Configuration: Hidden={HIDDEN_DIM}, Layers={N_LAYERS}, Dropout={DROPOUT}, LR={LEARNING_RATE}, Batch={BATCH_SIZE}")

# Train baseline model
baseline_history = train_model(
    baseline_model, train_iterator, val_iterator, optimizer, criterion,
    n_epochs=N_EPOCHS, device=device, patience=PATIENCE,
    save_path='rnn_baseline_best.pt'
)

# Load best model and evaluate on test set
baseline_model.load_state_dict(torch.load('rnn_baseline_best.pt'))
test_loss, test_acc, test_f1, test_auc = evaluate(baseline_model, test_iterator, criterion, device)

print(f"\n>>> Baseline Model Test Results:")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc*100:.2f}%")
print(f"Test F1 Score: {test_f1:.4f}")
print(f"Test AUC-ROC: {test_auc:.4f}")

# Topic-wise accuracy
print(f"\n>>> Topic-wise Accuracy (Baseline):")
baseline_topic_acc = evaluate_per_topic(baseline_model, test_iterator, device)

# Plot training curves
plot_training_curves(baseline_history, save_prefix='rnn_baseline')


PART 2: RNN MODEL TRAINING

>>> Training Baseline RNN Model
Configuration: Hidden=256, Layers=1, Dropout=0.5, LR=0.001, Batch=64

Starting training for 50 epochs...
Device: cpu
Model parameters: 2,594,190

[torchtext.data.batch.Batch of size 64]
	[.text]:('[torch.LongTensor of size 8x64]', '[torch.LongTensor of size 64]')
	[.label]:[torch.LongTensor of size 64]


IndexError: index out of range in self