#Model Classifying between in-group and out-group based on short voice snippets

### TBD
- [x] general architecture
- [x] data preprocessing pipeline (including creating the spectrograms)
- [X] loader
- [x] training
- [x] validation
- [x] logging to TensorBoard
- [x] dataset script
- [x] big-enough dataset
- [ ] data augmentation pipeline to diversify the dataset
- [ ] tuning the CNN layer and RNN layer so that it's lightweight enough to process but deep enough to generalize (try to see if it overfits at some point, if not add layers)
- [ ] inference pipeline (described in section inference pipeline)


##DATASET
###Creating .h5 dataset from the audio files stored in drive

  labels should contain the speaker id with -1 for speakers not in the in-group
      
      /train/logmel      -> float (N, T, F) or (N, F, T)
      /train/label       -> int64  (N,)           [optional]
      /train/length      -> int64  (N,)           [optional, original T]
    and equivalent for /val, and /test and
      /meta/sample_rate  → attribute (scalar)
      /meta/feature_type → "log-mel spectrogram"
      /meta/description.yaml with values
      - number_of_speakers
      - number_of_mels
      - frequency (a boolean True, False)
      - dictionary of speaker ids to speaker names

In [None]:
import os
import random
import warnings
import numpy as np
import librosa
import h5py
import yaml
from datetime import datetime
from pydub import AudioSegment
from scipy.signal import butter, lfilter
from collections import defaultdict
import json

warnings.filterwarnings("ignore", category=FutureWarning)

SUPPORTED_EXTS = (".wav", ".mp3", ".m4a", ".wma")
from google.colab import drive
drive.mount('/content/drive')


  m = re.match('([su]([0-9]{1,2})p?) \(([0-9]{1,2}) bit\)$', token)
  m2 = re.match('([su]([0-9]{1,2})p?)( \(default\))?$', token)
  elif re.match('(flt)p?( \(default\))?$', token):
  elif re.match('(dbl)p?( \(default\))?$', token):


Mounted at /content/drive


In [None]:
# ==========================================================
# INFERENCE-COMPATIBLE PREPROCESSING CLASS
# ==========================================================
class AudioPreprocessor:
    """Lightweight preprocessing pipeline."""
    def __init__(self, sr=16000, n_mels=64, n_fft=2048, hop_length=512,
                 chunk_duration=1.0, remove_silence=True, normalize=True,
                 lowcut=None, highcut=None, filter_order=4):
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.chunk_duration = chunk_duration
        self.remove_silence = remove_silence
        self.normalize = normalize
        self.lowcut = lowcut
        self.highcut = highcut
        self.filter_order = filter_order

    def load_audio(self, path):
        """Load and resample audio."""
        audio = AudioSegment.from_file(path)
        if audio.channels > 1:
            audio = audio.set_channels(1)

        samples = np.array(audio.get_array_of_samples()).astype(np.float32)
        samples /= (1 << (8 * audio.sample_width - 1))

        if audio.frame_rate != self.sr:
            samples = librosa.resample(samples, orig_sr=audio.frame_rate, target_sr=self.sr)

        return samples

    def trim_silence(self, samples, top_db=30):
        """Remove silence from audio."""
        non_silent, _ = librosa.effects.trim(samples, top_db=top_db)
        return non_silent

    def normalize_volume(self, samples):
        """Normalize audio volume."""
        max_val = np.max(np.abs(samples)) + 1e-9
        return samples / max_val

    def apply_filter(self, samples):
        """Apply Butterworth filter."""
        if not self.lowcut and not self.highcut:
            return samples

        nyq = 0.5 * self.sr
        if self.lowcut and self.highcut:
            b, a = butter(self.filter_order, [self.lowcut / nyq, self.highcut / nyq], btype="band")
        elif self.lowcut:
            b, a = butter(self.filter_order, self.lowcut / nyq, btype="high")
        else:
            b, a = butter(self.filter_order, self.highcut / nyq, btype="low")

        return lfilter(b, a, samples)

    def chunk_audio(self, samples):
        """Split audio into fixed-length chunks."""
        chunk_len = int(self.chunk_duration * self.sr)
        chunks = []

        for i in range(0, len(samples), chunk_len):
            chunk = samples[i:i + chunk_len]
            if len(chunk) < chunk_len:
                pad = np.zeros(chunk_len, dtype=samples.dtype)
                pad[:len(chunk)] = chunk
                chunk = pad
            chunks.append(chunk)

        return chunks

    def to_logmel(self, samples):
        """Convert audio to log-mel spectrogram."""
        mel = librosa.feature.melspectrogram(
            y=samples, sr=self.sr, n_mels=self.n_mels,
            n_fft=self.n_fft, hop_length=self.hop_length
        )
        logmel = librosa.power_to_db(mel, ref=np.max)
        return logmel.astype(np.float32)

    def get_config(self):
        """Return configuration dict."""
        return {
            "sr": self.sr,
            "n_mels": self.n_mels,
            "n_fft": self.n_fft,
            "hop_length": self.hop_length,
            "chunk_duration": self.chunk_duration,
            "remove_silence": self.remove_silence,
            "normalize": self.normalize,
            "lowcut": self.lowcut,
            "highcut": self.highcut,
            "filter_order": self.filter_order
        }


# ==========================================================
# AUGMENTATION FUNCTIONS
# ==========================================================
def add_ambient_noise(samples, noise_factor=0.005):
    """Add Gaussian noise."""
    noise = np.random.randn(len(samples)) * noise_factor
    return samples + noise


def speed_perturbation(samples, sr, speed_factor=1.0):
    """Apply time stretching."""
    return librosa.effects.time_stretch(samples, rate=speed_factor)


def spectral_augmentation(logmel, freq_mask_param=10, time_mask_param=20, n_masks=2):
    """Apply SpecAugment-style masking."""
    augmented = logmel.copy()
    n_mels, n_frames = augmented.shape

    for _ in range(n_masks):
        f = np.random.randint(0, min(freq_mask_param, n_mels))
        f0 = np.random.randint(0, max(1, n_mels - f))
        augmented[f0:f0+f, :] = augmented.mean()

    for _ in range(n_masks):
        t = np.random.randint(0, min(time_mask_param, n_frames))
        t0 = np.random.randint(0, max(1, n_frames - t))
        augmented[:, t0:t0+t] = augmented.mean()

    return augmented


def vocal_tract_length_perturbation(samples, sr, alpha=1.0):
    """Apply VTLP via frequency warping."""
    D = librosa.stft(samples)
    n_fft = (D.shape[0] - 1) * 2
    freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
    warped_freqs = freqs * alpha

    D_warped = np.zeros_like(D)
    for i, wf in enumerate(warped_freqs):
        if wf < freqs[-1]:
            idx = np.searchsorted(freqs, wf)
            if idx < len(freqs) - 1:
                D_warped[i] = D[idx]

    return librosa.istft(D_warped, length=len(samples))

In [None]:
# ==========================================================
# METADATA PARSING
# ==========================================================
def parse_labels_yaml(yaml_path):
    if not os.path.exists(yaml_path):
        return {}

    with open(yaml_path, 'r', encoding='utf-8') as f:
        data = yaml.safe_load(f)

    if not data:
        return {}

    parsed = {}
    for filename, value in data.items():
        if isinstance(value, (list, tuple)) and len(value) >= 3:
            gender, in_group, speaker_name = value[0], value[1], value[2]
            parsed[filename] = {
                "gender": gender,
                "in_group": in_group,
                "speaker_name": speaker_name
            }
        elif isinstance(value, dict):
            parsed[filename] = {
                "gender": value.get("gender", "Unknown"),
                "in_group": value.get("in_group", False),
                "speaker_name": value.get("speaker_name", os.path.splitext(filename)[0])
            }

    return parsed


def get_speaker_label_mapping(data_root="/content/drive/MyDrive/ML_PW/Recordings"):
    """
    Create a mapping from speaker_name to unique label.
    """
    speaker_to_label = {}
    label_counter = 0

    # First pass: collect all unique speaker names
    for folder in os.listdir(data_root):
        folder_path = os.path.join(data_root, folder)
        if not os.path.isdir(folder_path):
            continue

        labels_yaml_path = os.path.join(folder_path, "labels.yaml")
        file_metadata = parse_labels_yaml(labels_yaml_path)

        for filename, info in file_metadata.items():
            speaker_name = info["speaker_name"]
            if speaker_name not in speaker_to_label:
                speaker_to_label[speaker_name] = label_counter
                label_counter += 1

    # Second pass: add folder names as speakers
    for folder in os.listdir(data_root):
        folder_path = os.path.join(data_root, folder)
        if os.path.isdir(folder_path):
            if folder not in speaker_to_label:
                speaker_to_label[folder] = label_counter
                label_counter += 1

    label_to_speaker = {v: k for k, v in speaker_to_label.items()}

    return speaker_to_label, label_to_speaker


def process_audio_file(path, preprocessor, speaker_label, speaker_info,
                      with_noise=False, with_speed=False, with_vtlp=False,
                      with_spectral=False):
    """Process a single audio file with augmentations."""
    samples = preprocessor.load_audio(path)
    if preprocessor.remove_silence:
        samples = preprocessor.trim_silence(samples)

    all_results = []

    processed = samples.copy()
    if preprocessor.normalize:
        processed = preprocessor.normalize_volume(processed)
    if preprocessor.lowcut or preprocessor.highcut:
        processed = preprocessor.apply_filter(processed)

    # Original samples
    chunks = preprocessor.chunk_audio(processed)
    specs = [preprocessor.to_logmel(c) for c in chunks]

    for spec in specs:
        all_results.append((spec, "original", False, speaker_label, speaker_info, path))

    # Augmentations
    if with_noise:
        noise_factor = np.random.uniform(0.002, 0.008)
        noisy = add_ambient_noise(processed, noise_factor)
        chunks_noisy = preprocessor.chunk_audio(noisy)
        specs_noisy = [preprocessor.to_logmel(c) for c in chunks_noisy]
        if with_spectral:
            specs_noisy = [spectral_augmentation(s) for s in specs_noisy]
        for spec in specs_noisy:
            all_results.append((spec, f"noise_{noise_factor:.4f}", False, speaker_label, speaker_info, path))

    if with_speed:
        for speed_factor in [0.9, 1.1]:
            try:
                sped = speed_perturbation(processed, preprocessor.sr, speed_factor)
                chunks_sped = preprocessor.chunk_audio(sped)
                specs_sped = [preprocessor.to_logmel(c) for c in chunks_sped]
                if with_spectral:
                    specs_sped = [spectral_augmentation(s) for s in specs_sped]
                for spec in specs_sped:
                    all_results.append((spec, f"speed_{speed_factor:.1f}", False, speaker_label, speaker_info, path))
            except:
                pass

    if with_vtlp:
        for alpha in [0.9, 1.1]:
            try:
                vtlped = vocal_tract_length_perturbation(processed, preprocessor.sr, alpha)
                chunks_vtlp = preprocessor.chunk_audio(vtlped)
                specs_vtlp = [preprocessor.to_logmel(c) for c in chunks_vtlp]
                if with_spectral:
                    specs_vtlp = [spectral_augmentation(s) for s in specs_vtlp]
                for spec in specs_vtlp:
                    all_results.append((spec, f"vtlp_{alpha:.1f}", True, speaker_label, speaker_info, path))
            except:
                pass

    return all_results

In [None]:
# ==========================================================
# PROCESS ENTIRE DATASET
# ==========================================================
def process_dataset_with_speaker_names(data_root="/content/drive/MyDrive/ML_PW/Recordings", preprocessor=None,
                                      with_noise=False, with_speed=False,
                                      with_vtlp=False, with_spectral=False):
    """Process entire dataset using speaker names from YAML files."""
    if preprocessor is None:
        raise ValueError("Preprocessor is required")

    speaker_to_label, label_to_speaker = get_speaker_label_mapping(data_root)

    all_specs = []
    all_labels = []
    all_aug_types = []
    all_is_vtlp = []
    all_speaker_info = []
    all_file_paths = []

    print(f"\nSpeaker mapping:")
    for speaker_name, label in speaker_to_label.items():
        print(f"  {speaker_name} -> label {label}")

    next_vtlp_label = 10000
    vtlp_mapping = {}

    for folder in os.listdir(data_root):
        folder_path = os.path.join(data_root, folder)
        if not os.path.isdir(folder_path):
            continue

        print(f"\nProcessing folder: {folder}")

        labels_yaml_path = os.path.join(folder_path, "labels.yaml")
        file_metadata = parse_labels_yaml(labels_yaml_path)

        for filename in os.listdir(folder_path):
            if not filename.lower().endswith(SUPPORTED_EXTS):
                continue

            file_path = os.path.join(folder_path, filename)

            if filename in file_metadata:
                speaker_info = file_metadata[filename]
                speaker_name = speaker_info["speaker_name"]
            else:
                speaker_name = folder
                speaker_info = {
                    "gender": "Unknown",
                    "in_group": True,
                    "speaker_name": speaker_name
                }

            speaker_label = speaker_to_label.get(speaker_name, len(speaker_to_label))

            print(f"  {filename}: speaker '{speaker_name}' -> label {speaker_label}")

            try:
                results = process_audio_file(
                    file_path, preprocessor, speaker_label, speaker_info,
                    with_noise=with_noise, with_speed=with_speed,
                    with_vtlp=with_vtlp, with_spectral=with_spectral
                )

                for spec, aug_type, is_vtlp, spk_label, spk_info, fpath in results:
                    if is_vtlp:
                        key = (spk_label, aug_type)
                        if key not in vtlp_mapping:
                            vtlp_mapping[key] = next_vtlp_label
                            next_vtlp_label += 1
                        final_label = vtlp_mapping[key]
                    else:
                        final_label = spk_label

                    all_specs.append(spec)
                    all_labels.append(final_label)
                    all_aug_types.append(aug_type)
                    all_is_vtlp.append(is_vtlp)
                    all_speaker_info.append(spk_info)
                    all_file_paths.append(fpath)

                print(f"    {len(results)} chunks")
            except Exception as e:
                print(f"    Error: {e}")

    for (base_label, aug_type), vtlp_label in vtlp_mapping.items():
        base_speaker = label_to_speaker[base_label]
        speaker_to_label[f"{base_speaker}_{aug_type}"] = vtlp_label
        label_to_speaker[vtlp_label] = f"{base_speaker}_{aug_type}"

    print(f"\nDone. Total spectrograms: {len(all_specs)}")
    print(f"Base speakers: {len(speaker_to_label) - len(vtlp_mapping)}")
    print(f"VTLP speakers: {len(vtlp_mapping)}")
    print(f"Total speakers: {len(speaker_to_label)}")

    return (np.array(all_specs, dtype=object),
            np.array(all_labels),
            all_aug_types,
            np.array(all_is_vtlp),
            all_speaker_info,
            all_file_paths,
            speaker_to_label,
            label_to_speaker,
            vtlp_mapping)


def create_80_10_10_splits_by_file(spectrograms, labels, aug_types, is_vtlp,
                                   speaker_info_list, file_paths, random_seed=42):
    """Create 80-10-10 splits at FILE level."""
    random.seed(random_seed)
    np.random.seed(random_seed)

    file_to_data = defaultdict(lambda: {
        "specs": [], "labels": [], "aug_types": [],
        "is_vtlp": [], "speaker_info": []
    })

    for i, file_path in enumerate(file_paths):
        file_to_data[file_path]["specs"].append(spectrograms[i])
        file_to_data[file_path]["labels"].append(labels[i])
        file_to_data[file_path]["aug_types"].append(aug_types[i])
        file_to_data[file_path]["is_vtlp"].append(is_vtlp[i])
        file_to_data[file_path]["speaker_info"].append(speaker_info_list[i])

    all_files = list(file_to_data.keys())
    random.shuffle(all_files)

    n_files = len(all_files)
    n_train = int(0.8 * n_files)
    n_test = int(0.1 * n_files)

    train_files = all_files[:n_train]
    test_files = all_files[n_train:n_train + n_test]
    val_files = all_files[n_train + n_test:]

    print(f"\nFile-level 80-10-10 split:")
    print(f"  Total files: {n_files}")
    print(f"  Train: {len(train_files)} files ({len(train_files)/n_files*100:.1f}%)")
    print(f"  Test:  {len(test_files)} files ({len(test_files)/n_files*100:.1f}%)")
    print(f"  Val:   {len(val_files)} files ({len(val_files)/n_files*100:.1f}%)")

    splits = {
        "train": {"specs": [], "labels": [], "aug_types": [],
                 "is_vtlp": [], "speaker_info": [], "file_paths": []},
        "val": {"specs": [], "labels": [], "aug_types": [],
               "is_vtlp": [], "speaker_info": [], "file_paths": []},
        "test": {"specs": [], "labels": [], "aug_types": [],
                "is_vtlp": [], "speaker_info": [], "file_paths": []}
    }

    file_to_split = {}
    for f in train_files:
        file_to_split[f] = "train"
    for f in val_files:
        file_to_split[f] = "val"
    for f in test_files:
        file_to_split[f] = "test"

    for file_path, data in file_to_data.items():
        split_name = file_to_split[file_path]
        splits[split_name]["specs"].extend(data["specs"])
        splits[split_name]["labels"].extend(data["labels"])
        splits[split_name]["aug_types"].extend(data["aug_types"])
        splits[split_name]["is_vtlp"].extend(data["is_vtlp"])
        splits[split_name]["speaker_info"].extend(data["speaker_info"])
        splits[split_name]["file_paths"].extend([file_path] * len(data["specs"]))

    for split_name in splits:
        if len(splits[split_name]["specs"]) > 0:
            splits[split_name]["specs"] = np.array(splits[split_name]["specs"], dtype=object)
            splits[split_name]["labels"] = np.array(splits[split_name]["labels"])
            splits[split_name]["is_vtlp"] = np.array(splits[split_name]["is_vtlp"])

    print(f"\nSample distribution:")
    total = sum(len(splits[s]["specs"]) for s in splits if len(splits[s]["specs"]) > 0)
    for split_name in ["train", "val", "test"]:
        if len(splits[split_name]["specs"]) > 0:
            n = len(splits[split_name]["specs"])
            print(f"  {split_name}: {n} samples ({n/total*100:.1f}%)")

    return splits

In [None]:
# ==========================================================
# SAVE TO HDF5
# ==========================================================
def save_h5_dataset(splits, speaker_to_label, label_to_speaker, vtlp_mapping,
                   preprocessor, out_dir="/content/drive/MyDrive/ML_PW/outputs"):
    """Save dataset to HDF5."""
    os.makedirs(out_dir, exist_ok=True)

    date_tag = datetime.now().strftime("%d-%m-%y")

    aug_types_present = set()
    for split_data in splits.values():
        for aug in split_data["aug_types"]:
            if "noise" in aug:
                aug_types_present.add("noise")
            if "speed" in aug:
                aug_types_present.add("speed")
            if "vtlp" in aug:
                aug_types_present.add("vtlp")

    aug_tag = "_".join(sorted(aug_types_present)) if aug_types_present else "no_aug"
    out_name = f"logmels_{aug_tag}_{date_tag}.h5"
    out_path = os.path.join(out_dir, out_name)

    base_labels = [l for l in speaker_to_label.values() if l < 10000]
    vtlp_labels = [l for l in speaker_to_label.values() if l >= 10000]

    with h5py.File(out_path, 'w') as f:
        for split_name, split_data in splits.items():
            if len(split_data["specs"]) == 0:
                continue

            specs = split_data["specs"]
            labels = split_data["labels"]
            aug_types = split_data["aug_types"]
            file_paths = split_data["file_paths"]

            n_mels = specs[0].shape[0]
            max_T = max(s.shape[1] for s in specs)

            arr = np.zeros((len(specs), n_mels, max_T), dtype=np.float32)
            lengths = np.zeros(len(specs), dtype=np.int64)

            for i, s in enumerate(specs):
                T = s.shape[1]
                arr[i, :, :T] = s
                lengths[i] = T

            grp = f.create_group(split_name)
            grp.create_dataset("logmel", data=arr, compression="gzip")
            grp.create_dataset("label", data=np.array(labels, dtype=np.int64), compression="gzip")
            grp.create_dataset("length", data=lengths, compression="gzip")

            dt = h5py.string_dtype(encoding='utf-8')
            aug_array = np.array([str(a).encode('utf-8') if not isinstance(a, bytes) else a
                                 for a in aug_types], dtype=dt)
            grp.create_dataset("augmentation", data=aug_array, compression="gzip")

            file_array = np.array([str(p).encode('utf-8') if not isinstance(p, bytes) else p
                                  for p in file_paths], dtype=dt)
            grp.create_dataset("file_path", data=file_array, compression="gzip")

        meta = f.create_group("meta")
        cfg = preprocessor.get_config()

        split_stats = {}
        for split_name in splits:
            if len(splits[split_name]["specs"]) > 0:
                split_labels = splits[split_name]["labels"]
                unique_labels = set(split_labels)
                base_count = len([l for l in unique_labels if l < 10000])
                vtlp_count = len([l for l in unique_labels if l >= 10000])

                split_stats[split_name] = {
                    "num_samples": len(split_labels),
                    "num_speakers": len(unique_labels),
                    "num_base_speakers": base_count,
                    "num_vtlp_speakers": vtlp_count
                }

        meta.attrs.update({
            "sample_rate": cfg["sr"],
            "feature_type": "log-mel spectrogram",
            "split_strategy": "80-10-10 by file",
            "num_base_speakers": len(base_labels),
            "num_vtlp_speakers": len(vtlp_labels),
            "total_speakers": len(speaker_to_label)
        })

        file_desc = {
            "dataset_name": out_name,
            "preprocessing_config": cfg,
            "split_strategy": "80-10-10 by file (random split)",
            "split_statistics": split_stats,
            "total_base_speakers": len(base_labels),
            "total_vtlp_speakers": len(vtlp_labels),
            "total_speakers": len(speaker_to_label),
            "vtlp_speaker_id_start": 10000,
            "note": "File-level split ensures no temporal overlap"
        }

        speaker_mapping = {}
        for speaker_name, label in speaker_to_label.items():
            speaker_mapping[speaker_name] = {
                "label": int(label),
                "is_vtlp": label >= 10000
            }

        vtlp_mapping_clean = {}
        for (base_label, aug_type), vtlp_label in vtlp_mapping.items():
            base_speaker = label_to_speaker.get(base_label, f"Unknown_{base_label}")
            vtlp_mapping_clean[f"{base_speaker}_{aug_type}"] = {
                "vtlp_label": int(vtlp_label),
                "base_label": int(base_label),
                "augmentation": aug_type
            }

        meta.create_dataset("file_description.yaml",
                           data=np.bytes_(yaml.safe_dump(file_desc, default_flow_style=False)))
        meta.create_dataset("speaker_mapping.yaml",
                           data=np.bytes_(yaml.safe_dump({"speakers": speaker_mapping})))
        meta.create_dataset("vtlp_mapping.yaml",
                           data=np.bytes_(yaml.safe_dump({"vtlp_speakers": vtlp_mapping_clean})))
        meta.create_dataset("split_statistics.json",
                           data=np.bytes_(json.dumps(split_stats, indent=2).encode('utf-8')))

    print(f"\nSaved dataset → {out_path}")
    print(f"  Base speakers: {len(base_labels)}")
    print(f"  VTLP speakers: {len(vtlp_labels)}")
    print(f"  Total speakers: {len(speaker_to_label)}")

    return out_path

In [None]:
# ==========================================================
# MAIN EXECUTION
# ==========================================================

# Create preprocessor
preprocessor = AudioPreprocessor(
    sr=16000,
    n_mels=64,
    n_fft=2048,
    hop_length=512,
    chunk_duration=1.0,
    remove_silence=True,
    normalize=True,
    lowcut=None,
    highcut=None
)

# AUGMENTATION CONTROLS - Set these booleans
WITH_NOISE = True
WITH_SPEED = True
WITH_VTLP = True
WITH_SPECTRAL = True

# Process dataset
specs, labels, aug_types, is_vtlp, spk_info, file_paths, \
speaker_to_label, label_to_speaker, vtlp_mapping = process_dataset_with_speaker_names(
    data_root="/content/drive/MyDrive/ML_PW/Recordings",
    preprocessor=preprocessor,
    with_noise=WITH_NOISE,
    with_speed=WITH_SPEED,
    with_vtlp=WITH_VTLP,
    with_spectral=WITH_SPECTRAL
)

# Create 80-10-10 splits by file
print("\nCreating 80-10-10 splits by file...")
splits = create_80_10_10_splits_by_file(
    specs, labels, aug_types, is_vtlp, spk_info, file_paths, random_seed=42
)

# Save to HDF5
if len(specs) > 0:
    save_h5_dataset(
        splits, speaker_to_label, label_to_speaker, vtlp_mapping,
        preprocessor, out_dir="/content/drive/MyDrive/ML_PW/outputs"
    )
    print("\n✅ Dataset generation complete!")


Speaker mapping:
  Piotr -> label 0
  Aleksander -> label 1
  Mantas -> label 2
  Rafał -> label 3
  michał -> label 4

Processing folder: Piotr
  Emma_F_False.m4a: speaker 'Piotr' -> label 0
    78 chunks
  Greta2_F_False.m4a: speaker 'Piotr' -> label 0
    1186 chunks
  Obama2_M_False.m4a: speaker 'Piotr' -> label 0
    1565 chunks
  Anne_F_False.m4a: speaker 'Piotr' -> label 0
    1566 chunks
  Greta4_F_False.m4a: speaker 'Piotr' -> label 0
    1560 chunks
  Greta1_F_False.m4a: speaker 'Piotr' -> label 0
    1385 chunks
  Julian_M_False.m4a: speaker 'Piotr' -> label 0
    1475 chunks
  Greta3_F_False.m4a: speaker 'Piotr' -> label 0
    120 chunks
  Piotr_M_True.m4a: speaker 'Piotr' -> label 0
    7254 chunks
  marzena-2.mp3: speaker 'Piotr' -> label 0
    1114 chunks
  Dominika_F_False.m4a: speaker 'Piotr' -> label 0
    1831 chunks
  Lara_F_False.m4a: speaker 'Piotr' -> label 0
    1523 chunks
  Natalia_F_False.m4a: speaker 'Piotr' -> label 0
    1469 chunks
  Obama_M_False.m4a: s

## MODEL ARCHITECTURE
### Simple Outline
- CNN block which pools only the frequency preserving time-axis
- RNN block which process time-axis data
- if training then AAM Softmax, else just outputting
- all of that is contained in the SpeakerClassifier Class

In [None]:
!pip install  scikit-learn

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

from sklearn.metrics import precision_recall_fscore_support

# mount Drive once per session
from google.colab import drive  # type: ignore

drive.mount('/content/drive/',force_remount=True)


In [None]:
class SEBlock(nn.Module):
    def __init__(self, channels: int, reduction: int = 16):
        super().__init__()
        hidden = max(channels // reduction, 4)
        self.fc = nn.Sequential(
            nn.Linear(channels, hidden, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        # x: [B, C, T, F]
        B, C, T, F = x.shape
        # Global average pooling over time + frequency
        s = x.mean(dim=(2, 3))              # [B, C]
        w = self.fc(s)                      # [B, C]
        w = w.view(B, C, 1, 1)              # [B, C, 1, 1]
        return x * w

In [None]:
#input is log-mel spectrogram with shape [B, 1, T, F]
# B - batch size, 1 - dimension (assuming mono-audio),
# T - time frames, F - frequency bins
class Backbone(nn.Module):
  """
  Backbone of model, unifies entry CNN block with RNN block.
  """
  def __init__(self, no_mels,embed_dim,rnn_hidden,rnn_layers,bidir):
    super().__init__()


    self.cnn_block = nn.Sequential(

      nn.Conv2d(1, 32, 3, padding=1),
      nn.BatchNorm2d(32),
      nn.ReLU(),
      SEBlock(32, reduction=8),
      nn.MaxPool2d(kernel_size=(1, 2)),


      nn.Conv2d(32, 64, 3, padding=1),
      nn.BatchNorm2d(64),
      nn.ReLU(),
      SEBlock(64, reduction=8),
      nn.MaxPool2d(kernel_size=(1, 2)),


      nn.Conv2d(64, 128, 3, padding=1),
      nn.BatchNorm2d(128),
      nn.ReLU(),
      SEBlock(128, reduction=8),
      nn.MaxPool2d(kernel_size=(1, 2)),

    )

    self.rnn_hidden = rnn_hidden #for easier reconfiguration in the future
    self.rnn = nn.GRU( #gated recurrent unit (RNN)
        input_size = 128 * (no_mels // 8), # 8 = 2^3 because 3 pooling by 2
        hidden_size=self.rnn_hidden,
        num_layers=rnn_layers,
        bidirectional=bidir, #if true the GRU learns in both directions: forward direction → from past to future, backward → from future to past.
        batch_first=True,
        dropout=0.2
        )

    #EMBEDDING HEAD
    out_dim = (2 if bidir else 1) * rnn_hidden

    self.rnn_ln = nn.LayerNorm(out_dim)


    #simple attention for ASP(Attentive Statistics Pooling)
    #essentially ignores boring, uninteresting moments
    self.att = nn.Sequential(
              nn.Linear((2 if bidir else 1)*rnn_hidden, 128),
              nn.Tanh(),
              nn.Linear(128, 1)
          )

    self.proj = nn.Sequential(
        nn.Linear(out_dim*2, 256),
        nn.BatchNorm1d(256), nn.ReLU(),
        nn.Linear(256, embed_dim)   #embedding to given embedding dimension
              )


  def forward(self,x, lengths: torch.Tensor | None = None, mc_dropout: bool | None = None):
    if mc_dropout is None:
        mc_dropout = self.training
    #PIPELINE
    h = self.cnn_block(x) #process through CNN
    if mc_dropout:
      h = F.dropout(h, p=0.3, training=True) #monte carlo droupout applied
    #shape after cnn_block [B, C, T, F (pooled)]
    B,C,T,Fp = h.shape
    h = h.permute(0,2,1,3).contiguous().view(B,T,C*Fp) # reshape for time sequence analysis


    if lengths is not None:
      packed = nn.utils.rnn.pack_padded_sequence(
          h,
          lengths.cpu(),
          batch_first=True,
          enforce_sorted=False,
          )
      packed_out, _ = self.rnn(packed)

      rnn_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
      Tmax = rnn_out.size(1)
      mask = (torch.arange(Tmax, device=rnn_out.device).unsqueeze(0).expand(B, Tmax)) < lengths.unsqueeze(1)

    else:
       rnn_out, _ = self.rnn(h) #process reshaped cnn output through rnn,  outputs [B, T, H×2]
       mask = torch.ones(rnn_out.size(0), rnn_out.size(1),
                        dtype=torch.bool, device=rnn_out.device)

    if mc_dropout:
            rnn_out = F.dropout(rnn_out, p=0.3, training=True) #monte carlo droupout applied

    rnn_out = self.rnn_ln(rnn_out)

    #STATISTICS
    a = self.att(rnn_out).squeeze(-1) #attention weights over time, dim: BxTx1
    a = a.masked_fill(~mask, float('-inf'))
    w = torch.softmax(a,dim=1).unsqueeze(-1)  # dim: BxTx1

    mean = torch.sum(w*rnn_out,dim=1) # dim: BxH
    var  = torch.sum(w * (rnn_out - mean.unsqueeze(1))**2, dim=1)
    std  = torch.sqrt(var + 1e-5)
    stats = torch.cat([mean, std], 1)

    if mc_dropout:
            stats = F.dropout(stats, p=0.3, training=True)

    z = self.proj(stats)                  # B×emb_dim
    z = nn.functional.normalize(z, p=2, dim=1)

    return z




In [None]:
#AAMSoftmax head to enhance class identification using margins, https://medium.com/@zhaomin.chen/additive-margin-softmax-loss-3c78e37b08ed
#essentially tighter intra-class clusters and larger inter-class gaps
#used only for training, discarded at inference time
class AAMSoftmax(nn.Module):
  """
  AAMSoftmax head to enhance class identification at training time.
  Discarded at inference time.
  """
  def __init__(self, in_features, out_features, s=30.0, m=0.20):
    super().__init__()
    self.s = s
    self.m = m
    self.in_features = in_features
    self.out_features = out_features
    self.weight = nn.Parameter(torch.empty(out_features, in_features))
    nn.init.xavier_uniform_(self.weight)


  def forward(self,emb,labels):
    W = F.normalize(self.weight, dim=1) #normalize weight vectors, embedding is already normalized

    cos_theta = emb @ W.T #get the cosine similarities using matrix product

    #increase the margins




    theta = torch.acos(cos_theta.clamp(-1+1e-7, 1-1e-7))
    target_logits = torch.cos(theta + self.m)

    #one-hot encoding, substituting only for true speaker
    one_hot = F.one_hot(labels, num_classes=W.size(0)).float()
    output = cos_theta * (1 - one_hot) + target_logits * one_hot

    return output * self.s


In [None]:
class SpeakerClassifier(nn.Module):
  """
  Class unifying backbone with AAMSoftmax head. Used for training and inference.
  """
  def __init__(self,backbone,num_speakers,aamsm_scaler,aamsm_margin):
    super().__init__()
    self.backbone=backbone
    self.aamsm = AAMSoftmax(backbone.proj[-1].out_features,num_speakers,aamsm_scaler,aamsm_margin)
    self._inference_prepared = False
    self.score_alpha = nn.Parameter(torch.tensor(1.0))  # scale
    self.score_beta  = nn.Parameter(torch.tensor(0.0))  # bias
    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



  def forward(self,x,labels=None,lengths=None):
    emb = self.backbone(x,lengths=lengths,mc_dropout=None )
    if labels is not None:
      logits = self.aamsm(emb,labels)
      return logits,emb
    else:
      return emb

  def eval(self):
    super().eval()   #keeps original behaviour

    # Build bank for inference if it's not already built
    if not self._inference_prepared:
        with torch.no_grad():
            self.bank = F.normalize(self.aamsm.weight, dim=1)
        self._inference_prepared = True

    return self

  @torch.no_grad()
  def embed(self, x, lengths=None):
    z = self.backbone(x, lengths=lengths)
    return F.normalize(z, dim=1)

  @torch.no_grad()
  def mc_embed(self, x, lengths=None, n_samples: int = 10):
    """
    Monte Carlo dropout embeddings.
    Returns:
        embs: [n_samples, B, D]
    """
    embs = []
    for _ in range(n_samples):
        z = self.backbone(x, lengths=lengths, mc_dropout=True)
        z = F.normalize(z, dim=1)
        embs.append(z)
    return torch.stack(embs, dim=0)

  @torch.no_grad()
  def build_bank_from_aam(self):
    return F.normalize(self.aamsm.weight, dim=1)

  def set_default_inference_threshold(self, threshold):
    self.inference_threshold = threshold

  # @torch.no_grad()
  # def prepare_for_inference(self, threshold: float = 0.35):
  #   """
  #   Build a speaker bank from AAMSoftmax weights and put model into eval mode.
  #   After this, .infer(...) can be used directly with no calibration/enrollment.
  #   """
  #   self.eval()
  #   bank = self.build_bank_from_aam()          # [S, D]
  #   # store as buffer so it moves with model.to(device)
  #   self.bank = bank
  #   self.inference_threshold = threshold

  @torch.no_grad()
  def verify_any(self, x, lengths=None, *, threshold=None, bank=None, return_index=False):
    """
    Compare input embeddings against a speaker bank (default: self.bank built from AAM weights).

    x:       [B, 1, T, F]
    lengths: [B] (optional, time lengths)
    Returns:
        decisions: [B] bool
        scores:    [B] float (max cosine similarity)
        (optional) indices: [B] long (argmax speaker index)
    """
    if bank is None:
        if self.bank is None:
            raise RuntimeError('The bank has not been built before inference. Make sure the model has been set to eval mode!')
        else:
            bank = self.bank
    bank = F.normalize(bank, dim=1)

    probe = self.embed(x, lengths=lengths)      # [B, D]
    sims = probe @ bank.T                       # [B, S]
    scores, idx = sims.max(dim=1)               # max over speakers

    scores = self.score_alpha * scores + self.score_beta

    thr = self.inference_threshold if threshold is None else threshold
    decisions = scores >= thr

    if return_index:
        return decisions, scores, idx
    return decisions, scores

  @torch.no_grad()
  def mc_verify_any(self, x, lengths=None, n_samples: int = 10,
                      threshold=None, bank=None, return_index=False):
    """
    MC Dropout version of verify_any.

    Returns:
        decisions: [B] bool (based on mean score)
        mean_scores: [B] float
        var_scores: [B] float (uncertainty)
        (optional) pred_idx: [B] long from mean scores
    """
    self.eval()
    if bank is None:
      if self.bank is None:
          raise RuntimeError('The bank has not been built before inference. Make sure the model has been set to eval mode!')
      else:
          bank = self.bank
    bank = F.normalize(bank, dim=1)

    embs = self.mc_embed(x, lengths=lengths)      # [B, D]
    sims = torch.einsum("kbd,sd->kbs", embs, bank)  # [K, B, S]
    mean_sims = sims.mean(dim=0)  # [B, S]
    var_sims  = sims.var(dim=0)   # [B, S]
    mean_scores, pred_idx = mean_sims.max(dim=1)  # [B]

    mean_scores = self.score_alpha * mean_scores + self.score_beta
    score_var = var_sims.gather(1, pred_idx.view(-1, 1)).squeeze(1)  # [B]

    thr = self.inference_threshold if threshold is None else threshold
    decisions = mean_scores >= thr

    if return_index:
        return decisions, mean_scores, score_var, pred_idx
    return decisions, mean_scores, score_var

  @torch.no_grad()
  def infer(self, x, lengths=None, threshold=None):
    """
    High-level inference method.

    Input:
        x:       tensor [B, 1, T, F] (log-mel spectrograms)
        lengths: tensor [B] with valid time lengths (optional)
    Output:
        pred_ids:   [B] long  -- predicted speaker index (0..num_speakers-1)
        scores:     [B] float -- cosine similarity to predicted prototype
        decisions:  [B] bool  -- whether score >= threshold
    """
    decisions, scores, pred_ids = self.verify_any(
        x, lengths=lengths, threshold=threshold, return_index=True
    )
    return pred_ids, scores, decisions

  @torch.no_grad()
  def test(self, loaders):
    self.eval()
    device = self.device if self.device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
    self.to(device)
    results = []

    with torch.no_grad():
      for batch in loaders['test']:
            X, y, lengths = batch                   # no need for unpack()

            X = X.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            lengths = lengths.to(device, non_blocking=True)

            for i in range(X.size(0)):              # per-item verify
                x_single = X[i:i+1]                 # keep batch dim [1,1,T,F]
                l_single = lengths[i:i+1]           # keep batch dim [1]
                pred_ids, scores, decisions = self.infer(
                    x=x_single,
                    lengths=l_single,
                    threshold=0.7
                )
                results.append((
                    int(pred_ids.item()),         # True/False
                    float(scores.item()),            # similarity

                    int(y[i].item())               # true label for this item
                ))
    print(results)

##Model Training

In [None]:
#utilities for saving locally and in colab
def in_colab():
    try:
        import google.colab  # type: ignore
        return True
    except Exception:
        return False

def prepare_save_dir(checkpoint_dir: str, use_drive: bool=False, drive_dir: str="/content/drive/MyDrive/ML_PW/spk_checkpoints"):
    save_root = checkpoint_dir
    i = 1
    PATH  = save_root+"/train"+str(i)
    while os.path.exists(PATH):
      i+=1
      PATH = save_root+"/train"+str(i)
    os.makedirs(PATH)

    if use_drive and in_colab():

        save_root = drive_dir
    os.makedirs(save_root, exist_ok=True)
    return PATH

def save_checkpoint(model: torch.nn.Module, path: str):
    # minimal, robust save (state_dict only)
    torch.save(model.state_dict(), path)

In [None]:
def unpack(batch):
    # Supports both (X,y) and (X,y,lengths)
    if len(batch) == 3:
        X, y, lengths = batch
    else:
        X, y = batch
        lengths = None
    return X, y, lengths

#TRAINING AND VALIDATION funcs
def train_one_epoch(model, loader, optimizer, scheduler, writer, epoch, device,global_step,loss_fn,scaler,use_amp=True):
  model.train()
  total_loss, total_correct, total_samples = 0.0, 0, 0
  predicted_labels, ground_truth_labels = [], []
  total_time = batch_count = 0

  for batch in loader:
    batch_count+=1
    batch_time = time.perf_counter()
    X, y, lengths = unpack(batch)

    X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
    if lengths is not None:
        lengths = lengths.to(device, non_blocking=True)
    optimizer.zero_grad(set_to_none=True)

    with autocast(enabled=(use_amp and device.type == "cuda")):
      logits, emb = model(X, y, lengths=lengths)
      loss = loss_fn(logits, y)
      scaler.scale(loss).backward()
      scaler.step(optimizer)
      scaler.update()


    pred = logits.argmax(1)

    predicted_labels.append(pred.cpu().detach().numpy())
    ground_truth_labels.append(y.cpu().detach().numpy())

    correct = (pred==y).sum().item()
    total_correct += correct
    total_samples += y.size(0)
    total_loss += loss.item() * y.size(0)

    batch_time = time.perf_counter() - batch_time
    total_time+=batch_time


    # TensorBoard logging
    writer.add_scalar("train/loss", loss.item(), global_step)
    writer.add_scalar("train/acc", correct / y.size(0), global_step)
    global_step += 1

  scheduler.step()

  y_true = np.concatenate(ground_truth_labels)
  y_pred = np.concatenate(predicted_labels)
  print(f"Training batches took {total_time/batch_count*1000}ms per batch.")
  #Per-class:
  labels = np.unique(y_true)
  prec, rec, f1, supp = precision_recall_fscore_support(
      y_true, y_pred,
      beta=1.0,
      average=None,
      labels=labels,
      zero_division=0,
  )
  #Macro-average instead:
  prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
      y_true, y_pred, beta=1.0, average="macro", zero_division=0
  )

  for label, p, r, f, s in zip(labels, prec, rec, f1, supp):
      writer.add_scalar(f"train/precision_class_{label}", p, epoch)
      writer.add_scalar(f"train/recall_class_{label}",    r, epoch)
      writer.add_scalar(f"train/f1_class_{label}",        f, epoch)
      writer.add_scalar(f"train/support_class_{label}",   s, epoch)
  writer.add_scalar("train/precision_macro", prec_macro, epoch)
  writer.add_scalar("train/recall_macro",    rec_macro,  epoch)
  writer.add_scalar("train/f1_macro",        f1_macro,   epoch)

  avg_loss = total_loss / total_samples
  avg_acc = total_correct / total_samples

  print(f"{"Train":<20}| Epoch {epoch:03d} | loss {avg_loss:.4f} | acc {avg_acc:.4f} | precision {prec_macro:.3f} | f1-score {f1_macro:.3f} |")
  return global_step,avg_loss, avg_acc


@torch.no_grad()
def validate(model, loader, writer, epoch, device,loss_fn):
  model.eval()
  start_time = time.perf_counter()
  total_loss, total_correct, total_samples = 0.0, 0, 0
  predicted_labels, ground_truth_labels = [], []

  start_time = time.perf_counter()
  total_time = batch_count = 0
  for batch in loader:
    batch_count+=1
    batch_time = time.perf_counter()
    Xv, yv, lengths = unpack(batch)
    Xv, yv = Xv.to(device, non_blocking=True), yv.to(device, non_blocking=True)
    if lengths is not None:
        lengths = lengths.to(device, non_blocking=True)
    logits, _ = model(Xv, yv, lengths=lengths)
    loss = loss_fn(logits, yv)

    pred = logits.argmax(1)

    predicted_labels.append(pred.cpu().detach().numpy())
    ground_truth_labels.append(yv.cpu().detach().numpy())
    batch_time = time.perf_counter() - batch_time
    total_time += batch_time
    #stats
    correct = (pred==yv).sum().item()
    total_correct += correct
    total_samples += yv.size(0)
    total_loss += loss.item() * yv.size(0)
    avg_loss = total_loss / total_samples
    avg_acc = total_correct / total_samples

    # TensorBoard logging
    writer.add_scalar("validate/loss", avg_loss, epoch)
    writer.add_scalar("validate/acc",  avg_acc,  epoch)
  print(f"Validation performed in {(total_time/batch_count)*1000} ms per batch on average")


  y_true = np.concatenate(ground_truth_labels)
  y_pred = np.concatenate(predicted_labels)

  #Per-class:
  labels = np.unique(y_true)
  prec, rec, f1, supp = precision_recall_fscore_support(
      y_true, y_pred,
      beta=1.0,
      average=None,
      labels=labels,
      zero_division=0,
  )
  #Macro-average instead:
  prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
      y_true, y_pred, beta=1.0, average="macro", zero_division=0
  )

  for label, p, r, f, s in zip(labels, prec, rec, f1, supp):
      writer.add_scalar(f"validate/precision_class_{label}", p, epoch)
      writer.add_scalar(f"validate/recall_class_{label}",    r, epoch)
      writer.add_scalar(f"validate/f1_class_{label}",        f, epoch)
      writer.add_scalar(f"validate/support_class_{label}",   s, epoch)
  writer.add_scalar("validate/precision_macro", prec_macro, epoch)
  writer.add_scalar("validate/recall_macro",    rec_macro,  epoch)
  writer.add_scalar("validate/f1_macro",        f1_macro,   epoch)

  avg_loss = total_loss / total_samples
  avg_acc = total_correct / total_samples

  print(f"{"Validate":<20}| Epoch {epoch:03d} | loss {avg_loss:.4f} | acc {avg_acc:.4f} | precision {prec_macro:.3f} | f1-score {f1_macro:.3f} |")
  return avg_loss, avg_acc

def train_model(
    train_loader,
    val_loader,
    N_MELS,
    NUM_SPK,
    writer,
    loss_fn=nn.CrossEntropyLoss(),
    epochs=30,
    patience = 5,
    rnn_hidden=256,
    rnn_layers=2,
    bidir=True,
    use_drive=True, #true if in colab, false if local
    checkpoint_dir="./checkpoints",
    drive_dir ="/content/drive/MyDrive/ML_PW/spk_checkpoints",
    optimizer = "AdamW",
    lr = 1e-3,
    weight_decay =1e-4,
    T_max = 50
    ):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  backbone = Backbone(no_mels=N_MELS, embed_dim=256, rnn_hidden=256, rnn_layers=2, bidir=True)
  model = SpeakerClassifier(backbone, num_speakers=NUM_SPK, aamsm_scaler=30.0, aamsm_margin=0.25).to(device)

  #optimizers: AdamW, Stochastic Gradient Descent
  if optimizer == "SGD":
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
  elif optimizer =="AdamW":
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)


  #scheduler and scaler
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)
  scaler = torch.amp.GradScaler("cuda",enabled=torch.cuda.is_available())


  #speeds up training if cuda is available, which it should be if training on colab GPU
  torch.backends.cudnn.benchmark = True
  torch.manual_seed(1337)

  global_step = 0
  save_root = prepare_save_dir(checkpoint_dir=checkpoint_dir, use_drive=use_drive, drive_dir=drive_dir)
  best_val_acc = 0.0
  best_path = os.path.join(save_root, "best_model.pt")
  last_path = os.path.join(save_root, "last_model.pt")

  for epoch in range(1, epochs+1):
    global_step, _, _ = train_one_epoch(model, train_loader, optimizer, scheduler, writer, epoch, device,global_step,loss_fn,scaler)
    val_loss, val_acc = validate(model, val_loader, writer, epoch, device, loss_fn)
    torch.save(model.state_dict(), last_path)
    if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(),best_path)
            print(f"Saved new best model (acc={val_acc:.4f})")

  print(f"Training complete. Best validation accuracy: {best_val_acc:.4f}")
  return model




##Data Loader

In [None]:
import h5py
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [None]:
class LMDataset(Dataset):
  """
    Expects layout like:
      /train/logmel      -> float (N, T, F) or (N, F, T)
      /train/label       -> int64  (N,)           [optional]
      /train/length      -> int64  (N,)           [optional, original T]
    Equivalent for /val, and /test
    """

  def __init__(self,
                h5_path,
                split="train",
                feature_key="logmel",
                label_key="label",
                length_key="length",
                time_dim="F",   # "T" if feature dataset is (N, T, F); "F" if (N, F, T)
                dtype=np.float32):
    super().__init__()
    self.h5_path = h5_path
    self.split = split
    self.feature_key = feature_key
    self.label_key = label_key
    self.length_key = length_key
    self.time_dim = time_dim
    self.dtype = dtype


    with h5py.File(self.h5_path, "r") as f:
      grp = f[self.split]
      self.N = grp[self.feature_key].shape[0]
      self.has_labels = self.label_key in grp
      self.has_lengths = self.length_key in grp

    self._h5 = None


  def __len__(self):
        return self.N


  def _ensure_open(self):
      if self._h5 is None:
          self._h5 = h5py.File(self.h5_path, "r", swmr=True, libver="latest")
          self._grp = self._h5[self.split]
          self._X = self._grp[self.feature_key]
          self._Y = self._grp[self.label_key] if self.has_labels else None
          self._L = self._grp[self.length_key] if self.has_lengths else None


  def __getitem__(self, idx):
    self._ensure_open()
    X = np.array(self._X[idx],dtype = self.dtype)
    if X.ndim != 2:
      raise ValueError(f"{self.split}/{self.feature_key} must be 2D per item; got {X.shape}")

    if self.time_dim == "F":    # data stored (F,T)
      X = X.T                   #Transpose to get (T,F)

    X = torch.from_numpy(X[None, ...]) #change dimension for Conv2D to (1,T,F)

    y = torch.tensor(self._Y[idx], dtype=torch.long) if self.has_labels else None



    t_len = int(self._L[idx]) if self.has_lengths else X.shape[1] # prefer saved length else derive it
    t_len = torch.tensor(t_len, dtype=torch.long)

    return X, y, t_len


  def _pad_collate(self,batch, pad_value=-80.0):
    xs, ys, lens = zip(*batch)
    B = len(xs)
    Tmax = max(int(x.shape[1]) for x in xs)
    F = xs[0].shape[2]

    X = xs[0].new_full((B, 1, Tmax, F), fill_value=pad_value)
    for i, x in enumerate(xs):
        T = x.shape[1]
        X[i, :, :T, :] = x

    y = None
    if ys[0] is not None:
        y = torch.stack(ys, dim=0)

    lengths = torch.stack(lens, dim=0)

    return X, y, lengths


def build_h5_loaders(
    h5_path: str,
    splits=("train", "val", "test"),
    feature_key="logmel",
    label_key="label",
    length_key="length",
    time_dim="F",           # "T" if (N,T,F); "F" if (N,F,T)
    batch_sizes=None,
    num_workers=8,
    pad_value=-80.0,
    pin_memory=True,
    persistent_workers=True,
    shuffle_train=True,
    ):
  """
  Returns a dict of {split: DataLoader}. Only creates loaders for existing splits in the file.
  """
  loaders = {}

  if batch_sizes is None:
    batch_sizes = {s: 32 for s in splits}

  with h5py.File(h5_path, "r") as f:
    available = {k for k in f.keys()}  # top-level groups (e.g., "train", "val", "test")

  for split in splits:
    if split not in available:
      continue

    dataset = LMDataset(
        h5_path=h5_path,
        split=split,
        feature_key=feature_key,
        label_key=label_key,
        length_key=length_key,
        time_dim=time_dim,
    )

    dataset_loader = DataLoader(
        dataset,
        batch_size=batch_sizes.get(split, 32),
        shuffle= (shuffle_train and split == "train"),
        collate_fn= lambda b: dataset._pad_collate(b,pad_value=pad_value),
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=persistent_workers if num_workers>0 else False,
        )
    loaders[split] = dataset_loader

  return loaders




##Inference Pipeline

pipeline which takes multisecond voice recording and processes it into a batch of 10 milisecond snippets which are then fed into the model using parallel processing;

their average should then be return (where P=1, F=0 so that if half of the inferences return False the average is 0.5).

Probably only some snippets should be chosen so that we're not processing 500 samples.

##Implementation

In [None]:
# from google.colab import drive  # type: ignore

# drive.mount("/content/drive",force_remount=True)

In [None]:
import yaml
def get_yaml_meta(h5_path):
    """Read YAML metadata from /meta/file_description.yaml in HDF5."""
    with h5py.File(h5_path, "r") as f:
        raw = f["/meta/file_description.yaml"][()].decode("utf-8")
    return yaml.safe_load(raw)


In [None]:
from torch.utils.tensorboard import SummaryWriter
import time, os
working_directory = "/content/drive/MyDrive/ML_PW"
log_dir = os.path.join("/content/logs", time.strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(log_dir=log_dir)


In [None]:
%reload_ext tensorboard
%tensorboard --logdir /content/logs


In [None]:
!ls /content/drive/MyDrive/ML_PW/

In [None]:
import yaml
def get_yaml_meta(h5_path):
    """Read YAML metadata from /meta/file_description.yaml in HDF5."""
    with h5py.File(h5_path, "r") as f:
        raw = f["/meta/file_description.yaml"][()].decode("utf-8")
    return yaml.safe_load(raw)

dataset_adress = "/content/drive/MyDrive/ML_PW/outputs/logmels_volnorm_silrem_04-12-25.h5"

meta = get_yaml_meta(dataset_adress)
no_mels = meta["n_mels"]                        # stored key name
no_speakers = meta["num_speakers"]              # stored key name
frequency_first = meta["frequency"]  == "True" if "frequency" in meta else True


loaders = build_h5_loaders(
    dataset_adress,
    splits=("train","val","test"),
    feature_key="logmel",
    label_key="label",                # set to a missing key if you have no labels
    length_key="length",              # optional in file; if absent, we infer from T
    time_dim="F",                     # set "F" if stored as (N,F,T)
    num_workers=2
)
train_loader = loaders["train"]
val_loader   = loaders["val"]


model = train_model(train_loader,val_loader,no_mels, no_speakers,writer,use_drive=True,drive_dir="/content/drive/MyDrive/spk_checkpoints")

In [None]:
!cp -r /content/logs "/content/drive/MyDrive/ML_PW/"

#Loading model from drive

In [None]:
dataset_adress = "/content/drive/MyDrive/ML_PW/outputs/logmels_volnorm_silrem_04-11-25.h5"

meta = get_yaml_meta(dataset_adress)
no_mels = meta["n_mels"]
no_speakers = meta["num_speakers"]
backbone = Backbone(no_mels=no_mels,
                    embed_dim=256,
                    rnn_hidden=256,
                    rnn_layers=2,
                    bidir=True)

model = SpeakerClassifier(backbone,
                          num_speakers=no_speakers,
                          aamsm_scaler=30.0,
                          aamsm_margin=0.25)

state_dict = torch.load("/content/drive/MyDrive/ML_PW/spk_checkpoints/best_model.pt",
                        map_location="cpu")
model.load_state_dict(state_dict)

In [None]:
model.test(loaders["test"])

In [None]:
test_recording_path = "/content/test-recording-michal.m4a"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = []
X, lengths = logmels_to_batch(process_audio_to_logmels(test_recording_path),device)
for i in range(X.size(0)): # per-item verify
  x_single = X[i:i+1] # keep batch dim [1,1,T,F]
  l_single = lengths[i:i+1] # keep batch dim [1]
  pred_ids, scores, decisions = model.infer( x=x_single, lengths=l_single, threshold=0.7 )
  results.append(( int(pred_ids.item()), # True/False
                  float(scores.item()),)) # similarity

for res in results:
  print(f"predicted id: {res[0]} | similarity: {res[1]}")

### Export Model to ONNX

This section exports the `backbone` of the trained `SpeakerClassifier` model to ONNX format. The backbone is responsible for generating speaker embeddings, which is typically what you need for inference tasks like speaker verification or identification.

The export process requires a dummy input tensor that matches the expected shape of the model's input (log-mel spectrograms). We also specify `dynamic_axes` to allow for variable batch sizes and time steps in the ONNX model, making it more flexible for inference with different input lengths.

In [None]:
import torch.onnx

def export_to_onnx(model, onnx_path, n_mels, dummy_time_steps=32):
    """
    Exports the model's backbone to ONNX format.

    Args:
        model (SpeakerClassifier): The trained PyTorch SpeakerClassifier model.
        onnx_path (str): The path where the ONNX model will be saved.
        n_mels (int): The number of mel bands used for spectrograms.
        dummy_time_steps (int): A representative number of time steps for the dummy input.
    """
    model.eval()  # Set the model to evaluation mode

    # Create a dummy input for tracing
    # The input to the backbone is [B, 1, T, F]
    # Let's use a batch size of 1, 1 channel, dummy_time_steps, and n_mels features
    dummy_input_x = torch.randn(1, 1, dummy_time_steps, n_mels).to(model.device)
    dummy_lengths = torch.tensor([dummy_time_steps], dtype=torch.long).to(model.device)

    # Define input and output names for the ONNX graph
    input_names = ["input_features", "input_lengths"]
    output_names = ["output_embedding"]

    # Define dynamic axes for variable batch size and time steps
    dynamic_axes = {
        "input_features": {0: "batch_size", 2: "time_steps"},
        "input_lengths": {0: "batch_size"},
        "output_embedding": {0: "batch_size"},
    }

    print(f"Exporting model to ONNX at: {onnx_path}")
    torch.onnx.export(
        model.backbone,  # Export only the backbone
        (dummy_input_x, dummy_lengths),  # Model input
        onnx_path,       # Output file name
        verbose=False,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
        opset_version=11 # Opset version, ensure compatibility
    )
    print("ONNX export complete!")

# Define the output path for the ONNX model
onnx_output_path = "/content/drive/MyDrive/ML_PW/speaker_classifier_backbone.onnx"

# Execute the export function
export_to_onnx(model, onnx_output_path, no_mels)
