# This is the dev workbook for music detection

## RUN ONLY ONCE

In [1]:
# !pip install -r ../requirements.txt -v

## Dataset

### Imports

In [2]:
import torch
from torch.utils.data import Dataset
import torchaudio
from pathlib import Path
from torch.utils.data import DataLoader
import os
import time

### AudioDataset class

In [3]:
class AudioDataset(Dataset):
    def __init__(self, folder_path, target_sr=22050):
        self.folder_path = folder_path
        self.target_sr = target_sr
        self.file_paths = self.get_all_filepaths()

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        init_time = time.time()
        print("loading file n°", idx)
        file_path = self.file_paths[idx]
        wav, sr = torchaudio.load(file_path)  # wav shape: (channels, time)
        wav = torch.mean(wav, dim=0)  # mono
        if sr != self.target_sr:
            wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.target_sr)(wav)
            sr = self.target_sr
        print(f"file loaded at sr={sr} Hz in", time.time() - init_time, "seconds")
        return wav, sr, os.path.basename(file_path)

    def get_all_filepaths(self):
        folder = Path(self.folder_path)
        file_paths = [str(p) for p in folder.rglob("*")]
        return file_paths

    def get_time_length(self, wav):
        return len(wav[0]) / self.target_sr

### test dataset

In [4]:
dataset = AudioDataset(folder_path = "../../../Music/full_test")

mus1 = dataset[0]
mus2 = dataset[1]
print(mus2[0], len(mus2[0]))
print(mus2[1])
print(mus2[2])
print(mus1[0], len(mus1[0]))
print(mus1[1])
print(mus1[2])

loading file n° 0




file loaded at sr=22050 Hz in 0.5561816692352295 seconds
loading file n° 1
file loaded at sr=22050 Hz in 0.5975611209869385 seconds
tensor([ 5.4581e-14,  2.0858e-14, -8.9946e-14,  ...,  0.0000e+00,
         0.0000e+00,  0.0000e+00]) 7093440
22050
20050904_56110_kson_kda_14.mp3
tensor([-2.1066e-13, -3.4545e-13,  3.6385e-13,  ...,  0.0000e+00,
         0.0000e+00,  0.0000e+00]) 4858560
22050
20040905_56110_kson_kme_04.mp3


## Augmenter

In [5]:
import numpy as np
import librosa
import torch
import random
import torchaudio
from dataset import AudioDataset
from utils import save_wav

In [6]:
def normalize(wav):
    # Normalize to [-1, 1]
    max_val = wav.abs().max()
    if max_val > 0:
        wav = wav / max_val
    else:
        raise Exception("Audio is empty ...")
    return wav

class RandomClip:
    def __init__(self, sample_rate, clip_length):
        self.clip_length = sample_rate * clip_length
        self.sr = sample_rate
        self.vad = torchaudio.transforms.Vad(
            sample_rate=sample_rate, trigger_level=7.0)

    def __call__(self, audio_data):
        audio_length = audio_data.shape[0]
        print(audio_length)
        print(len(audio_data))
        print(audio_data)
        print(len(audio_data))
        if audio_length > self.clip_length:
            offset = random.randint(0, audio_length-self.clip_length)
            print(f"Audio cut between {offset/self.sr}s and, {(offset+self.clip_length)/self.sr}s")
            audio_data = audio_data[offset:(offset+self.clip_length)]
        else :
            raise Exception("Audio shorter than clip ...")
        return audio_data # remove silences at the beggining/end  self.vad()

def augment_audio(wav, sr, test = False): # TO DO implement noise adding
    if test:
        save_wav(path="tmp/test_augmenter/raw.wav", wav=wav, sr=sr)

    arr = wav.numpy().squeeze()
    n_steps=random.choice([-2, -1, 0, 1, 2])
    print(f"pitch schift of {n_steps} steps")
    arr = librosa.effects.pitch_shift(y=arr, sr=sr, n_steps=n_steps)
    if test:
        save_wav(path="tmp/test_augmenter/pitch.wav", wav=torch.tensor(arr).unsqueeze(0), sr=sr)


    rate = random.uniform(0.9, 1.1)
    print(f"time stretch rate of {rate}")
    arr = librosa.effects.time_stretch(y=arr, rate=rate)
    wav = torch.tensor(arr) #.unsqueeze(0)
    if test:
        save_wav(path="tmp/test_augmenter/stretch.wav", wav=wav, sr=sr)

    clip_transform = RandomClip(sr, clip_length = 7)
    len(wav)
    clipped_audio = clip_transform(wav)
    if test:
        save_wav(path="tmp/test_augmenter/clipped.wav", wav=clipped_audio, sr=sr)

    return clipped_audio.unsqueeze(0)

In [29]:
dataset = AudioDataset(folder_path = "../../../Music/full_test")
test_audio = dataset[5]
print("label",test_audio[2])
wav = augment_audio(wav=test_audio[0], sr=test_audio[1], test = True)
len(wav)

loading file n° 5
file loaded at sr=22050 Hz in 1.4726614952087402 seconds
label 20170903_56110_kson_kme_04.mp3
Successfully saved tmp/test_augmenter/raw.wav with shape (5918400,) at 22050Hz
pitch schift of 2 steps
Successfully saved tmp/test_augmenter/pitch.wav with shape (5918400,) at 22050Hz
time stretch rate of 0.96150071513127
Successfully saved tmp/test_augmenter/stretch.wav with shape (6155378,) at 22050Hz
6155378
6155378
tensor([-2.8399e-09, -3.6703e-08, -4.3963e-08,  ...,  7.1318e-05,
         7.5353e-05,  8.3140e-05])
6155378
Audio cut between 237.4568253968254s and, 244.4568253968254s
Successfully saved tmp/test_augmenter/clipped.wav with shape (154350,) at 22050Hz


1

## Model and training

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [9]:
class MusicEncoder(nn.Module):
    def __init__(self, in_channels=1, out_dim=128):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        # works for any spectrogram shape
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, out_dim)

    def forward(self, x):
        x = self.conv(x)                # (B, 128, H’, W’)
        x = self.pool(x)                # (B, 128, 1, 1)
        x = x.view(x.size(0), -1)       # (B, 128)
        x = self.fc(x)                  # (B, out_dim)
        return F.normalize(x, dim=-1)   # unit vector embeddings


In [10]:
def contrastive_loss(z1, z2, same_label: bool, margin=1.0):
    dist = F.pairwise_distance(z1, z2)
    if same_label:
        return torch.mean(dist ** 2)  # positive → bring closer
    else:
        return torch.mean(F.relu(margin - dist) ** 2)  # negative → push apart

In [11]:
def get_mel_spec(wav, sr):
    spec = T.MelSpectrogram(sr, n_mels=64)(wav)
    return T.AmplitudeToDB()(spec)

In [12]:
encoder = MusicEncoder().cuda()  # Instantiate the audio encoder and move it to the GPU
opt = torch.optim.Adam(encoder.parameters(), 1e-3)  # Adam optimizer for training the encoder
database = {}  # Dictionary to store embeddings for each audio clip


In [13]:
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu129


In [14]:
len(dataset)

6

In [15]:
def complete_batch_index_generator(batch_size, dataset):
    batches = []
    total_len = len(dataset)
    total_list = list(np.arange(total_len))
    while total_len > 0 :
        print("new batch")
        batch = []
        cpt = 0
        start_len = total_len
        while cpt < min(batch_size, start_len):
            index = np.random.randint(total_len)
            number = total_list.pop(index)
            batch.append(int(number))
            print(number, "added to batch")
            total_len -= 1
            cpt += 1
        batches.append(batch)
    return batches

print(complete_batch_index_generator(batch_size=1, dataset=dataset))
print(complete_batch_index_generator(batch_size=2, dataset=dataset))

new batch
2 added to batch
new batch
0 added to batch
new batch
3 added to batch
new batch
4 added to batch
new batch
1 added to batch
new batch
5 added to batch
[[2], [0], [3], [4], [1], [5]]
new batch
4 added to batch
5 added to batch
new batch
0 added to batch
3 added to batch
new batch
1 added to batch
2 added to batch
[[4, 5], [0, 3], [1, 2]]


In [16]:
def multiple_batch_index_generator(batch_size, loop_nb, dataset):
    batch_loop = []
    for i in range(loop_nb):
        batches = complete_batch_index_generator(batch_size, dataset)
        batch_loop += batches
    return batch_loop

batches = multiple_batch_index_generator(batch_size=2, loop_nb=5, dataset=dataset)
print(batches)


new batch
5 added to batch
2 added to batch
new batch
1 added to batch
3 added to batch
new batch
4 added to batch
0 added to batch
new batch
1 added to batch
2 added to batch
new batch
4 added to batch
3 added to batch
new batch
5 added to batch
0 added to batch
new batch
5 added to batch
3 added to batch
new batch
4 added to batch
1 added to batch
new batch
0 added to batch
2 added to batch
new batch
1 added to batch
4 added to batch
new batch
2 added to batch
3 added to batch
new batch
0 added to batch
5 added to batch
new batch
3 added to batch
2 added to batch
new batch
0 added to batch
4 added to batch
new batch
5 added to batch
1 added to batch
[[5, 2], [1, 3], [4, 0], [1, 2], [4, 3], [5, 0], [5, 3], [4, 1], [0, 2], [1, 4], [2, 3], [0, 5], [3, 2], [0, 4], [5, 1]]


In [17]:
def counter_sample(audio_index, dataset):
    total_len = len(dataset)
    total_list = list(np.arange(total_len))
    same = np.random.randint(2)
    if not same:
        print("Counter example")
        total_list.pop(audio_index)
        audio2_index = np.random.randint(total_len-1)
        audio2_index = total_list[audio2_index]
    else :
        print("Same audio")
        audio2_index = audio_index
    return int(audio2_index), same

print(counter_sample(0, dataset))

Counter example
(1, 0)


In [18]:
# all_batch = multiple_batch_index_generator(batch_size=2, loop_nb=5, dataset=dataset)
# encoder.train()

# for batch_index, batch in enumerate(all_batch): 
#     all_embeddings = []
#     all_labels = [] #collate to fit largest size
#     # Generate p augmentations for each sample
#     for i in tqdm(range(p)):
#         augmented_batch = [augment_audio(w, sr) for w in wav_batch]   # list of len n
#         mel_batch = [get_mel_spec(w, sr) for w in augmented_batch]    # list of len n
#         mel_batch = torch.stack(mel_batch).cuda()  # shape: (n, mel_bins, time_frames)
#         z_batch = encoder(mel_batch)               # shape: (n, embedding_dim)

#         all_embeddings.append(z_batch)
#         all_labels.append(torch.arange(len(wav_batch)))  # labels: 0..n-1
#         print(all_labels)

#     # Concatenate all p augmentations along the batch dimension
#     all_embeddings = torch.cat(all_embeddings, dim=0)  # shape: (n*p, embedding_dim)
#     all_labels = torch.cat(all_labels, dim=0).cuda()    # shape: (n*p,)

#     # Compute contrastive loss (e.g., NT-Xent)
#     loss = contrastive_loss(all_embeddings, all_labels)

#     # Backprop
#     opt.zero_grad()
#     loss.backward()
#     opt.step()

# print(f"Batch {batch_index}: Loss = {loss.item():.4f}")

In [19]:
import torchaudio.transforms as T

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device", device)
encoder = encoder.to(device)

all_batch = multiple_batch_index_generator(batch_size=2, loop_nb=500, dataset=dataset)
encoder.train()

for batch_index, batch in enumerate(all_batch): 
    all_embeddings = []
    all_labels = []
    # Generate p augmentations for each sample
    for audio_index in batch:
        wav, sr, label = dataset[audio_index]
        counter_audio_index, is_same = counter_sample(audio_index, dataset)
        counter_wav, counter_sr, counter_label = dataset[counter_audio_index]
        augmented_counter_wav = augment_audio(wav=counter_wav, sr=counter_sr)
        mel_spec = get_mel_spec(wav, sr).unsqueeze(0).unsqueeze(0).to(device) #to add the channel dimension for the model conv2D
        counter_mel_spec = get_mel_spec(augmented_counter_wav, counter_sr).unsqueeze(0).to(device)
        z = encoder(mel_spec)
        counter_z = encoder(counter_mel_spec)

        # Compute contrastive loss (e.g., NT-Xent)
        print("is_same =",is_same)
        loss = contrastive_loss(z, counter_z, is_same)
        print("z", z.cpu().detach().numpy()[0][:4], "\ncounter_z", counter_z.cpu().detach().numpy()[0][:4])
        print(f"Batch {batch_index}: Loss = {loss:.10f}, same={is_same}")

    # Backprop
    opt.zero_grad()
    loss.backward()
    opt.step()

    

device cuda
new batch
1 added to batch
4 added to batch
new batch
3 added to batch
5 added to batch
new batch
0 added to batch
2 added to batch
new batch
3 added to batch
2 added to batch
new batch
4 added to batch
1 added to batch
new batch
5 added to batch
0 added to batch
new batch
0 added to batch
3 added to batch
new batch
1 added to batch
4 added to batch
new batch
2 added to batch
5 added to batch
new batch
1 added to batch
3 added to batch
new batch
0 added to batch
4 added to batch
new batch
2 added to batch
5 added to batch
new batch
4 added to batch
1 added to batch
new batch
2 added to batch
5 added to batch
new batch
0 added to batch
3 added to batch
new batch
4 added to batch
3 added to batch
new batch
1 added to batch
2 added to batch
new batch
5 added to batch
0 added to batch
new batch
3 added to batch
1 added to batch
new batch
5 added to batch
4 added to batch
new batch
0 added to batch
2 added to batch
new batch
4 added to batch
2 added to batch
new batch
0 added to



file loaded at sr=22050 Hz in 0.5494625568389893 seconds
Counter example
loading file n° 3
file loaded at sr=22050 Hz in 0.28600072860717773 seconds
pitch schift of -1 steps
time stretch rate of 1.0900163482584122
3664151
3664151
tensor([ 9.1244e-05,  9.6710e-05,  5.6769e-05,  ..., -5.1553e-03,
        -4.9675e-03, -4.4932e-03])
3664151
Audio cut between 150.60721088435375s and, 157.60721088435375s
is_same = 0
z [-0.15838465 -0.15756871  0.01529315  0.06343004] 
counter_z [-0.15925588 -0.16053553  0.01414535  0.06552351]
Batch 0: Loss = 0.9489345551, same=0
loading file n° 4
file loaded at sr=22050 Hz in 0.3549995422363281 seconds
Same audio
loading file n° 4
file loaded at sr=22050 Hz in 0.2409989833831787 seconds
pitch schift of 2 steps
time stretch rate of 1.055139210165941
3277035
3277035
tensor([0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.7338e-04, 1.6639e-04,
        3.8720e-05])
3277035
Audio cut between 34.65859410430839s and, 41.65859410430839s
is_same = 1
z [-0.15643731 -0.16

In [21]:
import faiss

In [22]:
# --- Build Embedding Database ---
encoder.eval()  # Set encoder to evaluation mode
embs = []  # List to store embeddings
labels = []  # List to store corresponding labels


# For each clip, compute and store its embedding
for wav, sr, label in dataset:
    m = get_mel_spec(wav, sr).unsqueeze(0).unsqueeze(0).cuda()  # Mel spectrogram, batch dimension, move to GPU
    with torch.no_grad():
        em = encoder(m).cpu().numpy()[0]  # Get embedding, move to CPU, convert to numpy
    database[label] = em  # Store in database dictionary
    embs.append(em)  # Add to embedding list
    labels.append(label)  

dim = embs[0].shape[0]  # Dimensionality of embeddings
index = faiss.IndexFlatL2(dim)  # Create a FAISS index for fast nearest neighbor search (L2 distance)
index.add(np.array(embs))  # Add all embeddings to the index


loading file n° 0
file loaded at sr=22050 Hz in 0.3297698497772217 seconds
loading file n° 1
file loaded at sr=22050 Hz in 0.43454694747924805 seconds
loading file n° 2
file loaded at sr=22050 Hz in 0.2452101707458496 seconds
loading file n° 3
file loaded at sr=22050 Hz in 0.24305248260498047 seconds
loading file n° 4
file loaded at sr=22050 Hz in 0.2157001495361328 seconds
loading file n° 5
file loaded at sr=22050 Hz in 0.3741321563720703 seconds
loading file n° 6


In [23]:
import gradio as gr

  from .autonotebook import tqdm as notebook_tqdm


In [24]:
# --- Gradio Demo ---
def recognize(inp):
    if inp is None:
        return "Please upload an audio file"
    
    try:
        wav2, sr2 = torchaudio.load(inp)
        if sr2 != 22050:
            wav2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=22050)(wav2)
        wav2 = torch.mean(wav2, dim=0)
        m = get_mel_spec(wav2, sr2).unsqueeze(0).unsqueeze(0).cuda()
        with torch.no_grad():
            em2 = encoder(m).cpu().numpy()
        D, I = index.search(em2, 1)
        return f"Closest Match: {labels[I[0][0]]} (dist {D[0][0]:.3f})"
    except Exception as e:
        return f"Error processing audio: {str(e)}"

gr.Interface(
    fn=recognize,
    inputs=gr.Audio(type="filepath"),
    outputs="text",
    title="Music Recognizer Demo"
).launch()

* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.






In [31]:
torch.save(encoder.state_dict(), "../saved_weights/first_long_try")

In [None]:
# model = TheModelClass(*args, **kwargs)
# model.load_state_dict(torch.load(PATH, weights_only=True))
# model.eval()