<a href="https://colab.research.google.com/github/DavidToth23/music_instrument_classification/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import files
uploaded = files.upload()

Saving nsynth-test.jsonwav.tar.gz to nsynth-test.jsonwav.tar.gz


In [3]:
import tarfile

with tarfile.open("nsynth-test.jsonwav.tar.gz", "r:gz") as tar:
    tar.extractall("nsynth-test")

  tar.extractall("nsynth-test")


In [4]:
import os

print(os.listdir("nsynth-test")[:10])


['nsynth-test']


In [6]:
import json, os
from pathlib import Path
import pandas as pd

SPLIT_DIR = Path("nsynth-test/nsynth-test")  # ili "nsynth-valid" ako koristiš valid
META_PATH = SPLIT_DIR / "examples.json"

# map porodica (family_id -> ime)
FAMILY_MAP = {
    0: "bass", 1: "brass", 2: "flute", 3: "guitar", 4: "keyboard",
    5: "mallet", 6: "organ", 7: "reed", 8: "string", 9: "synth_lead", 10: "vocal"
}

with open(META_PATH, "r") as f:
    meta = json.load(f)

rows = []
for key, m in meta.items():
    # u jsonwav verziji obično postoji 'audio_path'
    rel = m.get("audio_path")
    if not rel:
        # fallback: probaj da nađeš wav po id-u
        cand = list((SPLIT_DIR / "audio").rglob(f"{key}.wav"))
        if not cand:
            continue
        rel = cand[0].relative_to(SPLIT_DIR).as_posix()

    rows.append({
        "id": key,
        "wav": str((SPLIT_DIR / rel).resolve()),
        "family_id": int(m["instrument_family"]),
        "family": FAMILY_MAP[int(m["instrument_family"])],
        "pitch": int(m["pitch"]),
        "velocity": int(m["velocity"])
    })

df = pd.DataFrame(rows)
df.head(), df["family"].value_counts().sort_index()


(                                id  \
 0       bass_synthetic_068-049-025   
 1  keyboard_electronic_001-021-127   
 2      guitar_acoustic_010-066-100   
 3        reed_acoustic_037-068-127   
 4       flute_acoustic_002-077-100   
 
                                                  wav  family_id    family  \
 0  /content/nsynth-test/nsynth-test/audio/bass_sy...          0      bass   
 1  /content/nsynth-test/nsynth-test/audio/keyboar...          4  keyboard   
 2  /content/nsynth-test/nsynth-test/audio/guitar_...          3    guitar   
 3  /content/nsynth-test/nsynth-test/audio/reed_ac...          7      reed   
 4  /content/nsynth-test/nsynth-test/audio/flute_a...          2     flute   
 
    pitch  velocity  
 0     49        25  
 1     21       127  
 2     66       100  
 3     68       127  
 4     77       100  ,
 family
 bass        843
 brass       269
 flute       180
 guitar      652
 keyboard    766
 mallet      202
 organ       502
 reed        235
 string      306


In [7]:
# koliko primera po porodici (podesi po želji)
K = 300   # ~ 11*300 ≈ 3300 uzoraka (ako ih ima dovoljno u test splitu)
mini = (df.groupby("family_id", group_keys=False)
          .apply(lambda g: g.sample(min(K, len(g)), random_state=42))
          .reset_index(drop=True))

mini["family"].value_counts().sort_index()
mini.to_csv("nsynth_mini.csv", index=False)


  .apply(lambda g: g.sample(min(K, len(g)), random_state=42))


In [8]:
!pip -q install torch torchaudio librosa soundfile --upgrade

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import librosa

# audio -> mel config
SR = 16000
N_FFT = 1024
HOP = 256
N_MELS = 64
FMIN = 20
FMAX = 8000

LABELS = sorted(mini["family"].unique())
label2idx = {l:i for i,l in enumerate(LABELS)}
idx2label = {i:l for l,i in label2idx.items()}

class NSynthMelDataset(Dataset):
    def __init__(self, table, augment=False):
        self.table = table.reset_index(drop=True)
        self.augment = augment

    def __len__(self):
        return len(self.table)

    def __getitem__(self, i):
        row = self.table.iloc[i]
        wav_path = row["wav"]
        y, sr = librosa.load(wav_path, sr=SR, mono=True)
        # opcione sitne augmentacije (za start ne preteruj)
        if self.augment:
            # random gain ±3 dB
            gain = 10**(np.random.uniform(-3,3)/20)
            y = y * gain

        # log-mel
        S = librosa.feature.melspectrogram(
            y=y, sr=SR, n_fft=N_FFT, hop_length=HOP,
            n_mels=N_MELS, fmin=FMIN, fmax=FMAX
        )
        S_db = librosa.power_to_db(S, ref=np.max).astype(np.float32)  # [n_mels, time]
        # standardizacija po-sample
        mu, sigma = S_db.mean(), S_db.std() + 1e-6
        S_norm = (S_db - mu) / sigma
        # PyTorch očekuje [C, H, W]
        x = torch.from_numpy(S_norm).unsqueeze(0)  # [1, n_mels, time]
        y_lbl = torch.tensor(label2idx[row["family"]], dtype=torch.long)
        return x, y_lbl


In [9]:
from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(mini, test_size=0.2, stratify=mini["family"], random_state=42)
train_ds = NSynthMelDataset(train_df, augment=True)
val_ds   = NSynthMelDataset(val_df, augment=False)

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)
val_dl   = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

len(train_ds), len(val_ds), LABELS


(2021,
 506,
 ['bass',
  'brass',
  'flute',
  'guitar',
  'keyboard',
  'mallet',
  'organ',
  'reed',
  'string',
  'vocal'])

In [10]:
import torch.nn as nn
import torch.nn.functional as F

class SmallCNN(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),   # [1,64,T] -> [16,32,T/2]
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),  # -> [32,16,T/4]
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),  # -> [64,8,T/8]
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((1,1)) # -> [128,1,1]
        )
        self.fc = nn.Linear(128, n_classes)

    def forward(self, x):
        x = self.net(x)
        x = x.flatten(1)
        return self.fc(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SmallCNN(n_classes=len(LABELS)).to(device)

opt = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()


In [11]:
from tqdm.auto import tqdm

def run_epoch(dl, train=True):
    model.train(train)
    total, correct, loss_sum = 0, 0, 0.0
    for xb, yb in tqdm(dl, leave=False):
        xb, yb = xb.to(device), yb.to(device)
        if train:
            opt.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        if train:
            loss.backward()
            opt.step()
        loss_sum += loss.item() * yb.size(0)
        preds = logits.argmax(1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
    return loss_sum/total, correct/total

EPOCHS = 5
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = run_epoch(train_dl, train=True)
    va_loss, va_acc = run_epoch(val_dl, train=False)
    print(f"Ep {ep}: train loss {tr_loss:.3f} acc {tr_acc:.3f} | val loss {va_loss:.3f} acc {va_acc:.3f}")

# sačuvaj model i label map
torch.save({"state_dict": model.state_dict(), "labels": LABELS}, "nsynth_cnn_baseline.pt")




  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700> 
 Traceback (most recent call last):
    File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^    ^self._shutdown_workers()
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^if w.is_alive():^
^ ^ ^ ^^  ^Exception ignored in:  
<function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
 
Traceback (most recent

Ep 1: train loss 2.123 acc 0.187 | val loss 1.843 acc 0.277


  0%|          | 0/64 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
 Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700> 
  Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
      self._shutdown_workers()^^
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^if w.is_alive():^
^ ^^^ ^ ^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process' 
      ^ ^  ^ ^ ^^Exception ignored in: 

  0%|          | 0/16 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>
<function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Exception ignored in:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>        
self._shutdown_workers()Traceback (most recent call last):

self._shutdown_workers()  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__

      File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():self._sh

Ep 2: train loss 1.718 acc 0.351 | val loss 1.496 acc 0.462


  0%|          | 0/64 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():    
if w.is_alive(): 
            ^^ ^^^^^^^^^^^^^^^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^    
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._par

  0%|          | 0/16 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>Traceback (most recent call last):

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
if w.is_alive(): 
          ^ ^  ^^^^^^^^^^^^^^^^^^^^Exception ignored in: ^^<function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>

Exception ignored in: 
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160,

Ep 3: train loss 1.485 acc 0.444 | val loss 1.409 acc 0.482


  0%|          | 0/64 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()    
self._shutdown_workers()  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

      File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():    
if w.is_alive():
             Exception ignored in:  ^^^<function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>^
^^^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/datalo

  0%|          | 0/16 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>
<function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>Traceback (most recent call last):

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
      File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
self._shutdown_workers()    if w.is_alive():

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      if w.is_alive(): 
        ^ ^  ^^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>^
^^Exception ignored in: Traceback (most recent call last):
^<function _MultiProcessingDataLoaderIter.__del_

Ep 4: train loss 1.335 acc 0.500 | val loss 1.233 acc 0.561


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    

  0%|          | 0/64 [00:00<?, ?it/s]

self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>if w.is_alive():
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    
self._shutdown_workers()
    File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive(): 
^ ^ ^  ^ ^^ ^^ ^^^^^^^^^^Exception ignored in: ^
^<function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive

^Exception ignored in: Traceback (most recent call last):
    <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^
assert self.

  0%|          | 0/16 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>
Exception ignored in: Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    
self._shutdown_workers()Traceback (most recent call last):

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        if w.is_alive():self._shutdown_workers()

   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      if w.is_alive():
          Exception ignored in: ^ <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>^^
^
^Traceback (most re

Ep 5: train loss 1.155 acc 0.569 | val loss 1.004 acc 0.668


In [12]:
import itertools
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np

# skupi sve predviđanja na validation
y_true, y_pred = [], []
model.eval()
with torch.no_grad():
    for xb, yb in val_dl:
        xb = xb.to(device)
        logits = model(xb)
        preds = logits.argmax(1).cpu().numpy()
        y_pred.extend(preds.tolist())
        y_true.extend(yb.numpy().tolist())

print(classification_report(y_true, y_pred, target_names=LABELS))
cm = confusion_matrix(y_true, y_pred, labels=list(range(len(LABELS))))
cm


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
Exception ignored in:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x79ed203f2700>    
Traceback (most recent call last):
if w.is_alive():  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__

     self._shutdown_workers() 
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      if w.is_alive(): 
  ^ ^  ^ ^  ^^ ^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^    assert self._parent_pid == os.getpid(), 'can only test a child process'^
^ ^
    File "/usr/lib/pyt

              precision    recall  f1-score   support

        bass       0.82      0.75      0.78        60
       brass       0.46      1.00      0.63        54
       flute       0.95      0.56      0.70        36
      guitar       0.64      0.68      0.66        60
    keyboard       0.56      0.30      0.39        60
      mallet       0.83      0.73      0.78        41
       organ       0.85      0.75      0.80        60
        reed       0.79      0.47      0.59        47
      string       0.59      0.77      0.67        60
       vocal       0.81      0.61      0.69        28

    accuracy                           0.67       506
   macro avg       0.73      0.66      0.67       506
weighted avg       0.71      0.67      0.66       506



array([[45,  3,  0,  2,  1,  0,  1,  0,  8,  0],
       [ 0, 54,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  6, 20,  4,  0,  0,  1,  1,  4,  0],
       [ 2,  3,  0, 41,  7,  1,  0,  1,  5,  0],
       [ 0,  7,  0, 12, 18,  5,  4,  3, 11,  0],
       [ 3,  0,  0,  2,  5, 30,  1,  0,  0,  0],
       [ 0,  7,  0,  2,  0,  0, 45,  1,  1,  4],
       [ 0, 22,  0,  0,  1,  0,  0, 22,  2,  0],
       [ 5,  7,  1,  1,  0,  0,  0,  0, 46,  0],
       [ 0,  9,  0,  0,  0,  0,  1,  0,  1, 17]])