# BAST Multitask Training Notebook

This notebook recreates the training loop from `train_multitask.py` with all configuration set directly here.


In [None]:
# Import necessary modules
from network.BAST import BAST_Variant, AngularLossWithCartesianCoordinate, MixWithCartesianCoordinate , MSELossWithPolarCoordinate
from data_loading import SpectrogramDataset
import argparse  # Not used but keeping for compatibility
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from datetime import datetime

In [None]:
# Configuration - Set all parameters directly here

# Dataset configuration
CSV_PATH = 'tensor_metadata.csv'  # Replace with your CSV file path

# Model architecture parameters
SPECTROGRAM_SIZE = [64, 18]  # [frequency_bins, time_frames] - cropped to 100ms
PATCH_SIZE = 16
PATCH_OVERLAP = 10
NUM_OUTPUT = 3  # Output dimension for localization: [x, y, elevation_deg]
EMBEDDING_DIM = 1024
TRANSFORMER_DEPTH = 3
TRANSFORMER_HEADS = 16
TRANSFORMER_MLP_DIM = 1024
TRANSFORMER_DIM_HEAD = 64
INPUT_CHANNEL = 1  # Single channel per ear
DROPOUT = 0.2
EMB_DROPOUT = 0.2
TRANSFORMER_POOL = 'conv'

# Training hyperparameters
EPOCHS = 150
BATCH_SIZE = 1500
LEARNING_RATE = 0.0001
TEST_SPLIT = 0.1  # Test split ratio (10% of total dataset)
VAL_SPLIT = 0.2   # Validation split ratio (20% of remaining 90% after test split)
SEED = 42

# Loss weights
LOC_WEIGHT = 1.0      # Localization (azimuth x,y) loss
ELEV_WEIGHT = 0.1     # Elevation loss
CLS_WEIGHT = 1.0      # Classification loss weight
OBJ_WEIGHT = 1.0      # Objectness loss weight

# Model configuration
BACKBONE = 'vanilla'  # Transformer variant: 'vanilla'
BINAURAL_INTEGRATION = 'SUB'  # 'SUB', 'ADD', 'CONCAT'
SHARE_WEIGHTS = False  # Share weights between left/right branches
MAX_SOURCES = 4        # number of detection slots
LOSS_TYPE = 'MIX'      # Localization loss: 'MSE', 'AD', 'MIX'

# GPU configuration
GPU_LIST = [0] if torch.cuda.is_available() else []  # Use GPU 0 if available

# Directory configuration
MODEL_SAVE_DIR = './output/models/'
MODEL_NAME = 'BAST'

# DataLoader configuration
NUM_WORKERS = 4

In [None]:
# Helper function to get localization criterion
def get_localization_criterion(name: str):
    if name == 'MSE':
        return nn.MSELoss()
    if name == 'AD':
        return AngularLossWithCartesianCoordinate()
    if name == 'MIX':
        return MixWithCartesianCoordinate()
    if name == "MSE_POLAR":
        return MSELossWithPolarCoordinate()
    raise ValueError('Unknown localization loss')

# Set random seeds for reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)

# Create output directory
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

In [None]:
# Load and prepare dataset
print(f"[{datetime.now()}] Loading dataset from {CSV_PATH} ...")
dataset = SpectrogramDataset(CSV_PATH)
num_classes = len(dataset.class_to_index)
print(f"[{datetime.now()}] Samples: {len(dataset)} | Classes: {num_classes}")

# Create train/test/validation split
test_size = int(len(dataset) * TEST_SPLIT)
remaining_size = len(dataset) - test_size

# First split off test set
remaining_ds, test_ds = random_split(dataset, [remaining_size, test_size], generator=torch.Generator().manual_seed(SEED))

# Then split remaining into train and validation
val_size = int(remaining_size * VAL_SPLIT)
train_size = remaining_size - val_size
train_ds, val_ds = random_split(remaining_ds, [train_size, val_size], generator=torch.Generator().manual_seed(SEED+1))

# Create data loaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f"Train set: {len(train_ds)} samples")
print(f"Validation set: {len(val_ds)} samples")
print(f"Test set: {len(test_ds)} samples")

In [None]:
# Build the model
print(f"[{datetime.now()}] Building model ...")

net = BAST_Variant(
    image_size=SPECTROGRAM_SIZE,
    patch_size=PATCH_SIZE,
    patch_overlap=PATCH_OVERLAP,
    num_coordinates_output=NUM_OUTPUT,
    dim=EMBEDDING_DIM,
    depth=TRANSFORMER_DEPTH,
    heads=TRANSFORMER_HEADS,
    mlp_dim=TRANSFORMER_MLP_DIM,
    pool=TRANSFORMER_POOL,
    channels=INPUT_CHANNEL,
    dim_head=TRANSFORMER_DIM_HEAD,
    dropout=DROPOUT,
    emb_dropout=EMB_DROPOUT,
    binaural_integration=BINAURAL_INTEGRATION,
    share_params=SHARE_WEIGHTS,
    transformer_variant=BACKBONE,
    max_sources=MAX_SOURCES,
    classify_sound=True,
    num_classes_cls=num_classes,
)

# Setup device and move model to device
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
if use_cuda and GPU_LIST:
    net = nn.DataParallel(net, device_ids=GPU_LIST).to(device)
else:
    net = net.to(device)

print(f"Model built successfully. Using device: {device}")
print(f"Model parameters: {sum(p.numel() for p in net.parameters())}")

We now use a detection-style head that outputs per-source slots: (loc_out [B,K,2], obj_logit [B,K], cls_logit [B,K,C]). Losses are computed after matching GT sources to slots.

In [None]:
# Setup optimizer and loss functions
optimizer = torch.optim.AdamW(net.parameters(), lr=LEARNING_RATE, weight_decay=0.0)
criterion_loc = get_localization_criterion(LOSS_TYPE)
criterion_elev = nn.MSELoss(reduction='none')          # per-slot elevation MSE (degrees)
criterion_obj = nn.BCEWithLogitsLoss(reduction='none')  # per-slot objectness
criterion_cls = nn.BCEWithLogitsLoss(reduction='none')  # multi-label per class per slot

# Create model save name
model_save_name = f"{MODEL_NAME}_{BINAURAL_INTEGRATION}_{LOSS_TYPE}_DET_{'SP' if SHARE_WEIGHTS else 'NSP'}_{BACKBONE}"

print(f"Model will be saved as: {model_save_name}")
print(f"Training for {EPOCHS} epochs with batch size {BATCH_SIZE}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Loss weights - Loc: {LOC_WEIGHT}, Elev: {ELEV_WEIGHT}, Cls: {CLS_WEIGHT}, Obj: {OBJ_WEIGHT}")

In [None]:
# Utilities for greedy matching (per sample)
def greedy_match(pred_xy, gt_xy):
    # pred_xy: [K, 2], gt_xy: [N, 2]
    K = pred_xy.size(0)
    N = gt_xy.size(0)
    if N == 0:
        return [], torch.empty((0,), dtype=torch.long, device=pred_xy.device)
    # pairwise L2 distance
    d = torch.cdist(pred_xy, gt_xy, p=2)  # [K, N]
    matched_pred_idx = []
    matched_gt_idx = []
    d_clone = d.clone()
    while len(matched_pred_idx) < min(K, N):
        # find min remaining
        idx = torch.argmin(d_clone)
        pi = (idx // N).item()
        gi = (idx % N).item()
        matched_pred_idx.append(pi)
        matched_gt_idx.append(gi)
        d_clone[pi, :] = float('inf')
        d_clone[:, gi] = float('inf')
    return torch.tensor(matched_pred_idx, device=pred_xy.device), torch.tensor(matched_gt_idx, device=pred_xy.device)

# Training loop
print(f"[{datetime.now()}] Start training ...")

best_val = float('inf')
train_losses = []
val_losses = []

for epoch in range(EPOCHS):
    # Training phase
    net.train()
    running = 0.0
    n_train = 0
    
    for batch in train_loader:
        specs, loc_xy, cls_idx, az_el_deg = batch
        specs = specs.to(device, non_blocking=True)
        loc_xy = loc_xy.to(device, non_blocking=True)          # [B, 2]
        cls_idx = cls_idx.to(device, non_blocking=True).squeeze(1)  # [B]

        # Forward
        loc_out, obj_logit, cls_logit = net(specs)
        # Shapes: loc_out [B,K,2], obj_logit [B,K], cls_logit [B,K,C]

        B = specs.size(0)
        K = obj_logit.size(1)
        C = cls_logit.size(2)

        # Build targets per sample (single source in CSV; extend to multi by stacking later)
        total_loss = torch.tensor(0., device=device)
        for b in range(B):
            pred_xy_b = loc_out[b][..., :2]       # [K,2] -> azimuth as unit vector
            pred_el_b = loc_out[b][..., 2]        # [K]   -> elevation in degrees
            obj_b = obj_logit[b]                  # [K]
            cls_b = cls_logit[b]                  # [K,C]

            # For now, 1 GT source per sample from CSV
            gt_xy_b = loc_xy[b].unsqueeze(0)      # [1,2] (unit vector from azimuth)
            gt_el_b = az_el_deg[b, 1].unsqueeze(0)  # [1] elevation in degrees
            gt_cls_multi_hot = torch.zeros(C, device=device)
            gt_cls_multi_hot[cls_idx[b]] = 1.0    # [C]

            # Greedy matching on azimuth vectors
            pred_idx, gt_idx = greedy_match(pred_xy_b, gt_xy_b)  # sizes: [min(K,1)]

            # Objectness targets: matched slots -> 1, others -> 0
            obj_target_b = torch.zeros_like(obj_b)
            obj_target_b[pred_idx] = 1.0
            loss_obj_b = criterion_obj(obj_b, obj_target_b).mean()

            # Localization + elevation + class on matched slots only
            if pred_idx.numel() > 0:
                matched_pred_xy = pred_xy_b[pred_idx]        # [M,2]
                matched_gt_xy = gt_xy_b[gt_idx]              # [M,2]
                loss_loc_b = criterion_loc(matched_pred_xy, matched_gt_xy)

                matched_pred_el = pred_el_b[pred_idx]        # [M]
                matched_gt_el = gt_el_b[gt_idx]              # [M]
                loss_el_b = criterion_elev(matched_pred_el, matched_gt_el).mean()

                matched_cls = cls_b[pred_idx]                # [M,C]
                gt_cls_rep = gt_cls_multi_hot.unsqueeze(0).expand(matched_cls.size(0), -1)
                loss_cls_b = criterion_cls(matched_cls, gt_cls_rep).mean()
            else:
                loss_loc_b = torch.tensor(0., device=device)
                loss_el_b = torch.tensor(0., device=device)
                loss_cls_b = torch.tensor(0., device=device)

            total_loss = total_loss + (LOC_WEIGHT * loss_loc_b + ELEV_WEIGHT * loss_el_b + CLS_WEIGHT * loss_cls_b + OBJ_WEIGHT * loss_obj_b)

        loss = total_loss / max(1, B)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        bsz = specs.size(0)
        running += loss.item() * bsz
        n_train += bsz

    avg_tr = running / max(1, n_train)
    train_losses.append(avg_tr)

    # Validation phase
    net.eval()
    running_val = 0.0
    n_val = 0
    
    with torch.no_grad():
        for batch in val_loader:
            specs, loc_xy, cls_idx, az_el_deg = batch
            specs = specs.to(device, non_blocking=True)
            loc_xy = loc_xy.to(device, non_blocking=True)
            cls_idx = cls_idx.to(device, non_blocking=True).squeeze(1)
            az_el_deg = az_el_deg.to(device, non_blocking=True)

            loc_out, obj_logit, cls_logit = net(specs)
            B = specs.size(0)
            total_loss = torch.tensor(0., device=device)
            for b in range(B):
                pred_xy_b = loc_out[b][..., :2]
                pred_el_b = loc_out[b][..., 2]
                obj_b = obj_logit[b]
                cls_b = cls_logit[b]
                gt_xy_b = loc_xy[b].unsqueeze(0)
                gt_el_b = az_el_deg[b, 1].unsqueeze(0)
                gt_cls_multi_hot = torch.zeros(cls_b.size(1), device=device)
                gt_cls_multi_hot[cls_idx[b]] = 1.0
                pred_idx, gt_idx = greedy_match(pred_xy_b, gt_xy_b)
                obj_target_b = torch.zeros_like(obj_b); obj_target_b[pred_idx] = 1.0
                loss_obj_b = criterion_obj(obj_b, obj_target_b).mean()
                if pred_idx.numel() > 0:
                    matched_pred_xy = pred_xy_b[pred_idx]
                    matched_gt_xy = gt_xy_b[gt_idx]
                    loss_loc_b = criterion_loc(matched_pred_xy, matched_gt_xy)
                    matched_pred_el = pred_el_b[pred_idx]
                    matched_gt_el = gt_el_b[gt_idx]
                    loss_el_b = criterion_elev(matched_pred_el, matched_gt_el).mean()
                    matched_cls = cls_b[pred_idx]
                    gt_cls_rep = gt_cls_multi_hot.unsqueeze(0).expand(matched_cls.size(0), -1)
                    loss_cls_b = criterion_cls(matched_cls, gt_cls_rep).mean()
                else:
                    loss_loc_b = torch.tensor(0., device=device)
                    loss_el_b = torch.tensor(0., device=device)
                    loss_cls_b = torch.tensor(0., device=device)
                total_loss = total_loss + (LOC_WEIGHT * loss_loc_b + ELEV_WEIGHT * loss_el_b + CLS_WEIGHT * loss_cls_b + OBJ_WEIGHT * loss_obj_b)
            loss = total_loss / max(1, B)
            bsz = specs.size(0)
            running_val += loss.item() * bsz
            n_val += bsz

    avg_val = running_val / max(1, n_val)
    val_losses.append(avg_val)

    print(f"[{datetime.now()}] Epoch {epoch+1:03d}/{EPOCHS} | train {avg_tr:.4f} | val {avg_val:.4f}")

    # Save best model
    if avg_val < best_val:
        best_val = avg_val
        state = net.module.state_dict() if isinstance(net, nn.DataParallel) else net.state_dict()
        torch.save({
            'epoch': epoch,
            'state_dict': state,
            'best_loss': best_val,
            'log': {'training': train_losses, 'validation': val_losses},
            'conf': {
                'image_size': SPECTROGRAM_SIZE,
                'patch_size': PATCH_SIZE,
                'patch_overlap': PATCH_OVERLAP,
                'num_coordinates_output': NUM_OUTPUT,
                'max_sources': MAX_SOURCES,
            }
        }, os.path.join(MODEL_SAVE_DIR, model_save_name + '_best.pkl'))
        print(f"  -> Saved best model with validation loss: {best_val:.4f}")

    # Save last model
    state = net.module.state_dict() if isinstance(net, nn.DataParallel) else net.state_dict()
    torch.save({
        'epoch': epoch,
        'state_dict': state,
        'best_loss': best_val,
        'log': {'training': train_losses, 'validation': val_losses},
    }, os.path.join(MODEL_SAVE_DIR, model_save_name + '_last.pkl'))

print(f"[{datetime.now()}] Training completed!")
# Evaluate on test set


In [None]:
# Load best checkpoint
best_ckpt_path = os.path.join(MODEL_SAVE_DIR, model_save_name + '_best.pkl')
ckpt = torch.load(best_ckpt_path, map_location=device)
state = ckpt['state_dict']
if isinstance(net, nn.DataParallel):
    net.module.load_state_dict(state)
else:
    net.load_state_dict(state)
print(f"Loaded best checkpoint from: {best_ckpt_path}")

print(f"[{datetime.now()}] Evaluating on test set...")
net.eval()
test_running = 0.0
n_test = 0

test_correct = 0
test_total = 0
az_errors = []
el_errors = []

with torch.no_grad():
    for batch in test_loader:
        specs, loc_xy, cls_idx, az_el_deg = batch
        specs = specs.to(device, non_blocking=True)
        loc_xy = loc_xy.to(device, non_blocking=True)
        cls_idx = cls_idx.to(device, non_blocking=True).squeeze(1)
        az_el_deg = az_el_deg.to(device, non_blocking=True)

        loc_out, obj_logit, cls_logit = net(specs)
        B = specs.size(0)

        # Compute detection loss as in validation (optional during eval)
        total_loss = torch.tensor(0., device=device)
        for b in range(B):
            pred_xy_b = loc_out[b]
            obj_b = obj_logit[b]
            cls_b = cls_logit[b]
            gt_xy_b = loc_xy[b].unsqueeze(0)
            gt_cls_multi_hot = torch.zeros(cls_b.size(1), device=device)
            gt_cls_multi_hot[cls_idx[b]] = 1.0
            pred_idx, gt_idx = greedy_match(pred_xy_b, gt_xy_b)
            obj_target_b = torch.zeros_like(obj_b); obj_target_b[pred_idx] = 1.0
            loss_obj_b = criterion_obj(obj_b, obj_target_b).mean()
            if pred_idx.numel() > 0:
                matched_pred_xy = pred_xy_b[pred_idx]
                matched_gt_xy = gt_xy_b[gt_idx]
                loss_loc_b = criterion_loc(matched_pred_xy, matched_gt_xy)
                matched_cls = cls_b[pred_idx]
                gt_cls_rep = gt_cls_multi_hot.unsqueeze(0).expand(matched_cls.size(0), -1)
                loss_cls_b = criterion_cls(matched_cls, gt_cls_rep).mean()
            else:
                loss_loc_b = torch.tensor(0., device=device)
                loss_cls_b = torch.tensor(0., device=device)
            total_loss = total_loss + (LOC_WEIGHT * loss_loc_b + CLS_WEIGHT * loss_cls_b + OBJ_WEIGHT * loss_obj_b)
        loss = total_loss / max(1, B)

        # Metrics: select best slot by objectness
        obj_prob = torch.sigmoid(obj_logit)  # [B,K]
        best_slot = torch.argmax(obj_prob, dim=1)  # [B]
        batch_idx = torch.arange(B, device=device)

        pred_loc = loc_out[batch_idx, best_slot]           # [B, num_coordinates_output]
        pred_cls_logits = cls_logit[batch_idx, best_slot]  # [B, C]
        pred_cls = torch.argmax(pred_cls_logits, dim=1)

        test_correct += (pred_cls == cls_idx).sum().item()
        test_total += B

        # Azimuth error from unit vector [x,y]
        pred_theta = torch.atan2(pred_loc[:, 1], pred_loc[:, 0])
        pred_az_deg = pred_theta * 180.0 / torch.pi
        true_az_deg = az_el_deg[:, 0]
        az_err = torch.abs(pred_az_deg - true_az_deg)
        az_err = torch.min(az_err, 360 - az_err)
        az_errors.extend(az_err.detach().cpu().numpy())

        # Elevation error (3rd component is elevation in degrees)
        if pred_loc.size(1) >= 3:
            pred_el_deg = pred_loc[:, 2]
            el_err = torch.abs(pred_el_deg - az_el_deg[:, 1])
            el_errors.extend(el_err.detach().cpu().numpy())

        bsz = specs.size(0)
        test_running += loss.item() * bsz
        n_test += bsz

avg_test = test_running / max(1, n_test)
cls_acc = test_correct / max(1, test_total)
mean_az_err = float(np.mean(az_errors)) if len(az_errors) else float('nan')
mean_el_err = float(np.mean(el_errors)) if len(el_errors) else float('nan')

print(f"[{datetime.now()}] Training completed!")
print(f"Best validation loss: {best_val:.4f}")
print(f"Test loss: {avg_test:.4f}")
print(f"Classification accuracy: {cls_acc:.4f}")
print(f"Average azimuth error: {mean_az_err:.2f}°")
if len(el_errors):
    print(f"Average elevation error: {mean_el_err:.2f}°")
else:
    print("Average elevation error: n/a (model does not output elevation)")
print(f"Models saved in: {MODEL_SAVE_DIR}")

In [None]:
# Plot training curves (optional)
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss (Log Scale)')
plt.legend()
plt.yscale('log')
plt.grid(True)

plt.tight_layout()
plt.show()

print(f"Final training loss: {train_losses[-1]:.4f}")
print(f"Final validation loss: {val_losses[-1]:.4f}")