In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
# !mkdir 'timit_dataset'
# !unzip '/content/drive/MyDrive/sps/preprocessed_timit_dataset.zip' -d timit_dataset

In [3]:
import os
import random
import time
import numpy as np
import pandas as pd
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, accuracy_score
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor, logging as hf_logging
from tqdm.notebook import tqdm
import warnings
import copy

In [4]:
# torch.cuda.empty_cache()

In [5]:
warnings.filterwarnings("ignore")
hf_logging.set_verbosity_error()
# torch.autograd.set_detect_anomaly(True)

For this config, L4 GPU is required

In [None]:
AUDIO_ROOT_DIR = '/content/timit_dataset'
DRIVE_PATH = '/content/drive/MyDrive/sps/'

TRAIN_CSV_PATH = os.path.join(DRIVE_PATH, 'final_train_data_merged.csv')
TEST_CSV_PATH = os.path.join(DRIVE_PATH, 'final_test_data_merged.csv')
CKPT_PATH = os.path.join(DRIVE_PATH, 'sps_bilstm_best_model.pth')

# Audio Processing
SAMPLE_RATE = 16000
CLIP_SECONDS = 4.0
WAV_LEN = int(SAMPLE_RATE * CLIP_SECONDS)

# Audio Augmentations
APPLY_AUGMENTATIONS = True
AUGMENTATION_PROBABILITY = 0.20
ADD_NOISE_PROBABILITY = 0.20
NOISE_SNR_MIN = 5.0
NOISE_SNR_MAX = 20.0

# Wav2Vec2.0 Base Model
PRETRAINED_W2V2 = 'facebook/wav2vec2-base-960h'
W2V2_OUTPUT_DIM = 768
FREEZE_ENCODER = True

# BiLSTM Specific Parameters
LSTM_HIDDEN_SIZE = 256
LSTM_NUM_LAYERS = 4

# Dropout
MODEL_DROPOUT_RATE = 0.3

# Training Hyperparameters
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 50
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
OPTIMIZER_WEIGHT_DECAY = 0.01

# Validation
VAL_SPLIT_RATIO = 0.3
VAL_SPLIT_SEED = 42

# Task Configuration
TASKS = {
    'age': {'type': 'regression', 'loss_weight': 0.4},
    'Gender': {'type': 'classification', 'loss_weight': 0.4},
    'height': {'type': 'regression', 'loss_weight': 0.2}
}

# Head-Specific Hyperparameters
HEAD_CONFIGS = {
    'age': {'head_hidden_dim': 128, 'head_dropout_rate': 0.3},
    'Gender': {'head_hidden_dim': 128, 'head_dropout_rate': 0.3},
    'height': {'head_hidden_dim': 64, 'head_dropout_rate': 0.2}
}

# Mappings and Normalization Stats
GENDER_MAP = {}
NORM_STATS = {
    'age': {'mean': 0.0, 'std': 1.0},
    'height': {'mean': 0.0, 'std': 1.0}
}

# DataLoader
NUM_WORKERS = 4
PIN_MEMORY = True if DEVICE == 'cuda' else False

# --- Model Saving, Early Stopping, LR Scheduling ---
MONITOR_METRIC_MODE = 'min'
EARLY_STOPPING_PATIENCE = 15
LR_SCHEDULER_PATIENCE = 5
LR_SCHEDULER_FACTOR = 0.1
MIN_LR = 3e-5


EVAL_LOSS_WEIGHTS = {
    'age': 0.45,
    'Gender': 0.45,
    'height': 0.10
}

In [None]:
class PadCrop:
    def __init__(self, length, mode = 'train'):
        self.length = length
        self.mode = mode

    def __call__(self, wav):
        current_len = wav.shape[-1]
        if current_len == self.length:
            return wav
        elif current_len > self.length:
            if self.mode == 'train':
                start = random.randint(0, current_len - self.length)
            else:
                start = (current_len - self.length) // 2
            wav = wav[..., start : start + self.length]
        else:
            pad_width = self.length - current_len
            pad_left = pad_width // 2
            pad_right = pad_width - pad_left
            wav = F.pad(wav, (pad_left, pad_right), mode='constant', value=0.0)
        return wav

class AddGaussianNoise:
    def __init__(self, min_snr_db = 5.0, max_snr_db = 20.0, p = 0.5):
        self.min_snr_db = min_snr_db
        self.max_snr_db = max_snr_db
        self.p = p

    def __call__(self, waveform):
        if random.random() < self.p:
            if waveform.ndim > 1 and waveform.shape[0] > 1:
                wav_data_for_power_calc = waveform[0,:]
            elif waveform.ndim > 1 and waveform.shape[0] == 1:
                wav_data_for_power_calc = waveform.squeeze(0)
            elif waveform.ndim == 1:
                wav_data_for_power_calc = waveform
            else:
                return waveform

            signal_power = torch.mean(wav_data_for_power_calc**2)
            if signal_power.item() < 1e-9:
                return waveform

            snr_db = random.uniform(self.min_snr_db, self.max_snr_db)
            snr_linear = 10**(snr_db / 10.0)
            noise_power = signal_power / snr_linear
            noise = torch.randn_like(waveform) * torch.sqrt(noise_power)
            return waveform + noise
        return waveform

In [None]:
class TimitDataset(Dataset):
    def __init__(self, data_df, mode= 'train'):
        self.data_df = data_df.reset_index(drop=True)
        self.mode = mode
        self.pad_crop = PadCrop(WAV_LEN, mode)
        self.target_cols = list(TASKS.keys())

        self.add_noise_transform = None
        if self.mode == 'train' and APPLY_AUGMENTATIONS:
            self.add_noise_transform = AddGaussianNoise(
                min_snr_db=NOISE_SNR_MIN,
                max_snr_db=NOISE_SNR_MAX,
                p=ADD_NOISE_PROBABILITY
            )

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):
        if idx >= len(self.data_df):
            print(f"Index {idx} out of bounds for dataset size {len(self.data_df)}")
            return None

        row = self.data_df.iloc[idx]
        wav_relative_path = row['FilePath']
        full_wav_path = os.path.normpath(os.path.join(AUDIO_ROOT_DIR, wav_relative_path))

        wav, sr = torchaudio.load(full_wav_path)

        if sr != SAMPLE_RATE:
            resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
            wav = resampler(wav)

        if wav.shape[0] > 1:
            wav = torch.mean(wav, dim=0, keepdim=True)

        if self.mode == 'train' and APPLY_AUGMENTATIONS:
            if random.random() < AUGMENTATION_PROBABILITY:
                if self.add_noise_transform:
                    wav = self.add_noise_transform(wav)

        wav = self.pad_crop(wav)
        wav = wav.squeeze(0)

        if torch.isnan(wav).any():
            print(f"NaNs in wav for item {idx} AFTER processing, path: {full_wav_path}. Returning None.")
            return None

        targets = {}
        valid_item = True
        for task_name, task_info in TASKS.items():
            value = row[task_name]
            if pd.isna(value):
                print(f"NaN target for task '{task_name}' at idx {idx}. Returning None.")
                valid_item = False
                break

            if task_info['type'] == 'regression':
                mean = NORM_STATS[task_name]['mean']
                std = NORM_STATS[task_name]['std']
                norm_value = (value - mean) / (std if std > 1e-6 else 1.0)
                targets[task_name] = torch.tensor(norm_value, dtype=torch.float32)
            elif task_info['type'] == 'classification':
                value_upper = str(value).upper()
                mapping = GENDER_MAP
                idx_value = mapping.get(value_upper, 0)
                targets[task_name] = torch.tensor(idx_value, dtype=torch.long)

        if not valid_item:
            return None
        return wav, targets

In [None]:
def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if not batch:
        return None

    wavs = [item[0] for item in batch]
    target_dicts = [item[1] for item in batch]
    padded_wavs = torch.nn.utils.rnn.pad_sequence(wavs, batch_first=True, padding_value=0.0)

    collated_targets = {}
    if target_dicts:
        first_item_keys = target_dicts[0].keys()
        for key in first_item_keys:
            if all(key in d for d in target_dicts):
                 collated_targets[key] = torch.stack([d[key] for d in target_dicts])

    if not collated_targets and target_dicts:
         return None
    return padded_wavs, collated_targets

In [None]:
class sps_BiLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(PRETRAINED_W2V2)
        self.wav2vec2_encoder = Wav2Vec2Model.from_pretrained(PRETRAINED_W2V2)

        if FREEZE_ENCODER:
            for param in self.wav2vec2_encoder.parameters():
                param.requires_grad = False

        encoder_output_dim = W2V2_OUTPUT_DIM
        self.bilstm = nn.LSTM(
            input_size=encoder_output_dim, hidden_size=LSTM_HIDDEN_SIZE,
            num_layers=LSTM_NUM_LAYERS, batch_first=True,
            bidirectional=True, dropout=MODEL_DROPOUT_RATE if LSTM_NUM_LAYERS > 1 else 0
        )
        bilstm_output_dim = LSTM_HIDDEN_SIZE * 2
        self.heads = nn.ModuleDict()
        head_common_input_dim = bilstm_output_dim

        for task_name, task_info in TASKS.items():
            head_specific_config = HEAD_CONFIGS[task_name]
            head_hidden_dim = head_specific_config['head_hidden_dim']
            head_dropout_rate = head_specific_config['head_dropout_rate']
            output_dim = 1 if task_info['type'] == 'regression' else task_info.get('num_classes')
            if output_dim is None and task_info['type'] == 'classification':
                raise ValueError(f"num_classes not set for classification task '{task_name}'")

            self.heads[task_name] = nn.Sequential(
                nn.Linear(head_common_input_dim, head_hidden_dim), nn.ReLU(),
                nn.Dropout(head_dropout_rate), nn.Linear(head_hidden_dim, output_dim)
            )

    def forward(self, waveform):
        current_device = next(self.parameters()).device
        waveform_list = [wav.cpu().numpy() for wav in waveform] if waveform.ndim == 2 else [waveform.cpu().numpy()]

        inputs = self.feature_extractor(
            waveform_list, sampling_rate=SAMPLE_RATE, return_tensors="pt",
            padding="longest", return_attention_mask=True
        )
        inputs = {k: v.to(current_device) for k, v in inputs.items()}
        attention_mask_input = inputs.get('attention_mask')
        if attention_mask_input is None:
            attention_mask_input = torch.ones(inputs['input_values'].shape[0], inputs['input_values'].shape[1], dtype=torch.long, device=current_device)

        w2v2_outputs = self.wav2vec2_encoder(inputs['input_values'], attention_mask=attention_mask_input)
        hidden_states = w2v2_outputs.last_hidden_state
        bilstm_output, _ = self.bilstm(hidden_states)

        actual_sequence_length = hidden_states.shape[1]
        # Correct attention_mask_output slicing
        if attention_mask_input.shape[1] >= actual_sequence_length:
            attention_mask_output = attention_mask_input[:, :actual_sequence_length]
        else: 
            attention_mask_output = torch.ones(hidden_states.shape[0], actual_sequence_length, device=current_device, dtype=torch.long)


        expanded_attention_mask = attention_mask_output.unsqueeze(-1).expand_as(bilstm_output)
        masked_bilstm_output = bilstm_output * expanded_attention_mask
        summed_bilstm_output = torch.sum(masked_bilstm_output, dim=1)
        valid_lengths = expanded_attention_mask.sum(dim=1).clamp(min=1e-9)
        pooled_output = summed_bilstm_output / valid_lengths

        predictions = {}
        for task_name, head_module in self.heads.items():
            task_prediction = head_module(pooled_output)
            if TASKS[task_name]['type'] == 'regression':
                predictions[task_name] = task_prediction.squeeze(-1)
            else:
                predictions[task_name] = task_prediction
        return predictions

In [None]:
print(f"Using device: {DEVICE}")

random.seed(VAL_SPLIT_SEED)
np.random.seed(VAL_SPLIT_SEED)
torch.manual_seed(VAL_SPLIT_SEED)

if DEVICE == 'cuda':
    torch.cuda.manual_seed_all(VAL_SPLIT_SEED)

Using device: cuda


In [12]:
full_train_df_raw = pd.read_csv(TRAIN_CSV_PATH)
test_df_raw = pd.read_csv(TEST_CSV_PATH)

print(f"Loaded raw train data: {len(full_train_df_raw)} rows, raw test data: {len(test_df_raw)} rows")

Loaded raw train data: 4490 rows, raw test data: 1640 rows


In [13]:
cols_to_drop = ['index', 'Use_x', 'DR', 'Ethnicity']

full_train_df = full_train_df_raw.drop(columns=[col for col in cols_to_drop if col in full_train_df_raw.columns], errors='ignore')
test_df = test_df_raw.drop(columns=[col for col in cols_to_drop if col in test_df_raw.columns], errors='ignore')

In [None]:
critical_cols = ['FilePath'] + list(TASKS.keys())
full_train_df.dropna(subset=critical_cols, inplace=True)
test_df.dropna(subset=critical_cols, inplace=True)
full_train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

print(f"Processed train data: {len(full_train_df)} rows, Processed test data: {len(test_df)} rows")

Processed train data: 4490 rows, Processed test data: 1640 rows


In [15]:
print("Calculating normalization stats and mappings...")
for task_name, task_info in TASKS.items():
    if task_info['type'] == 'regression':
        mean = full_train_df[task_name].astype(float).mean()
        std = full_train_df[task_name].astype(float).std()
        NORM_STATS[task_name]['mean'] = mean
        NORM_STATS[task_name]['std'] = std if (np.isfinite(std) and std > 1e-6) else 1.0
        print(f"  {task_name.capitalize()} stats: Mean={NORM_STATS[task_name]['mean']:.2f}, Std={NORM_STATS[task_name]['std']:.2f}")
    elif task_info['type'] == 'classification':
        if task_name == 'Gender':
            cats = sorted(list(full_train_df[task_name].astype(str).str.upper().unique()))
            TASKS[task_name]['num_classes'] = len(cats)
            GENDER_MAP = {cat: i for i, cat in enumerate(cats)}
            print(f"  Gender mapping: {GENDER_MAP}, Num Classes: {TASKS[task_name]['num_classes']}")
        else:
             unique_values = full_train_df[task_name].unique()
             TASKS[task_name]['num_classes'] = len(unique_values)
             print(f"  {task_name.capitalize()} Num Classes: {TASKS[task_name]['num_classes']}")

Calculating normalization stats and mappings...
  Age stats: Mean=30.29, Std=7.77
  Gender mapping: {'F': 0, 'M': 1}, Num Classes: 2
  Height stats: Mean=175.75, Std=9.52


In [16]:
if 'SpeakerID' in full_train_df.columns:
    print("Splitting data by SpeakerID...")
    speaker_ids = full_train_df['SpeakerID'].unique()
    train_spk_ids, val_spk_ids = train_test_split(speaker_ids, test_size=VAL_SPLIT_RATIO, random_state=VAL_SPLIT_SEED)
    train_df = full_train_df[full_train_df['SpeakerID'].isin(train_spk_ids)].copy()
    val_df = full_train_df[full_train_df['SpeakerID'].isin(val_spk_ids)].copy()
else:
    print("Warning: 'SpeakerID' not found. Performing random split.")
    train_df, val_df = train_test_split(full_train_df, test_size=VAL_SPLIT_RATIO, random_state=VAL_SPLIT_SEED)
print(f"Data split: Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}")


Splitting data by SpeakerID...
Data split: Train=3140, Val=1350, Test=1640


In [17]:
train_dataset = TimitDataset(train_df, mode='train')
val_dataset = TimitDataset(val_df, mode='eval')
test_dataset = TimitDataset(test_df, mode='eval')

In [18]:
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn, drop_last=True
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn
)

In [19]:
model = sps_BiLSTM().to(DEVICE)
# print(model)

In [20]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LEARNING_RATE, weight_decay=OPTIMIZER_WEIGHT_DECAY
)

In [21]:
criterion_reg = nn.MSELoss()
criterion_cls = nn.CrossEntropyLoss()

In [22]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode=MONITOR_METRIC_MODE, factor=LR_SCHEDULER_FACTOR,
    patience=LR_SCHEDULER_PATIENCE, min_lr=MIN_LR, verbose=True
)

In [None]:
def validate_epoch(val_loader, model, device):
    model.eval()
    running_task_losses_val = {task: 0.0 for task in TASKS}
    num_samples_val = 0
    all_targets_val = {task: [] for task in TASKS}
    all_preds_val = {task: [] for task in TASKS}

    criterion_reg_val = nn.MSELoss(reduction='sum')
    criterion_cls_val = nn.CrossEntropyLoss(reduction='sum')

    with torch.no_grad():
        for batch_idx, batch_data in enumerate(val_loader):
            if batch_data is None: continue
            wav, targets = batch_data
            if wav.numel() == 0 or not targets: continue

            wav = wav.to(device)
            targets = {k: v.to(device) for k, v in targets.items()}
            current_batch_size = wav.size(0)

            predictions = model(wav)
            valid_batch = True
            for task_name, task_info in TASKS.items():
                if task_name not in predictions or task_name not in targets:
                    print(f"Warning: Missing pred/target for task '{task_name}' in val batch {batch_idx}")
                    valid_batch = False
                    break
                pred_val = predictions[task_name]
                target_val = targets[task_name]
                if task_info['type'] == 'regression' and pred_val.shape != target_val.shape:
                    target_val = target_val.view_as(pred_val)

                loss_val = criterion_reg_val(pred_val, target_val) if task_info['type'] == 'regression' \
                           else criterion_cls_val(pred_val, target_val)

                if torch.isnan(loss_val) or torch.isinf(loss_val):
                    print(f"NaN/Inf val loss for task {task_name}, batch {batch_idx}. Setting to high value.")
                    running_task_losses_val[task_name] += 1e9 * current_batch_size
                    valid_batch = False
                    break
                else:
                    running_task_losses_val[task_name] += loss_val.item()
            if not valid_batch:
              continue

            num_samples_val += current_batch_size

    avg_task_losses_val = {}
    weighted_val_loss = 0.0
    if num_samples_val > 0:
        for task_name in TASKS.keys():
            avg_loss = running_task_losses_val[task_name] / num_samples_val
            avg_task_losses_val[task_name] = avg_loss
            if task_name in EVAL_LOSS_WEIGHTS:
                 weighted_val_loss += EVAL_LOSS_WEIGHTS[task_name] * avg_loss
    else:
        for task_name in TASKS.keys(): avg_task_losses_val[task_name] = float('inf')
        weighted_val_loss = float('inf')

    return avg_task_losses_val, weighted_val_loss

In [24]:
print("\n--- Starting Training ---")
best_val_metric = float('inf') if MONITOR_METRIC_MODE == 'min' else float('-inf')
epochs_no_improve = 0
best_model_state = None

for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss_epoch = 0.0
    task_losses_epoch = {task: 0.0 for task in TASKS}
    num_samples_processed = 0

    for batch_idx, batch_data in enumerate(train_loader):
        if batch_data is None: continue
        wav, targets = batch_data
        if wav.numel() == 0 or not targets:
          continue

        wav = wav.to(DEVICE)
        targets = {k: v.to(DEVICE) for k, v in targets.items()}
        current_batch_size = wav.size(0)
        optimizer.zero_grad()
        predictions = model(wav)
        combined_loss_batch = torch.tensor(0.0, device=DEVICE)
        current_batch_task_losses = {}
        valid_batch_for_loss = True

        for task_name, task_info in TASKS.items():
            if task_name not in predictions or task_name not in targets:
                 valid_batch_for_loss = False
                 break
            pred = predictions[task_name]
            target = targets[task_name]
            if task_info['type'] == 'regression' and pred.shape != target.shape:
                target = target.view_as(pred)
            loss = criterion_reg(pred, target) if task_info['type'] == 'regression' else criterion_cls(pred, target)
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"NaN/Inf train loss for task {task_name}, batch {batch_idx+1}. Skipping batch update.")
                valid_batch_for_loss = False
                break
            combined_loss_batch += task_info['loss_weight'] * loss
            current_batch_task_losses[task_name] = loss.item()

        if valid_batch_for_loss:
            combined_loss_batch.backward()
            optimizer.step()
            total_loss_epoch += combined_loss_batch.item() * current_batch_size
            for task_name, loss_item in current_batch_task_losses.items():
                task_losses_epoch[task_name] += loss_item * current_batch_size
            num_samples_processed += current_batch_size
            if (batch_idx + 1) % 50 == 0:
                 loss_str = " | ".join([f"{k[:3]}L: {v:.3f}" for k,v in current_batch_task_losses.items()])
                #  print(f"  Epoch {epoch}/{EPOCHS} | Batch {batch_idx+1}/{len(train_loader)} | Batch Loss: {combined_loss_batch.item():.4f} | {loss_str}")

    avg_loss_epoch = total_loss_epoch / num_samples_processed if num_samples_processed > 0 else float('inf')
    avg_task_train_losses_epoch = {k: v / num_samples_processed if num_samples_processed > 0 else float('inf') for k, v in task_losses_epoch.items()}
    train_loss_str = " | ".join([f"Avg {k[:3]} TrL: {v:.4f}" for k, v in avg_task_train_losses_epoch.items()])
    print(f"Epoch {epoch} Summary: Avg Train Loss: {avg_loss_epoch:.4f} | {train_loss_str}")

    # Validation
    if len(val_loader) > 0:
        avg_task_val_losses, current_val_metric = validate_epoch(val_loader, model, DEVICE)
        val_loss_str = " | ".join([f"Avg {k[:3]} VaL: {v:.4f}" for k,v in avg_task_val_losses.items()])
        print(f"Epoch {epoch} Validation: Weighted Val Loss: {current_val_metric:.4f} | {val_loss_str}")

        scheduler.step(current_val_metric)

        # Save Best Model
        if (MONITOR_METRIC_MODE == 'min' and current_val_metric < best_val_metric) or \
           (MONITOR_METRIC_MODE == 'max' and current_val_metric > best_val_metric):
            best_val_metric = current_val_metric
            best_model_state = copy.deepcopy(model.state_dict())
            torch.save(best_model_state, CKPT_PATH)
            print(f"  Best model saved! Epoch {epoch}, Val Metric: {best_val_metric:.4f}")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            print(f"  No improvement in val metric for {epochs_no_improve} epoch(s). Best: {best_val_metric:.4f}")

        # Early Stopping
        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"Early stopping triggered after {epoch} epochs due to no improvement for {EARLY_STOPPING_PATIENCE} epochs.")
            break
    else:
        print("Skipping validation as val_loader is empty.")

print("--- Training Finished ---")


--- Starting Training ---
Epoch 1 Summary: Avg Train Loss: 0.8179 | Avg age TrL: 0.9346 | Avg Gen TrL: 0.6097 | Avg hei TrL: 1.0008
Epoch 1 Validation: Weighted Val Loss: 0.9058 | Avg age VaL: 1.1623 | Avg Gen VaL: 0.6241 | Avg hei VaL: 1.0197
  Best model saved! Epoch 1, Val Metric: 0.9058
Epoch 2 Summary: Avg Train Loss: 0.8117 | Avg age TrL: 0.9294 | Avg Gen TrL: 0.5994 | Avg hei TrL: 1.0008
Epoch 2 Validation: Weighted Val Loss: 0.9117 | Avg age VaL: 1.1704 | Avg Gen VaL: 0.6322 | Avg hei VaL: 1.0053
  No improvement in val metric for 1 epoch(s). Best: 0.9058
Epoch 3 Summary: Avg Train Loss: 0.8121 | Avg age TrL: 0.9310 | Avg Gen TrL: 0.6000 | Avg hei TrL: 0.9987
Epoch 3 Validation: Weighted Val Loss: 0.9107 | Avg age VaL: 1.1665 | Avg Gen VaL: 0.6339 | Avg hei VaL: 1.0050
  No improvement in val metric for 2 epoch(s). Best: 0.9058
Epoch 4 Summary: Avg Train Loss: 0.8125 | Avg age TrL: 0.9287 | Avg Gen TrL: 0.6025 | Avg hei TrL: 0.9999
Epoch 4 Validation: Weighted Val Loss: 0.9065

In [None]:
print("\n--- Evaluating on Test Set with Best Model ---")
if best_model_state is not None:
    print(f"Loading best model from memory (Val Metric: {best_val_metric:.4f}) for final test.")
    model.load_state_dict(best_model_state)
elif os.path.exists(CKPT_PATH): 
    print(f"Loading best model from checkpoint: {CKPT_PATH}")
    model.load_state_dict(torch.load(CKPT_PATH, map_location=DEVICE))
else:
    print("No best model state found. Evaluating with the last model state.")

model.eval()
all_targets_test = {task: [] for task in TASKS}
all_preds_test = {task: [] for task in TASKS}
running_task_losses_test = {task: 0.0 for task in TASKS}
num_samples_test = 0

criterion_reg_test = nn.MSELoss(reduction='sum')
criterion_cls_test = nn.CrossEntropyLoss(reduction='sum')

with torch.no_grad():
    for batch_idx, batch_data in enumerate(test_loader):
        if batch_data is None: continue
        wav, targets = batch_data
        if wav.numel() == 0 or not targets: continue

        wav = wav.to(DEVICE)
        targets_device = {k: v.to(DEVICE) for k, v in targets.items()}
        current_batch_size = wav.size(0)
        predictions = model(wav)

        for task_name, task_info in TASKS.items():
            if task_name not in predictions or task_name not in targets_device: continue
            pred_val = predictions[task_name]
            target_val = targets_device[task_name]
            if task_info['type'] == 'regression' and pred_val.shape != target_val.shape:
                target_val = target_val.view_as(pred_val)

            loss_val = criterion_reg_test(pred_val, target_val) if task_info['type'] == 'regression' \
                       else criterion_cls_test(pred_val, target_val)
            running_task_losses_test[task_name] += loss_val.item()

            pred_cpu = pred_val.cpu()
            target_cpu = targets[task_name] 
            if task_info['type'] == 'regression':
                mean, std = NORM_STATS[task_name]['mean'], NORM_STATS[task_name]['std']
                pred_denorm = (pred_cpu * (std if std > 1e-6 else 1.0)) + mean
                target_denorm = (target_cpu * (std if std > 1e-6 else 1.0)) + mean
                all_preds_test[task_name].extend(pred_denorm.tolist())
                all_targets_test[task_name].extend(target_denorm.tolist())
            else:
                pred_labels = torch.argmax(pred_cpu, dim=1)
                all_preds_test[task_name].extend(pred_labels.tolist())
                all_targets_test[task_name].extend(target_cpu.tolist())
        num_samples_test += current_batch_size
        if (batch_idx + 1) % 20 == 0: print(f"  Evaluated test batch {batch_idx+1}/{len(test_loader)}")


--- Evaluating on Test Set with Best Model ---
Loading best model from memory (Val Metric: 0.9047) for final test.
  Evaluated test batch 20/52
  Evaluated test batch 40/52


In [None]:
metrics_test = {}
avg_task_losses_test = {}
print("\n--- Final Test Results (Best Model) ---")
if num_samples_test > 0:
    for task_name, task_info in TASKS.items():
        avg_task_losses_test[task_name] = running_task_losses_test[task_name] / \
            num_samples_test
        # print(f"  Avg Test Loss ({task_name}): {avg_task_losses_test[task_name]:.4f}")
        targets_np = np.array(all_targets_test[task_name])
        preds_np = np.array(all_preds_test[task_name])
        if len(targets_np) > 0 and len(targets_np) == len(preds_np):
            if task_info['type'] == 'regression':
                mse = mean_squared_error(targets_np, preds_np)
                metrics_test[f"{task_name}_mse"] = mse
                print(
                    f"  Test MSE ({task_name}): {mse:.4f} (RMSE: {np.sqrt(mse):.4f})")
            else:
                acc = accuracy_score(targets_np, preds_np)
                metrics_test[f"{task_name}_acc"] = acc
                print(f"  Test Accuracy ({task_name}): {acc:.4f}")
        else:
            metric_key = f"{task_name}_{'mse' if task_info['type'] == 'regression' else 'acc'}"
            metrics_test[metric_key] = float('nan')
            print(
                f"  Could not calculate metric for {task_name} (data issue).")
else:
    print("No samples processed during final test evaluation.")
print("-" * 30)


--- Final Test Results (Best Model) ---
  Test MSE (age): 72.5292 (RMSE: 8.5164)
  Test Accuracy (Gender): 0.6585
  Test MSE (height): 82.7372 (RMSE: 9.0960)
------------------------------



--- Final Test Results (Best Model) --- 2 layers BiLSTM
  * Test MSE (age): 72.5123 (RMSE: 8.5154)
  * Test Accuracy (Gender): 0.6585
  * Test MSE (height): 82.8730 (RMSE: 9.1035)


--- Final Test Results --- 3 layers BiLSTM
 * Test MSE (age): 73.1914 (RMSE: 8.5551)
 * Test Accuracy (Gender): 0.6585
 * Test MSE (height): 80.2048 (RMSE: 8.9557)

--- Final Test Results (Best Model) --- 4 layers BiLSTM
  * Test MSE (age): 72.5292 (RMSE: 8.5164)
  * Test Accuracy (Gender): 0.6585
  * Test MSE (height): 82.7372 (RMSE: 9.0960)