In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
!mkdir 'timit_dataset'
!unzip '/content/drive/MyDrive/sps/preprocessed_timit_dataset.zip' -d timit_dataset

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: timit_dataset/TRAIN/DR3/MMSM0/SX206.WAV  
  inflating: timit_dataset/TRAIN/DR3/MMSM0/SX26.WAV  
  inflating: timit_dataset/TRAIN/DR3/MMSM0/SI476.WAV  
  inflating: timit_dataset/TRAIN/DR3/MMSM0/SI1106.WAV  
  inflating: timit_dataset/TRAIN/DR3/MMSM0/SX386.WAV  
  inflating: timit_dataset/TRAIN/DR3/MMSM0/SA1.WAV  
  inflating: timit_dataset/TRAIN/DR3/MMSM0/SI1736.WAV  
  inflating: timit_dataset/TRAIN/DR3/MMSM0/SA2.WAV  
  inflating: timit_dataset/TRAIN/DR3/MMSM0/SX116.WAV  
  inflating: timit_dataset/TRAIN/DR3/FEME0/SX335.WAV  
  inflating: timit_dataset/TRAIN/DR3/FEME0/SX65.WAV  
  inflating: timit_dataset/TRAIN/DR3/FEME0/SI2135.WAV  
  inflating: timit_dataset/TRAIN/DR3/FEME0/SX155.WAV  
  inflating: timit_dataset/TRAIN/DR3/FEME0/SI1505.WAV  
  inflating: timit_dataset/TRAIN/DR3/FEME0/SA1.WAV  
  inflating: timit_dataset/TRAIN/DR3/FEME0/SX245.WAV  
  inflating: timit_dataset/TRAIN/DR3/FEME0/SA2.WAV  
  infl

In [3]:
import os
import random
import numpy as np
import pandas as pd
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, accuracy_score
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
import copy
import math

In [4]:
AUDIO_ROOT_DIR = '/content/timit_dataset'
DRIVE_PATH = '/content/drive/MyDrive/sps/'

TRAIN_CSV_PATH = os.path.join(DRIVE_PATH, 'final_train_data_merged.csv')
TEST_CSV_PATH = os.path.join(DRIVE_PATH, 'final_test_data_merged.csv')
CKPT_PATH = os.path.join(DRIVE_PATH, 'sps_transformer_best_model.pth')

# Audio Processing
SAMPLE_RATE = 16000
CLIP_SECONDS = 4.0
WAV_LEN = int(SAMPLE_RATE * CLIP_SECONDS)

# Audio Augmentations
APPLY_AUGMENTATIONS = True
AUGMENTATION_PROBABILITY = 0.2
ADD_NOISE_PROBABILITY = 0.2
NOISE_SNR_MIN = 5.0
NOISE_SNR_MAX = 20.0

# Wav2Vec2.0 Base Model
PRETRAINED_W2V2 = 'facebook/wav2vec2-base-960h'
W2V2_OUTPUT_DIM = 768
FREEZE_ENCODER = True

# --- Transformer Encoder Specific Parameters ---
TRANSFORMER_D_MODEL = W2V2_OUTPUT_DIM
TRANSFORMER_NHEAD = 8
TRANSFORMER_NUM_ENCODER_LAYERS = 3
TRANSFORMER_DIM_FEEDFORWARD = 2048
TRANSFORMER_DROPOUT = 0.1

# Shared Dropout (for heads)
MODEL_DROPOUT_RATE = 0.3

# Training Hyperparameters
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 50
BATCH_SIZE = 32
LEARNING_RATE = 3e-5
OPTIMIZER_WEIGHT_DECAY = 0.01
GRADIENT_CLIP_MAX_NORM = 1.0

# Validation
VAL_SPLIT_RATIO = 0.25
VAL_SPLIT_SEED = 42

# Task Configuration
TASKS = {
    'age': {'type': 'regression', 'loss_weight': 0.4},
    'Gender': {'type': 'classification', 'loss_weight': 0.4},
    'height': {'type': 'regression', 'loss_weight': 0.2}
}

# Head-Specific Hyperparameters
HEAD_CONFIGS = {
    'age': {'head_hidden_dim': 128, 'head_dropout_rate': 0.3},
    'Gender': {'head_hidden_dim': 96, 'head_dropout_rate': 0.25},
    'height': {'head_hidden_dim': 64, 'head_dropout_rate': 0.2}
}

# Mappings and Normalization Stats
GENDER_MAP = {}
NORM_STATS = { 'age': {'mean': 0.0, 'std': 1.0}, 'height': {'mean': 0.0, 'std': 1.0}}

# DataLoader
NUM_WORKERS = 4
PIN_MEMORY = True if DEVICE == 'cuda' else False

# Model Saving, Early Stopping, LR Scheduling
MONITOR_METRIC_MODE = 'min'
EARLY_STOPPING_PATIENCE = 10
LR_SCHEDULER_PATIENCE = 5
LR_SCHEDULER_FACTOR = 0.1
MIN_LR = 1e-6

EVAL_LOSS_WEIGHTS = { 'age': 0.4, 'Gender': 0.4, 'height': 0.2 }

In [None]:
class PadCrop:
    def __init__(self, length: int, mode: str = 'train'):
        self.length = length
        self.mode = mode

    def __call__(self, wav: torch.Tensor):
        current_len = wav.shape[-1]
        if current_len == self.length:
            return wav
        elif current_len > self.length:
            start = random.randint(
                0, current_len - self.length) if self.mode == 'train' else (current_len - self.length) // 2
            wav = wav[..., start: start + self.length]
        else:
            pad_width = self.length - current_len
            pad_left = pad_width // 2
            pad_right = pad_width - pad_left
            wav = F.pad(wav, (pad_left, pad_right), mode='constant', value=0.0)
        return wav


class AddGaussianNoise:
    def __init__(self, min_snr_db: float = 5.0, max_snr_db: float = 20.0, p: float = 0.5):
        self.min_snr_db = min_snr_db
        self.max_snr_db = max_snr_db
        self.p = p

    def __call__(self, waveform: torch.Tensor):
        if random.random() < self.p:
            if waveform.ndim > 1 and waveform.shape[0] > 1:
                wav_data_for_power_calc = waveform[0, :]
            elif waveform.ndim > 1 and waveform.shape[0] == 1:
                wav_data_for_power_calc = waveform.squeeze(0)
            elif waveform.ndim == 1:
                wav_data_for_power_calc = waveform
            else:
                return waveform
            signal_power = torch.mean(wav_data_for_power_calc**2)
            if signal_power.item() < 1e-9:
                return waveform
            snr_db = random.uniform(self.min_snr_db, self.max_snr_db)
            snr_linear = 10**(snr_db / 10.0)
            noise_power = signal_power / snr_linear
            noise = torch.randn_like(waveform) * torch.sqrt(noise_power)
            return waveform + noise
        return waveform

In [6]:
class TimitDataset(Dataset):
    def __init__(self, data_df: pd.DataFrame, mode: str = 'train'):
        self.data_df = data_df.reset_index(drop=True)
        self.mode = mode
        self.pad_crop = PadCrop(WAV_LEN, mode)
        self.target_cols = list(TASKS.keys())
        self.add_noise_transform = None
        if self.mode == 'train' and APPLY_AUGMENTATIONS:
            self.add_noise_transform = AddGaussianNoise(NOISE_SNR_MIN, NOISE_SNR_MAX, ADD_NOISE_PROBABILITY)

    def __len__(self) -> int: return len(self.data_df)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, dict] | None:
        if idx >= len(self.data_df):
          print(f"Index {idx} out of bounds...")
          return None
        row = self.data_df.iloc[idx]
        wav_relative_path = row['FilePath']
        full_wav_path = os.path.normpath(os.path.join(AUDIO_ROOT_DIR, wav_relative_path))
        wav, sr = torchaudio.load(full_wav_path)
        if sr != SAMPLE_RATE:
           wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
        if wav.shape[0] > 1:
          wav = torch.mean(wav, dim=0, keepdim=True)
        if self.mode == 'train' and APPLY_AUGMENTATIONS and random.random() < AUGMENTATION_PROBABILITY:
            if self.add_noise_transform: wav = self.add_noise_transform(wav)
        wav = self.pad_crop(wav).squeeze(0)
        if torch.isnan(wav).any():
          print(f"NaNs in wav {idx}.")
          return None
        targets = {}
        valid_item = True
        for task_name, task_info in TASKS.items():
            value = row[task_name]
            if pd.isna(value):
              print(f"NaN target {task_name} at {idx}.")
              valid_item = False
              break
            if task_info['type'] == 'regression':
                mean, std = NORM_STATS[task_name]['mean'], NORM_STATS[task_name]['std']
                targets[task_name] = torch.tensor((value - mean) / (std if std > 1e-6 else 1.0), dtype=torch.float32)
            elif task_info['type'] == 'classification':
                targets[task_name] = torch.tensor(GENDER_MAP.get(str(value).upper(), 0), dtype=torch.long)
        return (wav, targets) if valid_item else None

In [7]:
def collate_fn(batch: list) -> tuple[torch.Tensor | dict, dict] | None:
    batch = [item for item in batch if item is not None]
    if not batch:
      return None
    wavs = [item[0] for item in batch]
    target_dicts = [item[1] for item in batch]
    padded_wavs = torch.nn.utils.rnn.pad_sequence(wavs, batch_first=True, padding_value=0.0)
    collated_targets = {}
    if target_dicts:
        first_item_keys = target_dicts[0].keys()
        for key in first_item_keys:
            if all(key in d for d in target_dicts):
                 collated_targets[key] = torch.stack([d[key] for d in target_dicts])
    if not collated_targets and target_dicts:
      return None

    return padded_wavs, collated_targets

In [8]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) # (max_len, 1, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # (seq_len, batch, d_model)
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [9]:
class sps_Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(PRETRAINED_W2V2)
        self.wav2vec2_encoder = Wav2Vec2Model.from_pretrained(PRETRAINED_W2V2)

        if FREEZE_ENCODER:
            for param in self.wav2vec2_encoder.parameters():
                param.requires_grad = False


        self.pos_encoder = PositionalEncoding(TRANSFORMER_D_MODEL, TRANSFORMER_DROPOUT)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=TRANSFORMER_D_MODEL,
            nhead=TRANSFORMER_NHEAD,
            dim_feedforward=TRANSFORMER_DIM_FEEDFORWARD,
            dropout=TRANSFORMER_DROPOUT,
            batch_first=False # expecting (S, N, E)
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=TRANSFORMER_NUM_ENCODER_LAYERS
        )

        self.heads = nn.ModuleDict()
        head_common_input_dim = TRANSFORMER_D_MODEL

        for task_name, task_info in TASKS.items():
            head_specific_config = HEAD_CONFIGS[task_name]
            head_hidden_dim = head_specific_config['head_hidden_dim']
            head_dropout_rate = head_specific_config['head_dropout_rate'] # Using this for head's dropout
            output_dim = 1 if task_info['type'] == 'regression' else task_info.get('num_classes')
            if output_dim is None and task_info['type'] == 'classification':
                raise ValueError(f"num_classes not set for classification task '{task_name}'")

            self.heads[task_name] = nn.Sequential(
                nn.Linear(head_common_input_dim, head_hidden_dim), nn.ReLU(),
                nn.Dropout(head_dropout_rate), nn.Linear(head_hidden_dim, output_dim)
            )

    def forward(self, waveform: torch.Tensor) -> dict:
        current_device = next(self.parameters()).device
        waveform_list = [wav.cpu().numpy() for wav in waveform] if waveform.ndim == 2 else [waveform.cpu().numpy()]

        inputs = self.feature_extractor(
            waveform_list, sampling_rate=SAMPLE_RATE, return_tensors="pt",
            padding="longest", return_attention_mask=True
        )
        inputs = {k: v.to(current_device) for k, v in inputs.items()}
        attention_mask_input = inputs.get('attention_mask') # (N, S_in)
        if attention_mask_input is None:
            attention_mask_input = torch.ones(inputs['input_values'].shape[0], inputs['input_values'].shape[1], dtype=torch.long, device=current_device)

        w2v2_outputs = self.wav2vec2_encoder(inputs['input_values'], attention_mask=attention_mask_input)
        hidden_states = w2v2_outputs.last_hidden_state # (N, S_out, E) where E is TRANSFORMER_D_MODEL

        # TransformerEncoder expects (S_out, N, E)
        transformer_input = hidden_states.transpose(0, 1) # (S_out, N, E)


        transformer_input = self.pos_encoder(transformer_input)


        # Create src_key_padding_mask for Transformer: (N, S_out)
        # attention_mask_input is (N, S_in)
        actual_sequence_length = hidden_states.shape[1] # S_out
        if attention_mask_input.shape[1] >= actual_sequence_length:
            output_attention_mask = attention_mask_input[:, :actual_sequence_length] # (N, S_out), 1 for valid, 0 for pad
        else:
            output_attention_mask = torch.ones(hidden_states.shape[0], actual_sequence_length, device=current_device, dtype=torch.long)

        src_key_padding_mask = (output_attention_mask == 0) # (N, S_out), True for pad

        transformer_output = self.transformer_encoder(
            transformer_input,
            src_key_padding_mask=src_key_padding_mask
        ) # (S_out, N, E)

        # (N, S_out, E) for pooling
        transformer_output_permuted = transformer_output.transpose(0, 1) # (N, S_out, E)

        # Pooling
        expanded_output_attention_mask = output_attention_mask.unsqueeze(-1).expand_as(transformer_output_permuted) # (N, S_out, E)

        masked_transformer_output = transformer_output_permuted * expanded_output_attention_mask
        summed_transformer_output = torch.sum(masked_transformer_output, dim=1) # (N, E)

        # (N, 1)
        valid_lengths = output_attention_mask.sum(dim=1, keepdim=True).clamp(min=1e-9) # (N, 1)
        pooled_output = summed_transformer_output / valid_lengths # (N, E)

        predictions = {}
        for task_name, head_module in self.heads.items():
            task_prediction = head_module(pooled_output)
            if TASKS[task_name]['type'] == 'regression':
                predictions[task_name] = task_prediction.squeeze(-1)
            else:
                predictions[task_name] = task_prediction
        return predictions

In [10]:
print(f"Using device: {DEVICE}")
random.seed(VAL_SPLIT_SEED); np.random.seed(VAL_SPLIT_SEED)
torch.manual_seed(VAL_SPLIT_SEED)
if DEVICE == 'cuda':
  torch.cuda.manual_seed_all(VAL_SPLIT_SEED)

Using device: cuda


In [11]:
full_train_df_raw = pd.read_csv(TRAIN_CSV_PATH)
test_df_raw = pd.read_csv(TEST_CSV_PATH)
print(f"Loaded raw train data: {len(full_train_df_raw)} rows, raw test data: {len(test_df_raw)} rows")

cols_to_drop = ['index', 'Use_x', 'DR', 'Ethnicity']
full_train_df = full_train_df_raw.drop(columns=[c for c in cols_to_drop if c in full_train_df_raw.columns], errors='ignore')
test_df = test_df_raw.drop(columns=[c for c in cols_to_drop if c in test_df_raw.columns], errors='ignore')

critical_cols = ['FilePath'] + list(TASKS.keys())
full_train_df.dropna(subset=critical_cols, inplace=True)
test_df.dropna(subset=critical_cols, inplace=True)
full_train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

print(f"Processed train data: {len(full_train_df)} rows, Processed test data: {len(test_df)} rows")

Loaded raw train data: 4490 rows, raw test data: 1640 rows
Processed train data: 4490 rows, Processed test data: 1640 rows


In [12]:
print("Calculating normalization stats and mappings...")
for task_name, task_info in TASKS.items():
    if task_info['type'] == 'regression':
        mean, std = full_train_df[task_name].astype(float).mean(), full_train_df[task_name].astype(float).std()
        NORM_STATS[task_name]['mean'], NORM_STATS[task_name]['std'] = mean, std if (np.isfinite(std) and std > 1e-6) else 1.0
        print(f"  {task_name.capitalize()} stats: Mean={NORM_STATS[task_name]['mean']:.2f}, Std={NORM_STATS[task_name]['std']:.2f}")
    elif task_info['type'] == 'classification':
        if task_name == 'Gender':
            cats = sorted(list(full_train_df[task_name].astype(str).str.upper().unique()))
            TASKS[task_name]['num_classes'] = len(cats)
            GENDER_MAP = {cat: i for i, cat in enumerate(cats)}
            print(f"  Gender mapping: {GENDER_MAP}, Num Classes: {TASKS[task_name]['num_classes']}")
        else:
            TASKS[task_name]['num_classes'] = len(full_train_df[task_name].unique())
            print(f"  {task_name.capitalize()} Num Classes: {TASKS[task_name]['num_classes']}")

Calculating normalization stats and mappings...
  Age stats: Mean=30.29, Std=7.77
  Gender mapping: {'F': 0, 'M': 1}, Num Classes: 2
  Height stats: Mean=175.75, Std=9.52


In [13]:
if 'SpeakerID' in full_train_df.columns:
    print("Splitting data by SpeakerID...")
    speaker_ids = full_train_df['SpeakerID'].unique()
    train_spk_ids, val_spk_ids = train_test_split(speaker_ids, test_size=VAL_SPLIT_RATIO, random_state=VAL_SPLIT_SEED)
    train_df, val_df = full_train_df[full_train_df['SpeakerID'].isin(train_spk_ids)].copy(), full_train_df[full_train_df['SpeakerID'].isin(val_spk_ids)].copy()
else:
    print("Warning: 'SpeakerID' not found. Performing random split.")
    train_df, val_df = train_test_split(full_train_df, test_size=VAL_SPLIT_RATIO, random_state=VAL_SPLIT_SEED)
print(f"Data split: Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}")

Splitting data by SpeakerID...
Data split: Train=3360, Val=1130, Test=1640


In [14]:
train_dataset = TimitDataset(train_df, mode='train')
val_dataset = TimitDataset(val_df, mode='eval')
test_dataset = TimitDataset(test_df, mode='eval')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)

In [15]:
model = sps_Transformer().to(DEVICE)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.60k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/378M [00:00<?, ?B/s]

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE, weight_decay=OPTIMIZER_WEIGHT_DECAY)

In [17]:
criterion_reg, criterion_cls = nn.MSELoss(), nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=MONITOR_METRIC_MODE, factor=LR_SCHEDULER_FACTOR, patience=LR_SCHEDULER_PATIENCE, min_lr=MIN_LR, verbose=True)

In [18]:
def validate_epoch(val_loader, model, device):
    model.eval()
    running_task_losses_val = {task: 0.0 for task in TASKS}
    num_samples_val = 0
    criterion_reg_val, criterion_cls_val = nn.MSELoss(reduction='sum'), nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for batch_idx, batch_data in enumerate(val_loader):
            if batch_data is None: continue
            wav, targets = batch_data
            if wav.numel() == 0 or not targets: continue
            wav, targets_device = wav.to(device), {k: v.to(device) for k, v in targets.items()}
            current_batch_size = wav.size(0)
            predictions = model(wav)
            valid_batch = True
            for task_name, task_info in TASKS.items():
                if task_name not in predictions or task_name not in targets_device: valid_batch = False; break
                pred_val, target_val = predictions[task_name], targets_device[task_name]
                if task_info['type'] == 'regression' and pred_val.shape != target_val.shape: target_val = target_val.view_as(pred_val)
                loss_val = criterion_reg_val(pred_val, target_val) if task_info['type'] == 'regression' else criterion_cls_val(pred_val, target_val)
                if torch.isnan(loss_val) or torch.isinf(loss_val): running_task_losses_val[task_name] += 1e9 * current_batch_size; valid_batch = False; break
                else: running_task_losses_val[task_name] += loss_val.item()
            if not valid_batch: continue
            num_samples_val += current_batch_size
    avg_task_losses_val, weighted_val_loss = {}, 0.0
    if num_samples_val > 0:
        for task_name in TASKS.keys():
            avg_loss = running_task_losses_val[task_name] / num_samples_val
            avg_task_losses_val[task_name] = avg_loss
            if task_name in EVAL_LOSS_WEIGHTS: weighted_val_loss += EVAL_LOSS_WEIGHTS[task_name] * avg_loss
    else:
        for task_name in TASKS.keys(): avg_task_losses_val[task_name] = float('inf')
        weighted_val_loss = float('inf')
    return avg_task_losses_val, weighted_val_loss

In [19]:
print("\n--- Starting Training ---")
best_val_metric = float('inf') if MONITOR_METRIC_MODE == 'min' else float('-inf')
epochs_no_improve = 0
best_model_state = None

for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss_epoch = 0.0
    task_losses_epoch = {task: 0.0 for task in TASKS}
    num_samples_processed = 0
    for batch_idx, batch_data in enumerate(train_loader):
        if batch_data is None: continue
        wav, targets = batch_data
        if wav.numel() == 0 or not targets: continue
        wav, targets_device = wav.to(DEVICE), {k: v.to(DEVICE) for k, v in targets.items()}
        current_batch_size = wav.size(0)
        optimizer.zero_grad()
        predictions = model(wav)
        combined_loss_batch = torch.tensor(0.0, device=DEVICE)
        current_batch_task_losses = {}
        valid_batch_for_loss = True
        for task_name, task_info in TASKS.items():
            if task_name not in predictions or task_name not in targets_device: valid_batch_for_loss = False; break
            pred, target = predictions[task_name], targets_device[task_name]
            if task_info['type'] == 'regression' and pred.shape != target.shape: target = target.view_as(pred)
            loss = criterion_reg(pred, target) if task_info['type'] == 'regression' else criterion_cls(pred, target)
            if torch.isnan(loss) or torch.isinf(loss): valid_batch_for_loss = False; break
            combined_loss_batch += task_info['loss_weight'] * loss
            current_batch_task_losses[task_name] = loss.item()

        if valid_batch_for_loss:
            combined_loss_batch.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRADIENT_CLIP_MAX_NORM) # Gradient Clipping
            optimizer.step()
            total_loss_epoch += combined_loss_batch.item() * current_batch_size
            for task_name, loss_item in current_batch_task_losses.items():
                task_losses_epoch[task_name] += loss_item * current_batch_size
            num_samples_processed += current_batch_size
            # if (batch_idx + 1) % 50 == 0:
            #      print(f"  E{epoch} B{batch_idx+1}/{len(train_loader)} | BLoss: {combined_loss_batch.item():.3f} | " + \
            #            " | ".join([f"{k[:3]}L: {v:.3f}" for k,v in current_batch_task_losses.items()]))

    avg_loss_epoch = total_loss_epoch/num_samples_processed if num_samples_processed > 0 else float('inf')
    avg_task_train_losses = {k:v/num_samples_processed if num_samples_processed > 0 else float('inf') for k,v in task_losses_epoch.items()}
    print(f"Epoch{epoch} Summary: AvgTrL: {avg_loss_epoch:.4f} | " + \
          " | ".join([f"Avg {k[:3]}TrL: {v:.4f}" for k,v in avg_task_train_losses.items()]))

    if len(val_loader) > 0:
        avg_task_val_losses, current_val_metric = validate_epoch(val_loader, model, DEVICE)
        val_loss_str = " | ".join([f"Avg {k[:3]}VaL: {v:.4f}" for k,v in avg_task_val_losses.items()])
        print(f"E{epoch} Validation: WValL: {current_val_metric:.4f} | {val_loss_str} | LR: {optimizer.param_groups[0]['lr']:.2e}")
        scheduler.step(current_val_metric)
        if (MONITOR_METRIC_MODE == 'min' and current_val_metric < best_val_metric) or \
           (MONITOR_METRIC_MODE == 'max' and current_val_metric > best_val_metric):
            best_val_metric, epochs_no_improve = current_val_metric, 0
            best_model_state = copy.deepcopy(model.state_dict())
            torch.save(best_model_state, CKPT_PATH)
            print(f"  Best model saved! E{epoch}, ValM: {best_val_metric:.4f}")
        else:
            epochs_no_improve += 1
            print(f"  No improvement for {epochs_no_improve} epoch(s). Best: {best_val_metric:.4f}")
        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"Early stopping E{epoch}. No improvement for {EARLY_STOPPING_PATIENCE} epochs."); break
    else: print("Skipping validation.")
print("--- Training Finished ---")


--- Starting Training ---
E1 Summary: AvgTrL: 0.8226 | Avg ageTrL: 0.9571 | Avg GenTrL: 0.6006 | Avg heiTrL: 0.9977
E1 Validation: WValL: 0.9786 | Avg ageVaL: 1.2600 | Avg GenVaL: 0.6602 | Avg heiVaL: 1.0528 | LR: 3.00e-05
  Best model saved! E1, ValM: 0.9786
E2 Summary: AvgTrL: 0.8072 | Avg ageTrL: 0.9266 | Avg GenTrL: 0.5969 | Avg heiTrL: 0.9890
E2 Validation: WValL: 0.9648 | Avg ageVaL: 1.2413 | Avg GenVaL: 0.6441 | Avg heiVaL: 1.0534 | LR: 3.00e-05
  Best model saved! E2, ValM: 0.9648
E3 Summary: AvgTrL: 0.8049 | Avg ageTrL: 0.9290 | Avg GenTrL: 0.5908 | Avg heiTrL: 0.9851
E3 Validation: WValL: 0.9717 | Avg ageVaL: 1.2354 | Avg GenVaL: 0.6664 | Avg heiVaL: 1.0551 | LR: 3.00e-05
  No improvement for 1 epoch(s). Best: 0.9648
E4 Summary: AvgTrL: 0.8037 | Avg ageTrL: 0.9229 | Avg GenTrL: 0.5943 | Avg heiTrL: 0.9842
E4 Validation: WValL: 0.9631 | Avg ageVaL: 1.2391 | Avg GenVaL: 0.6449 | Avg heiVaL: 1.0471 | LR: 3.00e-05
  Best model saved! E4, ValM: 0.9631
E5 Summary: AvgTrL: 0.8024 |

In [20]:
print("\n--- Evaluating on Test Set with Best Model (Transformer) ---")
if best_model_state is not None:
    print(f"Loading best model from memory (ValM: {best_val_metric:.4f})")
    model.load_state_dict(best_model_state)
elif os.path.exists(CKPT_PATH):
    print(f"Loading best model from checkpoint: {CKPT_PATH}")
    model.load_state_dict(torch.load(CKPT_PATH, map_location=DEVICE))
else:
  print("No best model state found. Evaluating with the last model state.")


--- Evaluating on Test Set with Best Model (Transformer) ---
Loading best model from memory (ValM: 0.9631)


In [21]:
model.eval()
all_targets_test, all_preds_test = {t:[] for t in TASKS}, {t:[] for t in TASKS}
running_task_losses_test, num_samples_test = {t:0.0 for t in TASKS}, 0
crit_reg_test, crit_cls_test = nn.MSELoss(reduction='sum'), nn.CrossEntropyLoss(reduction='sum')
with torch.no_grad():
    for batch_idx, batch_data in enumerate(test_loader):
        if batch_data is None: continue
        wav, targets = batch_data
        if wav.numel() == 0 or not targets: continue
        wav, targets_dev = wav.to(DEVICE), {k:v.to(DEVICE) for k,v in targets.items()}
        current_batch_size = wav.size(0)
        predictions = model(wav)
        for task_name, task_info in TASKS.items():
            if task_name not in predictions or task_name not in targets_dev: continue
            pred_v, target_v = predictions[task_name], targets_dev[task_name]
            if task_info['type'] == 'regression' and pred_v.shape != target_v.shape: target_v=target_v.view_as(pred_v)
            loss_v = crit_reg_test(pred_v, target_v) if task_info['type']=='regression' else crit_cls_test(pred_v, target_v)
            running_task_losses_test[task_name] += loss_v.item()
            pred_cpu, targ_cpu = pred_v.cpu(), targets[task_name]
            if task_info['type'] == 'regression':
                m, s = NORM_STATS[task_name]['mean'], NORM_STATS[task_name]['std']
                pred_denorm = (pred_cpu * (s if s > 1e-6 else 1.0)) + m
                targ_denorm = (targ_cpu * (s if s > 1e-6 else 1.0)) + m
                all_preds_test[task_name].extend(pred_denorm.tolist())
                all_targets_test[task_name].extend(targ_denorm.tolist())
            else:
                all_preds_test[task_name].extend(torch.argmax(pred_cpu, dim=1).tolist())
                all_targets_test[task_name].extend(targ_cpu.tolist())
        num_samples_test += current_batch_size
        if (batch_idx + 1) % 20 == 0:
          print(f"  Evaluated test batch {batch_idx+1}/{len(test_loader)}")

  Evaluated test batch 20/52
  Evaluated test batch 40/52


In [22]:
metrics_test, avg_task_losses_test = {}, {}
print("\n--- Final Test Results (Best Transformer Model) ---")
if num_samples_test > 0:
    for task_name, task_info in TASKS.items():
        avg_task_losses_test[task_name] = running_task_losses_test[task_name]/num_samples_test
        print(f"  Avg Test Loss ({task_name}): {avg_task_losses_test[task_name]:.4f}")
        targets_np, preds_np = np.array(all_targets_test[task_name]), np.array(all_preds_test[task_name])
        if len(targets_np) > 0 and len(targets_np) == len(preds_np):
            if task_info['type'] == 'regression':
                mse = mean_squared_error(targets_np, preds_np)
                metrics_test[f"{task_name}_mse"] = mse
                print(f"  Test MSE ({task_name}): {mse:.4f} (RMSE: {np.sqrt(mse):.4f})")
            else:
                acc = accuracy_score(targets_np, preds_np)
                metrics_test[f"{task_name}_acc"] = acc
                print(f"  Test Accuracy ({task_name}): {acc:.4f}")
        else: print(f"  Could not calculate metric for {task_name}.")
else: print("No samples processed during final test evaluation.")
print("-" * 30)


--- Final Test Results (Best Transformer Model) ---
  Avg Test Loss (age): 1.2079
  Test MSE (age): 72.9151 (RMSE: 8.5390)
  Avg Test Loss (Gender): 0.6424
  Test Accuracy (Gender): 0.6585
  Avg Test Loss (height): 0.9039
  Test MSE (height): 81.9962 (RMSE: 9.0552)
------------------------------
