# Training Notebook - Narrative Similarity (SemEval4)

This notebook implements the training of `AspectSupervisedEncoder` model for **narrative similarity** based on triplet learning with aspect attention. The model uses a pre-trained encoder (DistilBERT) with specialized projections to capture narrative similarity between texts.

## 1. Path Setup

The main directory is added to the system path to enable importing modules from the `src` package.

In [1]:
import sys
sys.path.insert(0, '..')

## 2. Import Libraries

All necessary dependencies are imported

In [None]:
import torch 
from torch.utils.data import DataLoader
from src.trainer import Trainer
from src.datasets.dataset import AspectTripletDatasetTrain, AspectTripletDatasetDev, load_aspect_triplets, load_eval_triplets
from src.trainer import Trainer
from src.models.encoder import AspectSupervisedEncoder
import itertools
import io
import re
import copy

## 3. Load Data

The preprocessed data is loaded:
- **Training data**: triplets (anchor, positive, negative) with augmentation based on narrative aspects
- **Eval data**: triplets for model validation

The files are in JSONL format and contain examples already processed for triplet learning.

In [None]:
qwen_train = load_aspect_triplets("../data/processed/qwen_augment_from_partial_dev_large_with_aspects.jsonl")
llama_train = load_aspect_triplets("../data/processed/llama_augment_from_partial_dev_large_with_aspects.jsonl")
past_training_data = load_aspect_triplets("../data/processed/train_from_dev_w_aspect_aug_triplets.jsonl")
eval_data = load_eval_triplets("../data/processed/eval_from_dev_triplets.jsonl")

training_data = qwen_train + past_training_data + llama_train

len(training_data), len(eval_data)

(418, 60)

## 4. Create DataLoaders

The datasets and DataLoaders for training are created:
- **AspectTripletDatasetTrain**: dataset for training with triplets (anchor, positive, negative)
- **AspectTripletDatasetDev**: dataset for validation

In [None]:
train_dataset = AspectTripletDatasetTrain(training_data)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

eval_dataset = AspectTripletDatasetDev(eval_data)
eval_dataloader = DataLoader(eval_dataset, batch_size=8, shuffle=False)

## 5. Model Configuration

The **AspectSupervisedEncoder** is initialized with the following parameters:

| Parameter | Value | Description |
|-----------|-------|-------------|
| `model_name` | distilbert-base-uncased | Pre-trained transformer model as base encoder |
| `projection_dim` | 512 | Dimension of the projected embedding space |
| `aspect_dim` | 128 | Dimension of aspect representations |
| `num_heads` | 4 | Number of attention heads for multi-head attention |
| `dropout` | 0.3 | Dropout probability for regularization |
| `freeze_encoder` | True | Freeze pre-trained encoder weights |
| `use_lora` | False | LoRA (Low-Rank Adaptation) disabled |

**Note**: LoRA parameters (`lora_r=8`, `lora_alpha=16`) are configured but not active when `use_lora=False`.

In [None]:
# Model
original_model = AspectSupervisedEncoder(
    model_name="distilbert-base-uncased", # You can change to other models
    projection_dim=512,
    aspect_dim=128,
    num_heads=4,
    dropout=0.3,
    freeze_encoder=True,
)

# Count params
total = sum(p.numel() for p in original_model.parameters())
trainable = sum(p.numel() for p in original_model.parameters() if p.requires_grad)
print(f"Parameters: {trainable:,} trainable / {total:,} total")

## 6. Model Training

Training is configured and started.
It automatically uses the GPU if available (CUDA), otherwise the CPU is used.

In [None]:
# 1. Define the Hyperparameter Grid
param_grid = {
    'epochs': [5, 10],
    'lr': [5e-4],
    'weight_decay': [0.01, 0.001],
    'triplet_margin': [0.3, 0.4, 0.5],
    'cross_margin': [0.2],
    'gamma': [0.3],       # Aspect loss weight
    'beta': [1.0]         # Cross loss weight
}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def run_grid_search(original_model, param_grid, train_loader, eval_loader):
    keys, values = zip(*param_grid.items())
    combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
    
    best_concat_acc = 0.0
    best_params = {}
    best_metrics = {}
    results = []

    print(f"Starting Grid Search with {len(combinations)} combinations...\n")

    for i, params in enumerate(combinations):
        print(f"--- Run {i+1}/{len(combinations)}: {params} ---")
        
        # [IMPORTANT] Deep Copy the model
        current_model = copy.deepcopy(original_model)
        current_model.to(DEVICE)
        
        # Initialize Trainer
        trainer = Trainer(
            model=current_model, 
            device=DEVICE,
            lr=params['lr'],
            weight_decay=params['weight_decay'],
            triplet_margin=params['triplet_margin'],
            cross_margin=params['cross_margin'],
            gamma=params['gamma'],
            beta=params['beta']
        )

        # Capture Standard Output
        captured_output = io.StringIO()
        original_stdout = sys.stdout
        
        try:
            sys.stdout = captured_output
            trainer.fit(train_loader, eval_loader, epochs=params['epochs'])
        finally:
            sys.stdout = original_stdout

        output_str = captured_output.getvalue()
        
        # 3. Parse Output for ALL Metrics
        # Dictionary mapping your internal keys to the Regex patterns for the printed output
        metric_patterns = {
            "Story": r"Story Embedding Only:\s*(\d+\.\d+)",
            "Theme": r"Theme Aspect Only:\s*(\d+\.\d+)",
            "Action": r"Action Aspect Only:\s*(\d+\.\d+)",
            "Outcome": r"Outcome Aspect Only:\s*(\d+\.\d+)",
            "Concatenated": r"Concatenated \(Story \+ Aspects\):\s*(\d+\.\d+)"
        }
        
        current_metrics = {}
        for key, pattern in metric_patterns.items():
            match = re.search(pattern, output_str)
            if match:
                current_metrics[key] = float(match.group(1))
            else:
                current_metrics[key] = 0.0  # Default if parsing fails
        
        # Check if we successfully parsed the main metric
        if current_metrics["Concatenated"] > 0:
            print(f"Run {i+1} Result: {current_metrics}")
            
            # Save params and full metrics dict to results
            results.append((params, current_metrics))
            
            # Update Best Model (optimizing for Concatenated accuracy)
            if current_metrics["Concatenated"] > best_concat_acc:
                best_concat_acc = current_metrics["Concatenated"]
                best_params = params
                best_metrics = current_metrics
        else:
            print(f"Run {i+1} Warning: Could not parse metrics. Check output format.")

    return best_params, best_metrics, results

# Execute Grid Search
# 'original_model' must be the variable name of your UNTRAINED model instance
best_params, best_metrics, all_results = run_grid_search(original_model, param_grid, train_dataloader, eval_dataloader)

print("\n" + "="*50)
print(f"TOP PERFORMING CONFIGURATION")
print("="*50)
print(f"Best Parameters: {best_params}")
print(f"Best Metrics:    {best_metrics}")
print("="*50)