In [1]:
import os

os.environ.setdefault('TORCH_COMPILE_DISABLE', '1')
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import torch

# Method 2: Patch torch._dynamo.disable decorator after import
try:
    import torch._dynamo
    # Patch the disable function to ignore the 'wrapping' parameter
    if hasattr(torch._dynamo, 'disable'):
        def patched_disable(fn=None, *args, **kwargs):
            # Remove problematic 'wrapping' parameter if present
            if 'wrapping' in kwargs:
                kwargs.pop('wrapping')
            if fn is None:
                # Decorator usage: @disable
                return lambda f: f
            # Function usage: disable(fn) or disable(fn, **kwargs)
            # Simply return the function unwrapped to avoid recursion
            # The original disable was causing issues, so we bypass it entirely
            return fn
        torch._dynamo.disable = patched_disable
except Exception as e:
    print(f"Warning: Could not patch torch._dynamo: {e}")
    pass  # If patching fails, continue anyway

import random, string

from torchtext import data , datasets
from collections import defaultdict, Counter
import numpy as np

os.environ['GENSIM_DATA_DIR'] = os.path.join(os.getcwd(), 'gensim-data')

import gensim.downloader as api
from gensim.models import KeyedVectors
from gensim.models.fasttext import load_facebook_model

from spacy.lang.en.stop_words import STOP_WORDS
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score
from sklearn.preprocessing import label_binarize
import time, copy
from utils import *

In [2]:
SEED = 42
train_data, validation_data, test_data, LABEL, TEXT, pretrained_embed = data_prep(SEED)

[*] Prepping Data...
[+] Test set formed!
[+] Train and Validation sets formed!
[+] Data prepped successfully!
[*] Retrieving pretrained word embeddings...
[*] Loading fasttext model...
[+] Model loaded!
[*] Forming embedding matrix...
[+] Embedding matrix formed!
[+] Embeddings retrieved successfully!

Label distribution in training set:
- ABBR: 69 samples (1.58%)
- DESC: 930 samples (21.32%)
- ENTY: 1000 samples (22.93%)
- HUM: 978 samples (22.42%)
- LOC: 668 samples (15.31%)
- NUM: 717 samples (16.44%)
Total samples: 4362, Sum of percentages: 100.00%


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
# Build vocabulary for labels
LABEL.build_vocab(train_data)
num_classes = len(LABEL.vocab)
print(f"\nNumber of classes: {num_classes}")
print(f"Classes: {LABEL.vocab.itos}")

# Get pretrained embeddings from Part 1 (frozen embeddings)
pretrained_embeddings = pretrained_embed.weight.data

# Get embedding dimension and vocab size from the fasttext embedding layer
embedding_dim = pretrained_embed.weight.shape[1]
embedding_vocab_size = pretrained_embed.weight.shape[0]  # Vocab size from saved embedding


Number of classes: 6
Classes: ['ENTY', 'HUM', 'DESC', 'NUM', 'LOC', 'ABBR']


In [5]:
print(f'TEXT Vocab Size: {len(TEXT.vocab.stoi)}')

TEXT Vocab Size: 8102


In [6]:
# ============================================================================
# PART 3.4: Targeted Improvement for Weak Topics
# Strategy: Data Augmentation, Positional Embeddings
# ============================================================================

print("\n" + "="*80)
print("PART 3.4: TARGETED IMPROVEMENT FOR WEAK TOPICS")
print("="*80)
print("\nStrategies:")
print("  1. Data Augmentation for imbalanced classes (especially ABBR)")
print("  2. Positional Embeddings in attention layer")
print("="*80)

# Import required libraries for augmentation
import nltk
from nltk.corpus import wordnet
try:
    nltk.download('wordnet', quiet=True)
    nltk.download('averaged_perceptron_tagger', quiet=True)
except:
    pass



PART 3.4: TARGETED IMPROVEMENT FOR WEAK TOPICS

Strategies:
  1. Data Augmentation for imbalanced classes (especially ABBR)
  2. Positional Embeddings in attention layer


In [7]:
# ============================================================================
# Step 2: Data Augmentation Functions for Imbalanced Classes
# ============================================================================

print("\n>>> Step 2: Implementing Data Augmentation Functions...")

def get_synonyms(word):
    """Get synonyms for a word using WordNet"""
    synonyms = set()
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            synonym = lemma.name().replace('_', ' ').lower()
            if synonym != word and synonym.isalpha():
                synonyms.add(synonym)
    return list(synonyms)

def synonym_replacement(tokens, n=1):
    """Replace n random words with their synonyms"""
    new_tokens = tokens.copy()
    words_to_replace = [i for i, word in enumerate(tokens) if word.isalpha() and len(word) > 2]
    
    if len(words_to_replace) == 0:
        return tokens
    
    num_replacements = min(n, len(words_to_replace))
    indices_to_replace = random.sample(words_to_replace, num_replacements)
    
    for idx in indices_to_replace:
        synonyms = get_synonyms(tokens[idx])
        if synonyms:
            new_tokens[idx] = random.choice(synonyms)
    
    return new_tokens

def random_insertion(tokens, n=1):
    """Randomly insert synonyms of n words"""
    new_tokens = tokens.copy()
    
    for _ in range(n):
        if len(new_tokens) == 0:
            break
        word = random.choice(new_tokens)
        synonyms = get_synonyms(word)
        if synonyms:
            synonym = random.choice(synonyms)
            insert_pos = random.randint(0, len(new_tokens))
            new_tokens.insert(insert_pos, synonym)
    
    return new_tokens

def random_deletion(tokens, p=0.1):
    """Randomly delete words with probability p"""
    if len(tokens) == 1:
        return tokens
    
    new_tokens = []
    for token in tokens:
        if random.random() > p:
            new_tokens.append(token)
    
    if len(new_tokens) == 0:
        return tokens[:1]
    
    return new_tokens

def random_swap(tokens, n=1):
    """Randomly swap n pairs of words"""
    new_tokens = tokens.copy()
    
    for _ in range(n):
        if len(new_tokens) < 2:
            break
        idx1, idx2 = random.sample(range(len(new_tokens)), 2)
        new_tokens[idx1], new_tokens[idx2] = new_tokens[idx2], new_tokens[idx1]
    
    return new_tokens

def augment_text(text, augmentation_techniques=['synonym', 'insertion', 'deletion', 'swap'], 
                 num_augmentations=3):
    """Apply data augmentation to text"""
    augmented_texts = []
    
    for _ in range(num_augmentations):
        aug_text = text.copy()
        technique = random.choice(augmentation_techniques)
        
        if technique == 'synonym' and len(aug_text) > 0:
            aug_text = synonym_replacement(aug_text, n=random.randint(1, 2))
        elif technique == 'insertion' and len(aug_text) > 0:
            aug_text = random_insertion(aug_text, n=random.randint(1, 2))
        elif technique == 'deletion' and len(aug_text) > 1:
            aug_text = random_deletion(aug_text, p=0.1)
        elif technique == 'swap' and len(aug_text) > 1:
            aug_text = random_swap(aug_text, n=1)
        
        augmented_texts.append(aug_text)
    
    return augmented_texts

print("    ✓ Data augmentation functions ready")



>>> Step 2: Implementing Data Augmentation Functions...
    ✓ Data augmentation functions ready


In [8]:

# ============================================================================
# Step 3: Apply Data Augmentation for Imbalanced Classes
# ============================================================================

print("\n>>> Step 3: Applying Data Augmentation for Imbalanced Classes...")

# Count current label distribution
label_counts_p34 = Counter([ex.label for ex in train_data.examples])
print(f"\nOriginal label distribution:")
for label, count in sorted(label_counts_p34.items()):
    print(f"  {label}: {count} samples ({count/len(train_data.examples)*100:.2f}%)")

# Augmentation targets (boost weaker topics more aggressively)
target_counts_p34 = {
    'ABBR': 900,   # heavy boost (~13x) to improve weakest class
    'DESC': 930,   # keep strong class unchanged
    'ENTY': 1300,  # moderate boost (~1.3x)
    'HUM': 1200,   # boost (~1.23x)
    'LOC': 800,    # modest boost (~1.2x)
    'NUM': 850     # modest boost (~1.2x)
}

# Create augmented examples
augmented_examples = list(train_data.examples)  # Start with all original examples

for label in label_counts_p34.keys():
    current_count = label_counts_p34[label]
    target_count = target_counts_p34[label]
    
    if current_count < target_count:
        label_examples = [ex for ex in train_data.examples if ex.label == label]
        num_augmentations_needed = target_count - current_count
        
        print(f"\n  Augmenting {label}: {current_count} -> {target_count} samples")
        print(f"    Generating {num_augmentations_needed} additional samples...")
        
        augmented_count = 0
        while augmented_count < num_augmentations_needed:
            original_ex = random.choice(label_examples)
            aug_texts = augment_text(original_ex.text, num_augmentations=1)
            
            for aug_text in aug_texts:
                if augmented_count >= num_augmentations_needed:
                    break
                
                new_ex = data.Example.fromlist([aug_text, label], 
                                               fields=[('text', TEXT), ('label', LABEL)])
                augmented_examples.append(new_ex)
                augmented_count += 1
        
        print(f"    ✓ Generated {augmented_count} augmented samples")

# Create augmented dataset with proper field structure
augmented_train_data = data.Dataset(augmented_examples, fields=[('text', TEXT), ('label', LABEL)])

# Verify augmented distribution
new_label_counts = Counter([ex.label for ex in augmented_examples])
print(f"\nAugmented label distribution:")
for label, count in sorted(new_label_counts.items()):
    print(f"  {label}: {count} samples ({count/len(augmented_examples)*100:.2f}%)")

print(f"\n  Total samples: {len(train_data.examples)} -> {len(augmented_examples)}")
print(f"  ✓ Data augmentation complete")


>>> Step 3: Applying Data Augmentation for Imbalanced Classes...

Original label distribution:
  ABBR: 69 samples (1.58%)
  DESC: 930 samples (21.32%)
  ENTY: 1000 samples (22.93%)
  HUM: 978 samples (22.42%)
  LOC: 668 samples (15.31%)
  NUM: 717 samples (16.44%)

  Augmenting ENTY: 1000 -> 1300 samples
    Generating 300 additional samples...
    ✓ Generated 300 augmented samples

  Augmenting HUM: 978 -> 1200 samples
    Generating 222 additional samples...
    ✓ Generated 222 augmented samples

  Augmenting LOC: 668 -> 800 samples
    Generating 132 additional samples...
    ✓ Generated 132 augmented samples

  Augmenting ABBR: 69 -> 900 samples
    Generating 831 additional samples...
    ✓ Generated 831 augmented samples

  Augmenting NUM: 717 -> 850 samples
    Generating 133 additional samples...
    ✓ Generated 133 augmented samples

Augmented label distribution:
  ABBR: 900 samples (15.05%)
  DESC: 930 samples (15.55%)
  ENTY: 1300 samples (21.74%)
  HUM: 1200 samples (20.07

In [9]:
# augmented_train_data.examples[0].label
[(ex.text, ex.label) for ex in augmented_train_data if ex.label not in ['ABBR','DESC','ENTY','HUM','LOC','NUM']]

[]

In [10]:
LABEL.vocab.stoi

defaultdict(None,
            {'ENTY': 0, 'HUM': 1, 'DESC': 2, 'NUM': 3, 'LOC': 4, 'ABBR': 5})

In [11]:
abbr_aug_ex = [ex for ex in augmented_train_data if ex.label == "ABBR"]
abbr_ex = [ex for ex in train_data if ex.label == "ABBR"]

count = 0
# for ex in abbr_aug_ex:
#     if ex not in abbr_ex:
#         print(ex.text)
#         print(ex.label)
#         count += 1

print(count)

0


In [12]:
# Count how many samples per label in the train set
label_counts_p34 = Counter([ex.label for ex in augmented_train_data.examples])
total_examples_p34 = len(augmented_train_data)

print("\nLabel distribution in training set:")
for label, count in sorted(label_counts_p34.items()):
    percentage = (count / total_examples_p34) * 100
    print(f"- {label}: {count} samples ({percentage:.2f}%)")


Label distribution in training set:
- ABBR: 900 samples (15.05%)
- DESC: 930 samples (15.55%)
- ENTY: 1300 samples (21.74%)
- HUM: 1200 samples (20.07%)
- LOC: 800 samples (13.38%)
- NUM: 850 samples (14.21%)


In [13]:
# ============================================================================
# PART 3.5: Sampling Strategies vs Data Augmentation
# ============================================================================

print("\n" + "="*80)
print("PART 3.4: SAMPLING STRATEGIES VS DATA AUGMENTATION")
print("="*80)
print("Comparing text augmentation, weighted sampling, and their combination")
print("across the Simple RNN baseline (Part 2 best config) and the RNN + BERT hybrid.")
print("="*80)

# Extended RNN Classifier with multiple aggregation methods
class RNN_Classifier_Aggregation(nn.Module):
    """
    RNN for topic classification with multiple aggregation strategies.
    Uses pretrained embeddings (learnable/updated during training).
    """
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, 
                 n_layers=1, dropout=0.0, padding_idx=0, pretrained_embeddings=None,
                 aggregation='mean'):
        super(RNN_Classifier_Aggregation, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.aggregation = aggregation  # 'last', 'mean', 'max'
        
        # Embedding layer initialized with pretrained embeddings
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(pretrained_embeddings)
            self.embedding.weight.requires_grad = True
        
        # RNN layer
        self.rnn = nn.RNN(
            embedding_dim,
            hidden_dim,
            num_layers=n_layers,
            batch_first=True,
            dropout=0.0  # No dropout in baseline
        )
                
        # Fully connected output layer
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, text, text_lengths):        
        # Embed the input
        embedded = self.embedding(text)  # [batch_size, seq_len, embedding_dim]
        
        # Get dimensions
        batch_size = embedded.size(0)
        seq_len = embedded.size(1)
        
        # Handle text_lengths
        text_lengths_flat = text_lengths.flatten().cpu().long()
        if len(text_lengths_flat) != batch_size:
            raise ValueError(
                f"text_lengths size mismatch: got {len(text_lengths_flat)} elements, "
                f"expected {batch_size}"
            )
        
        # Clamp lengths
        text_lengths_clamped = torch.clamp(text_lengths_flat, min=1, max=seq_len)
        
        text_lengths_clamped_device = text_lengths_clamped.to(text.device)
        packed_embedded = nn.utils.rnn.pack_padded_sequence(
            embedded, text_lengths_clamped, batch_first=True, enforce_sorted=False
        )
        
        # Pass through RNN
        packed_output, hidden = self.rnn(packed_embedded)
        
        # Aggregate word representations to sentence representation
        if self.aggregation == 'last':
            sentence_repr = hidden[-1]  # [batch_size, hidden_dim]
            
        elif self.aggregation == 'mean':
            output, output_lengths = nn.utils.rnn.pad_packed_sequence(
                packed_output, batch_first=True
            )
            # output: [batch_size, seq_len, hidden_dim]
            
            # Create mask for padding
            mask = torch.arange(seq_len, device=text.device).unsqueeze(0) < text_lengths_clamped_device.unsqueeze(1)
            mask = mask.unsqueeze(2).float()  # [batch_size, seq_len, 1]
            
            # Apply mask and compute mean
            masked_output = output * mask
            sum_output = masked_output.sum(dim=1)  # [batch_size, hidden_dim]
            sentence_repr = sum_output / text_lengths_clamped_device.unsqueeze(1).float()
            
        elif self.aggregation == 'max':
            # Max pooling over all outputs
            output, output_lengths = nn.utils.rnn.pad_packed_sequence(
                packed_output, batch_first=True
            )
            # output: [batch_size, seq_len, hidden_dim]
            
            mask = torch.arange(seq_len, device=text.device).unsqueeze(0) < text_lengths_clamped_device.unsqueeze(1)
            mask = mask.unsqueeze(2).float()  # [batch_size, seq_len, 1]
            
            masked_output = output * mask + (1 - mask) * float('-inf')
            sentence_repr, _ = torch.max(masked_output, dim=1)
        
        # Pass through fully connected layer
        output = self.fc(sentence_repr)  # [batch_size, output_dim]
        
        return output

print(">>> Simple RNN ready (mean aggregation baseline)")



PART 3.4: SAMPLING STRATEGIES VS DATA AUGMENTATION
Comparing text augmentation, weighted sampling, and their combination
across the Simple RNN baseline (Part 2 best config) and the RNN + BERT hybrid.
>>> Simple RNN ready (mean aggregation baseline)


In [None]:
from huggingface_hub import login
login("")

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
# ============================================================================
# Hybrid Model: RNN + BERT with Attention (Part 3.3 best model)
# ============================================================================

try:
    from transformers import BertModel, BertTokenizer
    BERT_AVAILABLE = True
    print(">>> Transformers library available")
except ImportError:
    BERT_AVAILABLE = False
    print(">>> Warning: transformers library not found. Install `transformers` to run BERT experiments.")


class RNNBertClassifier(nn.Module):
    """
    RNN with Pretrained BERT embeddings
    Uses BERT to get contextualized embeddings, then passes through BiLSTM with attention
    """
    def __init__(self, output_dim, hidden_dim=256, n_layers=2, dropout=0.5,
                 bert_model_name='distilbert-base-uncased', freeze_bert=False):
        super(RNNBertClassifier, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.freeze_bert = freeze_bert
        
        # Load pretrained BERT model and tokenizer
        self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.bert_model = BertModel.from_pretrained(bert_model_name)
        
        # Freeze BERT parameters if specified
        if freeze_bert:
            for param in self.bert_model.parameters():
                param.requires_grad = False
        
        # BERT output dimension (768 for bert-base-uncased)
        bert_output_dim = self.bert_model.config.hidden_size
        
        # Bidirectional LSTM to process BERT embeddings
        self.bilstm = nn.LSTM(
            bert_output_dim,
            hidden_dim,
            num_layers=n_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if n_layers > 1 else 0
        )
        
        # Additive Attention Mechanism (Bahdanau-style)
        self.attention_linear1 = nn.Linear(hidden_dim * 2, hidden_dim)  # *2 for bidirectional
        self.attention_linear2 = nn.Linear(hidden_dim, 1)
        self.tanh = nn.Tanh()
        
        # Dropout layer
        self.dropout = nn.Dropout(dropout)
        
        # Fully connected output layer
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        
    def forward(self, text, text_lengths, text_vocab=None):        
        batch_size = text.size(0)
        seq_len = text.size(1)
        device = text.device
        
        # Convert token indices back to text strings using vocab
        text_list = []
        for i in range(batch_size):
            actual_len = text_lengths[i].item() if isinstance(text_lengths[i], torch.Tensor) else text_lengths[i]
            tokens = []
            for j in range(min(actual_len, seq_len)):
                token_idx = text[i, j].item()
                if text_vocab is not None and token_idx < len(text_vocab):
                    token = text_vocab[token_idx]
                    # Skip special tokens
                    if token not in ['<pad>', '<unk>', '<sos>', '<eos>']:
                        tokens.append(token)
                else:
                    # Fallback if vocab not provided
                    tokens.append(str(token_idx))
            # Join tokens to form sentence
            sentence = " ".join(tokens)
            text_list.append(sentence)
        
        # Tokenize with BERT
        encoded = self.bert_tokenizer(
            text_list,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        ).to(device)
        
        # Get BERT embeddings
        with torch.set_grad_enabled(not self.freeze_bert):
            bert_outputs = self.bert_model(**encoded)
            bert_embeddings = bert_outputs.last_hidden_state  # [batch_size, seq_len, 768]
        
        # Get actual sequence lengths from BERT tokenizer
        bert_lengths = encoded['attention_mask'].sum(dim=1).cpu()


        # Get BERT embeddings
        with torch.set_grad_enabled(not self.freeze_bert):
            bert_outputs = self.bert_model(**encoded)
            bert_embeddings = bert_outputs.last_hidden_state  # [batch_size, seq_len, 768]

        # Align lengths
        bert_lengths = encoded['attention_mask'].sum(dim=1)
        max_len = bert_embeddings.size(1)
        bert_lengths = bert_lengths.clamp(max=max_len).cpu()

        # Pack safely
        packed_bert = nn.utils.rnn.pack_padded_sequence(
        bert_embeddings, bert_lengths, batch_first=True, enforce_sorted=False
        )

        
        # Pack sequences for efficient RNN processing
        packed_bert = nn.utils.rnn.pack_padded_sequence(
            bert_embeddings, bert_lengths, batch_first=True, enforce_sorted=False
        )
        
        # Pass through bidirectional LSTM
        packed_output, (hidden, cell) = self.bilstm(packed_bert)
        
        # Unpack sequences
        bilstm_output, output_lengths = nn.utils.rnn.pad_packed_sequence(
            packed_output, batch_first=True
        )
        # bilstm_output: [batch_size, seq_len, hidden_dim * 2]
        
        # Apply Attention Mechanism
        attention_scores = self.attention_linear1(bilstm_output)  # [batch_size, seq_len, hidden_dim]
        attention_scores = self.tanh(attention_scores)
        attention_scores = self.attention_linear2(attention_scores).squeeze(2)  # [batch_size, seq_len]
        
        # Mask padding positions
        batch_size_attn, seq_len_attn = bilstm_output.size(0), bilstm_output.size(1)
        mask = torch.arange(seq_len_attn, device=device).unsqueeze(0) < bert_lengths.unsqueeze(1).to(device)
        attention_scores = attention_scores.masked_fill(~mask, float('-inf'))
        
        # Apply softmax to get attention weights
        attention_weights = torch.softmax(attention_scores, dim=1).unsqueeze(2)  # [batch_size, seq_len, 1]
        
        # Compute weighted sum
        context_vector = torch.sum(attention_weights * bilstm_output, dim=1)  # [batch_size, hidden_dim * 2]
        
        # Apply dropout
        context_vector = self.dropout(context_vector)
        
        # Pass through fully connected layer
        output = self.fc(context_vector)  # [batch_size, output_dim]
        
        return output

print(">>> RNNBertClassifier ready (Part 3.3 best model)")

>>> Transformers library available
>>> RNNBertClassifier ready (Part 3.3 best model)


In [16]:
# =========================================================================
# Helper: Topic-wise evaluation
# =========================================================================

def evaluate_per_topic_p35(model, iterator, device, text_vocab=None):
    """Evaluate accuracy per topic on the provided iterator."""
    model.eval()
    topic_correct = defaultdict(int)
    topic_total = defaultdict(int)
    idx_to_label = LABEL.vocab.itos

    with torch.no_grad():
        for batch in iterator:
            text, text_lengths, labels = process_batch(batch, debug=False)
            if text_vocab is not None:
                predictions = model(text, text_lengths, text_vocab=text_vocab)
            else:
                predictions = model(text, text_lengths)
            preds = torch.argmax(predictions, dim=1)
            for pred, label in zip(preds.cpu().numpy(), labels.cpu().numpy()):
                topic_name = idx_to_label[label]
                topic_total[topic_name] += 1
                if pred == label:
                    topic_correct[topic_name] += 1

    topic_metrics = {}
    for topic in sorted(topic_total.keys()):
        total = topic_total[topic]
        correct = topic_correct[topic]
        accuracy = correct / total if total > 0 else 0.0
        topic_metrics[topic] = {
            "accuracy": accuracy,
            "correct": correct,
            "total": total,
        }
    return topic_metrics

In [17]:
# ============================================================================
# Dataset Variants & Utilities for Experiments
# ============================================================================

def describe_dataset(name, dataset):
    counts = Counter(ex.label for ex in dataset.examples)
    total = len(dataset.examples)
    print(f"  - {name}: {total} samples")
    for label, count in sorted(counts.items()):
        print(f"      {label}: {count} ({count/total*100:.2f}%)")
    return counts


# Topic-wise accuracy from latest weighted-sampler evaluation (used to boost weak classes)
P35_TOPIC_ACCURACY = {
    "ABBR": 0.7778,
    "DESC": 0.9855,
    "ENTY": 0.7128,
    "HUM": 0.8769,
    "LOC": 0.8889,
    "NUM": 0.8584,
}
# Convert to difficulty scores (higher when accuracy is lower)
P35_TOPIC_DIFFICULTY = {label: max(0.0, 1.0 - acc) for label, acc in P35_TOPIC_ACCURACY.items()}
# Global multiplier for difficulty adjustment; tweak to emphasise weak topics more/less
P35_DIFFICULTY_SCALE = 2.0


def create_weighted_dataset(source_dataset, target_size=None, seed=SEED, difficulty_scale=P35_DIFFICULTY_SCALE):
    """Mimic WeightedRandomSampler by sampling examples according to class weights, with extra boosts for weak topics."""
    rng = random.Random(seed)
    counts = Counter(ex.label for ex in source_dataset.examples)
    total = sum(counts.values())
    base_class_weights = {label: total / count for label, count in counts.items()}

    class_boosts = {
        label: 1.0 + difficulty_scale * P35_TOPIC_DIFFICULTY.get(label, 0.0)
        for label in counts.keys()
    }

    weights = [base_class_weights[ex.label] * class_boosts.get(ex.label, 1.0) for ex in source_dataset.examples]
    sample_size = target_size or len(source_dataset.examples)

    sampled_examples = rng.choices(source_dataset.examples, weights=weights, k=sample_size)
    fields = [('text', TEXT), ('label', LABEL)]
    return data.Dataset(sampled_examples, fields=fields)


print("\n>>> Preparing dataset variants for Part 3.4 experiments...")
base_counts = describe_dataset("Original train", train_data)
aug_counts = describe_dataset("Augmented train", augmented_train_data)

weighted_train_data = create_weighted_dataset(train_data)
weighted_counts = describe_dataset("Weighted-sampled train", weighted_train_data)

augmented_weighted_train_data = create_weighted_dataset(augmented_train_data)
aug_weighted_counts = describe_dataset("Augmented + weighted train", augmented_weighted_train_data)

p35_datasets = {
    "original": train_data,
    "augmented": augmented_train_data,
    "weighted": weighted_train_data,
    "augmented_weighted": augmented_weighted_train_data,
}

criterion = nn.CrossEntropyLoss()

p35_results = {
    "simple_rnn_baseline": {},
    "rnn_bert": {}
}

print("\n>>> Dataset variants ready. Criterion initialised for upcoming runs.")



>>> Preparing dataset variants for Part 3.4 experiments...
  - Original train: 4362 samples
      ABBR: 69 (1.58%)
      DESC: 930 (21.32%)
      ENTY: 1000 (22.93%)
      HUM: 978 (22.42%)
      LOC: 668 (15.31%)
      NUM: 717 (16.44%)
  - Augmented train: 5980 samples
      ABBR: 900 (15.05%)
      DESC: 930 (15.55%)
      ENTY: 1300 (21.74%)
      HUM: 1200 (20.07%)
      LOC: 800 (13.38%)
      NUM: 850 (14.21%)
  - Weighted-sampled train: 4362 samples
      ABBR: 849 (19.46%)
      DESC: 581 (13.32%)
      ENTY: 862 (19.76%)
      HUM: 704 (16.14%)
      LOC: 670 (15.36%)
      NUM: 696 (15.96%)
  - Augmented + weighted train: 5980 samples
      ABBR: 1057 (17.68%)
      DESC: 813 (13.60%)
      ENTY: 1213 (20.28%)
      HUM: 922 (15.42%)
      LOC: 944 (15.79%)
      NUM: 1031 (17.24%)

>>> Dataset variants ready. Criterion initialised for upcoming runs.


In [18]:
# ============================================================================
# Simple RNN (mean pooling) experiment runner
# ============================================================================

def reset_random_seeds(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def build_iterator(dataset, batch_size, shuffle):
    return data.BucketIterator(
        dataset,
        batch_size=batch_size,
        sort_key=lambda x: len(x.text),
        sort_within_batch=True,
        shuffle=shuffle,
        device=device
    )

RNN_BASE_HYPERPARM = hyperparam_prep()
RNN_BASE_HYPERPARM['HIDDEN_DIM'] *= 2
RNN_BASE_HYPERPARM['N_LAYERS'] = 1
RNN_BASE_HYPERPARM['SAVE_MODEL'] = True


def run_simple_rnn_experiment(dataset_key, description, save_suffix):
    if dataset_key not in p35_datasets:
        raise ValueError(f"Unknown dataset key: {dataset_key}")

    reset_random_seeds(SEED)
    train_dataset = p35_datasets[dataset_key]

    print("\n" + "-"*80)
    print(f"Running Simple RNN (mean pooling) experiment: {description}")
    print("-"*80)

    train_iter = build_iterator(train_dataset, RNN_BASE_HYPERPARM['BATCH_SIZE'], shuffle=True)
    val_iter = build_iterator(validation_data, RNN_BASE_HYPERPARM['BATCH_SIZE'], shuffle=False)
    test_iter = build_iterator(test_data, RNN_BASE_HYPERPARM['BATCH_SIZE'], shuffle=False)

    model = RNN_Classifier_Aggregation(
        vocab_size=len(TEXT.vocab),
        embedding_dim=embedding_dim,
        hidden_dim=RNN_BASE_HYPERPARM['HIDDEN_DIM'],
        output_dim=num_classes,
        n_layers=RNN_BASE_HYPERPARM['N_LAYERS'],
        dropout=RNN_BASE_HYPERPARM['DROPOUT'],
        padding_idx=TEXT.vocab.stoi[TEXT.pad_token],
        pretrained_embeddings=pretrained_embeddings,
        aggregation=RNN_BASE_HYPERPARM['AGGREGATOR'],
    ).to(device)

    criterion = nn.CrossEntropyLoss()
    
    # Select optimizer with best learning rate
    if RNN_BASE_HYPERPARM['OPTIMIZER'] == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=RNN_BASE_HYPERPARM['LEARNING_RATE'],
                                        weight_decay=RNN_BASE_HYPERPARM['L2_LAMBDA'])
    elif RNN_BASE_HYPERPARM['OPTIMIZER'] == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=RNN_BASE_HYPERPARM['LEARNING_RATE'], momentum=0.9,
                                        weight_decay=RNN_BASE_HYPERPARM['L2_LAMBDA'])
    elif RNN_BASE_HYPERPARM['OPTIMIZER'] == 'RMSprop':
        optimizer = optim.RMSprop(model.parameters(), lr=RNN_BASE_HYPERPARM['LEARNING_RATE'],
                                        weight_decay=RNN_BASE_HYPERPARM['L2_LAMBDA'])
    elif RNN_BASE_HYPERPARM['OPTIMIZER'] == 'Adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=RNN_BASE_HYPERPARM['LEARNING_RATE'],
                                        weight_decay=RNN_BASE_HYPERPARM['L2_LAMBDA'])

    model, history = train_model_with_history(
        model,
        train_iter,
        val_iter,
        optimizer,
        criterion,
        RNN_BASE_HYPERPARM['N_EPOCHS'],
        device,
        num_classes,
        RNN_BASE_HYPERPARM['L1_LAMBDA'],
        patience=RNN_BASE_HYPERPARM['PATIENCE'],
        model_name=f"Simple RNN ({description})",
    )

    test_loss, test_acc, test_f1, test_auc = evaluate_model(
        model,
        test_iter,
        criterion,
        device,
        f"Simple RNN ({description})",
        num_classes,
    )

    topic_metrics = evaluate_per_topic_p35(model, test_iter, device)

    model_path = f"weights/part35_simple_rnn_{save_suffix}.pt"
    if RNN_BASE_HYPERPARM['SAVE_MODEL']:
        torch.save(model.state_dict(), model_path)

    p35_results["simple_rnn_baseline"][save_suffix] = {
        "description": description,
        "dataset_key": dataset_key,
        "history": history,
        "test_metrics": {
            "loss": test_loss,
            "accuracy": test_acc,
            "f1": test_f1,
            "auc": test_auc,
        },
        "topic_metrics": topic_metrics,
        "model_path": model_path if RNN_BASE_HYPERPARM['SAVE_MODEL'] else None,
        "model": model,
    }

    return {
        "history": history,
        "topic_metrics": topic_metrics,
        "test_metrics": {
            "loss": test_loss,
            "accuracy": test_acc,
            "f1": test_f1,
            "auc": test_auc,
        },
    }

In [19]:
# ============================================================================
# Run Simple RNN + Attention experiments
# ============================================================================

print("\n>>> Executing Simple RNN experiments...")

rnn_results_text_aug = run_simple_rnn_experiment(
    dataset_key="augmented",
    description="Text Augmentation",
    save_suffix="text_aug",
)

rnn_results_weighted = run_simple_rnn_experiment(
    dataset_key="weighted",
    description="Weighted Sampling",
    save_suffix="weighted_sampler",
)

rnn_results_aug_weighted = run_simple_rnn_experiment(
    dataset_key="augmented_weighted",
    description="Augmentation + Weighted Sampling",
    save_suffix="text_aug_weighted",
)

print("\n>>> Simple RNN experiments queued. Run the cells to execute training if needed.")



>>> Executing Simple RNN experiments...

--------------------------------------------------------------------------------
Running Simple RNN (mean pooling) experiment: Text Augmentation
--------------------------------------------------------------------------------

>>> Training Simple RNN (Text Augmentation)
    Parameters: 2,850,446
    Max epochs: 100, Patience: 10
Epoch: 01/100 | Time: 0m 1s
	Train Loss: 1.5556 | Train Acc: 37.42%
	Val Loss: 1.2733 | Val Acc: 44.04% | Val F1: 0.3522 | Val AUC: 0.8234
Epoch: 02/100 | Time: 0m 0s
	Train Loss: 1.0554 | Train Acc: 59.82%
	Val Loss: 1.0492 | Val Acc: 56.88% | Val F1: 0.5445 | Val AUC: 0.8734
Epoch: 03/100 | Time: 0m 0s
	Train Loss: 0.7579 | Train Acc: 76.40%
	Val Loss: 0.8201 | Val Acc: 71.19% | Val F1: 0.7267 | Val AUC: 0.9248
Epoch: 04/100 | Time: 0m 0s
	Train Loss: 0.5398 | Train Acc: 85.50%
	Val Loss: 0.7163 | Val Acc: 76.24% | Val F1: 0.7732 | Val AUC: 0.9406
Epoch: 05/100 | Time: 0m 0s
	Train Loss: 0.4751 | Train Acc: 87.76%
	Va

In [20]:
# ============================================================================
# Topic-wise accuracy summary for Simple RNN experiments
# ============================================================================

def display_topic_metrics(title, metrics_dict):
    print(f"\n{title}")
    print("-" * len(title))
    header = f"{'Topic':<10} {'Accuracy %':<12} {'Correct':<10} {'Total':<10}"
    print(header)
    print("-" * len(header))
    for topic in sorted(metrics_dict.keys()):
        stats = metrics_dict[topic]
        acc_pct = stats['accuracy'] * 100
        print(f"{topic:<10} {acc_pct:<12.2f} {stats['correct']:<10} {stats['total']:<10}")
    print("-" * len(header))
    
    total_cor = sum([metrics_dict[topic]['correct'] for topic in metrics_dict])
    total_sam = sum([metrics_dict[topic]['total'] for topic in metrics_dict])
    total_acc = total_cor / total_sam 
    print(f"{'Topic':<10} {total_acc:<12.2f} {total_cor:<10} {total_sam:<10}")

print("\n>>> Topic-wise accuracy for Simple RNN variants")
for run_key, info in p35_results["simple_rnn_baseline"].items():
    topic_metrics = info.get("topic_metrics")
    if not topic_metrics:
        continue
    title = f"Simple RNN ({info['description']})"
    display_topic_metrics(title, topic_metrics)


>>> Topic-wise accuracy for Simple RNN variants

Simple RNN (Text Augmentation)
------------------------------
Topic      Accuracy %   Correct    Total     
---------------------------------------------
ABBR       77.78        7          9         
DESC       86.23        119        138       
ENTY       67.02        63         94        
HUM        87.69        57         65        
LOC        87.65        71         81        
NUM        86.73        98         113       
---------------------------------------------
Topic      0.83         415        500       

Simple RNN (Weighted Sampling)
------------------------------
Topic      Accuracy %   Correct    Total     
---------------------------------------------
ABBR       77.78        7          9         
DESC       96.38        133        138       
ENTY       71.28        67         94        
HUM        80.00        52         65        
LOC        90.12        73         81        
NUM        85.84        97         113     

In [21]:
# ============================================================================
# RNN + BERT experiment runner
# ============================================================================
BERT_HYPERPARAM = hyperparam_prep()
BERT_HYPERPARAM['LEARNING_RATE'] = 2e-5
BERT_HYPERPARAM['OTHER_LR'] = BERT_HYPERPARAM['LEARNING_RATE'] * 10
BERT_HYPERPARAM['N_EPOCHS'] //= 2
BERT_HYPERPARAM['MODEL_NAME'] = 'bert-base-uncased'
BERT_HYPERPARAM['FREEZE'] = False
BERT_HYPERPARAM['SAVE_MODEL'] = True


def run_rnn_bert_experiment(dataset_key, description, save_suffix, TEXT, freeze_bert=BERT_HYPERPARAM['FREEZE']):
    if not BERT_AVAILABLE:
        raise RuntimeError("Transformers library is unavailable; cannot run RNN+BERT experiments.")
    if dataset_key not in p35_datasets:
        raise ValueError(f"Unknown dataset key: {dataset_key}")

    reset_random_seeds(SEED)
    train_dataset = p35_datasets[dataset_key]

    print("\n" + "-"*80)
    print(f"Running RNN + BERT experiment: {description}")
    print("-"*80)

    train_iter = build_iterator(train_dataset, BERT_HYPERPARAM['BATCH_SIZE'], shuffle=True)
    val_iter = build_iterator(validation_data, BERT_HYPERPARAM['BATCH_SIZE'], shuffle=False)
    test_iter = build_iterator(test_data, BERT_HYPERPARAM['BATCH_SIZE'], shuffle=False)

    model = RNNBertClassifier(
        output_dim=num_classes,
        hidden_dim=BERT_HYPERPARAM['HIDDEN_DIM'],
        n_layers=BERT_HYPERPARAM['N_LAYERS'],
        dropout=BERT_HYPERPARAM['DROPOUT'],
        bert_model_name=BERT_HYPERPARAM['MODEL_NAME'],
        freeze_bert=freeze_bert,
    ).to(device)

    bert_params = [p for p in model.bert_model.parameters() if p.requires_grad]
    other_params = [p for n, p in model.named_parameters() if 'bert_model' not in n]

    optimizer_grouped_parameters = []
    if bert_params:
        optimizer_grouped_parameters.append({'params': bert_params, 'lr': BERT_HYPERPARAM['LEARNING_RATE']})
    if other_params:
        optimizer_grouped_parameters.append({'params': other_params, 'lr': BERT_HYPERPARAM['OTHER_LR']})

    # Select optimizer with best learning rate
    if BERT_HYPERPARAM['OPTIMIZER'] == 'Adam':
        optimizer = optim.Adam(optimizer_grouped_parameters, weight_decay=BERT_HYPERPARAM['L2_LAMBDA'])
    elif BERT_HYPERPARAM['OPTIMIZER'] == 'SGD':
        optimizer = optim.SGD(optimizer_grouped_parameters, momentum=0.9, weight_decay=BERT_HYPERPARAM['L2_LAMBDA'])
    elif BERT_HYPERPARAM['OPTIMIZER'] == 'RMSprop':
        optimizer = optim.RMSprop(optimizer_grouped_parameters, weight_decay=BERT_HYPERPARAM['L2_LAMBDA'])
    elif BERT_HYPERPARAM['OPTIMIZER'] == 'Adagrad':
        optimizer = optim.Adagrad(optimizer_grouped_parameters, weight_decay=BERT_HYPERPARAM['L2_LAMBDA'])

    model, history = train_model_with_history_bert(
        model,
        train_iter,
        val_iter,
        optimizer,
        criterion,
        BERT_HYPERPARAM['N_EPOCHS'],
        device,
        num_classes,
        BERT_HYPERPARAM['L1_LAMBDA'],
        patience=BERT_HYPERPARAM['PATIENCE'],
        model_name=f"RNN+BERT ({description})",
        text_vocab=TEXT.vocab.itos,
    )

    test_loss, test_acc, test_f1, test_auc = evaluate_model_bert(
        model,
        test_iter,
        criterion,
        device,
        f"RNN+BERT ({description})",
        num_classes,
        text_vocab=TEXT.vocab.itos,
    )

    topic_metrics = evaluate_per_topic_p35(model, test_iter, device, TEXT.vocab.itos)

    model_path = f"weights/part35_rnn_bert_{save_suffix}.pt"
    if BERT_HYPERPARAM['SAVE_MODEL']:
        torch.save(model.state_dict(), model_path)

    p35_results["rnn_bert"][save_suffix] = {
        "description": description,
        "dataset_key": dataset_key,
        "history": history,
        "test_metrics": {
            "loss": test_loss,
            "accuracy": test_acc,
            "f1": test_f1,
            "auc": test_auc,
        },
        "topic_metrics": topic_metrics,
        "model_path": model_path if BERT_HYPERPARAM['SAVE_MODEL'] else None,
        "freeze_bert": freeze_bert,
        "model": model,
    }

    return {
        "history": history,
        "topic_metrics": topic_metrics,
        "test_metrics": {
            "loss": test_loss,
            "accuracy": test_acc,
            "f1": test_f1,
            "auc": test_auc,
        },
    }


In [22]:
# ============================================================================
# Run RNN + BERT experiments
# ============================================================================

if BERT_AVAILABLE:
    print("\n>>> Executing RNN + BERT experiments...")

    bert_results_text_aug = run_rnn_bert_experiment(
        dataset_key="augmented",
        description="Text Augmentation",
        save_suffix="text_aug",
        TEXT=TEXT,
    )

    bert_results_weighted = run_rnn_bert_experiment(
        dataset_key="weighted",
        description="Weighted Sampling",
        save_suffix="weighted_sampler",
        TEXT=TEXT,
    )

    bert_results_aug_weighted = run_rnn_bert_experiment(
        dataset_key="augmented_weighted",
        description="Augmentation + Weighted Sampling",
        save_suffix="text_aug_weighted",
        TEXT=TEXT,
    )

    print("\n>>> RNN + BERT experiments queued. Run the cells to execute training if needed.")
else:
    print("\n>>> Skipping RNN + BERT experiments (transformers library unavailable).")



>>> Executing RNN + BERT experiments...

--------------------------------------------------------------------------------
Running RNN + BERT experiment: Text Augmentation
--------------------------------------------------------------------------------

>>> Training RNN+BERT (Text Augmentation)
    Parameters: 113,295,111
    Max epochs: 50, Patience: 10
Epoch: 01/50 | Time: 0m 19s
	Train Loss: 0.7184 | Train Acc: 74.15%
	Val Loss: 0.4994 | Val Acc: 86.33% | Val F1: 0.8724 | Val AUC: 0.9749
Epoch: 02/50 | Time: 0m 20s
	Train Loss: 0.1680 | Train Acc: 95.30%
	Val Loss: 0.4237 | Val Acc: 89.08% | Val F1: 0.8952 | Val AUC: 0.9839
Epoch: 03/50 | Time: 0m 19s
	Train Loss: 0.1031 | Train Acc: 97.44%
	Val Loss: 0.7240 | Val Acc: 86.88% | Val F1: 0.8804 | Val AUC: 0.9655
Epoch: 04/50 | Time: 0m 19s
	Train Loss: 0.0555 | Train Acc: 98.75%
	Val Loss: 0.7127 | Val Acc: 88.53% | Val F1: 0.8943 | Val AUC: 0.9772
Epoch: 05/50 | Time: 0m 19s
	Train Loss: 0.0393 | Train Acc: 99.11%
	Val Loss: 0.6487 |

In [23]:
# =========================================================================
# Topic-wise accuracy summary for RNN + BERT experiments
# =========================================================================

if p35_results["rnn_bert"]:
    print("\n>>> Topic-wise accuracy for RNN + BERT variants")
    for key, info in p35_results["rnn_bert"].items():
        topic_metrics = info.get("topic_metrics")
        if not topic_metrics:
            continue
        title = f"RNN + BERT ({info['description']})"
        display_topic_metrics(title, topic_metrics)




>>> Topic-wise accuracy for RNN + BERT variants

RNN + BERT (Text Augmentation)
------------------------------
Topic      Accuracy %   Correct    Total     
---------------------------------------------
ABBR       88.89        8          9         
DESC       34.06        47         138       
ENTY       86.17        81         94        
HUM        95.38        62         65        
LOC        98.77        80         81        
NUM        94.69        107        113       
---------------------------------------------
Topic      0.77         385        500       

RNN + BERT (Weighted Sampling)
------------------------------
Topic      Accuracy %   Correct    Total     
---------------------------------------------
ABBR       77.78        7          9         
DESC       83.33        115        138       
ENTY       86.17        81         94        
HUM        92.31        60         65        
LOC        97.53        79         81        
NUM        93.81        106        113     

In [24]:
# ============================================================================
# PART 3.4 SUMMARY
# ============================================================================

print("\n" + "="*80)
print("PART 3.4: TEXT AUGMENTATION VS WEIGHTED SAMPLING")
print("="*80)

print("\n>>> Dataset Variants Used:")
for key in ["original", "augmented", "weighted", "augmented_weighted"]:
    dataset = p35_datasets[key]
    print(f"  - {key}: {len(dataset.examples)} samples")

print("\n>>> Simple RNN (mean pooling) Experiments:")
if p35_results["simple_rnn_baseline"]:
    for key, info in p35_results["simple_rnn_baseline"].items():
        metrics = info.get("test_metrics", {})
        print(f"  {info['description']} [{key}]")
        if metrics:
            print(f"    - Test Accuracy: {metrics.get('accuracy', 0)*100:.2f}%")
            print(f"    - Test F1: {metrics.get('f1', 0):.4f}")
            print(f"    - Test AUC: {metrics.get('auc', 0):.4f}")
            print(f"    - Test Loss: {metrics.get('loss', 0):.4f}")
        topic_metrics = info.get("topic_metrics")
        if topic_metrics:
            weakest = min(topic_metrics.items(), key=lambda kv: kv[1]['accuracy'])
            print(f"    - Weakest Topic: {weakest[0]} ({weakest[1]['accuracy']*100:.2f}%)")
        print(f"    - Model saved to: {info['model_path']}")
else:
    print("  - Pending (run the experiment cells above)")

print("\n>>> RNN + BERT Experiments:")
if p35_results["rnn_bert"]:
    for key, info in p35_results["rnn_bert"].items():
        metrics = info.get("test_metrics", {})
        print(f"  {info['description']} [{key}]")
        if metrics:
            print(f"    - Test Accuracy: {metrics.get('accuracy', 0)*100:.2f}%")
            print(f"    - Test F1: {metrics.get('f1', 0):.4f}")
            print(f"    - Test AUC: {metrics.get('auc', 0):.4f}")
            print(f"    - Test Loss: {metrics.get('loss', 0):.4f}")
        topic_metrics = info.get("topic_metrics")
        if topic_metrics:
            weakest = min(topic_metrics.items(), key=lambda kv: kv[1]['accuracy'])
            print(f"    - Weakest Topic: {weakest[0]} ({weakest[1]['accuracy']*100:.2f}%)")
        print(f"    - Model saved to: {info['model_path']}")
else:
    print("  - Pending (run the experiment cells above)")

print("\n" + "="*80)
print("PART 3.4 SETUP COMPLETE")
print("="*80)



PART 3.4: TEXT AUGMENTATION VS WEIGHTED SAMPLING

>>> Dataset Variants Used:
  - original: 4362 samples
  - augmented: 5980 samples
  - weighted: 4362 samples
  - augmented_weighted: 5980 samples

>>> Simple RNN (mean pooling) Experiments:
  Text Augmentation [text_aug]
    - Test Accuracy: 83.00%
    - Test F1: 0.8382
    - Test AUC: 0.9586
    - Test Loss: 0.6631
    - Weakest Topic: ENTY (67.02%)
    - Model saved to: weights/part35_simple_rnn_text_aug.pt
  Weighted Sampling [weighted_sampler]
    - Test Accuracy: 85.80%
    - Test F1: 0.8570
    - Test AUC: 0.9565
    - Test Loss: 0.5815
    - Weakest Topic: ENTY (71.28%)
    - Model saved to: weights/part35_simple_rnn_weighted_sampler.pt
  Augmentation + Weighted Sampling [text_aug_weighted]
    - Test Accuracy: 68.60%
    - Test F1: 0.7295
    - Test AUC: 0.9464
    - Test Loss: 0.7742
    - Weakest Topic: DESC (42.03%)
    - Model saved to: weights/part35_simple_rnn_text_aug_weighted.pt

>>> RNN + BERT Experiments:
  Text Augme