# This is the dev workbook for music detection

## RUN ONLY ONCE

In [228]:
# !pip install -r ../requirements.txt -v

## Dataset

### Imports

In [229]:
import torch
from torch.utils.data import Dataset
import torchaudio
from pathlib import Path
from torch.utils.data import DataLoader
import os
import time

### AudioDataset class

In [230]:
class AudioDataset(Dataset):
    def __init__(self, folder_path, target_sr=22050):
        self.folder_path = folder_path
        self.target_sr = target_sr
        self.file_paths = self.get_all_filepaths()

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        init_time = time.time()
        print("loading file n°", idx)
        file_path = self.file_paths[idx]
        wav, sr = torchaudio.load(file_path)  # wav shape: (channels, time)
        wav = torch.mean(wav, dim=0)  # mono
        if sr != self.target_sr:
            wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.target_sr)(wav)
            sr = self.target_sr
        print(f"file loaded at sr={sr} Hz in", time.time() - init_time, "seconds")
        return wav, sr, os.path.basename(file_path)

    def get_all_filepaths(self):
        folder = Path(self.folder_path)
        file_paths = [str(p) for p in folder.rglob("*")]
        return file_paths

    def get_time_length(self, wav):
        return len(wav[0]) / self.target_sr

### test dataset

In [231]:
dataset = AudioDataset(folder_path = "../../../Music/full_test")

mus1 = dataset[0]
mus2 = dataset[1]
print(mus2[0], len(mus2[0]))
print(mus2[1])
print(mus2[2])
print(mus1[0], len(mus1[0]))
print(mus1[1])
print(mus1[2])

loading file n° 0
file loaded at sr=22050 Hz in 0.35961246490478516 seconds
loading file n° 1
file loaded at sr=22050 Hz in 0.5574502944946289 seconds
tensor([ 5.4581e-14,  2.0858e-14, -8.9946e-14,  ...,  0.0000e+00,
         0.0000e+00,  0.0000e+00]) 7093440
22050
20050904_56110_kson_kda_14.mp3
tensor([-2.1066e-13, -3.4545e-13,  3.6385e-13,  ...,  0.0000e+00,
         0.0000e+00,  0.0000e+00]) 4858560
22050
20040905_56110_kson_kme_04.mp3


## Augmenter

In [232]:
import numpy as np
import librosa
import torch
import random
import torchaudio
from dataset import AudioDataset
from utils import save_wav

In [233]:
def normalize(wav):
    # Normalize to [-1, 1]
    max_val = wav.abs().max()
    if max_val > 0:
        wav = wav / max_val
    else:
        raise Exception("Audio is empty ...")
    return wav

class RandomClip:
    def __init__(self, sample_rate, clip_length):
        self.clip_length = sample_rate * clip_length
        self.sr = sample_rate
        self.vad = torchaudio.transforms.Vad(
            sample_rate=sample_rate, trigger_level=7.0)

    def __call__(self, audio_data):
        audio_length = audio_data.shape[0]
        print(audio_length)
        print(len(audio_data))
        print(audio_data)
        print(len(audio_data))
        if audio_length > self.clip_length:
            offset = random.randint(0, audio_length-self.clip_length)
            print(f"Audio cut between {offset/self.sr}s and, {(offset+self.clip_length)/self.sr}s")
            audio_data = audio_data[offset:(offset+self.clip_length)]
        else :
            raise Exception("Audio shorter than clip ...")
        return audio_data # remove silences at the beggining/end  self.vad()

def augment_audio(wav, sr, test = False): # TO DO implement noise adding
    if test:
        save_wav(path="tmp/test_augmenter/raw.wav", wav=wav, sr=sr)

    arr = wav.numpy().squeeze()
    n_steps=random.choice([-2, -1, 0, 1, 2])
    print(f"pitch schift of {n_steps} steps")
    arr = librosa.effects.pitch_shift(y=arr, sr=sr, n_steps=n_steps)
    if test:
        save_wav(path="tmp/test_augmenter/pitch.wav", wav=torch.tensor(arr).unsqueeze(0), sr=sr)


    rate = random.uniform(0.9, 1.1)
    print(f"time stretch rate of {rate}")
    arr = librosa.effects.time_stretch(y=arr, rate=rate)
    wav = torch.tensor(arr) #.unsqueeze(0)
    if test:
        save_wav(path="tmp/test_augmenter/stretch.wav", wav=wav, sr=sr)

    clip_transform = RandomClip(sr, clip_length = 7)
    len(wav)
    clipped_audio = clip_transform(wav)
    if test:
        save_wav(path="tmp/test_augmenter/clipped.wav", wav=clipped_audio, sr=sr)

    return clipped_audio.unsqueeze(0)

In [234]:
dataset = AudioDataset(folder_path = "../../../Music/full_test")
test_audio = dataset[0]
print("label",test_audio[2])
wav = augment_audio(wav=test_audio[0], sr=test_audio[1], test = True)
len(wav)

loading file n° 0
file loaded at sr=22050 Hz in 0.3621690273284912 seconds
label 20040905_56110_kson_kme_04.mp3
Successfully saved tmp/test_augmenter/raw.wav with shape (4858560,) at 22050Hz
pitch schift of -2 steps
Successfully saved tmp/test_augmenter/pitch.wav with shape (4858560,) at 22050Hz
time stretch rate of 1.0391863370571826
Successfully saved tmp/test_augmenter/stretch.wav with shape (4675350,) at 22050Hz
4675350
4675350
tensor([ 2.1534e-07, -5.3995e-08, -7.9182e-07,  ...,  0.0000e+00,
         0.0000e+00,  0.0000e+00])
4675350
Audio cut between 198.1906575963719s and, 205.1906575963719s
Successfully saved tmp/test_augmenter/clipped.wav with shape (154350,) at 22050Hz


1

## Model and training

In [235]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [236]:
class MusicEncoder(nn.Module):
    def __init__(self, in_channels=1, out_dim=128):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        # works for any spectrogram shape
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, out_dim)

    def forward(self, x):
        x = self.conv(x)                # (B, 128, H’, W’)
        x = self.pool(x)                # (B, 128, 1, 1)
        x = x.view(x.size(0), -1)       # (B, 128)
        x = self.fc(x)                  # (B, out_dim)
        return F.normalize(x, dim=-1)   # unit vector embeddings


In [237]:
def contrastive_loss(z1, z2, same_label: bool, margin=1.0):
    dist = F.pairwise_distance(z1, z2)
    if same_label:
        return torch.mean(dist ** 2)  # positive → bring closer
    else:
        return torch.mean(F.relu(margin - dist) ** 2)  # negative → push apart

In [238]:
def get_mel_spec(wav, sr):
    spec = T.MelSpectrogram(sr, n_mels=64)(wav)
    return T.AmplitudeToDB()(spec)

In [239]:
encoder = MusicEncoder().cuda()  # Instantiate the audio encoder and move it to the GPU
opt = torch.optim.Adam(encoder.parameters(), 1e-3)  # Adam optimizer for training the encoder
database = {}  # Dictionary to store embeddings for each audio clip


In [240]:
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu129


In [241]:
len(dataset)

6

In [242]:
def complete_batch_index_generator(batch_size, dataset):
    batches = []
    total_len = len(dataset)
    total_list = list(np.arange(total_len))
    while total_len > 0 :
        print("new batch")
        batch = []
        cpt = 0
        start_len = total_len
        while cpt < min(batch_size, start_len):
            index = np.random.randint(total_len)
            number = total_list.pop(index)
            batch.append(int(number))
            print(number, "added to batch")
            total_len -= 1
            cpt += 1
        batches.append(batch)
    return batches

print(complete_batch_index_generator(batch_size=1, dataset=dataset))
print(complete_batch_index_generator(batch_size=2, dataset=dataset))

new batch
0 added to batch
new batch
2 added to batch
new batch
1 added to batch
new batch
3 added to batch
new batch
4 added to batch
new batch
5 added to batch
[[0], [2], [1], [3], [4], [5]]
new batch
1 added to batch
3 added to batch
new batch
0 added to batch
4 added to batch
new batch
5 added to batch
2 added to batch
[[1, 3], [0, 4], [5, 2]]


In [243]:
def multiple_batch_index_generator(batch_size, loop_nb, dataset):
    batch_loop = []
    for i in range(loop_nb):
        batches = complete_batch_index_generator(batch_size, dataset)
        batch_loop += batches
    return batch_loop

batches = multiple_batch_index_generator(batch_size=2, loop_nb=5, dataset=dataset)
print(batches)


new batch
4 added to batch
0 added to batch
new batch
1 added to batch
5 added to batch
new batch
2 added to batch
3 added to batch
new batch
1 added to batch
4 added to batch
new batch
2 added to batch
5 added to batch
new batch
0 added to batch
3 added to batch
new batch
3 added to batch
2 added to batch
new batch
4 added to batch
5 added to batch
new batch
0 added to batch
1 added to batch
new batch
0 added to batch
1 added to batch
new batch
4 added to batch
2 added to batch
new batch
3 added to batch
5 added to batch
new batch
0 added to batch
4 added to batch
new batch
1 added to batch
3 added to batch
new batch
2 added to batch
5 added to batch
[[4, 0], [1, 5], [2, 3], [1, 4], [2, 5], [0, 3], [3, 2], [4, 5], [0, 1], [0, 1], [4, 2], [3, 5], [0, 4], [1, 3], [2, 5]]


In [244]:
def counter_sample(audio_index, dataset):
    total_len = len(dataset)
    total_list = list(np.arange(total_len))
    same = np.random.randint(2)
    if not same:
        print("Counter example")
        total_list.pop(audio_index)
        audio2_index = np.random.randint(total_len-1)
        audio2_index = total_list[audio2_index]
    else :
        print("Same audio")
        audio2_index = audio_index
    return int(audio2_index), same

print(counter_sample(0, dataset))

Counter example
(2, 0)


In [245]:
# all_batch = multiple_batch_index_generator(batch_size=2, loop_nb=5, dataset=dataset)
# encoder.train()

# for batch_index, batch in enumerate(all_batch): 
#     all_embeddings = []
#     all_labels = [] #collate to fit largest size
#     # Generate p augmentations for each sample
#     for i in tqdm(range(p)):
#         augmented_batch = [augment_audio(w, sr) for w in wav_batch]   # list of len n
#         mel_batch = [get_mel_spec(w, sr) for w in augmented_batch]    # list of len n
#         mel_batch = torch.stack(mel_batch).cuda()  # shape: (n, mel_bins, time_frames)
#         z_batch = encoder(mel_batch)               # shape: (n, embedding_dim)

#         all_embeddings.append(z_batch)
#         all_labels.append(torch.arange(len(wav_batch)))  # labels: 0..n-1
#         print(all_labels)

#     # Concatenate all p augmentations along the batch dimension
#     all_embeddings = torch.cat(all_embeddings, dim=0)  # shape: (n*p, embedding_dim)
#     all_labels = torch.cat(all_labels, dim=0).cuda()    # shape: (n*p,)

#     # Compute contrastive loss (e.g., NT-Xent)
#     loss = contrastive_loss(all_embeddings, all_labels)

#     # Backprop
#     opt.zero_grad()
#     loss.backward()
#     opt.step()

# print(f"Batch {batch_index}: Loss = {loss.item():.4f}")

In [None]:
import torchaudio.transforms as T

: 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device", device)
encoder = encoder.to(device)

all_batch = multiple_batch_index_generator(batch_size=2, loop_nb=500, dataset=dataset)
encoder.train()

for batch_index, batch in enumerate(all_batch): 
    all_embeddings = []
    all_labels = []
    # Generate p augmentations for each sample
    for audio_index in batch:
        wav, sr, label = dataset[audio_index]
        counter_audio_index, is_same = counter_sample(audio_index, dataset)
        counter_wav, counter_sr, counter_label = dataset[counter_audio_index]
        augmented_counter_wav = augment_audio(wav=counter_wav, sr=counter_sr)
        mel_spec = get_mel_spec(wav, sr).unsqueeze(0).unsqueeze(0).to(device) #to add the channel dimension for the model conv2D
        counter_mel_spec = get_mel_spec(augmented_counter_wav, counter_sr).unsqueeze(0).to(device)
        z = encoder(mel_spec)
        counter_z = encoder(counter_mel_spec)

        # Compute contrastive loss (e.g., NT-Xent)
        print("is_same =",is_same)
        loss = contrastive_loss(z, counter_z, is_same)
        print("z", z.cpu().detach().numpy()[0][:4], "\ncounter_z", counter_z.cpu().detach().numpy()[0][:4])
        print(f"Batch {batch_index}: Loss = {loss:.10f}, same={is_same}")

    # Backprop
    opt.zero_grad()
    loss.backward()
    opt.step()

    

device cuda
new batch
0 added to batch
5 added to batch
new batch
1 added to batch
4 added to batch
new batch
3 added to batch
2 added to batch
new batch
4 added to batch
3 added to batch
new batch
0 added to batch
2 added to batch
new batch
5 added to batch
1 added to batch
new batch
3 added to batch
4 added to batch
new batch
5 added to batch
0 added to batch
new batch
2 added to batch
1 added to batch
new batch
4 added to batch
2 added to batch
new batch
0 added to batch
5 added to batch
new batch
1 added to batch
3 added to batch
new batch
3 added to batch
4 added to batch
new batch
1 added to batch
5 added to batch
new batch
0 added to batch
2 added to batch
new batch
0 added to batch
5 added to batch
new batch
3 added to batch
1 added to batch
new batch
2 added to batch
4 added to batch
new batch
1 added to batch
5 added to batch
new batch
0 added to batch
4 added to batch
new batch
2 added to batch
3 added to batch
new batch
3 added to batch
2 added to batch
new batch
1 added to

In [None]:
import faiss

In [None]:
# --- Build Embedding Database ---
encoder.eval()  # Set encoder to evaluation mode
embs = []  # List to store embeddings
labels = []  # List to store corresponding labels


# For each clip, compute and store its embedding
for wav, sr, label in dataset:
    m = get_mel_spec(wav, sr).unsqueeze(0).unsqueeze(0).cuda()  # Mel spectrogram, batch dimension, move to GPU
    with torch.no_grad():
        em = encoder(m).cpu().numpy()[0]  # Get embedding, move to CPU, convert to numpy
    database[label] = em  # Store in database dictionary
    embs.append(em)  # Add to embedding list
    labels.append(label)  

dim = embs[0].shape[0]  # Dimensionality of embeddings
index = faiss.IndexFlatL2(dim)  # Create a FAISS index for fast nearest neighbor search (L2 distance)
index.add(np.array(embs))  # Add all embeddings to the index


loading file n° 0
file loaded at sr=22050 Hz in 0.31972718238830566 seconds
loading file n° 1
file loaded at sr=22050 Hz in 0.30269622802734375 seconds
loading file n° 2


In [None]:
import gradio as gr

In [None]:
# --- Gradio Demo ---
def recognize(inp):
    if inp is None:
        return "Please upload an audio file"
    
    try:
        wav2, sr2 = torchaudio.load(inp)
        if sr2 != 22050:
            wav2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=22050)(wav2)
        wav2 = torch.mean(wav2, dim=0)
        m = get_mel_spec(wav2, sr2).unsqueeze(0).unsqueeze(0).cuda()
        with torch.no_grad():
            em2 = encoder(m).cpu().numpy()
        D, I = index.search(em2, 1)
        return f"Closest Match: {labels[I[0][0]]} (dist {D[0][0]:.3f})"
    except Exception as e:
        return f"Error processing audio: {str(e)}"

gr.Interface(
    fn=recognize,
    inputs=gr.Audio(type="filepath"),
    outputs="text",
    title="Music Recognizer Demo"
).launch()

* Running on local URL:  http://127.0.0.1:7864
* To create a public link, set `share=True` in `launch()`.




