In [3]:
import os
import sys
import time
import torch
from datetime import datetime
from torch.utils.data import DataLoader

# Import utils
from utils.Logger import Logger
from utils.Seed import set_seed
from utils.Splitter import stratified_split
from classes.FeatureDataset.CombinedFeatureDataset import CombinedFeatureDataset

seed = 42
set_seed(42)

# Load full dataset
print("==================== LOADING DATASET ====================\n")
full_dataset = CombinedFeatureDataset("preprocessed_data/combined")
print("\n==================== DATASET LOADED ====================\n")

# Stratified dataset split
print("\n==================== SPLITTING DATASET ====================\n")
train_dataset, val_dataset, test_dataset = stratified_split(full_dataset, splits=(0.7, 0.15, 0.15), seed=seed)

# Looping to do some variations on the models' parameters
batch_sizes = [32]
learning_rates = [0.0001]
epochs = 5

# Print dataset sizes
print(f"Total samples: {len(full_dataset)}")
print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")


Skipped 0 files due to NaNs/Infs.




Total samples: 71118
Train samples: 49782
Validation samples: 10667
Test samples: 10669


In [5]:
# Show class frequency for each dataset (label mapping: Bonafide=1, Spoof=0)
from collections import Counter

label_map = {1: "Bonafide", 0: "Spoof"}

def get_class_counts(dataset):
    labels = [dataset[i][1] for i in range(len(dataset))]
    return Counter(labels)

print("\nClass frequencies:")
for name, ds in zip(['Train', 'Validation', 'Test'], [train_dataset, val_dataset, test_dataset]):
    counts = get_class_counts(ds)
    # Map numeric labels to class names for display
    named_counts = {label_map.get(k, k): v for k, v in counts.items()}
    print(f"{name} set class counts: {named_counts}")


Class frequencies:
Train set class counts: {'Bonafide': 12978, 'Spoof': 36804}
Validation set class counts: {'Bonafide': 2781, 'Spoof': 7886}
Test set class counts: {'Bonafide': 2781, 'Spoof': 7888}
