In [1]:
import torch
from torch import nn
from src.models import AffectModel, masked_mse_loss
from src.data import setup_dataloader

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Config
TRAIN_PATH = "data/TRAIN_RELEASE_3SEP2025/train_subtask1.csv"
TOKENIZER_PATH = "bert-base-uncased"

MODEL_CONFIG = {
    # Encoder
    'model_path': 'bert-base-uncased',
    'n_groups': 4,
    'grouped_mode': 'attention',
    'pooling': 'mean',
    'freeze_backbone': True,
    'conv_kernel_size': 3,
    # LSTM
    'lstm_hidden': 256,
    'lstm_layers': 2,
    'bidirectional': True,
    # Head
    'constrain_arousal': True,
    # Shared
    'dropout': 0.3,
}

DATA_CONFIG = {
    'csv_path': TRAIN_PATH,
    'tokenizer_path': TOKENIZER_PATH,
    'max_text_length': 512,
    'batch_size': 4,
    'shuffle': True,
    'num_workers': 0,
}

In [None]:
# Setup
train_loader, train_dataset = setup_dataloader(**DATA_CONFIG)
model = AffectModel(**MODEL_CONFIG)

print(f"Dataset size: {len(train_dataset)} users")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Dataset size: 137 users
Model parameters: 119,850,246
Trainable parameters: 10,368,006


In [None]:
# Test forward pass
for batch in train_loader:
    predictions = model(
        input_ids=batch['input_ids'],
        attention_mask=batch['attention_mask'],
        seq_lengths=batch['seq_lengths'],
        seq_mask=batch['seq_attention_mask']
    )
    
    targets = torch.stack([batch['valences'], batch['arousals']], dim=-1)
    mask = batch['seq_attention_mask'].bool()
    loss = masked_mse_loss(predictions, targets, mask)
    
    print(f"Input shape: {batch['input_ids'].shape}")
    print(f"Predictions: {predictions.shape}")
    print(f"Targets: {targets.shape}")
    print(f"Loss: {loss.item():.4f}")
    break

Input shape: torch.Size([4, 8, 512])
Predictions: torch.Size([4, 8, 2])
Targets: torch.Size([4, 8, 2])
Loss: 1.3923


In [6]:
print(f"Valence range: [{batch['valences'].min():.2f}, {batch['valences'].max():.2f}]")
print(f"Arousal range: [{batch['arousals'].min():.2f}, {batch['arousals'].max():.2f}]")
print(f"Prediction range: [{predictions.min():.2f}, {predictions.max():.2f}]")
print(f"Valid timesteps: {mask.sum()} / {mask.numel()} ({100*mask.sum()/mask.numel():.1f}%)")

Valence range: [-2.00, 2.00]
Arousal range: [0.00, 2.00]
Prediction range: [-0.08, 1.02]
Valid timesteps: 22 / 32 (68.8%)
