In [28]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# !mkdir 'timit_dataset'
# !unzip '/content/drive/MyDrive/sps/preprocessed_timit_dataset.zip' -d timit_dataset

In [30]:
import os
import random
import numpy as np
import pandas as pd
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, accuracy_score
from transformers import WavLMModel, Wav2Vec2FeatureExtractor
import copy
import math

In [None]:
AUDIO_ROOT_DIR = '/content/timit_dataset'
DRIVE_PATH = '/content/drive/MyDrive/sps/'

TRAIN_CSV_PATH = os.path.join(DRIVE_PATH, 'final_train_data_merged.csv')
TEST_CSV_PATH = os.path.join(DRIVE_PATH, 'final_test_data_merged.csv')
CKPT_PATH = os.path.join(DRIVE_PATH, 'sps_wavlm_conformer_best_model.pth')

# Audio Processing
SAMPLE_RATE = 16000
CLIP_SECONDS = 4.0
WAV_LEN = int(SAMPLE_RATE * CLIP_SECONDS)

# Audio Augmentations
APPLY_AUGMENTATIONS = True
AUGMENTATION_PROBABILITY = 0.3
ADD_NOISE_PROBABILITY = 0.3
NOISE_SNR_MIN = 7.0
NOISE_SNR_MAX = 25.0

# WavLM
PRETRAINED_SSL_MODEL = 'microsoft/wavlm-base-plus'
SSL_OUTPUT_DIM = 768
FINETUNE_SSL_MODEL = True
FINETUNE_SSL_LAYERS = 2

# Conformer Encoder Parameters
CONFORMER_INPUT_DIM = SSL_OUTPUT_DIM
CONFORMER_NUM_BLOCKS = 4
CONFORMER_FF_DIM_FACTOR = 4
CONFORMER_NHEAD = 8
CONFORMER_CONV_KERNEL_SIZE = 31
CONFORMER_DROPOUT = 0.1

# Attentive Statistics Pooling Parameters
ATTN_POOL_DIM = 128

# Shared Dropout
MODEL_DROPOUT_RATE = 0.3

# Training Hyperparameters
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 50
BATCH_SIZE = 32
LEARNING_RATE = 5e-5
SSL_FINETUNE_LR_FACTOR = 0.1
OPTIMIZER_WEIGHT_DECAY = 0.01
GRADIENT_CLIP_MAX_NORM = 1.0
LR_WARMUP_EPOCHS = 5

# Validation
VAL_SPLIT_RATIO = 0.25
VAL_SPLIT_SEED = 42

# Task Configuration
TASKS = {
    'age': {'type': 'regression', 'loss_weight': 0.4},
    'Gender': {'type': 'classification', 'loss_weight': 0.3},
    'height': {'type': 'regression', 'loss_weight': 0.3}
}

# Head-Specific Hyperparameters
HEAD_CONFIGS = {
    'age': {'head_hidden_dim': 128, 'head_dropout_rate': 0.25},
    'Gender': {'head_hidden_dim': 96, 'head_dropout_rate': 0.25},
    'height': {'head_hidden_dim': 64, 'head_dropout_rate': 0.2}
}

# Mappings and Normalization Stats
GENDER_MAP = {}
NORM_STATS = {'age': {'mean': 0.0, 'std': 1.0},
              'height': {'mean': 0.0, 'std': 1.0}}

# DataLoader
NUM_WORKERS = 4
PIN_MEMORY = True if DEVICE == 'cuda' else False

# Model Saving, Early Stopping, LR Scheduling
MONITOR_METRIC_MODE = 'min'
EARLY_STOPPING_PATIENCE = 15
LR_SCHEDULER_PATIENCE = 7
LR_SCHEDULER_FACTOR = 0.2
MIN_LR = 1e-7

EVAL_LOSS_WEIGHTS = {'age': 0.4, 'Gender': 0.4, 'height': 0.2}

In [None]:
class PadCrop:
    def __init__(self, length: int, mode: str = 'train'):
        self.length = length
        self.mode = mode

    def __call__(self, wav: torch.Tensor):
        current_len = wav.shape[-1]
        if current_len == self.length:
          return wav

        elif current_len > self.length:
            start = random.randint(0, current_len - self.length) if self.mode == 'train' else (current_len - self.length) // 2
            wav = wav[..., start : start + self.length]

        else:
            pad_width = self.length - current_len
            pad_left = pad_width // 2
            pad_right = pad_width - pad_left
            wav = F.pad(wav, (pad_left, pad_right), mode='constant', value=0.0)

        return wav

class AddGaussianNoise:
    def __init__(self, min_snr_db: float = 5.0, max_snr_db: float = 20.0, p: float = 0.5):
        self.min_snr_db = min_snr_db
        self.max_snr_db = max_snr_db
        self.p = p

    def __call__(self, waveform: torch.Tensor):
        if random.random() < self.p:
            if waveform.ndim > 1 and waveform.shape[0] > 1:
              wav_data_for_power_calc = waveform[0,:]

            elif waveform.ndim > 1 and waveform.shape[0] == 1:
              wav_data_for_power_calc = waveform.squeeze(0)

            elif waveform.ndim == 1:
              wav_data_for_power_calc = waveform
            else:
              return waveform

            signal_power = torch.mean(wav_data_for_power_calc**2)
            if signal_power.item() < 1e-9:
              return waveform

            snr_db = random.uniform(self.min_snr_db, self.max_snr_db)
            snr_linear = 10**(snr_db / 10.0)
            noise_power = signal_power / snr_linear
            noise = torch.randn_like(waveform) * torch.sqrt(noise_power)
            return waveform + noise
        return waveform


In [None]:
class TimitDataset(Dataset):
    def __init__(self, data_df: pd.DataFrame, mode: str = 'train'):
        self.data_df = data_df.reset_index(drop=True)
        self.mode = mode
        self.pad_crop = PadCrop(WAV_LEN, mode)
        self.target_cols = list(TASKS.keys())
        self.add_noise_transform = None
        if self.mode == 'train' and APPLY_AUGMENTATIONS:
            self.add_noise_transform = AddGaussianNoise(NOISE_SNR_MIN, NOISE_SNR_MAX, ADD_NOISE_PROBABILITY)

    def __len__(self): 
        return len(self.data_df)

    def __getitem__(self, idx: int):
        if idx >= len(self.data_df): 
            print(f"Index {idx} out of bounds...")
            return None
        row = self.data_df.iloc[idx]
        wav_relative_path = row['FilePath']
        full_wav_path = os.path.normpath(os.path.join(AUDIO_ROOT_DIR, wav_relative_path))
        wav, sr = torchaudio.load(full_wav_path)
        if sr != SAMPLE_RATE: 
            wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
        if wav.shape[0] > 1: 
            wav = torch.mean(wav, dim=0, keepdim=True)
        if self.mode == 'train' and APPLY_AUGMENTATIONS and random.random() < AUGMENTATION_PROBABILITY:
            if self.add_noise_transform: wav = self.add_noise_transform(wav)
        wav = self.pad_crop(wav).squeeze(0)
        if torch.isnan(wav).any(): 
            print(f"NaNs in wav {idx}.")
            return None

        targets = {}
        valid_item = True
        for task_name, task_info in TASKS.items():
            value = row[task_name]
            if pd.isna(value): 
                print(f"NaN target {task_name} at {idx}.")
                valid_item = False 
                break
            if task_info['type'] == 'regression':
                mean, std = NORM_STATS[task_name]['mean'], NORM_STATS[task_name]['std']
                targets[task_name] = torch.tensor((value - mean) / (std if std > 1e-6 else 1.0), dtype=torch.float32)
            elif task_info['type'] == 'classification':
                targets[task_name] = torch.tensor(GENDER_MAP.get(str(value).upper(), 0), dtype=torch.long)
                
        return (wav, targets) if valid_item else None

In [33]:
def collate_fn(batch: list):
    batch = [item for item in batch if item is not None]

    if not batch:
      return None

    wavs = [item[0] for item in batch]
    target_dicts = [item[1] for item in batch]
    padded_wavs = torch.nn.utils.rnn.pad_sequence(wavs, batch_first=True, padding_value=0.0)

    collated_targets = {}
    if target_dicts:
        first_item_keys = target_dicts[0].keys()
        for key in first_item_keys:
            if all(key in d for d in target_dicts):
                 collated_targets[key] = torch.stack([d[key] for d in target_dicts])

    if not collated_targets and target_dicts:
      return None

    return padded_wavs, collated_targets

In [None]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

In [None]:
class ConvolutionModule(nn.Module):
    def __init__(self, channels, kernel_size, activation = Swish(), bias= True):
        super().__init__()
        self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias)
        self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias)
        self.norm = nn.BatchNorm1d(channels)
        self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias)
        self.activation = activation

    def forward(self, x):
        x = x.transpose(1, 2)  # (batch, channels, time)
        x = self.pointwise_conv1(x)  # (batch, 2*channels, time)
        x_act, x_gate = x.chunk(2, dim=1)  # (batch, channels, time) each
        x = x_act * torch.sigmoid(x_gate) # GLU
        x = self.depthwise_conv(x)
        x = self.norm(x)
        x = self.activation(x)
        x = self.pointwise_conv2(x)
        return x.transpose(1, 2) # (batch, time, channels)

In [None]:
class FeedForwardModule(nn.Module):
    def __init__(self, d_model, d_ff, dropout, activation = Swish()):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.activation = activation

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

In [None]:
class ConformerBlock(nn.Module):
    def __init__(self, d_model, n_head, d_ff, conv_kernel_size, dropout):
        super().__init__()
        self.ffn1 = FeedForwardModule(d_model, d_ff, dropout)
        self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True)
        self.conv_module = ConvolutionModule(d_model, conv_kernel_size)
        self.ffn2 = FeedForwardModule(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.norm4 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_key_padding_mask) :
        # FFNN 1
        residual = x
        x = self.norm1(x)
        x = self.ffn1(x)
        x = self.dropout(x) * 0.5 + residual # Half-step residual

        # Multi-Head Self-Attention Part
        residual = x
        x = self.norm2(x)
        x_attn, _ = self.self_attn(x, x, x, key_padding_mask=src_key_padding_mask)
        x = self.dropout(x_attn) + residual

        # Convolution Part
        residual = x
        x = self.norm3(x)
        x = self.conv_module(x)
        x = self.dropout(x) + residual

        # FFNN 2
        residual = x
        x = self.norm4(x)
        x = self.ffn2(x)
        x = self.dropout(x) * 0.5 + residual # Half-step residual
        return x

In [None]:
class AttentiveStatisticsPooling(nn.Module):
    def __init__(self, input_dim, attention_dim):
        super().__init__()
        self.attention_mlp = nn.Sequential(
            nn.Linear(input_dim, attention_dim),
            nn.Tanh(),
            nn.Linear(attention_dim, 1)
        )

    def forward(self, x, attention_mask):
        # x shape: (batch, seq_len, input_dim)
        # attention_mask shape: (batch, seq_len), 1 for valid, 0 for pad

        attn_weights = self.attention_mlp(x).squeeze(-1)  # (batch, seq_len)

        if attention_mask is not None:
            attn_weights = attn_weights.masked_fill(attention_mask == 0, -1e9) # Mask before softmax

        attn_weights = F.softmax(attn_weights, dim=1)  # (batch, seq_len)
        attn_weights_expanded = attn_weights.unsqueeze(-1) # (batch, seq_len, 1)

        weighted_mean = torch.sum(x * attn_weights_expanded, dim=1) # (batch, input_dim)

        # Weighted standard deviation
        weighted_var = torch.sum((x**2) * attn_weights_expanded, dim=1) - weighted_mean**2
        weighted_std = torch.sqrt(weighted_var.clamp(min=1e-9)) # (batch, input_dim)

        # Concatenate mean and std
        pooled_output = torch.cat((weighted_mean, weighted_std), dim=1) # (batch, 2 * input_dim)
        return pooled_output

In [None]:
class sps_ConformerWavLM(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(PRETRAINED_SSL_MODEL, trust_remote_code=True)
        self.wavlm = WavLMModel.from_pretrained(PRETRAINED_SSL_MODEL, trust_remote_code=True)

        if FINETUNE_SSL_MODEL:
            if FINETUNE_SSL_LAYERS == 0: # Freeze all if 0 layers specified for fine-tuning
                 for param in self.wavlm.parameters():
                    param.requires_grad = False

            elif FINETUNE_SSL_LAYERS > 0:
                # Freeze all parameters first
                for param in self.wavlm.parameters():
                    param.requires_grad = False
                # Unfreeze the feature projection layer (adapter)
                if hasattr(self.wavlm, 'feature_projection'):
                    for param in self.wavlm.feature_projection.parameters():
                        param.requires_grad = True
                # Unfreeze the top N encoder layers
                if hasattr(self.wavlm, 'encoder') and hasattr(self.wavlm.encoder, 'layers'):
                    num_total_layers = len(self.wavlm.encoder.layers)
                    for i in range(num_total_layers - FINETUNE_SSL_LAYERS, num_total_layers):
                        if i >= 0:
                            for param in self.wavlm.encoder.layers[i].parameters():
                                param.requires_grad = True

        else: # Freeze entire SSL model if FINETUNE_SSL_MODEL is False
            for param in self.wavlm.parameters():
                param.requires_grad = False

        self.conformer_encoder = nn.Sequential(
            *[ConformerBlock(
                d_model=CONFORMER_INPUT_DIM,
                n_head=CONFORMER_NHEAD,
                d_ff=CONFORMER_INPUT_DIM * CONFORMER_FF_DIM_FACTOR,
                conv_kernel_size=CONFORMER_CONV_KERNEL_SIZE,
                dropout=CONFORMER_DROPOUT
              ) for _ in range(CONFORMER_NUM_BLOCKS)]
        )

        self.pooling = AttentiveStatisticsPooling(CONFORMER_INPUT_DIM, ATTN_POOL_DIM)
        pooled_output_dim = CONFORMER_INPUT_DIM * 2

        self.heads = nn.ModuleDict()
        for task_name, task_info in TASKS.items():
            head_config = HEAD_CONFIGS[task_name]
            output_dim = 1 if task_info['type'] == 'regression' else task_info.get('num_classes')

            if output_dim is None:
              print(f"num_classes missing for {task_name}")

            self.heads[task_name] = nn.Sequential(
                nn.Linear(pooled_output_dim, head_config['head_hidden_dim']), nn.ReLU(),
                nn.Dropout(head_config['head_dropout_rate']),
                nn.Linear(head_config['head_hidden_dim'], output_dim)
            )

    def forward(self, waveform):
        current_device = waveform.device
        inputs = self.feature_extractor(
            [wav.cpu().numpy() for wav in waveform], # WavLM/W2V2 FE expects list of numpy
            sampling_rate=SAMPLE_RATE,
            return_tensors="pt",
            padding="longest"
        )
        input_values = inputs.input_values.to(current_device)

        attention_mask_ssl_input = inputs.attention_mask.to(current_device) if hasattr(inputs, 'attention_mask') and inputs.attention_mask is not None else None

        wavlm_outputs = self.wavlm(
            input_values,
            attention_mask=attention_mask_ssl_input,
            output_hidden_states=False
        )
        hidden_states = wavlm_outputs.last_hidden_state

        # Conformer expects src_key_padding_mask: (batch, seq_len_ssl), True for padded
        # attention_mask_ssl_input is (batch, seq_len_ssl), 1 for valid, 0 for pad
        # So, src_key_padding_mask is (attention_mask_ssl_input == 0)
        if attention_mask_ssl_input is not None:
            output_seq_len = hidden_states.shape[1]
            if attention_mask_ssl_input.shape[1] > output_seq_len:
                 conformer_padding_mask = (attention_mask_ssl_input[:, :output_seq_len] == 0)

            elif attention_mask_ssl_input.shape[1] < output_seq_len:
                 conformer_padding_mask = torch.zeros(hidden_states.shape[0], output_seq_len, dtype=torch.bool, device=current_device)

            else:
                 conformer_padding_mask = (attention_mask_ssl_input == 0)
        else:
            conformer_padding_mask = torch.zeros(hidden_states.shape[0], hidden_states.shape[1], dtype=torch.bool, device=current_device)



        conformer_output = hidden_states
        for block in self.conformer_encoder:
            conformer_output = block(conformer_output, src_key_padding_mask=conformer_padding_mask)

        # For AttentiveStatisticsPooling, we need a mask where 1 is valid, 0 is pad
        pooling_attention_mask = ~conformer_padding_mask

        pooled_output = self.pooling(conformer_output, attention_mask=pooling_attention_mask)

        predictions = {}
        for task_name, head_module in self.heads.items():
            task_prediction = head_module(pooled_output)
            if TASKS[task_name]['type'] == 'regression':
                predictions[task_name] = task_prediction.squeeze(-1)
            else:
                predictions[task_name] = task_prediction
        return predictions

In [None]:
print(f"Using device: {DEVICE}")
random.seed(VAL_SPLIT_SEED)
np.random.seed(VAL_SPLIT_SEED)
torch.manual_seed(VAL_SPLIT_SEED)

if DEVICE == 'cuda':
  torch.cuda.manual_seed_all(VAL_SPLIT_SEED)

Using device: cuda


In [None]:
full_train_df_raw = pd.read_csv(TRAIN_CSV_PATH)
test_df_raw = pd.read_csv(TEST_CSV_PATH)
print(f"Loaded raw train data: {len(full_train_df_raw)} rows, raw test data: {len(test_df_raw)} rows")

cols_to_drop = ['index', 'Use_x', 'DR', 'Ethnicity']
full_train_df = full_train_df_raw.drop(columns=[c for c in cols_to_drop if c in full_train_df_raw.columns], errors='ignore')
test_df = test_df_raw.drop(columns=[c for c in cols_to_drop if c in test_df_raw.columns], errors='ignore')

critical_cols = ['FilePath'] + list(TASKS.keys())
full_train_df.dropna(subset=critical_cols, inplace=True)
test_df.dropna(subset=critical_cols, inplace=True)

full_train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)
print(f"Processed train data: {len(full_train_df)} rows, Processed test data: {len(test_df)} rows")

Loaded raw train data: 4490 rows, raw test data: 1640 rows
Processed train data: 4490 rows, Processed test data: 1640 rows


In [42]:
print("Calculating normalization stats and mappings...")
for task_name, task_info in TASKS.items():
    if task_info['type'] == 'regression':
        mean = full_train_df[task_name].astype(float).mean()
        std = full_train_df[task_name].astype(float).std()

        NORM_STATS[task_name]['mean'] = mean
        NORM_STATS[task_name]['std'] = std if (np.isfinite(std) and std > 1e-6) else 1.0

        print(f"  {task_name.capitalize()} stats: Mean={NORM_STATS[task_name]['mean']:.2f}, Std={NORM_STATS[task_name]['std']:.2f}")

    elif task_info['type'] == 'classification':
        if task_name == 'Gender':
            cats = sorted(list(full_train_df[task_name].astype(str).str.upper().unique()))
            TASKS[task_name]['num_classes'] = len(cats)

            GENDER_MAP = {cat: i for i, cat in enumerate(cats)}
            print(f"  Gender mapping: {GENDER_MAP}, Num Classes: {TASKS[task_name]['num_classes']}")

        else:
            TASKS[task_name]['num_classes'] = len(full_train_df[task_name].unique())
            print(f"  {task_name.capitalize()} Num Classes: {TASKS[task_name]['num_classes']}")

Calculating normalization stats and mappings...
  Age stats: Mean=30.29, Std=7.77
  Gender mapping: {'F': 0, 'M': 1}, Num Classes: 2
  Height stats: Mean=175.75, Std=9.52


In [43]:
if 'SpeakerID' in full_train_df.columns:
    print("Splitting data by SpeakerID...")
    speaker_ids = full_train_df['SpeakerID'].unique()
    train_spk_ids, val_spk_ids = train_test_split(speaker_ids, test_size=VAL_SPLIT_RATIO, random_state=VAL_SPLIT_SEED)

    train_df, val_df = full_train_df[full_train_df['SpeakerID'].isin(train_spk_ids)].copy(), full_train_df[full_train_df['SpeakerID'].isin(val_spk_ids)].copy()

else:
    print("Warning: 'SpeakerID' not found. Performing random split.")
    train_df, val_df = train_test_split(full_train_df, test_size=VAL_SPLIT_RATIO, random_state=VAL_SPLIT_SEED)

print(f"Data split: Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}")

Splitting data by SpeakerID...
Data split: Train=3360, Val=1130, Test=1640


In [44]:
train_dataset = TimitDataset(train_df, mode='train')
val_dataset = TimitDataset(val_df, mode='eval')
test_dataset = TimitDataset(test_df, mode='eval')

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)

In [45]:
model = sps_ConformerWavLM().to(DEVICE)

In [46]:
if FINETUNE_SSL_MODEL and FINETUNE_SSL_LAYERS != 0:
    ssl_params = list(model.wavlm.parameters())
    conformer_params = list(model.conformer_encoder.parameters())
    pooling_params = list(model.pooling.parameters())
    head_params = list(model.heads.parameters())

    # Filter out frozen parameters from ssl_params
    trainable_ssl_params = [p for p in ssl_params if p.requires_grad]

    optimizer = torch.optim.AdamW([
        {'params': trainable_ssl_params, 'lr': LEARNING_RATE * SSL_FINETUNE_LR_FACTOR if trainable_ssl_params else LEARNING_RATE}, # Use smaller LR for SSL
        {'params': conformer_params, 'lr': LEARNING_RATE},
        {'params': pooling_params, 'lr': LEARNING_RATE},
        {'params': head_params, 'lr': LEARNING_RATE}
    ], weight_decay=OPTIMIZER_WEIGHT_DECAY)

    print(f"Optimizer: AdamW with differential LR (SSL factor: {SSL_FINETUNE_LR_FACTOR if trainable_ssl_params else 1.0}). SSL params count: {len(trainable_ssl_params)}")

else:
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LEARNING_RATE, weight_decay=OPTIMIZER_WEIGHT_DECAY
    )
    print("Optimizer: AdamW with single LR for all trainable parameters.")

Optimizer: AdamW with differential LR (SSL factor: 0.1). SSL params count: 42


In [47]:
criterion_reg = nn.MSELoss()
criterion_cls = nn.CrossEntropyLoss()

In [48]:
main_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                            mode=MONITOR_METRIC_MODE,
                                                            factor=LR_SCHEDULER_FACTOR,
                                                            patience=LR_SCHEDULER_PATIENCE,
                                                            min_lr=MIN_LR, verbose=False)

In [None]:
def validate_epoch(val_loader, model, device):
    model.eval()
    running_task_losses_val = {task: 0.0 for task in TASKS}
    num_samples_val = 0
    criterion_reg_val, criterion_cls_val = nn.MSELoss(reduction='sum'), nn.CrossEntropyLoss(reduction='sum')

    with torch.no_grad():
        for batch_idx, batch_data in enumerate(val_loader):
            if batch_data is None:
              continue

            wav, targets = batch_data
            if wav.numel() == 0 or not targets:
              continue

            wav, targets_device = wav.to(device), {k: v.to(device) for k, v in targets.items()}

            current_batch_size = wav.size(0)
            predictions = model(wav)
            valid_batch = True

            for task_name, task_info in TASKS.items():
                if task_name not in predictions or task_name not in targets_device:
                  valid_batch = False
                  break

                pred_val, target_val = predictions[task_name], targets_device[task_name]
                if task_info['type'] == 'regression' and pred_val.shape != target_val.shape:
                  target_val = target_val.view_as(pred_val)

                loss_val = criterion_reg_val(pred_val, target_val) if task_info['type'] == 'regression' else criterion_cls_val(pred_val, target_val)

                if torch.isnan(loss_val) or torch.isinf(loss_val):
                  running_task_losses_val[task_name] += 1e9 * current_batch_size
                  valid_batch = False
                  break

                else:
                  running_task_losses_val[task_name] += loss_val.item()


            if not valid_batch:
              continue

            num_samples_val += current_batch_size


    avg_task_losses_val, weighted_val_loss = {}, 0.0
    if num_samples_val > 0:
        for task_name in TASKS.keys():
            avg_loss = running_task_losses_val[task_name] / num_samples_val
            avg_task_losses_val[task_name] = avg_loss
            if task_name in EVAL_LOSS_WEIGHTS:
               weighted_val_loss += EVAL_LOSS_WEIGHTS[task_name] * avg_loss

    else:
        for task_name in TASKS.keys():
          avg_task_losses_val[task_name] = float('inf')

        weighted_val_loss = float('inf')



    return avg_task_losses_val, weighted_val_loss

In [None]:
print(f"\n--- Starting Training (WavLM-Conformer Model ---")
best_val_metric = float('inf') if MONITOR_METRIC_MODE == 'min' else float('-inf')
epochs_no_improve = 0
best_model_state = None
initial_lrs = [pg['lr'] for pg in optimizer.param_groups]

for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss_epoch = 0.0
    task_losses_epoch = {task: 0.0 for task in TASKS}
    num_samples_processed = 0

    # LR Warm-up
    if epoch <= LR_WARMUP_EPOCHS:
        for i, param_group in enumerate(optimizer.param_groups):
            target_lr = initial_lrs[i]
            param_group['lr'] = target_lr * (epoch / LR_WARMUP_EPOCHS)
        current_lr_display = optimizer.param_groups[0]['lr'] # Display LR of the first group

        if len(optimizer.param_groups) > 1:
            current_lr_display_ssl = optimizer.param_groups[0]['lr']
            current_lr_display_main = optimizer.param_groups[1]['lr']
            print(f"Warm-up Epoch {epoch}: LR SSL: {current_lr_display_ssl:.2e}, LR Main: {current_lr_display_main:.2e}")

        else:
            print(f"Warm-up Epoch {epoch}: LR: {current_lr_display:.2e}")


    for batch_idx, batch_data in enumerate(train_loader):
        if batch_data is None:
          continue

        wav, targets = batch_data
        if wav.numel() == 0 or not targets:
          continue

        wav, targets_device = wav.to(DEVICE), {k: v.to(DEVICE) for k, v in targets.items()}
        current_batch_size = wav.size(0)
        optimizer.zero_grad()
        predictions = model(wav)
        combined_loss_batch = torch.tensor(0.0, device=DEVICE)
        current_batch_task_losses = {}
        valid_batch_for_loss = True

        for task_name, task_info in TASKS.items():
            if task_name not in predictions or task_name not in targets_device:
              valid_batch_for_loss = False
              break

            pred, target = predictions[task_name], targets_device[task_name]

            if task_info['type'] == 'regression' and pred.shape != target.shape:
              target = target.view_as(pred)

            loss = criterion_reg(pred, target) if task_info['type'] == 'regression' else criterion_cls(pred, target)

            if torch.isnan(loss) or torch.isinf(loss):
              valid_batch_for_loss = False
              break

            combined_loss_batch += task_info['loss_weight'] * loss
            current_batch_task_losses[task_name] = loss.item()

        if valid_batch_for_loss:
            combined_loss_batch.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRADIENT_CLIP_MAX_NORM)
            optimizer.step()
            total_loss_epoch += combined_loss_batch.item() * current_batch_size

            for task_name, loss_item in current_batch_task_losses.items():
                task_losses_epoch[task_name] += loss_item * current_batch_size

            num_samples_processed += current_batch_size
            # if (batch_idx + 1) % 50 == 0:
            #      print(f"  E{epoch} B{batch_idx+1}/{len(train_loader)} | BLoss: {combined_loss_batch.item():.3f} | " + \
            #            " | ".join([f"{k[:3]}L: {v:.3f}" for k,v in current_batch_task_losses.items()]))

    avg_loss_epoch = total_loss_epoch/num_samples_processed if num_samples_processed > 0 else float('inf')
    avg_task_train_losses = {k:v/num_samples_processed if num_samples_processed > 0 else float('inf') for k,v in task_losses_epoch.items()}
    print(f"E{epoch} Summary: AvgTrL: {avg_loss_epoch:.4f} | " + \
          " | ".join([f"Avg {k[:3]}TrL: {v:.4f}" for k,v in avg_task_train_losses.items()]))

    if len(val_loader) > 0:
        avg_task_val_losses, current_val_metric = validate_epoch(val_loader, model, DEVICE)
        val_loss_str = " | ".join([f"Avg {k[:3]}VaL: {v:.4f}" for k,v in avg_task_val_losses.items()])

        current_lr_display = optimizer.param_groups[0]['lr']
        lr_info_str = f"LR: {current_lr_display:.2e}"
        if len(optimizer.param_groups) > 1:
             lr_info_str = f"LR SSL: {optimizer.param_groups[0]['lr']:.2e}, LR Main: {optimizer.param_groups[1]['lr']:.2e}"

        print(f"E{epoch} Validation: WValL: {current_val_metric:.4f} | {val_loss_str} | {lr_info_str}")

        if epoch > LR_WARMUP_EPOCHS:
            main_scheduler.step(current_val_metric)

        if (MONITOR_METRIC_MODE == 'min' and current_val_metric < best_val_metric) or \
           (MONITOR_METRIC_MODE == 'max' and current_val_metric > best_val_metric):
            best_val_metric, epochs_no_improve = current_val_metric, 0
            best_model_state = copy.deepcopy(model.state_dict())
            torch.save(best_model_state, CKPT_PATH)
            print(f"  Best model saved! E{epoch}, ValM: {best_val_metric:.4f}")

        else:
            epochs_no_improve += 1
            print(f"  No improvement for {epochs_no_improve} epoch(s). Best: {best_val_metric:.4f}")

        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"Early stopping E{epoch}. No improvement for {EARLY_STOPPING_PATIENCE} epochs.")
            break

    else:
      print("Skipping validation.")


print("--- Training Finished ---")


--- Starting Training (WavLM-Conformer Model) ---
Warm-up Epoch 1: LR SSL: 1.00e-06, LR Main: 1.00e-05




E1 Summary: AvgTrL: 0.6999 | Avg ageTrL: 0.8757 | Avg GenTrL: 0.4210 | Avg heiTrL: 0.7445
E1 Validation: WValL: 0.6698 | Avg ageVaL: 1.2315 | Avg GenVaL: 0.1324 | Avg heiVaL: 0.6209 | LR SSL: 1.00e-06, LR Main: 1.00e-05
  Best model saved! E1, ValM: 0.6698
Warm-up Epoch 2: LR SSL: 2.00e-06, LR Main: 2.00e-05
E2 Summary: AvgTrL: 0.4968 | Avg ageTrL: 0.7601 | Avg GenTrL: 0.1000 | Avg heiTrL: 0.5424
E2 Validation: WValL: 0.6361 | Avg ageVaL: 1.2453 | Avg GenVaL: 0.0517 | Avg heiVaL: 0.5867 | LR SSL: 2.00e-06, LR Main: 2.00e-05
  Best model saved! E2, ValM: 0.6361
Warm-up Epoch 3: LR SSL: 3.00e-06, LR Main: 3.00e-05
E3 Summary: AvgTrL: 0.4533 | Avg ageTrL: 0.6717 | Avg GenTrL: 0.0681 | Avg heiTrL: 0.5473
E3 Validation: WValL: 0.5090 | Avg ageVaL: 0.9510 | Avg GenVaL: 0.0334 | Avg heiVaL: 0.5759 | LR SSL: 3.00e-06, LR Main: 3.00e-05
  Best model saved! E3, ValM: 0.5090
Warm-up Epoch 4: LR SSL: 4.00e-06, LR Main: 4.00e-05
E4 Summary: AvgTrL: 0.4225 | Avg ageTrL: 0.6224 | Avg GenTrL: 0.0564 |

In [51]:
print("\n--- Evaluating on Test Set with Best Model (WavLM-Conformer) ---")
if best_model_state is not None:
    print(f"Loading best model from memory (ValM: {best_val_metric:.4f})")
    model.load_state_dict(best_model_state)
elif os.path.exists(CKPT_PATH):
    print(f"Loading best model from checkpoint: {CKPT_PATH}")
    model.load_state_dict(torch.load(CKPT_PATH, map_location=DEVICE))
else:
  print("No best model state found. Evaluating with the last model state.")


--- Evaluating on Test Set with Best Model (WavLM-Conformer) ---
Loading best model from memory (ValM: 0.4576)


In [54]:
model.eval()
all_targets_test, all_preds_test = {t:[] for t in TASKS}, {t:[] for t in TASKS}
running_task_losses_test, num_samples_test = {t:0.0 for t in TASKS}, 0
crit_reg_test, crit_cls_test = nn.MSELoss(reduction='sum'), nn.CrossEntropyLoss(reduction='sum')

with torch.no_grad():
    for batch_idx, batch_data in enumerate(test_loader):
        if batch_data is None:
          continue

        wav, targets = batch_data
        if wav.numel() == 0 or not targets:
          continue

        wav, targets_dev = wav.to(DEVICE), {k:v.to(DEVICE) for k,v in targets.items()}
        current_batch_size = wav.size(0)
        predictions = model(wav)

        for task_name, task_info in TASKS.items():
            if task_name not in predictions or task_name not in targets_dev:
              continue

            pred_v, target_v = predictions[task_name], targets_dev[task_name]
            if task_info['type'] == 'regression' and pred_v.shape != target_v.shape:
              target_v=target_v.view_as(pred_v)

            loss_v = crit_reg_test(pred_v, target_v) if task_info['type']=='regression' else crit_cls_test(pred_v, target_v)

            running_task_losses_test[task_name] += loss_v.item()
            pred_cpu, targ_cpu = pred_v.cpu(), targets[task_name]
            if task_info['type'] == 'regression':
                m, s = NORM_STATS[task_name]['mean'], NORM_STATS[task_name]['std']
                pred_denorm = (pred_cpu * (s if s > 1e-6 else 1.0)) + m
                targ_denorm = (targ_cpu * (s if s > 1e-6 else 1.0)) + m
                all_preds_test[task_name].extend(pred_denorm.tolist())
                all_targets_test[task_name].extend(targ_denorm.tolist())

            else:
                all_preds_test[task_name].extend(torch.argmax(pred_cpu, dim=1).tolist())
                all_targets_test[task_name].extend(targ_cpu.tolist())
        num_samples_test += current_batch_size
        if (batch_idx + 1) % 20 == 0:
          print(f"  Evaluated test batch {batch_idx+1}/{len(test_loader)}")




  Evaluated test batch 20/52
  Evaluated test batch 40/52


In [55]:
metrics_test, avg_task_losses_test = {}, {}
print("\n--- Final Test Results (Best WavLM-Conformer Model) ---")
if num_samples_test > 0:
    for task_name, task_info in TASKS.items():
        avg_task_losses_test[task_name] = running_task_losses_test[task_name]/num_samples_test
        # print(f"  Avg Test Loss ({task_name}): {avg_task_losses_test[task_name]:.4f}")
        targets_np, preds_np = np.array(all_targets_test[task_name]), np.array(all_preds_test[task_name])

        if len(targets_np) > 0 and len(targets_np) == len(preds_np):
            if task_info['type'] == 'regression':
                mse = mean_squared_error(targets_np, preds_np)
                metrics_test[f"{task_name}_mse"] = mse
                print(f"  Test MSE ({task_name}): {mse:.4f} (RMSE: {np.sqrt(mse):.4f})")

            else:
                acc = accuracy_score(targets_np, preds_np)
                metrics_test[f"{task_name}_acc"] = acc
                print(f"  Test Accuracy ({task_name}): {acc:.4f}")
        else:
          print(f"  Could not calculate metric for {task_name}.")


else: print("No samples processed during final test evaluation.")
print("-" * 30)


--- Final Test Results (Best WavLM-Conformer Model) ---
  Test MSE (age): 50.4608 (RMSE: 7.1036)
  Test Accuracy (Gender): 0.9933
  Test MSE (height): 52.9677 (RMSE: 7.2779)
------------------------------



--- Final Test Results (WavLM-Conformer Model) ---

with Age loss_weight: 0.4, Gender loss_weight: 0.4 and  Height loss_weight: 0.2

  * Test MSE (age): 55.3258 (RMSE: 7.4381)
  * Test Accuracy (Gender): 0.9945
  * Test MSE (height): 54.2472 (RMSE: 7.3653)


--- Final Test Results (Best WavLM-Conformer Model) ---

with Age loss_weight: 0.4, Gender loss_weight: 0.3 and Height loss_weight: 0.3
  * Test MSE (age): 50.4608 (RMSE: 7.1036)
  * Test Accuracy (Gender): 0.9933
  * Test MSE (height): 52.9677 (RMSE: 7.2779)