<a href="https://colab.research.google.com/github/DavidToth23/music_instrument_classification/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
from google.colab import files
uploaded = files.upload()

Saving nsynth-test.jsonwav.tar.gz to nsynth-test.jsonwav.tar.gz


In [9]:
import tarfile

with tarfile.open("nsynth-test.jsonwav.tar.gz", "r:gz") as tar:
    tar.extractall("nsynth-test")

  tar.extractall("nsynth-test")


In [10]:
import os

print(os.listdir("nsynth-test")[:10])


['nsynth-test']


In [12]:
import json, os
from pathlib import Path
import pandas as pd

SPLIT_DIR = Path("nsynth-test/nsynth-test")  # ili "nsynth-valid" ako koristiš valid
META_PATH = SPLIT_DIR / "examples.json"

# map porodica (family_id -> ime)
FAMILY_MAP = {
    0: "bass", 1: "brass", 2: "flute", 3: "guitar", 4: "keyboard",
    5: "mallet", 6: "organ", 7: "reed", 8: "string", 9: "synth_lead", 10: "vocal"
}

with open(META_PATH, "r") as f:
    meta = json.load(f)

rows = []
for key, m in meta.items():
    # u jsonwav verziji obično postoji 'audio_path'
    rel = m.get("audio_path")
    if not rel:
        # fallback: probaj da nađeš wav po id-u
        cand = list((SPLIT_DIR / "audio").rglob(f"{key}.wav"))
        if not cand:
            continue
        rel = cand[0].relative_to(SPLIT_DIR).as_posix()

    rows.append({
        "id": key,
        "wav": str((SPLIT_DIR / rel).resolve()),
        "family_id": int(m["instrument_family"]),
        "family": FAMILY_MAP[int(m["instrument_family"])],
        "pitch": int(m["pitch"]),
        "velocity": int(m["velocity"])
    })

df = pd.DataFrame(rows)
df.head(), df["family"].value_counts().sort_index()


(                                id  \
 0       bass_synthetic_068-049-025   
 1  keyboard_electronic_001-021-127   
 2      guitar_acoustic_010-066-100   
 3        reed_acoustic_037-068-127   
 4       flute_acoustic_002-077-100   
 
                                                  wav  family_id    family  \
 0  /content/nsynth-test/nsynth-test/audio/bass_sy...          0      bass   
 1  /content/nsynth-test/nsynth-test/audio/keyboar...          4  keyboard   
 2  /content/nsynth-test/nsynth-test/audio/guitar_...          3    guitar   
 3  /content/nsynth-test/nsynth-test/audio/reed_ac...          7      reed   
 4  /content/nsynth-test/nsynth-test/audio/flute_a...          2     flute   
 
    pitch  velocity  
 0     49        25  
 1     21       127  
 2     66       100  
 3     68       127  
 4     77       100  ,
 family
 bass        843
 brass       269
 flute       180
 guitar      652
 keyboard    766
 mallet      202
 organ       502
 reed        235
 string      306


In [13]:
# koliko primera po porodici (podesi po želji)
K = 300   # ~ 11*300 ≈ 3300 uzoraka (ako ih ima dovoljno u test splitu)
mini = (df.groupby("family_id", group_keys=False)
          .apply(lambda g: g.sample(min(K, len(g)), random_state=42))
          .reset_index(drop=True))

mini["family"].value_counts().sort_index()
mini.to_csv("nsynth_mini.csv", index=False)


  .apply(lambda g: g.sample(min(K, len(g)), random_state=42))


In [14]:
!pip -q install torch torchaudio librosa soundfile --upgrade

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import librosa

# audio -> mel config
SR = 16000
N_FFT = 1024
HOP = 256
N_MELS = 64
FMIN = 20
FMAX = 8000

LABELS = sorted(mini["family"].unique())
label2idx = {l:i for i,l in enumerate(LABELS)}
idx2label = {i:l for l,i in label2idx.items()}

class NSynthMelDataset(Dataset):
    def __init__(self, table, augment=False):
        self.table = table.reset_index(drop=True)
        self.augment = augment

    def __len__(self):
        return len(self.table)

    def __getitem__(self, i):
        row = self.table.iloc[i]
        wav_path = row["wav"]
        y, sr = librosa.load(wav_path, sr=SR, mono=True)
        # opcione sitne augmentacije (za start ne preteruj)
        if self.augment:
            # random gain ±3 dB
            gain = 10**(np.random.uniform(-3,3)/20)
            y = y * gain

        # log-mel
        S = librosa.feature.melspectrogram(
            y=y, sr=SR, n_fft=N_FFT, hop_length=HOP,
            n_mels=N_MELS, fmin=FMIN, fmax=FMAX
        )
        S_db = librosa.power_to_db(S, ref=np.max).astype(np.float32)  # [n_mels, time]
        # standardizacija po-sample
        mu, sigma = S_db.mean(), S_db.std() + 1e-6
        S_norm = (S_db - mu) / sigma
        # PyTorch očekuje [C, H, W]
        x = torch.from_numpy(S_norm).unsqueeze(0)  # [1, n_mels, time]
        y_lbl = torch.tensor(label2idx[row["family"]], dtype=torch.long)
        return x, y_lbl
