We train the Siamese model on multiple randomly sampled short audio segments per song. Positive pairs are drawn both from temporally corresponding regions between originals and covers, and from non-aligned regions to encourage robustness. This allows the model to first learn harmonic similarity and then generalize across structural variation, while negatives are sampled from different songs to enforce discrimination.

In [None]:
# ==============================================================================
# 1. SETUP & DEPENDENCIES
# ==============================================================================
import os
import subprocess
import sys

# Install necessary libraries if not present
packages = ["audiomentations", "torchaudio"]
for package in packages:
    try:
        __import__(package)
    except ImportError:
        print(f"üì¶ Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import random
import glob
import time
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from audiomentations import Compose, AddBackgroundNoise, PitchShift, TimeStretch, Gain, PolarityInversion

# Mount Drive
from google.colab import drive
drive.mount('/content/drive')

# ==============================================================================
 #DATA INGESTION (Drive -> Local)
# ==============================================================================
print("\nüöÄ STARTING DATA EXTRACTION...")

# Define Source (Drive) and Destination (Local)
ZIP_SOURCE_DIR = '/content/drive/MyDrive/FINE_TUNE_V3'
LOCAL_BASE_DIR = '/content/data'

# Paths for specific zips on Drive
ZIP_ORIGINALS = os.path.join(ZIP_SOURCE_DIR, 'train_originals_1300.zip')
ZIP_COVERS = os.path.join(ZIP_SOURCE_DIR, 'train_covers_1300.zip')
ZIP_NOISE_16K = os.path.join(ZIP_SOURCE_DIR, 'noise_data_16k.zip')

# Define Extraction Destinations
DIR_ORIGINALS = os.path.join(LOCAL_BASE_DIR, 'originals')
DIR_COVERS = os.path.join(LOCAL_BASE_DIR, 'covers')
DIR_NOISE_16K = os.path.join(LOCAL_BASE_DIR, 'noise_16k')

# Create Directories
for d in [DIR_ORIGINALS, DIR_COVERS, DIR_NOISE_16K]:
    os.makedirs(d, exist_ok=True)

# Unzip Functions
def safe_unzip(zip_path, dest_path, desc):
    if not os.listdir(dest_path):
        print(f"üìÇ Unzipping {desc}...")
        subprocess.run(f"unzip -q -n '{zip_path}' -d '{dest_path}'", shell=True)
    else:
        print(f"‚úÖ {desc} already extracted.")

safe_unzip(ZIP_ORIGINALS, DIR_ORIGINALS, "Training Originals")
safe_unzip(ZIP_COVERS, DIR_COVERS, "Training Covers")
safe_unzip(ZIP_NOISE_16K, DIR_NOISE_16K, "Noise Data")

# --- NOISE CONVERSION (MP3 -> WAV) ---
# Ensures multi-thread safety for DataLoader
def convert_mp3_to_wav(noise_dir):
    mp3_files = glob.glob(os.path.join(noise_dir, '**', '*.mp3'), recursive=True)
    if not mp3_files:
        return
    print(f"üîÑ Converting {len(mp3_files)} noise MP3s to WAV...")
    for mp3_path in tqdm(mp3_files):
        wav_path = mp3_path.replace('.mp3', '.wav')
        try:
            waveform, sr = torchaudio.load(mp3_path)
            torchaudio.save(wav_path, waveform, sr)
            os.remove(mp3_path)
        except Exception:
            continue

convert_mp3_to_wav(DIR_NOISE_16K)

# ==============================================================================
# 3. ROBUST CSV MAPPING & VALIDATION
# ==============================================================================
def load_verified_pairs(csv_path, originals_dir, covers_dir):
    """
    Reads the CSV and validates that the file pairs actually exist on the local disk.
    Adapted for columns: 'original_filename', 'augmented_filename'
    """
    if not os.path.exists(csv_path):
        print(f"‚ùå Error: CSV not found at {csv_path}")
        return [], {}

    print(f"üìñ Reading CSV: {csv_path}")
    df = pd.read_csv(csv_path)

    valid_anchors = []
    pair_map = {}
    missing_count = 0

    # --- DETECTED COLUMNS FROM YOUR INFO ---
    col_orig = 'original_filename'
    col_pair = 'augmented_filename'

    print(f"üîç Validating {len(df)} pairs...")
    print(f"   Anchor Col: '{col_orig}'")
    print(f"   Pair Col:   '{col_pair}'")

    for index, row in df.iterrows():
        orig_name = row[col_orig]
        pair_name = row[col_pair]

        # 1. Check Anchor in Originals Dir
        path_orig = os.path.join(originals_dir, str(orig_name))

        # 2. Check Pair in Covers Dir (or fallback to the absolute path column if present)
        path_pair = os.path.join(covers_dir, str(pair_name))

        # Fallback: If not in covers_dir, checks if 'path' column has a valid full path
        if not os.path.exists(path_pair) and 'path' in row:
             if os.path.exists(row['path']):
                 path_pair = row['path']

        # 3. Validate existence
        if os.path.exists(path_orig) and os.path.exists(path_pair):
            valid_anchors.append(orig_name)
            pair_map[orig_name] = pair_name
        else:
            missing_count += 1
            if missing_count <= 5:
                # Debug print to help you see WHICH path is wrong
                if not os.path.exists(path_orig):
                    print(f"   ‚ö†Ô∏è Missing Anchor: {path_orig}")
                if not os.path.exists(path_pair):
                    print(f"   ‚ö†Ô∏è Missing Pair:   {path_pair}")

    print("-" * 40)
    print(f"‚úÖ Found {len(valid_anchors)} valid pairs locally.")
    if missing_count > 0:
        print(f"‚ùå Skipped {missing_count} pairs (files missing).")

    return valid_anchors, pair_map

# --- EXECUTE PAIRING ---
CSV_PATH = "/content/drive/Othercomputers/My laptop/Desktop/FINE-TUNE/Data/dataset_tracking.csv"
DIR_ORIGINALS = '/content/data/originals'
DIR_COVERS = '/content/data/covers' # Points to where 'augmented_filename' files live
NOISE_PATH_FINAL = DIR_NOISE_16K

# Run Validation
VALID_ANCHORS, PAIR_MAP = load_verified_pairs(CSV_PATH, DIR_ORIGINALS, DIR_COVERS)

In [None]:
import os
import subprocess

# Define Paths
ZIP_SOURCE_DIR = '/content/drive/MyDrive/FINE_TUNE_V3'
ZIP_EVAL = os.path.join(ZIP_SOURCE_DIR, 'eval_originals_300.zip')
DIR_EVAL = '/content/data/eval'

# Unzip
if not os.path.exists(DIR_EVAL):
    os.makedirs(DIR_EVAL, exist_ok=True)

print(f"üìÇ Unzipping Evaluation Set to {DIR_EVAL}...")
if os.path.exists(ZIP_EVAL):
    subprocess.run(f"unzip -q -n '{ZIP_EVAL}' -d '{DIR_EVAL}'", shell=True)
    num_files = len(os.listdir(DIR_EVAL))
    print(f"‚úÖ Success! Found {num_files} wav files ready for evaluation.")
else:
    print(f"‚ùå Error: Could not find {ZIP_EVAL}. Check your Drive.")

In [None]:
import torch
import torchaudio
import numpy as np
import random
import os
from torch.utils.data import Dataset
from audiomentations import (
    Compose,
    AddBackgroundNoise,
    PitchShift,
    TimeStretch,
    Gain,
    PolarityInversion
)
import torch.nn.functional as F


class DualObjectiveSiameseDataset(Dataset):
    """
    Triplet dataset with:
    - multiple stochastic 3s crops per song
    - aligned + unaligned cover positives
    - self-invariance task
    """

    def __init__(
        self,
        anchor_list,
        pair_map,
        originals_dir,
        covers_dir,
        noise_dir,
        sample_rate=16000,
        duration=3.0,
        aligned_cover_prob=0.6,
        max_align_jitter_sec=2.0
    ):
        self.anchor_list = anchor_list
        self.pair_map = pair_map
        self.originals_dir = originals_dir
        self.covers_dir = covers_dir

        self.sample_rate = sample_rate
        self.num_samples = int(sample_rate * duration)
        self.aligned_cover_prob = aligned_cover_prob
        self.max_align_jitter = int(max_align_jitter_sec * sample_rate)

        self.num_songs = len(anchor_list)

        # Augmentation
        self.augment = Compose([
            AddBackgroundNoise(
                sounds_path=noise_dir,
                min_snr_db=3.0,
                max_snr_db=15.0,
                p=0.8
            ),
            Gain(min_gain_db=-6.0, max_gain_db=6.0, p=0.2),
            PitchShift(min_semitones=-2, max_semitones=2, p=0.5),
            TimeStretch(min_rate=0.9, max_rate=1.1, p=0.4),
            PolarityInversion(p=0.2),
        ])

        # Spectrogram
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=1024,
            hop_length=512,
            n_mels=64
        )
        self.db_transform = torchaudio.transforms.AmplitudeToDB()

    # --------------------------------------------------
    def _load_audio(self, path):
        try:
            wav, sr = torchaudio.load(path)
        except Exception:
            return None

        if sr != self.sample_rate:
            wav = torchaudio.transforms.Resample(sr, self.sample_rate)(wav)

        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)

        return wav

    # --------------------------------------------------
    def _crop(self, waveform, start_sample):
        total_len = waveform.shape[1]

        if total_len < self.num_samples:
            waveform = F.pad(waveform, (0, self.num_samples - total_len))
            return waveform[:, :self.num_samples]

        start_sample = max(0, min(start_sample, total_len - self.num_samples))
        return waveform[:, start_sample:start_sample + self.num_samples]

    # --------------------------------------------------
    def _to_spec(self, audio_np):
        tensor = torch.from_numpy(audio_np).unsqueeze(0)
        spec = self.mel_transform(tensor)
        spec = self.db_transform(spec)
        spec = (spec - spec.mean()) / (spec.std() + 1e-6)
        return spec

    # --------------------------------------------------
    def __getitem__(self, idx):

        # -------- task selection --------
        if idx < self.num_songs:
            real_idx = idx
            is_cover_task = True
        else:
            real_idx = idx - self.num_songs
            is_cover_task = False

        anchor_name = self.anchor_list[real_idx]
        anchor_path = os.path.join(self.originals_dir, anchor_name)

        anchor_wav = self._load_audio(anchor_path)
        if anchor_wav is None:
            anchor_wav = torch.zeros(1, self.num_samples)

        total_len = anchor_wav.shape[1]
        anchor_start = (
            0 if total_len <= self.num_samples
            else random.randint(0, total_len - self.num_samples)
        )

        anchor_crop = self._crop(anchor_wav, anchor_start)
        anchor_raw = anchor_crop.squeeze(0).numpy()

        # -------- positive --------
        if is_cover_task:
            cover_name = self.pair_map[anchor_name]
            cover_path = os.path.join(self.covers_dir, cover_name)
            cover_wav = self._load_audio(cover_path)

            if cover_wav is None:
                positive_raw = anchor_raw.copy()
            else:
                if random.random() < self.aligned_cover_prob:
                    jitter = random.randint(-self.max_align_jitter, self.max_align_jitter)
                    pos_start = anchor_start + jitter
                else:
                    pos_start = random.randint(
                        0,
                        max(0, cover_wav.shape[1] - self.num_samples)
                    )

                pos_crop = self._crop(cover_wav, pos_start)
                positive_raw = pos_crop.squeeze(0).numpy()
        else:
            # self-invariance
            positive_raw = anchor_raw.copy()

        try:
            positive_aug = self.augment(samples=positive_raw, sample_rate=self.sample_rate)
        except Exception:
            positive_aug = positive_raw

        # -------- negative --------
        neg_idx = random.randint(0, self.num_songs - 1)
        while neg_idx == real_idx:
            neg_idx = random.randint(0, self.num_songs - 1)

        neg_path = os.path.join(self.originals_dir, self.anchor_list[neg_idx])
        neg_wav = self._load_audio(neg_path)

        if neg_wav is None:
            negative_raw = anchor_raw.copy()
        else:
            neg_start = random.randint(
                0,
                max(0, neg_wav.shape[1] - self.num_samples)
            )
            neg_crop = self._crop(neg_wav, neg_start)
            negative_raw = neg_crop.squeeze(0).numpy()

        try:
            negative_aug = self.augment(samples=negative_raw, sample_rate=self.sample_rate)
        except Exception:
            negative_aug = negative_raw

        return (
            self._to_spec(anchor_raw),
            self._to_spec(positive_aug),
            self._to_spec(negative_aug),
        )

    def __len__(self):
        return self.num_songs * 2


In [None]:
# ==============================================================================
# üß± BLOCK 5: FINAL TRAINING LOOP (WITH MODEL B)
# ==============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import os

# --- 1. MODEL DEFINITION (MODEL B - CHOSEN) ---
class AudioSiameseNet(nn.Module):
    def __init__(self, embed_dim=128):
        super().__init__()

        # ----------- CONV BLOCK 1 -----------
        # Input: (Batch, 1, 64, ~94)
        """
        1 ‚Üí mono
        64 ‚Üí mel bins
        ~94 ‚Üí time frames (3 sec @ hop 512)

        Conv Block 1 (learns Local time‚Äìfrequency edges)
          Conv2d(1 ‚Üí 32)
          BatchNorm
          ReLU
          MaxPool(2√ó2) (Early pooling removes noise)

        Conv Block 2 (learns Harmonic stacks)
          Conv2d(32 ‚Üí 64)
          BatchNorm
          ReLU
          MaxPool(2√ó2)

        Conv Block 3 (no pooling)
          Conv2d(64 ‚Üí 128)
          BatchNorm
          ReLU
        Why no MaxPool here?
          Pooling here would destroy: melody contour | rhythmic micro-patterns
          At this stage, the model is learning song identity, not noise suppression

        Global Average Pooling
          AdaptiveAvgPool2d((1,1)) (B, 128, H, W) ‚Üí (B, 128)

        Projection head (embedding layer)
          Linear(128 ‚Üí 256) (Increase representational capacity)
          ReLU
          Dropout(0.3)
          Linear(256 ‚Üí embed_dim)

        F.normalize(x, p=2)
        This enforces:||embedding|| = 1
            So:
            cosine similarity = dot product
            perfect for FAISS
            stable for contrastive / triplet loss

        """
        self.block1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)          # Downsample: 32 mels, 47 time
        )

        # ----------- CONV BLOCK 2 -----------
        self.block2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)          # Downsample: 16 mels, 23 time
        )

        # ----------- CONV BLOCK 3 -----------
        # No MaxPool here to preserve melody resolution
        self.block3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
            # Output: (Batch, 128, 16, 23)
        )

        # ----------- GLOBAL POOL -----------
        # Averages the feature map into a single vector
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        # ----------- PROJECTION HEAD -----------
        self.fc = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.3),       # Safety against noise overfitting
            nn.Linear(256, embed_dim),
        )

    def forward_one(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.global_pool(x)          # -> (B, 128, 1, 1)
        x = x.view(x.size(0), -1)        # -> (B, 128)
        x = self.fc(x)                   # -> (B, embed_dim)
        return F.normalize(x, p=2, dim=1)

    def forward(self, anchor, positive, negative,*args):
        return self.forward_one(anchor), self.forward_one(positive), self.forward_one(negative)


In [None]:
# ==============================================================================
# üß± SHALLOW CNN TRAINING LOOP (TRIPLET LOSS + RESUME + AMP)
# ==============================================================================

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler

def train_siamese_network():
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nüî• Training on: {DEVICE}")

    # ---------------- CONFIG ----------------
    BATCH_SIZE = 64
    EPOCHS = 100
    LR = 1e-4
    MARGIN = 0.75
    PATIENCE = 4

    # ---------------- OUTPUT PATHS ----------------
    BASE_DIR = "/content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN"
    os.makedirs(BASE_DIR, exist_ok=True)

    BEST_MODEL_PATH = os.path.join(BASE_DIR, "best_shallow_cnn.pth")
    LATEST_CKPT_PATH = os.path.join(BASE_DIR, "checkpoint_latest.pth")

    print(f"üìÇ Models will be saved to: {BASE_DIR}")

    # ---------------- DATASET ----------------
    if "VALID_ANCHORS" not in globals() or not VALID_ANCHORS:
        raise RuntimeError("‚ùå VALID_ANCHORS not found. Run CSV mapping first.")

    dataset = DualObjectiveSiameseDataset(
        anchor_list=VALID_ANCHORS,
        pair_map=PAIR_MAP,
        originals_dir="/content/data/originals",
        covers_dir="/content/data/covers",
        noise_dir="/content/data/noise_16k",
        sample_rate=16000,
        duration=3.0
    )

    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    print(f"‚úÖ Dataset ready: {len(dataset)} samples")

    # ---------------- MODEL ----------------
    model = AudioSiameseNet().to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LR)
    criterion = nn.TripletMarginLoss(margin=MARGIN)

    scaler = GradScaler()
    scheduler = lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=0.5,
        patience=PATIENCE,
        min_lr=1e-7
    )

    # ---------------- RESUME LOGIC ----------------
    start_epoch = 0
    best_loss = float("inf")

    if os.path.exists(LATEST_CKPT_PATH):
        print(f"üîÑ Resuming from {LATEST_CKPT_PATH}")
        ckpt = torch.load(LATEST_CKPT_PATH, map_location=DEVICE)

        model.load_state_dict(ckpt["model_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        scheduler.load_state_dict(ckpt["scheduler_state_dict"])

        start_epoch = ckpt["epoch"] + 1
        best_loss = ckpt["best_loss"]

        print(f"   ‚ñ∂ Resumed at epoch {start_epoch}")
        print(f"   ‚ñ∂ Best loss so far: {best_loss:.4f}")
    else:
        print("üÜï No checkpoint found. Starting fresh training.")

    # ---------------- TRAINING ----------------
    print("üöÄ Starting Shallow CNN Training...")

    for epoch in range(start_epoch, EPOCHS):
        model.train()
        running_loss = 0.0

        for batch_idx, (anc, pos, neg) in enumerate(dataloader):
            anc = anc.to(DEVICE, non_blocking=True)
            pos = pos.to(DEVICE, non_blocking=True)
            neg = neg.to(DEVICE, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)

            with autocast():
                emb_a, emb_p, emb_n = model(anc, pos, neg)
                loss = criterion(emb_a, emb_p, emb_n)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

            if batch_idx % 20 == 0:
                print(
                    f"   Epoch {epoch+1} | Batch {batch_idx}/{len(dataloader)} "
                    f"| Loss: {loss.item():.4f}"
                )

        avg_loss = running_loss / len(dataloader)
        print(f"\nüì¢ Epoch {epoch+1}/{EPOCHS} | Avg Loss: {avg_loss:.4f}")

        # ---------------- LR Scheduler ----------------
        old_lr = optimizer.param_groups[0]["lr"]
        scheduler.step(avg_loss)
        new_lr = optimizer.param_groups[0]["lr"]

        if new_lr != old_lr:
            print(f"üìâ LR reduced: {old_lr:.2e} ‚Üí {new_lr:.2e}")

        # ---------------- SAVE BEST ----------------
        if avg_loss < best_loss:
            print(f"‚≠ê New BEST: {best_loss:.4f} ‚Üí {avg_loss:.4f}")
            best_loss = avg_loss
            torch.save(model.state_dict(), BEST_MODEL_PATH)

        # ---------------- SAVE CHECKPOINTS ----------------
        epoch_ckpt_path = os.path.join(BASE_DIR, f"checkpoint_epoch_{epoch}.pth")

        ckpt = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "loss": avg_loss,
            "best_loss": best_loss
        }

        torch.save(ckpt, LATEST_CKPT_PATH)      # rolling checkpoint
        torch.save(ckpt, epoch_ckpt_path)       # backup checkpoint

        print(f"üíæ Saved checkpoint: {epoch_ckpt_path}\n")

if __name__ == "__main__":
    train_siamese_network()


üî• Training on: cuda
üìÇ Models will be saved to: /content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN
‚úÖ Dataset ready: 1900 samples
üÜï No checkpoint found. Starting fresh training.
üöÄ Starting Shallow CNN Training...


  scaler = GradScaler()
  with autocast():


   Epoch 1 | Batch 0/30 | Loss: 0.7440
   Epoch 1 | Batch 20/30 | Loss: 0.7497

üì¢ Epoch 1/100 | Avg Loss: 0.7470
‚≠ê New BEST: inf ‚Üí 0.7470
üíæ Saved checkpoint: /content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN/checkpoint_epoch_0.pth

   Epoch 2 | Batch 0/30 | Loss: 0.7324
   Epoch 2 | Batch 20/30 | Loss: 0.7319

üì¢ Epoch 2/100 | Avg Loss: 0.7439
‚≠ê New BEST: 0.7470 ‚Üí 0.7439
üíæ Saved checkpoint: /content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN/checkpoint_epoch_1.pth

   Epoch 3 | Batch 0/30 | Loss: 0.7440
   Epoch 3 | Batch 20/30 | Loss: 0.7407

üì¢ Epoch 3/100 | Avg Loss: 0.7426
‚≠ê New BEST: 0.7439 ‚Üí 0.7426
üíæ Saved checkpoint: /content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN/checkpoint_epoch_2.pth

   Epoch 4 | Batch 0/30 | Loss: 0.7451
   Epoch 4 | Batch 20/30 | Loss: 0.7343

üì¢ Epoch 4/100 | Avg Loss: 0.7383
‚≠ê New BEST: 0.7426 ‚Üí 0.7383
üíæ Saved checkpoint: /content/drive/MyDrive/FIND_TUNE/spectro

KeyboardInterrupt: 

EVAL

In [None]:
# ==============================================================================
# üß™ BLOCK 8: EVALUATE CHECKPOINT 34
# ==============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import soundfile as sf
import numpy as np
import os, glob, random, math
import gc
from tqdm import tqdm
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from audiomentations import Compose, AddBackgroundNoise, PitchShift, TimeStretch, Gain

# ------------------------------------------------------------------------------
# ‚öôÔ∏è CONFIG
# ------------------------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

EVAL_DIR = "/content/data/eval"
NOISE_DIR = "/content/data/noise_16k"

# üîë TARGET CHECKPOINT: EPOCH 34
# Update this path if your checkpoints are stored elsewhere
CHECKPOINT_DIR = "/content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN"
MODEL_PATH = os.path.join(CHECKPOINT_DIR, "checkpoint_epoch_34.pth")

SAMPLE_RATE = 16000
WIN_SEC = 3.0
HOP_SEC = 1.5
QUERY_LEN = 15
INFERENCE_BATCH_SIZE = 64
TOLERANCE = 1.5
SIGMA = 0.5
SPREAD_FACTOR = 0.3

# ------------------------------------------------------------------------------
# 1. ARCHITECTURE: SHALLOW CNN
# ------------------------------------------------------------------------------
class AudioSiameseNet(nn.Module):
    def __init__(self, embed_dim=128):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU()
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(128, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, embed_dim)
        )

    def forward_one(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.pool(x).view(x.size(0), -1)
        return F.normalize(self.fc(x), p=2, dim=1)

# ------------------------------------------------------------------------------
# 2. HELPER: ROBUST LOADING
# ------------------------------------------------------------------------------
def robust_load(path):
    try:
        wav_np, sr = sf.read(path)
        wav_np = wav_np.astype(np.float32)
        wav = torch.from_numpy(wav_np)
        if wav.ndim == 1: wav = wav.unsqueeze(0)
        else: wav = wav.t()
        if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True)
        return wav, sr
    except: return None, 0

# ------------------------------------------------------------------------------
# 3. HELPER: EMBEDDING
# ------------------------------------------------------------------------------
mel = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE, n_fft=1024, hop_length=512, n_mels=64
).to(DEVICE)
db = torchaudio.transforms.AmplitudeToDB().to(DEVICE)

def audio_to_embedding(model, wav):
    samples_win = int(SAMPLE_RATE * WIN_SEC)
    samples_hop = int(SAMPLE_RATE * HOP_SEC)
    if wav.shape[1] < samples_win: wav = F.pad(wav, (0, samples_win - wav.shape[1]))

    windows, times = [], []
    for i in range(0, wav.shape[1] - samples_win + 1, samples_hop):
        windows.append(wav[:, i:i + samples_win])
        times.append(i / SAMPLE_RATE)

    if not windows: return None, None

    all_embeddings = []
    for i in range(0, len(windows), INFERENCE_BATCH_SIZE):
        batch = torch.stack(windows[i : i + INFERENCE_BATCH_SIZE]).to(DEVICE)
        spec = db(mel(batch))
        spec = (spec - spec.mean(dim=(2,3), keepdim=True)) / (spec.std(dim=(2,3), keepdim=True) + 1e-6)
        with torch.no_grad():
            all_embeddings.append(model.forward_one(spec).cpu())

    return torch.cat(all_embeddings), times

# ------------------------------------------------------------------------------
# 4. BUILD DB
# ------------------------------------------------------------------------------
class EvalDataset(Dataset):
    def __init__(self, file_paths): self.files = file_paths
    def __len__(self): return len(self.files)
    def __getitem__(self, idx):
        path = self.files[idx]
        wav, sr = robust_load(path)
        if wav is None: return torch.zeros(1, SAMPLE_RATE), "ERROR"
        if sr != SAMPLE_RATE: wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
        return wav, os.path.basename(path)

def build_database(model):
    print("üèóÔ∏è Building vector database...")
    files = glob.glob(os.path.join(EVAL_DIR, "*.wav"))
    loader = DataLoader(EvalDataset(files), batch_size=1, shuffle=False, num_workers=2)
    vectors, metadata = [], []

    model.eval()
    with torch.no_grad():
        for wav, name in tqdm(loader):
            if name[0] == "ERROR": continue
            emb, times = audio_to_embedding(model, wav.squeeze(0))
            if emb is None: continue
            for i, t in enumerate(times):
                vectors.append(emb[i])
                metadata.append({"name": name[0], "offset": t})

    if not vectors: return None, None
    return torch.stack(vectors).to(DEVICE), metadata

# ------------------------------------------------------------------------------
# 5. RUNNER
# ------------------------------------------------------------------------------
def calculate_scores(matches):
    scores = defaultdict(lambda: defaultdict(float))
    for dist, meta, q_t in matches:
        w = math.exp(-(dist**2) / (2 * SIGMA**2))
        if w < 0.01: continue
        b = int(round((meta["offset"] - q_t) / TOLERANCE))
        scores[meta["name"]][b] += w
        scores[meta["name"]][b-1] += w * SPREAD_FACTOR
        scores[meta["name"]][b+1] += w * SPREAD_FACTOR
    return sorted([(k, max(v.values())) for k, v in scores.items()], key=lambda x: x[1], reverse=True)

def run_evaluation(model, db_vecs, db_meta, trials=100):
    if db_vecs is None: return
    db_vecs = db_vecs.to(DEVICE)

    modes = ["clean", "soft", "hard"]
    augmenters = {
        "soft": Compose([Gain(-3, 3, p=0.5), PitchShift(-1, 1, p=0.3)]),
        "hard": Compose([AddBackgroundNoise(NOISE_DIR, 5, 15, p=1.0) if os.path.exists(NOISE_DIR) else Gain(0,0,p=0), PitchShift(-2, 2, p=0.8)])
    }

    results = {m: {"t1":0, "t5":0} for m in modes}
    songs = list(set(m["name"] for m in db_meta))

    print(f"\n‚ö° Eval: {trials} trials/mode...")
    for mode in modes:
        print(f"‚ñ∂ {mode.upper()}")
        aug = augmenters.get(mode, None)
        for _ in tqdm(range(trials)):
            target = random.choice(songs)
            wav, sr = robust_load(os.path.join(EVAL_DIR, target))
            if wav is None: continue
            if sr != SAMPLE_RATE: wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)

            max_len = int(QUERY_LEN * SAMPLE_RATE)
            if wav.shape[1] > max_len:
                s = random.randint(0, wav.shape[1] - max_len)
                wav = wav[:, s:s + max_len]

            if aug:
                try: wav = torch.from_numpy(aug(samples=wav.squeeze(0).numpy(), sample_rate=SAMPLE_RATE)).unsqueeze(0)
                except: pass

            q_emb, q_times = audio_to_embedding(model, wav)
            if q_emb is None: continue

            dists = torch.cdist(q_emb.to(DEVICE), db_vecs)
            vals, idxs = torch.topk(dists, k=5, largest=False)

            matches = []
            for i in range(q_emb.shape[0]):
                for k in range(5):
                    matches.append((vals[i,k].item(), db_meta[idxs[i,k]], q_times[i]))

            ranked = calculate_scores(matches)
            if not ranked: continue

            if target == ranked[0][0]: results[mode]["t1"] += 1
            if target in [x[0] for x in ranked[:5]]: results[mode]["t5"] += 1

    print("\nüèÜ RESULTS (Epoch 34)")
    print(f"{'MODE':<10} | {'TOP-1':<8} | {'TOP-5':<8}")
    print("-" * 32)
    for m in modes:
        print(f"{m.upper():<10} | {results[m]['t1']/trials*100:.1f}%     | {results[m]['t5']/trials*100:.1f}%")

if __name__ == "__main__":
    if not os.path.exists(MODEL_PATH):
        print(f"‚ùå Checkpoint not found: {MODEL_PATH}")
        print(f"   Check your drive path or epoch number.")
    else:
        print(f"üìÇ Loading Checkpoint 34: {MODEL_PATH}")
        model = AudioSiameseNet(embed_dim=128).to(DEVICE)

        # üîë Extract 'model_state_dict' from checkpoint
        checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
        if "model_state_dict" in checkpoint:
            model.load_state_dict(checkpoint["model_state_dict"])
            print(f"‚úÖ Weights loaded from Epoch {checkpoint.get('epoch', '?')}")
        else:
            model.load_state_dict(checkpoint) # Legacy/Raw weights fallback

        model.eval()

        db_vecs, db_meta = build_database(model)
        run_evaluation(model, db_vecs, db_meta, trials=100)

üìÇ Loading Checkpoint 34: /content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN/checkpoint_epoch_34.pth
‚úÖ Weights loaded from Epoch 34
üèóÔ∏è Building vector database...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 284/284 [00:48<00:00,  5.90it/s]



‚ö° Eval: 100 trials/mode...
‚ñ∂ CLEAN


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [00:07<00:00, 12.95it/s]


‚ñ∂ SOFT


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [00:12<00:00,  8.15it/s]


‚ñ∂ HARD


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [00:14<00:00,  6.82it/s]


üèÜ RESULTS (Epoch 34)
MODE       | TOP-1    | TOP-5   
--------------------------------
CLEAN      | 98.0%     | 100.0%
SOFT       | 97.0%     | 100.0%
HARD       | 39.0%     | 52.0%





COMAPRISON IN BOTH MODEL(shallow CNN and CRNN(CNN+bilstm+attention )) ACCURACY

SHALLOW CNN
        MODE       | TOP-1    | TOP-5   
        
        CLEAN      | 98.0%     | 100.0%
        SOFT       | 97.0%     | 100.0%
        HARD       | 39.0%     | 52.0%

        Audio (3 sec)
          ‚Üì
        Mel Spectrogram (64 √ó ~94)
          ‚Üì
        CNN Block 1
          ‚Üì
        CNN Block 2
          ‚Üì
        CNN Block 3
          ‚Üì
        Global Average Pool
          ‚Üì
        Projection Head
          ‚Üì
        L2-Normalized 128-D Embedding


CRNN
      
      CLEAN  | Top1: 27.0% | Top5: 43.0% | Top10: 48.0%
      SOFT   | Top1: 25.0% | Top5: 32.0% | Top10: 40.0%
      HARD   | Top1: 4.0% | Top5: 13.0% | Top10: 20.0%

      Mel Spectrogram (64 √ó ~94)
              ‚Üì
      CNN  ‚Üí local timbre + pitch invariance
              ‚Üì
      BiLSTM ‚Üí temporal melody progression
              ‚Üì
      Attention ‚Üí focus on salient moments
              ‚Üì
      Projection ‚Üí fixed 128-D embedding


SHALLOW CNN benefits

      Early pooling removes noise

      No temporal modeling = no overthinking

      GlobalAvgPool enforces invariance

      Embedding represents what, not when

Why the BiLSTM CRNN underperforms (despite being ‚Äúsmarter‚Äù)


      BiLSTM introduces ordering sensitivity

      Your CRNN explicitly models:

      melody progression over time

      But in retrieval:
        Covers reorder sections
        Queries start at arbitrary points
        Chorus ‚â† verse
        3 seconds ‚â† meaningful musical sentence

So the LSTM is forced to answer:‚ÄúDoes this sequence match another sequence?‚Äù

But the correct question is:‚ÄúDoes this fragment contain song-identity evidence?‚Äù

‚ÄúAttention focuses on salient moments like chorus‚Äù

      That is correct for classification.

      But for retrieval:
      Problem:

          Attention suppresses ‚Äúboring‚Äù frames
          But boring frames still contain fingerprint info

      In shallow CNN:

          Everything votes equally
          Weak evidence still accumulates

In CRNN:

Attention throws information away

If the ‚Äúsalient‚Äù moment is not present in the 3s clip ‚Üí embedding collapses

This explains:

CLEAN drops from 98% ‚Üí 27%

HARD collapses to 4%

In [None]:

"""import torch
import os

# 1. Define Paths
CHECKPOINT_DIR = "/content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN"
SOURCE_CHECKPOINT = os.path.join(CHECKPOINT_DIR, "checkpoint_epoch_34.pth")
DESTINATION_MODEL = os.path.join(CHECKPOINT_DIR, "best_model.pth")

print(f"üìÇ Source: {SOURCE_CHECKPOINT}")
print(f"üìÇ Dest:   {DESTINATION_MODEL}")

if not os.path.exists(SOURCE_CHECKPOINT):
    print("‚ùå Error: Source checkpoint not found.")
else:
    # 2. Load the checkpoint
    # Map to CPU to avoid GPU OOM if you are just doing file ops
    checkpoint = torch.load(SOURCE_CHECKPOINT, map_location="cpu")

    # 3. Extract the weights
    if "model_state_dict" in checkpoint:
        print(f"‚úÖ Found state dict for Epoch {checkpoint.get('epoch', '?')}")
        clean_weights = checkpoint["model_state_dict"]
    else:
        # If it was already a clean weight file
        clean_weights = checkpoint

    # 4. Save as best_model.pth
    torch.save(clean_weights, DESTINATION_MODEL)
    print(f"üíæ Saved successfully to: {DESTINATION_MODEL}")

    # 5. Verify file size
    size_mb = os.path.getsize(DESTINATION_MODEL) / (1024 * 1024)
    print(f"üì¶ File Size: {size_mb:.2f} MB")"""

üìÇ Source: /content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN/checkpoint_epoch_34.pth
üìÇ Dest:   /content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN/best_model.pth
‚úÖ Found state dict for Epoch 34
üíæ Saved successfully to: /content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN/best_model.pth
üì¶ File Size: 0.62 MB


Model B ‚Äî CRNN (CNN + BiLSTM + Attention)

Initially I used a CNN-based Siamese network where temporal alignment was handled entirely in post-processing via offset clustering, similar to Shazam.
Later, I explored a CRNN architecture with BiLSTM and attention, which pushes temporal modeling into the embedding itself. This improved robustness for noisy and cover-style audio, at the cost of higher latency. In the final design, I treat the CRNN as a high-confidence fallback when classical fingerprinting fails.


Key properties

        Preserves time ordering

        Learns melodic evolution

        Embedding answers: ‚Äúhow does this evolve over time?‚Äù


In [None]:
# ==============================================================================
# 1. SETUP & DEPENDENCIES
# ==============================================================================
import os
import subprocess
import sys

# Install necessary libraries if not present
packages = ["audiomentations", "torchaudio"]
for package in packages:
    try:
        __import__(package)
    except ImportError:
        print(f"üì¶ Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import random
import glob
import time
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from audiomentations import Compose, AddBackgroundNoise, PitchShift, TimeStretch, Gain, PolarityInversion

# Mount Drive
from google.colab import drive
drive.mount('/content/drive')

# ==============================================================================
 #DATA INGESTION (Drive -> Local)
# ==============================================================================
print("\nüöÄ STARTING DATA EXTRACTION...")

# Define Source (Drive) and Destination (Local)
ZIP_SOURCE_DIR = '/content/drive/MyDrive/FINE_TUNE_V3'
LOCAL_BASE_DIR = '/content/data'

# Paths for specific zips on Drive
ZIP_ORIGINALS = os.path.join(ZIP_SOURCE_DIR, 'train_originals_1300.zip')
ZIP_COVERS = os.path.join(ZIP_SOURCE_DIR, 'train_covers_1300.zip')
ZIP_NOISE_16K = os.path.join(ZIP_SOURCE_DIR, 'noise_data_16k.zip')

# Define Extraction Destinations
DIR_ORIGINALS = os.path.join(LOCAL_BASE_DIR, 'originals')
DIR_COVERS = os.path.join(LOCAL_BASE_DIR, 'covers')
DIR_NOISE_16K = os.path.join(LOCAL_BASE_DIR, 'noise_16k')

# Create Directories
for d in [DIR_ORIGINALS, DIR_COVERS, DIR_NOISE_16K]:
    os.makedirs(d, exist_ok=True)

# Unzip Functions
def safe_unzip(zip_path, dest_path, desc):
    if not os.listdir(dest_path):
        print(f"üìÇ Unzipping {desc}...")
        subprocess.run(f"unzip -q -n '{zip_path}' -d '{dest_path}'", shell=True)
    else:
        print(f"‚úÖ {desc} already extracted.")

safe_unzip(ZIP_ORIGINALS, DIR_ORIGINALS, "Training Originals")
safe_unzip(ZIP_COVERS, DIR_COVERS, "Training Covers")
safe_unzip(ZIP_NOISE_16K, DIR_NOISE_16K, "Noise Data")

# --- NOISE CONVERSION (MP3 -> WAV) ---
# Ensures multi-thread safety for DataLoader
def convert_mp3_to_wav(noise_dir):
    mp3_files = glob.glob(os.path.join(noise_dir, '**', '*.mp3'), recursive=True)
    if not mp3_files:
        return
    print(f"üîÑ Converting {len(mp3_files)} noise MP3s to WAV...")
    for mp3_path in tqdm(mp3_files):
        wav_path = mp3_path.replace('.mp3', '.wav')
        try:
            waveform, sr = torchaudio.load(mp3_path)
            torchaudio.save(wav_path, waveform, sr)
            os.remove(mp3_path)
        except Exception:
            continue

convert_mp3_to_wav(DIR_NOISE_16K)

# ==============================================================================
# 3. ROBUST CSV MAPPING & VALIDATION
# ==============================================================================
def load_verified_pairs(csv_path, originals_dir, covers_dir):
    """
    Reads the CSV and validates that the file pairs actually exist on the local disk.
    Adapted for columns: 'original_filename', 'augmented_filename'
    """
    if not os.path.exists(csv_path):
        print(f"‚ùå Error: CSV not found at {csv_path}")
        return [], {}

    print(f"üìñ Reading CSV: {csv_path}")
    df = pd.read_csv(csv_path)

    valid_anchors = []
    pair_map = {}
    missing_count = 0

    # --- DETECTED COLUMNS FROM YOUR INFO ---
    col_orig = 'original_filename'
    col_pair = 'augmented_filename'

    print(f"üîç Validating {len(df)} pairs...")
    print(f"   Anchor Col: '{col_orig}'")
    print(f"   Pair Col:   '{col_pair}'")

    for index, row in df.iterrows():
        orig_name = row[col_orig]
        pair_name = row[col_pair]

        # 1. Check Anchor in Originals Dir
        path_orig = os.path.join(originals_dir, str(orig_name))

        # 2. Check Pair in Covers Dir (or fallback to the absolute path column if present)
        path_pair = os.path.join(covers_dir, str(pair_name))

        # Fallback: If not in covers_dir, checks if 'path' column has a valid full path
        if not os.path.exists(path_pair) and 'path' in row:
             if os.path.exists(row['path']):
                 path_pair = row['path']

        # 3. Validate existence
        if os.path.exists(path_orig) and os.path.exists(path_pair):
            valid_anchors.append(orig_name)
            pair_map[orig_name] = pair_name
        else:
            missing_count += 1
            if missing_count <= 5:
                # Debug print to help you see WHICH path is wrong
                if not os.path.exists(path_orig):
                    print(f"   ‚ö†Ô∏è Missing Anchor: {path_orig}")
                if not os.path.exists(path_pair):
                    print(f"   ‚ö†Ô∏è Missing Pair:   {path_pair}")

    print("-" * 40)
    print(f"‚úÖ Found {len(valid_anchors)} valid pairs locally.")
    if missing_count > 0:
        print(f"‚ùå Skipped {missing_count} pairs (files missing).")

    return valid_anchors, pair_map

# --- EXECUTE PAIRING ---
CSV_PATH = "/content/drive/Othercomputers/My laptop/Desktop/FINE-TUNE/Data/dataset_tracking.csv"
DIR_ORIGINALS = '/content/data/originals'
DIR_COVERS = '/content/data/covers' # Points to where 'augmented_filename' files live
NOISE_PATH_FINAL = DIR_NOISE_16K

# Run Validation
VALID_ANCHORS, PAIR_MAP = load_verified_pairs(CSV_PATH, DIR_ORIGINALS, DIR_COVERS)

Mounted at /content/drive

üöÄ STARTING DATA EXTRACTION...
üìÇ Unzipping Training Originals...
üìÇ Unzipping Training Covers...
üìÇ Unzipping Noise Data...
üîÑ Converting 110 noise MP3s to WAV...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 110/110 [00:00<00:00, 10179.00it/s]


üìñ Reading CSV: /content/drive/Othercomputers/My laptop/Desktop/FINE-TUNE/Data/dataset_tracking.csv
üîç Validating 950 pairs...
   Anchor Col: 'original_filename'
   Pair Col:   'augmented_filename'
----------------------------------------
‚úÖ Found 950 valid pairs locally.


In [None]:
import os
import subprocess

# Define Paths
ZIP_SOURCE_DIR = '/content/drive/MyDrive/FINE_TUNE_V3'
ZIP_EVAL = os.path.join(ZIP_SOURCE_DIR, 'eval_originals_300.zip')
DIR_EVAL = '/content/data/eval'

# Unzip
if not os.path.exists(DIR_EVAL):
    os.makedirs(DIR_EVAL, exist_ok=True)

print(f"üìÇ Unzipping Evaluation Set to {DIR_EVAL}...")
if os.path.exists(ZIP_EVAL):
    subprocess.run(f"unzip -q -n '{ZIP_EVAL}' -d '{DIR_EVAL}'", shell=True)
    num_files = len(os.listdir(DIR_EVAL))
    print(f"‚úÖ Success! Found {num_files} wav files ready for evaluation.")
else:
    print(f"‚ùå Error: Could not find {ZIP_EVAL}. Check your Drive.")

üìÇ Unzipping Evaluation Set to /content/data/eval...
‚úÖ Success! Found 284 wav files ready for evaluation.


In [None]:
import torch
import torchaudio
import numpy as np
import random
import os
from torch.utils.data import Dataset
from audiomentations import (
    Compose,
    AddBackgroundNoise,
    PitchShift,
    TimeStretch,
    Gain,
    PolarityInversion
)
import torch.nn.functional as F


class DualObjectiveSiameseDataset(Dataset):
    """
    Triplet Dataset for CRNN-based Audio Embedding

    Key properties:
    - Multiple random 3-sec clips per song across epochs
    - Aligned + unaligned cover positives (curriculum)
    - Self-invariance task included
    """

    def __init__(
        self,
        anchor_list,
        pair_map,
        originals_dir,
        covers_dir,
        noise_dir,
        sample_rate=16000,
        duration=3.0,
        aligned_cover_prob=0.6,   # üîë aligned vs random cover positives
        max_align_jitter_sec=2.0 # üîë allow loose alignment (¬± seconds)
    ):
        self.anchor_list = anchor_list
        self.pair_map = pair_map
        self.originals_dir = originals_dir
        self.covers_dir = covers_dir

        self.sample_rate = sample_rate
        self.num_samples = int(sample_rate * duration)
        self.max_align_jitter = int(max_align_jitter_sec * sample_rate)

        self.aligned_cover_prob = aligned_cover_prob
        self.num_songs = len(anchor_list)

        # üîä Strong but realistic augmentation
        self.augment = Compose([
            AddBackgroundNoise(
                sounds_path=noise_dir,
                min_snr_db=3.0,
                max_snr_db=15.0,
                p=0.8
            ),
            Gain(min_gain_db=-6.0, max_gain_db=6.0, p=0.2),
            PitchShift(min_semitones=-2, max_semitones=2, p=0.5),
            TimeStretch(min_rate=0.9, max_rate=1.1, p=0.4),
            PolarityInversion(p=0.2),
        ])

        # Spectrogram pipeline (CRNN-friendly)
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=1024,
            hop_length=512,
            n_mels=64
        )
        self.db_transform = torchaudio.transforms.AmplitudeToDB()

    # ------------------------------------------------------------------
    # Audio loading + stochastic cropping
    # ------------------------------------------------------------------
    def _load_crop_process(self, path, start_sample=None):
        try:
            waveform, sr = torchaudio.load(path)
        except Exception:
            return np.zeros(self.num_samples, dtype=np.float32)

        if sr != self.sample_rate:
            waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)

        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        total_len = waveform.shape[1]

        # Pad if too short
        if total_len < self.num_samples:
            waveform = F.pad(waveform, (0, self.num_samples - total_len))
            return waveform.squeeze(0).numpy()

        # Decide crop start
        if start_sample is None:
            start_sample = random.randint(0, total_len - self.num_samples)
        else:
            start_sample = max(0, min(start_sample, total_len - self.num_samples))

        crop = waveform[:, start_sample:start_sample + self.num_samples]
        return crop.squeeze(0).numpy()

    # ------------------------------------------------------------------
    # Waveform ‚Üí Mel Spectrogram
    # ------------------------------------------------------------------
    def _to_spec(self, audio_np):
        tensor = torch.from_numpy(audio_np).unsqueeze(0)
        spec = self.mel_transform(tensor)
        spec = self.db_transform(spec)
        spec = (spec - spec.mean()) / (spec.std() + 1e-6)
        return spec

    # ------------------------------------------------------------------
    # Triplet sampling logic
    # ------------------------------------------------------------------
    def __getitem__(self, idx):

        # --------------------------------------------------------------
        # 1. Task selection (dual objective)
        # --------------------------------------------------------------
        if idx < self.num_songs:
            real_idx = idx
            is_cover_task = True
        else:
            real_idx = idx - self.num_songs
            is_cover_task = False

        anchor_name = self.anchor_list[real_idx]
        anchor_path = os.path.join(self.originals_dir, anchor_name)

        # --------------------------------------------------------------
        # 2. Anchor: RANDOM 3-sec window (every call)
        # --------------------------------------------------------------
        anchor_raw = self._load_crop_process(anchor_path)

        # --------------------------------------------------------------
        # 3. Positive sampling
        # --------------------------------------------------------------
        if is_cover_task:
            cover_name = self.pair_map[anchor_name]
            cover_path = os.path.join(self.covers_dir, cover_name)

            if random.random() < self.aligned_cover_prob:
                # üîë ALIGNED (loose) ‚Äî SAFE VERSION
                anchor_total_len = len(anchor_raw)

                anchor_start = random.randint(
                    0, max(0, anchor_total_len - self.num_samples)
                )

                jitter = random.randint(-self.max_align_jitter, self.max_align_jitter)
                cover_start = anchor_start + jitter

                positive_raw = self._load_crop_process(
                    cover_path,
                    start_sample=cover_start
                )
            else:
                # üîë UNALIGNED
                positive_raw = self._load_crop_process(cover_path)
        else:
            # Self-invariance task
            positive_raw = anchor_raw.copy()

        # Augment positive
        try:
            positive_aug = self.augment(
                samples=positive_raw,
                sample_rate=self.sample_rate
            )
        except Exception:
            positive_aug = positive_raw

        # --------------------------------------------------------------
        # 4. Negative sampling (different song)
        # --------------------------------------------------------------
        neg_idx = random.randint(0, self.num_songs - 1)
        while neg_idx == real_idx:
            neg_idx = random.randint(0, self.num_songs - 1)

        neg_name = self.anchor_list[neg_idx]
        neg_path = os.path.join(self.originals_dir, neg_name)

        negative_raw = self._load_crop_process(neg_path)

        try:
            negative_aug = self.augment(
                samples=negative_raw,
                sample_rate=self.sample_rate
            )
        except Exception:
            negative_aug = negative_raw

        return (
            self._to_spec(anchor_raw),
            self._to_spec(positive_aug),
            self._to_spec(negative_aug)
        )

    def __len__(self):
        # Each song participates in:
        # 1√ó cover task + 1√ó self task per epoch
        return self.num_songs * 2


In [None]:
"""
Mel Spectrogram (64 √ó ~94)
‚Üí CNN backbone (freq reduced ‚Üí 1)
‚Üí Sequence of length T
‚Üí BiLSTM (2 layers)
‚Üí Attention pooling
‚Üí FC ‚Üí 128-D embedding
"""

'\nMel Spectrogram (64 √ó ~94)\n‚Üí CNN backbone (freq reduced ‚Üí 1)\n‚Üí Sequence of length T\n‚Üí BiLSTM (2 layers)\n‚Üí Attention pooling\n‚Üí FC ‚Üí 128-D embedding\n'

In [None]:
"""
Mel Spectrogram (64 √ó ~94)
        ‚Üì
CNN  ‚Üí local timbre + pitch invariance
        ‚Üì
BiLSTM ‚Üí temporal melody progression
        ‚Üì
Attention ‚Üí focus on salient moments
        ‚Üì
Projection ‚Üí fixed 128-D embedding

Input tensor
(B, 1, F, T) ‚Üí (B, 1, 64, ~94)
F = 64 Mel bands ‚Üí perceptual frequency scale
T ‚âà 94 frames (3 sec @ hop=1.5 sec)

Block 1
Conv2d(1 ‚Üí 32, 3√ó3) BatchNorm || ReLU || MaxPool(2√ó2)(improves invariance)

Input ‚Üí Output
(B, 1, 64, 94) ‚Üí (B, 32, 32, 47)

Block 2
Conv2d(32 ‚Üí 64, 3√ó3)
    BatchNorm
    ReLU
    MaxPool(2√ó2)
    Shape

(B, 32, 32, 47)‚Üí (B, 64, 16, 23)


Block 3 + Adaptive Pool
Conv2d(64 ‚Üí 128)
BatchNorm
ReLU
AdaptiveAvgPool((1, None))


Shape

(B, 64, 16, 23)
‚Üí (B, 128, 1, 23)

Reshape for Temporal Modelingx = x.squeeze(2).permute(0, 2, 1)

Shape
(B, 128, 1, 23)
‚Üí (B, 23, 128)

BiLSTM (Temporal Modeling)
LSTM(
  input_size=128,
  hidden_size=128,
  num_layers=2,
  bidirectional=True
)


Output shape

(B, 23, 256)


(128 forward + 128 backward)

Why BiLSTM? Music is not causal:chorus explains verse ,melody context matters ,backward info is important

Attention Pooling
attn = softmax(W2(tanh(W1(x))))
embedding = Œ£(attn_t * x_t)


Shape (B, 23, 256)--> (B, 256)

Why attention instead of average pooling?
Because not all moments matter.
Attention learns to emphasize:

    hooks
    chorus
    strong melodic regions

Projection Head (Metric Learning)
Linear(256 ‚Üí 256)
ReLU
Dropout(0.3)
Linear(256 ‚Üí 128)

Why projection?
decouples representation learning from embedding geometry
improves triplet loss convergence
standard practice (SimCLR, CLIP, FaceNet)
"""

# --- 2. NEW MODEL: AudioCRNN_Fast (CNN + BiLSTM + Attention) ---
class AudioCRNN_Fast(nn.Module):
    def __init__(self, embed_dim=128):
        super().__init__()

        #
        # -------- 1. FAST CNN BACKBONE --------
        self.cnn = nn.Sequential(
            # Block 1
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),

            # Block 2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),

            # Block 3 + Adaptive Pooling (Freq -> 1)
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, None))
        )

        # -------- 2. TEMPORAL MODEL (BiLSTM) --------
        # Input: 128 channels
        self.lstm = nn.LSTM(
            input_size=128,
            hidden_size=128,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )
        # Output dim: 128 * 2 = 256

        # -------- 3. ATTENTION POOLING --------
        self.attention = nn.Sequential(
            nn.Linear(256, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

        # -------- 4. PROJECTION HEAD --------
        self.fc = nn.Sequential(
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, embed_dim)
        )

    def forward_one(self, x):
        # 1. CNN: (B, 1, F, T) -> (B, 128, 1, T)
        x = self.cnn(x)

        # 2. Reshape: (B, 128, 1, T) -> (B, T, 128)
        x = x.squeeze(2).permute(0, 2, 1)

        # 3. LSTM
        x, _ = self.lstm(x)         # (B, T, 256)

        # 4. Attention
        attn_weights = self.attention(x)       # (B, T, 1)
        attn_weights = F.softmax(attn_weights, dim=1)
        x = torch.sum(x * attn_weights, dim=1) # (B, 256)

        # 5. Projection
        x = self.fc(x)
        return F.normalize(x, p=2, dim=1)

    def forward(self, anchor, positive, negative, *args):
        return self.forward_one(anchor), self.forward_one(positive), self.forward_one(negative)

print("‚úÖ Model and Dataset definitions ready.")

‚úÖ Model and Dataset definitions ready.


In [None]:
#!pip install --upgrade torch torchaudio

In [None]:
# ==============================================================================
# üß± BLOCK 3: FINAL TRAINING LOOP (CRNN + TRIPLET LOSS)
# ==============================================================================

import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader

# ------------------------------------------------------------------------------
# Utility: find latest checkpoint
# ------------------------------------------------------------------------------
def get_latest_checkpoint(ckpt_dir):
    pattern = re.compile(r"checkpoint_epoch_(\d+)\.pth")
    checkpoints = []

    for f in os.listdir(ckpt_dir):
        match = pattern.match(f)
        if match:
            checkpoints.append((int(match.group(1)), f))

    if not checkpoints:
        return None, 0

    checkpoints.sort()
    epoch, fname = checkpoints[-1]
    return os.path.join(ckpt_dir, fname), epoch


# ------------------------------------------------------------------------------
# Training Loop
# ------------------------------------------------------------------------------
def train_audio_crnn():

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nüî• Training on: {DEVICE}")

    # ---------------- CONFIG ----------------
    BATCH_SIZE = 64
    EPOCHS = 100
    LR = 3e-4
    MARGIN = 0.75
    PATIENCE = 6
    GRAD_CLIP = 1.0

    # ---------------- OUTPUT DIRS ----------------
    BASE_DRIVE_DIR = "/content/drive/MyDrive/FIND_TUNE"
    MODEL_DIR = os.path.join(BASE_DRIVE_DIR, "spectrogram_based_model")
    os.makedirs(MODEL_DIR, exist_ok=True)

    BEST_MODEL_PATH = os.path.join(
        MODEL_DIR, "best_spectrogram_model.pth"
    )

    print(f"üìÇ Models will be saved to: {MODEL_DIR}")

    # ---------------- DATASET ----------------
    if "VALID_ANCHORS" not in globals() or not VALID_ANCHORS:
        print("‚ùå VALID_ANCHORS not found. Run CSV mapping first.")
        return

    dataset = DualObjectiveSiameseDataset(
        anchor_list=VALID_ANCHORS,
        pair_map=PAIR_MAP,
        originals_dir="/content/data/originals",
        covers_dir="/content/data/covers",
        noise_dir="/content/data/noise_16k",
        sample_rate=16000,
        duration=3.0,
    )

    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
    )

    # ---------------- MODEL ----------------
    model = AudioCRNN_Fast(embed_dim=128).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LR)
    criterion = nn.TripletMarginLoss(margin=MARGIN)

    scaler = torch.amp.GradScaler("cuda")

    scheduler = lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=0.5,
        patience=PATIENCE,
        min_lr=1e-7,
    )

    # ---------------- RESUME LOGIC ----------------
    start_epoch = 0
    best_loss = float("inf")

    latest_ckpt, last_epoch = get_latest_checkpoint(MODEL_DIR)

    if latest_ckpt:
        print(f"üîÑ Resuming from {latest_ckpt}")
        ckpt = torch.load(latest_ckpt, map_location=DEVICE)

        model.load_state_dict(ckpt["model_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        scheduler.load_state_dict(ckpt["scheduler_state_dict"])

        start_epoch = last_epoch + 1
        best_loss = ckpt.get("best_loss", float("inf"))

        print(f"   Resumed at epoch {start_epoch}, best loss {best_loss:.4f}")
    else:
        print("üÜï No checkpoint found. Starting fresh training.")

    # ---------------- TRAIN ----------------
    print("üöÄ Starting CRNN Training...")

    for epoch in range(start_epoch, EPOCHS):
        model.train()
        running_loss = 0.0

        for batch_idx, (anc, pos, neg) in enumerate(dataloader):
            anc = anc.to(DEVICE, non_blocking=True)
            pos = pos.to(DEVICE, non_blocking=True)
            neg = neg.to(DEVICE, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)

            with torch.amp.autocast("cuda"):
                emb_a, emb_p, emb_n = model(anc, pos, neg)
                loss = criterion(emb_a, emb_p, emb_n)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

            if batch_idx % 20 == 0:
                print(
                    f"   Epoch {epoch+1} | Batch {batch_idx}/{len(dataloader)} "
                    f"| Loss: {loss.item():.4f}"
                )

        avg_loss = running_loss / len(dataloader)
        print(f"\nüì¢ Epoch {epoch+1}/{EPOCHS} | Avg Loss: {avg_loss:.4f}")

        # ---- Scheduler ----
        old_lr = optimizer.param_groups[0]["lr"]
        scheduler.step(avg_loss)
        new_lr = optimizer.param_groups[0]["lr"]

        if new_lr != old_lr:
            print(f"üìâ LR reduced: {old_lr:.2e} ‚Üí {new_lr:.2e}")

        # ---- Save Best ----
        if avg_loss < best_loss:
            print(f"‚≠ê New BEST: {best_loss:.4f} ‚Üí {avg_loss:.4f}")
            best_loss = avg_loss
            torch.save(model.state_dict(), BEST_MODEL_PATH)

        # ---- Epoch Checkpoint ----
        ckpt_path = os.path.join(
            MODEL_DIR, f"checkpoint_epoch_{epoch}.pth"
        )

        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "loss": avg_loss,
                "best_loss": best_loss,
            },
            ckpt_path,
        )

        print(f"üíæ Saved checkpoint: {ckpt_path}\n")


# ------------------------------------------------------------------------------
# Entry Point
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    train_audio_crnn()



üî• Training on: cuda
üìÇ Models will be saved to: /content/drive/MyDrive/FIND_TUNE/spectrogram_based_model
üÜï No checkpoint found. Starting fresh training.
üöÄ Starting CRNN Training...
   Epoch 1 | Batch 0/30 | Loss: 0.7609
   Epoch 1 | Batch 20/30 | Loss: 0.7268

üì¢ Epoch 1/100 | Avg Loss: 0.7455
‚≠ê New BEST: inf ‚Üí 0.7455
üíæ Saved checkpoint: /content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/checkpoint_epoch_0.pth

   Epoch 2 | Batch 0/30 | Loss: 0.7438
   Epoch 2 | Batch 20/30 | Loss: 0.6903

üì¢ Epoch 2/100 | Avg Loss: 0.7092
‚≠ê New BEST: 0.7455 ‚Üí 0.7092
üíæ Saved checkpoint: /content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/checkpoint_epoch_1.pth

   Epoch 3 | Batch 0/30 | Loss: 0.6619
   Epoch 3 | Batch 20/30 | Loss: 0.5423

üì¢ Epoch 3/100 | Avg Loss: 0.5830
‚≠ê New BEST: 0.7092 ‚Üí 0.5830
üíæ Saved checkpoint: /content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/checkpoint_epoch_2.pth

   Epoch 4 | Batch 0/30 | Loss: 0.5905
   Epoch 4 | Ba

KeyboardInterrupt: 

EVAL

In [None]:
# 1. Uninstall the broken decoder
!pip uninstall torchcodec -y

# 2. Install the stable decoder (libsndfile based)
!pip install soundfile

# 3. Re-install torchaudio to ensure clean bindings
!pip install --upgrade torchaudio

In [None]:
!apt-get update
!apt-get install -y ffmpeg
!pip install soundfile torchcodec


In [None]:
!pip install soundfile



In [None]:
import torchaudio
import glob
import os

# 1. Grab the first file
files = glob.glob("/content/data/eval/*.wav")
if not files:
    raise ValueError("‚ùå No files found in directory!")

test_file = files[0]
print(f"üßê Inspecting: {test_file}")
print(f"   Size: {os.path.getsize(test_file)} bytes")

# 2. Try to load it EXACTLY like the dataset does (no try/except)
print("\nüí• Attempting to load with backend='soundfile'...")
try:
    wav, sr = torchaudio.load(test_file, backend="soundfile")
    print(f"‚úÖ Success! Shape: {wav.shape}, SR: {sr}")
except Exception as e:
    print(f"\n‚ùå CRITICAL FAILURE (This is why your DB is empty):")
    print(f"Error Type: {type(e).__name__}")
    print(f"Error Message: {e}")

    # Heuristic Help
    if "backend" in str(e).lower():
        print("\nüí° HINT: You did not restart the runtime after 'pip install soundfile'.")
        print("   Action: Runtime > Restart Session")
    elif "header" in str(e).lower():
        print("\nüí° HINT: The file is likely corrupted or 0 bytes.")

üßê Inspecting: /content/data/eval/6c39a423-0487-443f-8d71-a6c92b29760b.wav
   Size: 6038348 bytes

üí• Attempting to load with backend='soundfile'...

‚ùå CRITICAL FAILURE (This is why your DB is empty):
Error Type: ImportError
Error Message: TorchCodec is required for load_with_torchcodec. Please install torchcodec to use this function.


In [None]:
# ==============================================================================
# üß™ BLOCK 6: V3 EVALUATION SYSTEM (MEMORY SAFE + SOUNDFILE)
# ==============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import soundfile as sf
import numpy as np
import os, glob, random, math
import gc
from tqdm import tqdm
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from audiomentations import Compose, AddBackgroundNoise, PitchShift, TimeStretch, Gain

# ------------------------------------------------------------------------------
# CONFIG
# ------------------------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

EVAL_DIR = "/content/data/eval"
NOISE_DIR = "/content/data/noise_16k"

MODEL_BASE_DIR = "/content/drive/MyDrive/FIND_TUNE/spectrogram_based_model"
MODEL_PATH = os.path.join(MODEL_BASE_DIR, "best_spectrogram_model.pth")

SAMPLE_RATE = 16000
WIN_SEC = 3.0
HOP_SEC = 1.5
QUERY_LEN = 15

TOLERANCE = 1.5
SIGMA = 0.5
SPREAD_FACTOR = 0.3
INFERENCE_BATCH_SIZE = 32  # üîë NEW: Process song in chunks to save RAM

# ------------------------------------------------------------------------------
# 1. MODEL ARCHITECTURE
# ------------------------------------------------------------------------------
class AudioCRNN_Fast(nn.Module):
    def __init__(self, embed_dim=128):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.AdaptiveAvgPool2d((1, None))
        )
        self.lstm = nn.LSTM(128, 128, num_layers=2, batch_first=True, bidirectional=True)
        self.attention = nn.Sequential(nn.Linear(256, 64), nn.Tanh(), nn.Linear(64, 1))
        self.fc = nn.Sequential(nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, embed_dim))

    def forward_one(self, x):
        x = self.cnn(x)
        x = x.squeeze(2).permute(0, 2, 1)
        x, _ = self.lstm(x)
        attn = F.softmax(self.attention(x), dim=1)
        x = torch.sum(x * attn, dim=1)
        return F.normalize(self.fc(x), p=2, dim=1)

# ------------------------------------------------------------------------------
# 2. AUGMENTATION
# ------------------------------------------------------------------------------
def get_augmenter(level):
    if level == "soft":
        return Compose([
            Gain(min_gain_db=-3, max_gain_db=3, p=0.5),
            PitchShift(min_semitones=-1, max_semitones=1, p=0.3)
        ])
    if level == "hard":
        return Compose([
            AddBackgroundNoise(NOISE_DIR, min_snr_db=5, max_snr_db=15, p=1.0),
            PitchShift(min_semitones=-2, max_semitones=2, p=0.8),
            TimeStretch(min_rate=0.9, max_rate=1.1, p=0.5)
        ])
    return None

# ------------------------------------------------------------------------------
# 3. HELPER: DIRECT LOADING
# ------------------------------------------------------------------------------
def robust_load(path, target_sr=16000):
    try:
        wav_np, sr = sf.read(path)
        wav_np = wav_np.astype(np.float32)
        wav = torch.from_numpy(wav_np)
        if wav.ndim == 1:
            wav = wav.unsqueeze(0)
        else:
            wav = wav.t()
        return wav, sr
    except Exception as e:
        return None, 0

# ------------------------------------------------------------------------------
# 4. AUDIO ‚Üí EMBEDDINGS (Memory Safe)
# ------------------------------------------------------------------------------
mel = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE, n_fft=1024, hop_length=512, n_mels=64
).to(DEVICE)
db = torchaudio.transforms.AmplitudeToDB().to(DEVICE)

def audio_to_embedding(model, wav):
    # üîë INPUT: wav is on CPU (1, T)
    samples_win = int(SAMPLE_RATE * WIN_SEC)
    samples_hop = int(SAMPLE_RATE * HOP_SEC)

    if wav.shape[1] < samples_win:
        wav = F.pad(wav, (0, samples_win - wav.shape[1]))

    # Slice on CPU to avoid GPU OOM
    windows = []
    times = []

    for i in range(0, wav.shape[1] - samples_win + 1, samples_hop):
        windows.append(wav[:, i:i + samples_win])
        times.append(i / SAMPLE_RATE)

    if not windows:
        return None, None

    # üîë PROCESS IN MINI-BATCHES
    all_embeddings = []

    for i in range(0, len(windows), INFERENCE_BATCH_SIZE):
        # Create mini-batch
        chunk = windows[i : i + INFERENCE_BATCH_SIZE]
        batch = torch.stack(chunk).to(DEVICE) # Move only 32 items to GPU

        # Transform
        spec = mel(batch)
        spec = db(spec)
        spec = (spec - spec.mean(dim=(2,3), keepdim=True)) / (spec.std(dim=(2,3), keepdim=True) + 1e-6)

        # Inference
        with torch.no_grad():
            emb = model.forward_one(spec)

        # Move back to CPU immediately
        all_embeddings.append(emb.cpu())

    # Concatenate all CPU results
    return torch.cat(all_embeddings), times

# ------------------------------------------------------------------------------
# 5. DATABASE BUILDER
# ------------------------------------------------------------------------------
class ReferenceDatabaseDataset(Dataset):
    def __init__(self, file_paths, sample_rate):
        self.file_paths = file_paths
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        path = self.file_paths[idx]
        name = os.path.basename(path)

        wav, sr = robust_load(path, self.sample_rate)

        if wav is None:
            return torch.zeros(1, self.sample_rate), "ERROR"

        if sr != self.sample_rate:
            wav = torchaudio.transforms.Resample(sr, self.sample_rate)(wav)
        if wav.shape[0] > 1:
            wav = wav.mean(0, keepdim=True)

        return wav, name

def build_database(model):
    # üßπ CLEANUP BEFORE STARTING
    torch.cuda.empty_cache()
    gc.collect()

    print("üèóÔ∏è Building vector database (Memory Safe)...")
    files = glob.glob(os.path.join(EVAL_DIR, "*.wav"))
    if not files:
        print("‚ùå No .wav files found")
        return None, None

    dataset = ReferenceDatabaseDataset(files, SAMPLE_RATE)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)

    vectors = []
    metadata = []

    model.eval()
    with torch.no_grad():
        for wav, name in tqdm(loader):
            if name[0] == "ERROR": continue

            # üîë KEEP WAV ON CPU HERE
            wav = wav.squeeze(0)

            emb, times = audio_to_embedding(model, wav)
            if emb is None: continue

            for i, t in enumerate(times):
                vectors.append(emb[i]) # Already on CPU
                metadata.append({"name": name[0], "offset": t})

    if not vectors:
        print("‚ùå DB Build Failed.")
        return None, None

    print(f"‚úÖ Indexed {len(metadata)} segments.")
    return torch.stack(vectors).to(DEVICE), metadata

# ------------------------------------------------------------------------------
# 6. SCORING & EVAL LOOP
# ------------------------------------------------------------------------------
def calculate_v3_scores(matches):
    song_scores = defaultdict(lambda: defaultdict(float))
    for dist, meta, q_t in matches:
        w = math.exp(-(dist ** 2) / (2 * SIGMA ** 2))
        if w < 0.01: continue
        delta = meta["offset"] - q_t
        b = int(round(delta / TOLERANCE))
        song_scores[meta["name"]][b] += w
        song_scores[meta["name"]][b-1] += w * SPREAD_FACTOR
        song_scores[meta["name"]][b+1] += w * SPREAD_FACTOR

    ranked = []
    for song, buckets in song_scores.items():
        ranked.append((song, max(buckets.values())))
    ranked.sort(key=lambda x: x[1], reverse=True)
    return ranked

# ------------------------------------------------------------------------------
# 6. SCORING & EVAL LOOP (FIXED DEVICE MISMATCH)
# ------------------------------------------------------------------------------
def run_evaluation(model, db_vecs, db_meta, trials=100):
    if db_vecs is None: return

    # Ensure Database is on the correct device
    db_vecs = db_vecs.to(DEVICE)

    modes = ["clean", "soft", "hard"]
    results = {m: {"top1":0,"top5":0,"top10":0} for m in modes}
    songs = list(set(m["name"] for m in db_meta))

    print(f"\n‚ö° Starting Evaluation ({trials} trials per mode)...")

    for mode in modes:
        print(f"\n‚ñ∂ MODE: {mode.upper()}")
        aug = get_augmenter(mode)

        for _ in tqdm(range(trials)):
            target = random.choice(songs)

            # Load Audio
            wav, sr = robust_load(os.path.join(EVAL_DIR, target))
            if wav is None: continue

            # Resample & Mix
            if sr != SAMPLE_RATE:
                wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
            if wav.shape[0] > 1:
                wav = wav.mean(0, keepdim=True)

            # Random Crop (Query Length)
            max_len = int(QUERY_LEN * SAMPLE_RATE)
            if wav.shape[1] > max_len:
                s = random.randint(0, wav.shape[1] - max_len)
                wav = wav[:, s:s + max_len]

            # Augmentation
            if aug:
                try:
                    wav_np = aug(samples=wav.squeeze(0).numpy(), sample_rate=SAMPLE_RATE)
                    wav = torch.from_numpy(wav_np).unsqueeze(0)
                except: pass

            # Inference
            # q_emb comes back on CPU (from our memory-safe function)
            q_emb, q_times = audio_to_embedding(model, wav)
            if q_emb is None: continue

            # üîë FIX: Move Query to GPU for distance calculation
            q_emb = q_emb.to(DEVICE)

            # Distance Calculation
            dists = torch.cdist(q_emb, db_vecs)
            vals, idxs = torch.topk(dists, k=5, largest=False)

            matches = []
            for i in range(q_emb.shape[0]):
                for k in range(5):
                    matches.append((vals[i,k].item(), db_meta[idxs[i,k]], q_times[i]))

            ranked = calculate_v3_scores(matches)
            if not ranked: continue

            names = [x[0] for x in ranked]
            if target == names[0]: results[mode]["top1"] += 1
            if target in names[:5]: results[mode]["top5"] += 1
            if target in names[:10]: results[mode]["top10"] += 1

    print("\nüèÜ FINAL RESULTS")
    for m in modes:
        print(f"{m.upper():<6} | "
              f"Top1: {results[m]['top1']/trials*100:.1f}% | "
              f"Top5: {results[m]['top5']/trials*100:.1f}% | "
              f"Top10: {results[m]['top10']/trials*100:.1f}%")

# ------------------------------------------------------------------------------
# RUN
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    # Assuming model and db_vecs are already loaded/built from your previous cell
    # If not, reload them:
    if 'model' not in globals():
        model = AudioCRNN_Fast(embed_dim=128).to(DEVICE)
        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
        model.eval()

    if 'db_vecs' in globals() and db_vecs is not None:
        run_evaluation(model, db_vecs, db_meta, trials=100)
    else:
        print("‚ö†Ô∏è Please run build_database() first.")


‚ö° Starting Evaluation (100 trials per mode)...

‚ñ∂ MODE: CLEAN


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [00:09<00:00, 10.40it/s]



‚ñ∂ MODE: SOFT


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [00:09<00:00, 10.82it/s]



‚ñ∂ MODE: HARD


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [00:16<00:00,  5.90it/s]


üèÜ FINAL RESULTS
CLEAN  | Top1: 27.0% | Top5: 43.0% | Top10: 48.0%
SOFT   | Top1: 25.0% | Top5: 32.0% | Top10: 40.0%
HARD   | Top1: 4.0% | Top5: 13.0% | Top10: 20.0%





FINAL EVAL

In [None]:
# ==============================================================================
# üß™ FINAL SPECTROGRAM EVALUATION (SMART RESUME + ULTRA MODE)
# ==============================================================================

import os
import subprocess
import sys
import glob
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import soundfile as sf
from tqdm import tqdm
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader

# 1. INSTALL DEPENDENCIES
packages = ["audiomentations", "torchaudio"]
for package in packages:
    try:
        __import__(package)
    except ImportError:
        print(f"üì¶ Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

from audiomentations import Compose, AddBackgroundNoise, PitchShift, TimeStretch, Gain, PolarityInversion

# ------------------------------------------------------------------------------
# ‚öôÔ∏è CONFIGURATION
# ------------------------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Input Zips (Drive)
ZIP_SOURCE_DIR = '/content/drive/MyDrive/FINE_TUNE_V3'
ZIP_EVAL_ORIGINALS = os.path.join(ZIP_SOURCE_DIR, 'eval_originals_300.zip')
ZIP_TRAIN_COVERS = os.path.join(ZIP_SOURCE_DIR, 'train_covers_1300.zip')
ZIP_NOISE = os.path.join(ZIP_SOURCE_DIR, 'noise_data_16k.zip')

# Local Paths
EVAL_DIR = "/content/data/eval_combined"
NOISE_DIR = "/content/data/noise_16k"

# Model Checkpoint
CHECKPOINT_DIR = "/content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN"
MODEL_PATH = os.path.join(CHECKPOINT_DIR, "checkpoint_epoch_34.pth")

# Audio Params
SAMPLE_RATE = 16000
WIN_SEC = 3.0
HOP_SEC = 1.5
QUERY_LEN = 15
INFERENCE_BATCH_SIZE = 64
SIGMA = 0.5

# ------------------------------------------------------------------------------
# 2. DATA PREPARATION (SMART RESUME)
# ------------------------------------------------------------------------------
def cleanup_noise_mp3s(noise_dir):
    """Recursively finds MP3s, converts to WAV 16k, and deletes MP3."""
    mp3s = glob.glob(os.path.join(noise_dir, "**/*.mp3"), recursive=True)
    if not mp3s:
        return # Nothing to do

    print(f"üßπ Cleaning up {len(mp3s)} MP3 files in noise directory...")
    for m in tqdm(mp3s, desc="Converting MP3->WAV"):
        try:
            w, sr = torchaudio.load(m)
            if sr != SAMPLE_RATE:
                w = torchaudio.functional.resample(w, sr, SAMPLE_RATE)
            torchaudio.save(m.replace('.mp3','.wav'), w, SAMPLE_RATE)
            os.remove(m)
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to convert {os.path.basename(m)}: {e}")

def setup_data():
    print(f"\nüöÄ SETTING UP DATA...")
    os.makedirs(EVAL_DIR, exist_ok=True)
    os.makedirs(NOISE_DIR, exist_ok=True)

    # --- A. NOISE SETUP ---
    # Check if noise dir is populated (contains at least 5 files)
    noise_files = glob.glob(os.path.join(NOISE_DIR, "**/*.wav"), recursive=True)
    if len(noise_files) > 5:
        print(f"‚úÖ Noise directory seems populated ({len(noise_files)} files). Skipping Unzip.")
    else:
        print(f"üìÇ Unzipping Noise Data...")
        subprocess.run(f"unzip -q -n '{ZIP_NOISE}' -d '{NOISE_DIR}'", shell=True)

    # Always run MP3 cleanup to be safe (it's fast if no MP3s exist)
    cleanup_noise_mp3s(NOISE_DIR)

    # --- B. EVALUATION DB SETUP ---
    # Check if Eval dir is populated (contains > 100 wavs)
    eval_files = glob.glob(os.path.join(EVAL_DIR, "*.wav"))
    if len(eval_files) > 100:
        print(f"‚úÖ Eval directory seems populated ({len(eval_files)} files). Skipping Unzip.")
    else:
        print(f"üìÇ Unzipping Eval Originals (Targets)...")
        subprocess.run(f"unzip -q -n '{ZIP_EVAL_ORIGINALS}' -d '{EVAL_DIR}'", shell=True)

        print(f"üìÇ Unzipping Train Covers (Distractors)...")
        subprocess.run(f"unzip -q -n '{ZIP_TRAIN_COVERS}' -d '{EVAL_DIR}'", shell=True)

        # Flatten directory (move subfolder contents to root)
        print("   Flattening directory structure...")
        for root, dirs, files in os.walk(EVAL_DIR):
            for file in files:
                if file.endswith(".wav"):
                    src = os.path.join(root, file)
                    dst = os.path.join(EVAL_DIR, file)
                    if src != dst:
                        try: os.rename(src, dst)
                        except: pass

    final_count = len(glob.glob(os.path.join(EVAL_DIR, "*.wav")))
    print(f"üìä Final Database Count: {final_count} files in {EVAL_DIR}")

# ------------------------------------------------------------------------------
# 3. MODEL ARCHITECTURE (SHALLOW CNN)
# ------------------------------------------------------------------------------
class AudioSiameseNet(nn.Module):
    def __init__(self, embed_dim=128):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU()
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(128, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, embed_dim)
        )

    def forward_one(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.pool(x).view(x.size(0), -1)
        return F.normalize(self.fc(x), p=2, dim=1)

# ------------------------------------------------------------------------------
# 4. INFERENCE ENGINE
# ------------------------------------------------------------------------------
mel = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE, n_fft=1024, hop_length=512, n_mels=64
).to(DEVICE)
db_transform = torchaudio.transforms.AmplitudeToDB().to(DEVICE)

def robust_load(path):
    try:
        wav, sr = torchaudio.load(path)
        if sr != SAMPLE_RATE:
            wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
        if wav.shape[0] > 1:
            wav = wav.mean(0, keepdim=True)
        return wav, SAMPLE_RATE
    except: return None, 0

def audio_to_embedding(model, wav):
    samples_win = int(SAMPLE_RATE * WIN_SEC)
    samples_hop = int(SAMPLE_RATE * HOP_SEC)

    if wav.shape[1] < samples_win:
        wav = F.pad(wav, (0, samples_win - wav.shape[1]))

    windows = []
    times = []
    for i in range(0, wav.shape[1] - samples_win + 1, samples_hop):
        windows.append(wav[:, i:i + samples_win])
        times.append(i / SAMPLE_RATE)

    if not windows: return None, None

    all_embeds = []
    with torch.no_grad():
        for i in range(0, len(windows), INFERENCE_BATCH_SIZE):
            batch = torch.stack(windows[i : i + INFERENCE_BATCH_SIZE]).to(DEVICE)
            spec = db_transform(mel(batch))
            mean = spec.mean(dim=(2,3), keepdim=True)
            std = spec.std(dim=(2,3), keepdim=True)
            spec = (spec - mean) / (std + 1e-6)
            emb = model.forward_one(spec)
            all_embeds.append(emb.cpu())

    return torch.cat(all_embeds), times

# ------------------------------------------------------------------------------
# 5. DATABASE BUILDER
# ------------------------------------------------------------------------------
def build_database(model):
    print("\nüèóÔ∏è  Indexing Database (Originals + Covers)...")
    files = glob.glob(os.path.join(EVAL_DIR, "*.wav"))

    db_vecs = []
    db_meta = []

    model.eval()

    for f in tqdm(files, desc="Indexing"):
        wav, _ = robust_load(f)
        if wav is None: continue

        emb, times = audio_to_embedding(model, wav)
        if emb is None: continue

        base_name = os.path.basename(f)

        for i in range(len(times)):
            db_vecs.append(emb[i])
            db_meta.append({"name": base_name, "offset": times[i]})

    if not db_vecs: return None, None
    print(f"‚úÖ Indexed {len(db_vecs)} vectors from {len(files)} songs.")
    return torch.stack(db_vecs).to(DEVICE), db_meta

# ------------------------------------------------------------------------------
# 6. EVALUATION LOOP
# ------------------------------------------------------------------------------
def run_evaluation(model, db_vecs, db_meta, trials=200):
    if db_vecs is None: return

    augmenters = {
        "clean": None,
        "soft": Compose([Gain(min_gain_db=-3, max_gain_db=3, p=0.5)]),
        "hard": Compose([
            AddBackgroundNoise(sounds_path=NOISE_DIR, min_snr_db=5, max_snr_db=15, p=1.0),
            PitchShift(min_semitones=-2, max_semitones=2, p=0.5)
        ]),
        "ultra": Compose([
            AddBackgroundNoise(sounds_path=NOISE_DIR, min_snr_db=3, max_snr_db=10, p=1.0),
            PitchShift(min_semitones=-3, max_semitones=3, p=0.8),
            TimeStretch(min_rate=0.85, max_rate=1.15, p=0.5),
            PolarityInversion(p=0.5)
        ])
    }

    unique_songs = list(set([m['name'] for m in db_meta]))

    print(f"\n‚ö° STARTING EVALUATION ({trials} trials per mode)")
    print(f"   Database Size: {len(unique_songs)} unique tracks")

    results = defaultdict(lambda: {"top1": 0, "top5": 0})

    for mode, aug in augmenters.items():
        print(f"\n‚ñ∂ MODE: {mode.upper()}")

        for _ in tqdm(range(trials)):
            target_name = random.choice(unique_songs)
            target_path = os.path.join(EVAL_DIR, target_name)

            wav, _ = robust_load(target_path)
            if wav is None: continue

            req_samples = int(QUERY_LEN * SAMPLE_RATE)
            if wav.shape[1] > req_samples:
                start = random.randint(0, wav.shape[1] - req_samples)
                query_wav = wav[:, start:start+req_samples]
                query_offset = start / SAMPLE_RATE
            else:
                query_wav = wav
                query_offset = 0.0

            if aug:
                try:
                    q_np = query_wav.squeeze().numpy()
                    q_aug = aug(samples=q_np, sample_rate=SAMPLE_RATE)
                    query_wav = torch.from_numpy(q_aug).unsqueeze(0)
                except: pass

            q_emb, q_times = audio_to_embedding(model, query_wav)
            if q_emb is None: continue
            q_emb = q_emb.to(DEVICE)

            dists = torch.cdist(q_emb, db_vecs)
            scores = defaultdict(float)

            topk_vals, topk_idxs = torch.topk(dists, k=10, largest=False)

            for i in range(len(q_times)):
                for k in range(10):
                    match_idx = topk_idxs[i, k].item()
                    match_dist = topk_vals[i, k].item()
                    match_meta = db_meta[match_idx]
                    weight = math.exp(-(match_dist**2) / (2 * SIGMA**2))
                    scores[match_meta['name']] += weight

            ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
            if not ranked: continue

            top1_name = ranked[0][0]
            top5_names = [r[0] for r in ranked[:5]]

            if top1_name == target_name: results[mode]["top1"] += 1
            if target_name in top5_names: results[mode]["top5"] += 1

    print("\n" + "="*40)
    print(f"üèÜ FINAL RESULTS (Epoch 34)")
    print("="*40)
    print(f"{'MODE':<10} | {'TOP-1':<8} | {'TOP-5':<8}")
    print("-" * 32)
    for m in ["clean", "soft", "hard", "ultra"]:
        t1 = results[m]["top1"] / trials * 100
        t5 = results[m]["top5"] / trials * 100
        print(f"{m.upper():<10} | {t1:.1f}%     | {t5:.1f}%")
    print("-" * 32)

# ------------------------------------------------------------------------------
# 7. EXECUTION
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    setup_data()

    if os.path.exists(MODEL_PATH):
        print(f"üìÇ Loading Model: {MODEL_PATH}")
        model = AudioSiameseNet(embed_dim=128).to(DEVICE)
        ckpt = torch.load(MODEL_PATH, map_location=DEVICE)

        if "model_state_dict" in ckpt:
            model.load_state_dict(ckpt["model_state_dict"])
        else:
            model.load_state_dict(ckpt)

        model.eval()

        db_vecs, db_meta = build_database(model)
        run_evaluation(model, db_vecs, db_meta, trials=200)
    else:
        print(f"‚ùå Error: Model not found at {MODEL_PATH}")


üöÄ SETTING UP DATA...
‚úÖ Noise directory seems populated (2129 files). Skipping Unzip.
üìÇ Unzipping Eval Originals (Targets)...
üìÇ Unzipping Train Covers (Distractors)...
   Flattening directory structure...
üìä Final Database Count: 1612 files in /content/data/eval_combined
üìÇ Loading Model: /content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN/checkpoint_epoch_34.pth

üèóÔ∏è  Indexing Database (Originals + Covers)...


Indexing: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1612/1612 [02:30<00:00, 10.74it/s]


‚úÖ Indexed 288216 vectors from 1612 songs.

‚ö° STARTING EVALUATION (200 trials per mode)
   Database Size: 1612 unique tracks

‚ñ∂ MODE: CLEAN


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 200/200 [00:11<00:00, 16.77it/s]



‚ñ∂ MODE: SOFT


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 200/200 [00:13<00:00, 14.65it/s]



‚ñ∂ MODE: HARD


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 200/200 [00:39<00:00,  5.02it/s]



‚ñ∂ MODE: ULTRA


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 200/200 [00:36<00:00,  5.52it/s]


üèÜ FINAL RESULTS (Epoch 34)
MODE       | TOP-1    | TOP-5   
--------------------------------
CLEAN      | 91.5%     | 99.0%
SOFT       | 93.5%     | 99.5%
HARD       | 37.5%     | 54.5%
ULTRA      | 22.0%     | 35.0%
--------------------------------





FINAL best models to onnx conversion


In [5]:
!pip install onnxscript

Collecting onnxscript
  Downloading onnxscript-0.6.2-py3-none-any.whl.metadata (13 kB)
Collecting onnx_ir<2,>=0.1.15 (from onnxscript)
  Downloading onnx_ir-0.1.16-py3-none-any.whl.metadata (3.2 kB)
Downloading onnxscript-0.6.2-py3-none-any.whl (689 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m689.1/689.1 kB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnx_ir-0.1.16-py3-none-any.whl (159 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m159.3/159.3 kB[0m [31m22.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: onnx_ir, onnxscript
Successfully installed onnx_ir-0.1.16 onnxscript-0.6.2


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

# ==========================================
# 1. DEFINE ARCHITECTURES
# ==========================================

# --- PITCH MODEL (CRNN) ---
# --- UPDATED PITCH MODEL (CRNN) ---
class CRNN(nn.Module):
    def __init__(self, embed_dim=128):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(1, 64, 5, padding=2), nn.BatchNorm1d(64), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(64, 128, 3, padding=1), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(128, 256, 3, padding=1), nn.BatchNorm1d(256), nn.ReLU(),
        )
        self.lstm = nn.LSTM(256, 128, num_layers=2, batch_first=True, bidirectional=True)
        self.fc = nn.Sequential(
            nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, embed_dim)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = x.permute(0, 2, 1)

        # ‚ùå REMOVED: self.lstm.flatten_parameters()
        # (It breaks ONNX export by mutating state during forward pass)

        out, _ = self.lstm(x)
        out = torch.mean(out, dim=1)
        return F.normalize(self.fc(out), p=2, dim=1)

# --- SPECTROGRAM MODEL (Shallow CNN) ---
class AudioSiameseNet(nn.Module):
    def __init__(self, embed_dim=128):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU()
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(128, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, embed_dim)
        )

    # Renamed from forward_one for ONNX tracing
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.pool(x).view(x.size(0), -1)
        return F.normalize(self.fc(x), p=2, dim=1)


# ==========================================
# 2. EXPORT FUNCTION
# ==========================================
def export_to_onnx(model, weights_path, output_path, dummy_input, dynamic_axes):
    print(f"\nüîÑ Loading weights from: {weights_path}")

    # Load weights safely
    checkpoint = torch.load(weights_path, map_location="cpu")
    if "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint)

    model.eval()

    print(f"üì¶ Exporting to ONNX: {output_path}")
    torch.onnx.export(
        model,
        dummy_input,
        output_path,
        export_params=True,
        opset_version=14,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes=dynamic_axes,
        # üöÄ THE FIX: Tell PyTorch to use the stable TorchScript exporter, not Dynamo
        dynamo=False
    )
    print("‚úÖ Export Successful!")

# ==========================================
# 3. EXECUTE EXPORTS
# ==========================================
if __name__ == "__main__":
    # Ensure output directory exists
    OUT_DIR = "/content/drive/MyDrive/FIND_TUNE/onnx_models"
    os.makedirs(OUT_DIR, exist_ok=True)

    # --- Export CRNN (Pitch) ---
    crnn_model = CRNN(embed_dim=128)
    crnn_weights = "/content/drive/MyDrive/FIND_TUNE/pitch_based_model/finetuned_models(BEST)/FINETUNED_CRNN_Smooth.pth"
    crnn_out = os.path.join(OUT_DIR, "pitch_crnn.onnx")

    # Dummy input: (Batch=1, Channels=1, Sequence Length=1000)
    crnn_dummy = torch.randn(1, 1, 1000)
    crnn_axes = {'input': {0: 'batch_size', 2: 'seq_length'}, 'output': {0: 'batch_size'}}

    export_to_onnx(crnn_model, crnn_weights, crnn_out, crnn_dummy, crnn_axes)


    # --- Export AudioSiameseNet (Spectrogram) ---
    spec_model = AudioSiameseNet(embed_dim=128)
    spec_weights = "/content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN/checkpoint_epoch_34.pth"
    spec_out = os.path.join(OUT_DIR, "spectrogram_cnn.onnx")

    # Dummy input: (Batch=1, Channels=1, Mels=64, Time_Frames=94)
    spec_dummy = torch.randn(1, 1, 64, 94)
    spec_axes = {'input': {0: 'batch_size', 3: 'time_frames'}, 'output': {0: 'batch_size'}}

    export_to_onnx(spec_model, spec_weights, spec_out, spec_dummy, spec_axes)


üîÑ Loading weights from: /content/drive/MyDrive/FIND_TUNE/pitch_based_model/finetuned_models(BEST)/FINETUNED_CRNN_Smooth.pth
üì¶ Exporting to ONNX: /content/drive/MyDrive/FIND_TUNE/onnx_models/pitch_crnn.onnx


  torch.onnx.export(


‚úÖ Export Successful!

üîÑ Loading weights from: /content/drive/MyDrive/FIND_TUNE/spectrogram_based_model/shallow_CNN/checkpoint_epoch_34.pth
üì¶ Exporting to ONNX: /content/drive/MyDrive/FIND_TUNE/onnx_models/spectrogram_cnn.onnx
‚úÖ Export Successful!
