In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# !mkdir 'timit_dataset'
# !unzip '/content/drive/MyDrive/sps/preprocessed_timit_dataset.zip' -d timit_dataset

In [3]:
import os
import random
import time
import gc
import numpy as np
import pandas as pd
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, accuracy_score
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor, logging
from tqdm.notebook import tqdm

In [4]:
import warnings
warnings.filterwarnings("ignore")

In [5]:
logging.set_verbosity_error() # Suppress transformer warnings

In [6]:
train_df = pd.read_csv('/content/drive/MyDrive/sps/final_train_data_merged.csv')
test_df = pd.read_csv('/content/drive/MyDrive/sps/final_test_data_merged.csv')
print(f"Loaded train data: {len(train_df)} rows, test data: {len(test_df)} rows")

Loaded train data: 4490 rows, test data: 1640 rows


In [7]:
train_df.head()

Unnamed: 0,index,Use_x,DR,SpeakerID,filename,FilePath,Gender,Ethnicity,age,height
0,8,TRAIN,DR4,MMDM0,SI681.WAV,TRAIN/DR4/MMDM0/SI681.WAV,M,WHT,27.2,190.5
1,18,TRAIN,DR4,MMDM0,SA2.WAV,TRAIN/DR4/MMDM0/SA2.WAV,M,WHT,27.2,190.5
2,20,TRAIN,DR4,MMDM0,SX411.WAV,TRAIN/DR4/MMDM0/SX411.WAV,M,WHT,27.2,190.5
3,23,TRAIN,DR4,MMDM0,SA1.WAV,TRAIN/DR4/MMDM0/SA1.WAV,M,WHT,27.2,190.5
4,26,TRAIN,DR4,MMDM0,SX231.WAV,TRAIN/DR4/MMDM0/SX231.WAV,M,WHT,27.2,190.5


In [8]:
test_df.head()

Unnamed: 0,index,Use_x,DR,SpeakerID,filename,FilePath,Gender,Ethnicity,age,height
0,1,TEST,DR4,MGMM0,SX139.WAV,TEST/DR4/MGMM0/SX139.WAV,M,WHT,23.43,177.8
1,18,TEST,DR4,MGMM0,SA2.WAV,TEST/DR4/MGMM0/SA2.WAV,M,WHT,23.43,177.8
2,21,TEST,DR4,MGMM0,SX229.WAV,TEST/DR4/MGMM0/SX229.WAV,M,WHT,23.43,177.8
3,24,TEST,DR4,MGMM0,SA1.WAV,TEST/DR4/MGMM0/SA1.WAV,M,WHT,23.43,177.8
4,25,TEST,DR4,MGMM0,SX49.WAV,TEST/DR4/MGMM0/SX49.WAV,M,WHT,23.43,177.8


In [9]:
train_df.drop(columns=['index','Use_x','DR'], inplace=True)
test_df.drop(columns=['index','Use_x','DR'], inplace=True)

In [None]:
class Config:
  # --- Paths ---
  AUDIO_ROOT_DIR = '/content/timit_dataset'
  CKPT_PATH     = './sps_model.pth'

  # --- Audio Processing ---
  SAMPLE_RATE   = 16000
  CLIP_SECONDS  = 4.0
  WAV_LEN       = int(SAMPLE_RATE * CLIP_SECONDS)

  # --- Model Architecture ---
  PRETRAINED_W2V2 = 'facebook/wav2vec2-base-960h'
  FREEZE_ENCODER = True
  LSTM_HIDDEN   = 256
  LSTM_LAYERS   = 3
  DROPOUT_RATE  = 0.3

  # --- Training ---
  DEVICE        = 'cuda' if torch.cuda.is_available() else 'cpu'
  EPOCHS        = 50
  BATCH_SIZE    = 32
  LEARNING_RATE = 3e-5
  OPTIMIZER_WEIGHT_DECAY = 0.01
  USE_MIXED_PRECISION = False
  GRADIENT_CLIP_VAL = 1.0

  # --- Validation ---
  VAL_SPLIT_RATIO = 0.25
  VAL_SPLIT_SEED  = 42

  # --- Task Configuration  ---
  TASKS = { 'age': {'type': 'regression', 'loss_weight': 1.0},
            'Gender': {'type': 'classification', 'loss_weight': 1.0},
            'height': {'type': 'regression', 'loss_weight': 1.0}
          # 'Ethnicity': {'type': 'classification', 'loss_weight': 1.0}}
  }

  GENDER_MAP = {}
  # ETHNICITY_MAP = {}
  NORM_STATS = {'age': {'mean': 0.0, 'std': 1.0}, 'height': {'mean': 0.0, 'std': 1.0}}

  # --- DataLoader ---
  NUM_WORKERS   = 4 # Keep 1 or 2 to reduce RAM overhead
  PIN_MEMORY    = True 

In [11]:
cfg = Config()

In [12]:
random.seed(cfg.VAL_SPLIT_SEED)
np.random.seed(cfg.VAL_SPLIT_SEED)
torch.manual_seed(cfg.VAL_SPLIT_SEED)

if torch.cuda.is_available():
  torch.cuda.manual_seed_all(cfg.VAL_SPLIT_SEED)

In [13]:
required_cols_for_tasks = list(cfg.TASKS.keys()) + ['SpeakerID']
missing_cols = [col for col in required_cols_for_tasks if col not in train_df.columns]

if missing_cols:
  print(f"Missing required columns in train_df: {missing_cols}")

In [14]:
for task_name, task_info in cfg.TASKS.items():
  if task_info['type'] == 'regression':
    mean = train_df[task_name].mean()
    std = train_df[task_name].std()
    std = std if (np.isfinite(std) and std > 0) else 1.0 # Handle std=0 or NaN
    cfg.NORM_STATS[task_name]['mean'] = mean
    cfg.NORM_STATS[task_name]['std'] = std
    print(f"  {task_name.capitalize()} stats: Mean={mean:.2f}, Std={std:.2f}")


  elif task_info['type'] == 'classification':
    cats = train_df[task_name].astype(str).str.upper().unique()
    cats = sorted(list(cats)) # Sort for consistent mapping order
    mapping = {cat: i for i, cat in enumerate(cats)}
    num_classes = len(mapping)

    if task_name == 'Gender':
      cfg.GENDER_MAP = mapping
      print(f"Gender mapping: {cfg.GENDER_MAP}")
    # elif task_name == 'Ethnicity':
    #   cfg.ETHNICITY_MAP = mapping
    #   print(f" Ethnicity mapping: {cfg.ETHNICITY_MAP}")
    else:
      print(f"Warning: No specific map variable defined for classification task '{task_name}' in Config.")

    cfg.TASKS[task_name]['num_classes'] = num_classes
    if num_classes == 0:
      print(f"Warning: No unique categories found for '{task_name}' in training data!")

  Age stats: Mean=30.29, Std=7.77
Gender mapping: {'F': 0, 'M': 1}
  Height stats: Mean=175.75, Std=9.52


In [15]:
speaker_ids = train_df['SpeakerID'].unique()
train_spk_ids, val_spk_ids = train_test_split(
    speaker_ids,
    test_size=cfg.VAL_SPLIT_RATIO,
    random_state=cfg.VAL_SPLIT_SEED,
)

val_df = train_df[train_df['SpeakerID'].isin(val_spk_ids)].copy()
train_df = train_df[train_df['SpeakerID'].isin(train_spk_ids)].copy()
test_df = test_df.copy()

print(f"Data split: Train={len(train_df)} ({len(train_spk_ids)} spk), Val={len(val_df)} ({len(val_spk_ids)} spk), Test={len(test_df)} ({test_df['SpeakerID'].nunique()} spk)")


Data split: Train=3360 (336 spk), Val=1130 (113 spk), Test=1640 (164 spk)


In [16]:
print(f"train_df shape: {train_df.shape}")
print(f"val_df shape: {val_df.shape}")
print(f"test_df shape: {test_df.shape}")

train_df shape: (3360, 7)
val_df shape: (1130, 7)
test_df shape: (1640, 7)


In [17]:
# print("NaN entries in train_df:")
# print(train_df.isna().sum())
# print("\nNaN entries in val_df:")
# print(val_df.isna().sum())
# print("\nNaN entries in test_df:")
# print(test_df.isna().sum())

In [18]:
class PadCrop:
  def __init__(self, length, mode='train'):
    self.length = length
    self.mode = mode

  def __call__(self, wav):
    current_len = wav.shape[-1]
    if current_len == self.length:
      return wav
    elif current_len > self.length:
      start = random.randint(0, current_len - self.length) if self.mode == 'train' else (current_len - self.length) // 2
      wav = wav[..., start : start + self.length]
    else:
      pad_width = self.length - current_len
      pad_left = pad_width // 2
      pad_right = pad_width - pad_left
      wav = F.pad(wav, (pad_left, pad_right), mode='constant', value=0)
    return wav

In [19]:
class TimitDataset(Dataset):
  def __init__(self, data_df, cfg: Config, mode='train'):
    self.data_df = data_df.reset_index(drop=True)
    self.cfg = cfg
    self.mode = mode
    self.pad_crop = PadCrop(cfg.WAV_LEN, mode)
    self.target_cols = list(cfg.TASKS.keys())

    for col in self.target_cols + ['FilePath']:
        if col not in self.data_df.columns:
          print(f"Column '{col}' missing in DataFrame for mode '{mode}'.")

  def __len__(self):
    return len(self.data_df)

  def __getitem__(self, idx):
    if idx >= len(self.data_df):
      print("Index out of bounds")

    row = self.data_df.iloc[idx]
    wav_relative_path = row['FilePath']
    full_wav_path = os.path.normpath(os.path.join(self.cfg.AUDIO_ROOT_DIR, wav_relative_path))

    try:
      wav, sr = torchaudio.load(full_wav_path)
    except Exception as e:
      print(f"ERROR loading audio idx {idx}, path: {full_wav_path}. Error: {e}. Returning None.")
      return None

    if sr != self.cfg.SAMPLE_RATE:
      if not hasattr(self, f'resampler_{sr}'):
        setattr(self, f'resampler_{sr}', torchaudio.transforms.Resample(sr, self.cfg.SAMPLE_RATE))
      wav = getattr(self, f'resampler_{sr}')(wav)

    if wav.shape[0] > 1:
      wav = torch.mean(wav, dim=0, keepdim=True)
    wav = self.pad_crop(wav)
    wav = wav.squeeze(0)  # Remove channel dim -> (wav_len,)

    if torch.isnan(wav).any():
      print(f"CRITICAL NAN DETECTED in wav for item {idx}, path: {full_wav_path} BEFORE returning from dataset.")

    targets = {}
    for task_name, task_info in self.cfg.TASKS.items():
      value = row[task_name]
      if task_info['type'] == 'regression':
        mean = self.cfg.NORM_STATS[task_name]['mean']
        std = self.cfg.NORM_STATS[task_name]['std']
        norm_value = (value - mean) / std
        targets[task_name] = torch.tensor(norm_value, dtype=torch.float32)

      elif task_info['type'] == 'classification':
        value_upper = str(value).upper()
        # mapping = self.cfg.GENDER_MAP if task_name == 'Gender' else self.cfg.ETHNICITY_MAP
        if task_name == 'Gender':
          mapping = self.cfg.GENDER_MAP

        idx_value = mapping.get(value_upper, 0)  # Class 0 if unseen

        if value_upper not in mapping:
          print(f"Warning: Unmapped category '{value}' (upper: '{value_upper}') for task '{task_name}' at idx {idx}. Defaulting to 0.")
        targets[task_name] = torch.tensor(idx_value, dtype=torch.long)

    return wav, targets

In [20]:
def collate_fn(batch):  # Collate to handle None returns from dataset
  batch = [item for item in batch if item is not None]
  if not batch:
    return {}, {}

  wavs = [item[0] for item in batch]
  target_dicts = [item[1] for item in batch]
  padded_wavs = torch.nn.utils.rnn.pad_sequence(wavs, batch_first=True, padding_value=0.0)
  collated_targets = {key: torch.stack([d[key] for d in target_dicts]) for key in target_dicts[0].keys()}

  return padded_wavs, collated_targets

In [21]:
train_dataset = TimitDataset(train_df, cfg, mode='train')
val_dataset = TimitDataset(val_df, cfg, mode='eval')
test_dataset = TimitDataset(test_df, cfg, mode='eval')

train_loader = DataLoader(
    train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True,
    num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY, drop_last=True, collate_fn=collate_fn
)
val_loader = DataLoader(
    val_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False,
    num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY, collate_fn=collate_fn
)
test_loader = DataLoader(
    test_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False,
    num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY, collate_fn=collate_fn
)

In [22]:
class spsBiLSTM(nn.Module):
  def __init__(self, cfg: Config):
    super().__init__()
    self.cfg = cfg

    # Wav2Vec2.0
    self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(cfg.PRETRAINED_W2V2)
    self.encoder = Wav2Vec2Model.from_pretrained(cfg.PRETRAINED_W2V2)
    if cfg.FREEZE_ENCODER:
      for param in self.encoder.parameters(): param.requires_grad = False
    encoder_dim = self.encoder.config.hidden_size

    # LSTM layers
    self.lstm = nn.LSTM(
        input_size=encoder_dim, hidden_size=cfg.LSTM_HIDDEN, num_layers=cfg.LSTM_LAYERS,
        batch_first=True, bidirectional=True, dropout=cfg.DROPOUT_RATE if cfg.LSTM_LAYERS > 1 else 0
    )
    lstm_output_dim = cfg.LSTM_HIDDEN * 2

    # Prediction heads
    self.heads = nn.ModuleDict()
    head_input_dim = lstm_output_dim
    head_hidden_dim = head_input_dim // 2

    for task_name, task_info in cfg.TASKS.items():
      if task_info['type'] == 'regression':
        output_dim = 1

      elif task_info['type'] == 'classification':
          num_classes = task_info.get('num_classes')
          if num_classes is None or num_classes == 0:
            print(f"Num classes not set or is 0 for {task_name}")
          output_dim = num_classes
      else:
        print(f"Unknown task type: {task_info['type']}")

      self.heads[task_name] = nn.Sequential(
          nn.Linear(head_input_dim, head_hidden_dim),
          nn.ReLU(),
          nn.Dropout(cfg.DROPOUT_RATE),
          nn.Linear(head_hidden_dim, output_dim)
      )


  def forward(self, waveform):
    dev = next(self.parameters()).device

    if waveform.ndim == 2:      # (Batch, Seq_len)
      waveform_list = [wav.cpu().numpy() for wav in waveform]
    elif waveform.ndim == 1:    # (Seq_len) - for a batch size of 1, if collate_fn doesn't batch
      waveform_list = [waveform.cpu().numpy()]
    else:
      print(f"Unexpected waveform ndim: {waveform.ndim}")

    inputs = self.feature_extractor(
        waveform_list,
        sampling_rate=self.cfg.SAMPLE_RATE,
        return_tensors="pt",
        padding="longest",
        return_attention_mask=True
    )
    inputs = {k: v.to(dev) for k, v in inputs.items()}
    # if torch.isnan(inputs['input_values']).any():
    #   print("DEBUG NAN: NaNs found in 'input_values' AFTER Wav2Vec2FeatureExtractor!")

    attention_mask = inputs.get('attention_mask')
    if attention_mask is None:
      print("WARN: 'attention_mask' is None after feature_extractor. Creating default (all ones).")
      attention_mask = torch.ones_like(inputs['input_values'], device=dev, dtype=torch.long)

    hidden_states = self.encoder(inputs['input_values'], attention_mask=attention_mask).last_hidden_state
    # print(f"DEBUG STATS: hidden_states - Min: {hidden_states.min().item()}, Max: {hidden_states.max().item()}, Mean: {hidden_states.mean().item()}, Std: {hidden_states.std().item()}")
    # if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any():
    #   print("CRITICAL DEBUG: NaNs or Infs found in 'hidden_states' (input to LSTM)!")

    # if torch.isnan(inputs['input_values']).any():
    #   print("DEBUG NAN: NaNs found in 'input_values' AFTER Wav2Vec2FeatureExtractor!")

    lstm_output, _ = self.lstm(hidden_states)
    # if torch.isnan(lstm_output).any():
    #   print("DEBUG NAN: NaNs found in 'lstm_output' AFTER LSTM!")

    # mask = inputs['attention_mask'].unsqueeze(-1).repeat(1, 1, lstm_output.shape[-1])
    mask = (lstm_output != 0).type(torch.float32)

    pooled = torch.sum(lstm_output * mask, dim=1) / mask.sum(dim=1).clamp(min=1e-6)
    # if torch.isnan(pooled).any():
    #   print("DEBUG NAN: NaNs found in 'pooled' output!")

    preds = {name: head(pooled).squeeze(-1) if self.cfg.TASKS[name]['type'] == 'regression' else head(pooled) for name, head in self.heads.items()}
    return preds

In [23]:
model = spsBiLSTM(cfg).to(cfg.DEVICE)
# print(model)

preprocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.60k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/378M [00:00<?, ?B/s]

In [31]:
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=cfg.LEARNING_RATE, weight_decay=cfg.OPTIMIZER_WEIGHT_DECAY
)
print(f"Optimizer: AdamW (LR={cfg.LEARNING_RATE}, WD={cfg.OPTIMIZER_WEIGHT_DECAY})")

criterion_reg = nn.MSELoss()
criterion_cls = nn.CrossEntropyLoss()

scaler = torch.cuda.amp.GradScaler() if cfg.USE_MIXED_PRECISION and cfg.DEVICE == 'cuda' else None
if scaler:
  print("Using Mixed Precision training.")

Optimizer: AdamW (LR=3e-05, WD=0.01)


In [30]:
start_epoch = 1
best_val_metric = float('inf')

In [None]:
checkpoint = torch.load(cfg.CKPT_PATH, map_location=cfg.DEVICE)
if 'model_state_dict' in checkpoint:
  model.load_state_dict(checkpoint['model_state_dict'], strict=False) # Use strict=False if adding/removing heads
  print("  Model state loaded.")

  if 'optimizer_state_dict' in checkpoint:
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print("  Optimizer state loaded.")
  if 'epoch' in checkpoint:
    start_epoch = checkpoint['epoch'] + 1
    print(f"  Resuming from epoch {start_epoch}")
  if 'best_val_metric' in checkpoint:
    best_val_metric = checkpoint['best_val_metric']
    print(f"  Best previous metric: {best_val_metric:.4f}")
  if scaler and 'scaler_state_dict' in checkpoint:
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
    print("  Scaler state loaded.")

else:
  model.load_state_dict(checkpoint, strict=False)
  print("Loaded model state_dict directly.")

In [None]:
del checkpoint
torch.cuda.empty_cache()

In [None]:
def train_epoch(model, loader, optimizer, scaler, device, cfg):
    model.train()

    total_loss = 0.0
    task_losses = {task: 0.0 for task in cfg.TASKS}
    num_samples = 0

    criterion_reg = nn.MSELoss()
    criterion_cls = nn.CrossEntropyLoss()
    pbar = tqdm(
        loader, desc=f"Epoch {epoch} Training", leave=False, dynamic_ncols=True)

    for batch_data in pbar:
        if batch_data is None:
            continue
        wav, targets = batch_data
        wav = wav.to(device)

        # if torch.isnan(wav).any():
        #     print(
        #         f" NAN DETECTED in batched 'wav' tensor at start of train_epoch.")

        targets = {k: v.to(device) for k, v in targets.items()}
        optimizer.zero_grad()

        with torch.amp.autocast(device_type=cfg.DEVICE, enabled=(scaler is not None)):
            predictions = model(wav)

            combined_loss = 0.0
            current_batch_losses = {}

            for task_name, task_info in cfg.TASKS.items():
                pred = predictions[task_name]
                target = targets[task_name]
                loss = criterion_reg(
                    pred, target) if task_info['type'] == 'regression' else criterion_cls(pred, target)
                weight = task_info['loss_weight']
                combined_loss += weight * loss
                current_batch_losses[task_name] = loss.item()
                task_losses[task_name] += loss.item() * wav.size(0)

        if scaler:
            scaler.scale(combined_loss).backward()
            if cfg.GRADIENT_CLIP_VAL > 0:
                scaler.unscale_(optimizer)

            torch.nn.utils.clip_grad_norm_(
                model.parameters(), cfg.GRADIENT_CLIP_VAL)
            scaler.step(optimizer)
            scaler.update()

        else:
            combined_loss.backward()
            if cfg.GRADIENT_CLIP_VAL > 0:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), cfg.GRADIENT_CLIP_VAL)

            optimizer.step()

        valid_samples_in_batch = wav.size(0)
        total_loss += combined_loss.item() * valid_samples_in_batch
        num_samples += valid_samples_in_batch

        pbar_postfix = {f"{k[:3]}_L": f"{v:.2f}" for k,
                        v in current_batch_losses.items()}
        pbar_postfix["Total_L"] = f"{combined_loss.item():.3f}"
        pbar.set_postfix(pbar_postfix)

    avg_loss = total_loss / num_samples if num_samples > 0 else 0.0
    avg_task_losses = {k: v / num_samples if num_samples >
                       0 else 0.0 for k, v in task_losses.items()}
    return avg_loss, avg_task_losses

In [27]:
def evaluate(model, loader, device, cfg):
  model.eval()

  all_targets = {task: [] for task in cfg.TASKS}
  all_preds = {task: [] for task in cfg.TASKS}

  pbar = tqdm(loader, desc="Evaluating", leave=False, dynamic_ncols=True)
  with torch.no_grad():
    for batch_data in pbar:
      if batch_data is None:
        print("batch_data is None")
        continue

      wav, targets = batch_data
      wav = wav.to(device)
      predictions = model(wav)
      for task_name, task_info in cfg.TASKS.items():
        pred = predictions[task_name].cpu()
        target = targets[task_name].cpu()

        if task_info['type'] == 'regression':
          mean=cfg.NORM_STATS[task_name]['mean']
          std=cfg.NORM_STATS[task_name]['std']
          pred_denorm = (pred * std) + mean
          target_denorm = (target * std) + mean
          all_preds[task_name].extend(pred_denorm.tolist())
          all_targets[task_name].extend(target_denorm.tolist())

        else:     # classification
          pred_labels = torch.argmax(pred, dim=1)
          all_preds[task_name].extend(pred_labels.tolist())
          all_targets[task_name].extend(target.tolist())


  metrics = {}
  for task_name, task_info in cfg.TASKS.items():
      targets_np = np.array(all_targets[task_name])
      preds_np = np.array(all_preds[task_name])
      # print(all_targets)
      # print(all_preds)
      # print(targets_np)
      # print(preds_np)


      if len(targets_np) == 0:
        # print(f"Warning: No targets/preds for metric: {task_name}")
        continue
      if task_info['type'] == 'regression':
        metrics[f"{task_name}_mse"] = mean_squared_error(targets_np, preds_np)
      else: # classification
        metrics[f"{task_name}_acc"] = accuracy_score(targets_np, preds_np)

  return metrics

if the train losses are coming out to be NaN, then restart the session and run all cells

In [28]:
print(f"Running from epoch {start_epoch} to {cfg.EPOCHS}.")

for epoch in range(start_epoch, cfg.EPOCHS + 1):
  epoch_start_time = time.time()
  avg_train_loss, avg_task_losses = train_epoch(model, train_loader, optimizer, scaler, cfg.DEVICE, cfg)
  val_metrics = evaluate(model, val_loader, cfg.DEVICE, cfg)
  epoch_duration = time.time() - epoch_start_time

  task_loss_str = " | ".join([f"{k[:3]}L={v:.3f}" for k, v in avg_task_losses.items()])
  val_metrics_str = " | ".join([f"{k}={v:.3f}" for k, v in val_metrics.items()])
  print(f"Epoch {epoch}/{cfg.EPOCHS} [{epoch_duration:.1f}s] => Loss: {avg_train_loss:.4f} | {task_loss_str}")
  print(f"Validation => {val_metrics_str}")

  current_val_metric = val_metrics.get('age_mse', float('inf'))

  if current_val_metric < best_val_metric:
    best_val_metric = current_val_metric
    print(f"  Saving best model (Age MSE={best_val_metric:.3f}) to {cfg.CKPT_PATH}")
    save_content = {
          'epoch': epoch,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'best_val_metric': best_val_metric,
          'config': {k: v for k, v in vars(cfg).items() if not k.startswith('__') and type(v) in [str, int, float, bool, dict, list]},
          'norm_stats': cfg.NORM_STATS,
          'gender_map': cfg.GENDER_MAP
          # 'ethnicity_map': cfg.ETHNICITY_MAP
      }
    if scaler:
      save_content['scaler_state_dict'] = scaler.state_dict()
    torch.save(save_content, cfg.CKPT_PATH)


Running from epoch 1 to 50.


Epoch 1 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 1/50 [149.9s] => Loss: 2.5850 | ageL=0.922 | GenL=0.680 | heiL=0.983
Validation => age_mse=74.861 | Gender_acc=0.655 | height_mse=95.609
  Saving best model (Age MSE=74.861) to ./sps_model.pth


Epoch 2 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 2/50 [147.9s] => Loss: 2.5461 | ageL=0.921 | GenL=0.642 | heiL=0.982
Validation => age_mse=74.577 | Gender_acc=0.655 | height_mse=96.237
  Saving best model (Age MSE=74.577) to ./sps_model.pth


Epoch 3 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 3/50 [147.2s] => Loss: 2.5086 | ageL=0.922 | GenL=0.605 | heiL=0.982
Validation => age_mse=74.842 | Gender_acc=0.655 | height_mse=96.622


Epoch 4 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 4/50 [147.4s] => Loss: 2.4962 | ageL=0.922 | GenL=0.592 | heiL=0.982
Validation => age_mse=74.708 | Gender_acc=0.655 | height_mse=96.456


Epoch 5 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 5/50 [147.4s] => Loss: 2.4936 | ageL=0.921 | GenL=0.591 | heiL=0.982
Validation => age_mse=74.643 | Gender_acc=0.655 | height_mse=96.753


Epoch 6 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 6/50 [147.6s] => Loss: 2.4939 | ageL=0.922 | GenL=0.589 | heiL=0.983
Validation => age_mse=74.846 | Gender_acc=0.655 | height_mse=96.606


Epoch 7 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 7/50 [147.4s] => Loss: 2.4924 | ageL=0.922 | GenL=0.589 | heiL=0.982
Validation => age_mse=74.786 | Gender_acc=0.655 | height_mse=96.617


Epoch 8 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 8/50 [148.6s] => Loss: 2.4895 | ageL=0.921 | GenL=0.587 | heiL=0.981
Validation => age_mse=74.734 | Gender_acc=0.655 | height_mse=96.206


Epoch 9 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 9/50 [148.5s] => Loss: 2.4937 | ageL=0.922 | GenL=0.590 | heiL=0.982
Validation => age_mse=74.790 | Gender_acc=0.655 | height_mse=96.285


Epoch 10 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 10/50 [151.0s] => Loss: 2.4918 | ageL=0.922 | GenL=0.589 | heiL=0.981
Validation => age_mse=74.684 | Gender_acc=0.655 | height_mse=96.669


Epoch 11 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 11/50 [150.4s] => Loss: 2.4906 | ageL=0.921 | GenL=0.588 | heiL=0.981
Validation => age_mse=74.752 | Gender_acc=0.655 | height_mse=96.502


Epoch 12 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 12/50 [150.1s] => Loss: 2.4910 | ageL=0.922 | GenL=0.588 | heiL=0.981
Validation => age_mse=74.784 | Gender_acc=0.655 | height_mse=96.459


Epoch 13 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 13/50 [148.9s] => Loss: 2.4919 | ageL=0.922 | GenL=0.589 | heiL=0.981
Validation => age_mse=74.813 | Gender_acc=0.655 | height_mse=96.662


Epoch 14 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 14/50 [148.9s] => Loss: 2.4917 | ageL=0.921 | GenL=0.589 | heiL=0.981
Validation => age_mse=74.767 | Gender_acc=0.655 | height_mse=96.406


Epoch 15 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 15/50 [148.8s] => Loss: 2.4909 | ageL=0.921 | GenL=0.588 | heiL=0.982
Validation => age_mse=74.749 | Gender_acc=0.655 | height_mse=96.671


Epoch 16 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 16/50 [149.2s] => Loss: 2.4926 | ageL=0.922 | GenL=0.589 | heiL=0.981
Validation => age_mse=74.713 | Gender_acc=0.655 | height_mse=96.478


Epoch 17 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 17/50 [148.3s] => Loss: 2.4918 | ageL=0.922 | GenL=0.589 | heiL=0.982
Validation => age_mse=74.765 | Gender_acc=0.655 | height_mse=96.552


Epoch 18 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 18/50 [148.4s] => Loss: 2.4911 | ageL=0.921 | GenL=0.588 | heiL=0.981
Validation => age_mse=74.781 | Gender_acc=0.655 | height_mse=96.484


Epoch 19 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 19/50 [147.8s] => Loss: 2.4913 | ageL=0.922 | GenL=0.588 | heiL=0.982
Validation => age_mse=74.759 | Gender_acc=0.655 | height_mse=96.546


Epoch 20 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch 20/50 [148.4s] => Loss: 2.4925 | ageL=0.922 | GenL=0.589 | heiL=0.982
Validation => age_mse=74.798 | Gender_acc=0.655 | height_mse=96.498


Epoch 21 Training:   0%|          | 0/105 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [29]:
if os.path.isfile(cfg.CKPT_PATH):
  print(f"Loading best model from: {cfg.CKPT_PATH}")
  try:
    ckpt = torch.load(cfg.CKPT_PATH, map_location=cfg.DEVICE)
    model.load_state_dict(ckpt.get('model_state_dict', ckpt), strict=False)

    del ckpt
    torch.cuda.empty_cache()
    gc.collect()

  except Exception as e:
    print(f"  Error loading best model: {e}. Evaluating current state.")

else:
  print(f"Warning: Best checkpoint '{cfg.CKPT_PATH}' not found. Evaluating final model state.")


if test_loader is None or len(test_loader.dataset) == 0:
  print("Skipping test evaluation: Test data empty.")

else:
  print("Running final evaluation...")
  test_metrics = evaluate(model, test_loader, cfg.DEVICE, cfg)

  print(f"-------------------------")
  print(f"Test Set Results:")
  for metric, value in test_metrics.items():
    print(f"  {metric.replace('_',' ').capitalize():<20}: {value:.3f}")
  print(f"-------------------------")

Loading best model from: ./sps_model.pth
  Error loading best model: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy._core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([scalar])` or the `torch.serialization.safe_globals([scalar])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to 

Evaluating:   0%|          | 0/52 [00:00<?, ?it/s]

-------------------------
Test Set Results:
  Age mse             : 72.888
  Gender acc          : 0.659
  Height mse          : 83.043
-------------------------


## Conclusion

### After training spsBiLSTM model on T4 GPU for 20 epochs, the final evaluation metrics are:

* RMSE Age = 8.53, implies on average model's guess on age is off by 8.5 years

* RMSE Age = 9.11, implies on average model's guess on age is off by 9.11 cm

* Gender accuracy = 65.9%