In [None]:
%pip install scikit-learn transformers datasets pandas

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split

import numpy as np 
import pandas as pd
import math

from Dependencies.Early_Stop import EarlyStopping
from Dependencies.AdditionalFunctions import topK_one_hot, smooth_multi_hot
from Dependencies.MovieDataset import MovieGenresDataset, MovieOverviewDataset, collate_fn, PAD_VALUE, EPOCH_NUMBER
from Dependencies.RNN_model_class import RNN

  from .autonotebook import tqdm as notebook_tqdm


### Initialize Model and Device

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda:0


### Initialize the Dataset

In [3]:
mgd_ds = MovieGenresDataset()
movie_genre_ds = mgd_ds.getDs()
movie_id_loc = mgd_ds.get_classes()


In [4]:
def create_smoothed_list(target_batch, class_lst):
    """
    Create smoothed multi-hot targets from batch of genre indices.
    Properly handles padding (values == PAD_VALUE).
    
    Args:
        target_batch: Tensor of shape (batch_size, max_num_genres) with genre indices
        class_lst: List to append results to
        PAD_VALUE: Value used for padding (default: -1)
    
    Returns:
        class_lst with smoothed targets appended
    """
    for idx, target in enumerate(target_batch):
        # Filter out padding values
        valid_targets = target[target != PAD_VALUE]
        
        # ✅ CRITICAL FIX: Handle all-padding case (empty valid_targets)
        if len(valid_targets) == 0:
            # This sample has no valid genre labels (all padding)
            # Use uniform distribution as fallback
            print(f"⚠️ Sample {idx} in batch has no valid labels (all padding), using uniform distribution")
            smoothed_target = torch.ones(19, dtype=torch.float32) / 19
            class_lst.append(smoothed_target)
            continue
        
        # Convert to CPU numpy for one-hot encoding
        valid_targets_np = valid_targets.cpu().numpy()
        
        # ✅ Additional validation: check for invalid indices
        if (valid_targets_np < 0).any() or (valid_targets_np >= 19).any():
            print(f"⚠️ Sample {idx} has invalid genre indices: {valid_targets_np}")
            smoothed_target = torch.ones(19, dtype=torch.float32) / 19
            class_lst.append(smoothed_target)
            continue
        
        # Create one-hot encoding
        one_hot_target = topK_one_hot(valid_targets_np.tolist(), 19)
        
        # Apply smoothing (this now handles edge cases internally)
        smoothed_target = smooth_multi_hot(
            torch.tensor(one_hot_target, dtype=torch.float32), 
            num_valid_labels=len(valid_targets)
        )
        
        # ✅ Final validation (belt and suspenders approach)
        if torch.isnan(smoothed_target).any() or torch.isinf(smoothed_target).any():
            print(f"⚠️ Sample {idx} produced NaN/Inf after smoothing, using uniform distribution")
            smoothed_target = torch.ones(19, dtype=torch.float32) / 19
        
        class_lst.append(smoothed_target)
    
    return class_lst




def diagnose_batch(movie_ovw_batch, y_hat, classes, batch_idx):
    """
    Comprehensive diagnostics to find the root cause
    """
    print(f"\n{'='*60}")
    print(f"DIAGNOSTICS FOR BATCH {batch_idx}")
    print(f"{'='*60}")
    
    # 1. Input statistics
    print("\n1. INPUT EMBEDDINGS:")
    print(f"   Shape: {movie_ovw_batch.shape}")
    print(f"   Mean: {movie_ovw_batch.mean().item():.6f}")
    print(f"   Std: {movie_ovw_batch.std().item():.6f}")
    print(f"   Min: {movie_ovw_batch.min().item():.6f}")
    print(f"   Max: {movie_ovw_batch.max().item():.6f}")
    print(f"   Contains NaN: {torch.isnan(movie_ovw_batch).any().item()}")
    print(f"   Contains Inf: {torch.isinf(movie_ovw_batch).any().item()}")
    
    # Check for zero variance (dead features)
    if movie_ovw_batch.std().item() < 1e-6:
        print(f"   ⚠️ WARNING: Input has very low variance (nearly constant)")
    
    # 2. Model output (logits) statistics
    print("\n2. MODEL OUTPUT (LOGITS):")
    print(f"   Shape: {y_hat.shape}")
    print(f"   Mean: {y_hat.mean().item():.6f}")
    print(f"   Std: {y_hat.std().item():.6f}")
    print(f"   Min: {y_hat.min().item():.6f}")
    print(f"   Max: {y_hat.max().item():.6f}")
    print(f"   Contains NaN: {torch.isnan(y_hat).any().item()}")
    print(f"   Contains Inf: {torch.isinf(y_hat).any().item()}")
    
    # Check for extreme logits
    extreme_negative = (y_hat < -50).sum().item()
    extreme_positive = (y_hat > 50).sum().item()
    if extreme_negative > 0:
        print(f"   ⚠️ WARNING: {extreme_negative} logits < -50 (will cause underflow)")
    if extreme_positive > 0:
        print(f"   ⚠️ WARNING: {extreme_positive} logits > 50 (will cause overflow)")
    
    # 3. Target statistics
    print("\n3. TARGETS:")
    print(f"   Shape: {classes.shape}")
    print(f"   Mean: {classes.mean().item():.6f}")
    print(f"   Std: {classes.std().item():.6f}")
    print(f"   Min: {classes.min().item():.6f}")
    print(f"   Max: {classes.max().item():.6f}")
    print(f"   Contains NaN: {torch.isnan(classes).any().item()}")
    print(classes)
    
    # 4. Loss computation simulation
    print("\n4. LOSS COMPUTATION CHECK:")
    with torch.no_grad():
        # Manually compute BCE with logits to see where it fails
        sigmoid_output = torch.sigmoid(y_hat)
        print(f"   Sigmoid output range: [{sigmoid_output.min().item():.6f}, {sigmoid_output.max().item():.6f}]")
        
        # Check for numerical issues in sigmoid
        zeros_in_sigmoid = (sigmoid_output == 0).sum().item()
        ones_in_sigmoid = (sigmoid_output == 1).sum().item()
        if zeros_in_sigmoid > 0:
            print(f"   ⚠️ WARNING: {zeros_in_sigmoid} sigmoid outputs exactly 0 (underflow)")
        if ones_in_sigmoid > 0:
            print(f"   ⚠️ WARNING: {ones_in_sigmoid} sigmoid outputs exactly 1 (overflow)")
        
        # Simulate BCE computation
        max_val = torch.clamp(y_hat, min=0)
        loss_part1 = (1 - classes) * y_hat
        loss_part2 = max_val
        loss_part3 = torch.log(torch.exp(-max_val) + torch.exp(y_hat - max_val))
        
        print(f"   Loss part 1 (negative term) range: [{loss_part1.min():.6f}, {loss_part1.max():.6f}]")
        print(f"   Loss part 3 (log term) contains NaN: {torch.isnan(loss_part3).any().item()}")
        
    print(f"{'='*60}\n")


### **Training Functions**

In [5]:
def epoch_train(rnn, optimizer, dev, train_loader, val_loader, batch_size, ecpoh_num):
    rnn.train()
    loss_arr = []
    l1_grad_sq = []
    l2_grad_sq = []

    i = 0
    continue_run = True
    enum_train = enumerate(train_loader)
    train_size = len(train_loader) - len(train_loader) % batch_size

    while i < train_size and continue_run:
        try:
            # ✅ NOW UNPACKING 3 VALUES: inputs, targets, and sequence lengths
            i, (movie_ovw_batch, target_batch, seq_lengths) = next(enum_train)
        except StopIteration:
            break

        # Move batches to device
        movie_ovw_batch = movie_ovw_batch.to(dev)
        target_batch = target_batch.to(dev)
        seq_lengths = seq_lengths.to(dev)

        # Prepare targets

        #Delete if statement later - checks if dataset contains nan/inf
        classes_list = []
        if torch.isnan(target_batch).any() or torch.isinf(target_batch).any():
            print(f"\n{'!'*60}")
            print(f"NaN/Inf VALUE DETECTED AT TARGET BATCH {i}")
            print(f"{'!'*60}")
            print(f"Target batch: {target_batch}")
            

            continue_run = False
            break

        classes_list = create_smoothed_list(target_batch, classes_list)
        classes = torch.stack(classes_list).to(dev)

        # ✅ Forward Pass WITH sequence lengths
        y_hat = rnn.forward(movie_ovw_batch, seq_lengths)

        # Loss Calculation
        loss_func = nn.BCEWithLogitsLoss()
        loss = loss_func(y_hat, classes)

        # Check for NaN loss
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"\n{'!'*60}")
            print(f"NaN/Inf LOSS DETECTED AT BATCH {i}")
            print(f"{'!'*60}")
            print(f"Sequence lengths: {seq_lengths}")
            print(f"Input shape: {movie_ovw_batch.shape}")
            print(f"Logits range: [{y_hat.min():.4f}, {y_hat.max():.4f}]")
            print(f"Target batch: {target_batch}")
            diagnose_batch(movie_ovw_batch,y_hat,classes,i)

            continue_run = False
            break

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=1.0)

        # Check gradients
        if rnn.rnnL1.weight_hh.grad is not None:
            grad_norm_l1 = rnn.rnnL1.weight_hh.grad.norm().item() ** 2
            grad_norm_l2 = rnn.rnnL2.weight_hh.grad.norm().item() ** 2
            
            l1_grad_sq.append(grad_norm_l1)
            l2_grad_sq.append(grad_norm_l2)

        # Optimizer step
        optimizer.step()
        loss_arr.append(loss.item())

        # Print progress
        if (i + 1) % 10 == 0:
            avg_seq_len = seq_lengths.float().mean().item()
            grad_info = f"L2 grad²={grad_norm_l2:.6f}" if rnn.rnnL1.weight_hh.grad is not None else "No grads"
            print(f"Epoch {ecpoh_num+1} | Batch {i+1}/{len(train_loader)} | "
                  f"Loss={loss.item():.6f} | Avg seq len={avg_seq_len:.1f} | {grad_info}")

    print("\nEpoch finished.")
    
    # Save tracking data
    if len(loss_arr) > 0:
        df = pd.DataFrame({
            'l1_gradient_sq': l1_grad_sq,
            'l2_gradient_sq': l2_grad_sq,
            'loss_arr': loss_arr
        })
        df.to_csv(f"model_track_epoch_{ecpoh_num}.csv", index=False, header=True)
    
    return continue_run


### **Dataset and DataLoader Setup**

In [6]:
# IMPORTANT: This cell pre-processes all movie overviews into embeddings.
# This should be run only ONCE to create the 'overview_embs.pt' file.
# Running this every time would be very slow.
# It saves the embeddings to the CPU to avoid taking up GPU memory.

import os

embedding_file = "overview_embs.pt"

if not os.path.exists(embedding_file):
    print("Embedding file not found. Creating embeddings...")
    overview_ds = []
    # Use a temporary model on the correct device for tokenization
    temp_model = RNN().to(device)
    for i, overview in enumerate(movie_genre_ds["overview"]):
        # We move the embeddings to the CPU before storing them in the list
        tokenized_ovw = temp_model.tokenize_input(overview, device=device).cpu()
        overview_ds.append(tokenized_ovw)
        if (i+1) % 100 == 0:
            print(f"Processed {i+1}/{len(movie_genre_ds['overview'])} overviews")
    
    torch.save(overview_ds, embedding_file)
    print(f"Saved embeddings to {embedding_file}")
    del temp_model # Free up memory
else:
    print(f"Loading embeddings from {embedding_file}")

# Load the pre-computed embeddings
tokenized_overview_tensors = torch.load(embedding_file)

Loading embeddings from overview_embs.pt


### **Train RNN**

In [7]:
if __name__ == "__main__":
    BATCH_SIZE = 8
    
    my_rnn = RNN().to(device)
    optimizer = optim.Adam(params=my_rnn.parameters(), lr=5e-6, weight_decay=1e-2)
    
    full_dataset = MovieOverviewDataset(tokenized_overview_tensors, movie_id_loc)
    
    # Split dataset
    train_size = int(0.8 * len(full_dataset))
    val_size = int((len(full_dataset) - train_size) / 2) 
    test_size = len(full_dataset) - train_size - val_size
    
    print(f"Dataset sizes - Train: {train_size}, Val: {val_size}, Test: {test_size}")
    train_ds, test_ds, val_ds = random_split(full_dataset, [train_size, test_size, val_size])
    
    # ✅ Use the NEW collate_fn that returns sequence lengths
    train_loader = DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True, 
                             num_workers=0, collate_fn=collate_fn)
    test_loader = DataLoader(dataset=test_ds, batch_size=1, shuffle=True, 
                            num_workers=0, collate_fn=collate_fn)
    val_loader = DataLoader(dataset=val_ds, batch_size=BATCH_SIZE, shuffle=True, 
                           num_workers=0, collate_fn=collate_fn)
    
    print("Starting training...")
    for epoch_iter in range(EPOCH_NUMBER):
        print(f"\n{'='*70}\nEPOCH {epoch_iter + 1}\n{'='*70}")
        continue_run = epoch_train(my_rnn, optimizer=optimizer, dev=device, 
                                   train_loader=train_loader, val_loader=val_loader, 
                                   batch_size=BATCH_SIZE, ecpoh_num=epoch_iter)
        
        if not continue_run:
            print(f"Training stopped at epoch {epoch_iter + 1}")
            break
    
    if continue_run:
        print("\nTraining complete. Saving model...")
        torch.save(my_rnn.state_dict(), "model_parameters.pt")


Dataset sizes - Train: 7984, Val: 998, Test: 998
Starting training...

EPOCH 1
Epoch 1 | Batch 10/998 | Loss=0.394375 | Avg seq len=72.5 | L2 grad²=0.012403
Epoch 1 | Batch 20/998 | Loss=0.223173 | Avg seq len=77.1 | L2 grad²=0.013041
Epoch 1 | Batch 30/998 | Loss=0.184688 | Avg seq len=84.8 | L2 grad²=0.013573
Epoch 1 | Batch 40/998 | Loss=0.195886 | Avg seq len=97.4 | L2 grad²=0.014322
Epoch 1 | Batch 50/998 | Loss=0.199011 | Avg seq len=64.2 | L2 grad²=0.014436
Epoch 1 | Batch 60/998 | Loss=0.198356 | Avg seq len=86.0 | L2 grad²=0.018499
Epoch 1 | Batch 70/998 | Loss=0.185009 | Avg seq len=52.9 | L2 grad²=0.015499
Epoch 1 | Batch 80/998 | Loss=0.220273 | Avg seq len=68.9 | L2 grad²=0.014181
Epoch 1 | Batch 90/998 | Loss=0.192878 | Avg seq len=58.2 | L2 grad²=0.019997
⚠️ Sample 2 in batch has no valid labels (all padding), using uniform distribution
Epoch 1 | Batch 100/998 | Loss=0.206879 | Avg seq len=57.5 | L2 grad²=0.016415
Epoch 1 | Batch 110/998 | Loss=0.218741 | Avg seq len=72.