In [None]:
import torch
from torch import nn
from src.models import AffectModel, masked_mse_loss
from src.data import setup_dataloader

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Config
TRAIN_PATH = "data/TRAIN_RELEASE_3SEP2025/train_subtask1.csv"
TOKENIZER_PATH = "bert-base-uncased"

MODEL_CONFIG = {
    # Encoder
    'model_path': 'bert-base-uncased',
    # Set Attention
    'n_seeds': 4,
    'n_inducing': 32,
    'n_heads': 8,
    # LSTM
    'lstm_hidden': 256,
    'lstm_layers': 2,
    'bidirectional': True,
    # Head
    'constrain_output': True,
    # Shared
    'dropout': 0.3,
    # Debug
    'verbose': True,
}

DATA_CONFIG = {
    'csv_path': TRAIN_PATH,
    'tokenizer_path': TOKENIZER_PATH,
    'max_text_length': 512,
    'batch_size': 2,
    'shuffle': True,
    'num_workers': 0,
}

In [None]:
# Setup
train_loader, train_dataset = setup_dataloader(**DATA_CONFIG)
model = AffectModel(**MODEL_CONFIG)

print(f"\n{'='*50}")
print(f"Dataset size: {len(train_dataset)} users")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"{'='*50}")

In [None]:
# Test forward pass
model.eval()

with torch.no_grad():
    for batch in train_loader:
        predictions = model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            seq_lengths=batch['seq_lengths'],
            seq_mask=batch['seq_attention_mask']
        )
        
        targets = torch.stack([batch['valences'], batch['arousals']], dim=-1)
        mask = batch['seq_attention_mask'].bool()
        loss = masked_mse_loss(predictions, targets, mask)
        
        break

print(f"\n{'='*50}")
print(f"RESULTS")
print(f"{'='*50}")
print(f"Input shape:     {batch['input_ids'].shape}")
print(f"Predictions:     {predictions.shape}")
print(f"Targets:         {targets.shape}")
print(f"Loss:            {loss.item():.4f}")

In [None]:
# Inspect ranges
print(f"Target Ranges:")
print(f"  Valence: [{batch['valences'].min():.2f}, {batch['valences'].max():.2f}]")
print(f"  Arousal: [{batch['arousals'].min():.2f}, {batch['arousals'].max():.2f}]")

print(f"\nPrediction Ranges:")
print(f"  Valence: [{predictions[..., 0].min():.2f}, {predictions[..., 0].max():.2f}]")
print(f"  Arousal: [{predictions[..., 1].min():.2f}, {predictions[..., 1].max():.2f}]")

print(f"\nMask Stats:")
print(f"  Valid timesteps: {mask.sum()} / {mask.numel()} ({100*mask.sum()/mask.numel():.1f}%)")
print(f"  Seq lengths: {batch['seq_lengths'].tolist()}")

In [None]:
# Inspect one sample
sample_idx = 0
seq_len = batch['seq_lengths'][sample_idx].item()

print(f"Sample {sample_idx} (user: {batch['user_ids'][sample_idx]})")
print(f"  Sequence length: {seq_len} documents")
print(f"\n  Predictions vs Targets (first 5 docs):")
print(f"  {'Doc':<5} {'Pred V':>8} {'True V':>8} {'Pred A':>8} {'True A':>8}")
print(f"  {'-'*41}")

for i in range(min(5, seq_len)):
    pred_v = predictions[sample_idx, i, 0].item()
    pred_a = predictions[sample_idx, i, 1].item()
    true_v = targets[sample_idx, i, 0].item()
    true_a = targets[sample_idx, i, 1].item()
    print(f"  {i:<5} {pred_v:>8.3f} {true_v:>8.3f} {pred_a:>8.3f} {true_a:>8.3f}")