In [1]:
import kagglehub

path = kagglehub.dataset_download("andradaolteanu/gtzan-dataset-music-genre-classification")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'gtzan-dataset-music-genre-classification' dataset.
Path to dataset files: /kaggle/input/gtzan-dataset-music-genre-classification


In [2]:
import os, random, math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import librosa
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

In [3]:
import soundfile as sf

In [4]:
sounds_df = pd.read_csv('/kaggle/input/gtzan-dataset-music-genre-classification/Data/features_3_sec.csv')
sounds_df

Unnamed: 0,filename,length,chroma_stft_mean,chroma_stft_var,rms_mean,rms_var,spectral_centroid_mean,spectral_centroid_var,spectral_bandwidth_mean,spectral_bandwidth_var,...,mfcc16_var,mfcc17_mean,mfcc17_var,mfcc18_mean,mfcc18_var,mfcc19_mean,mfcc19_var,mfcc20_mean,mfcc20_var,label
0,blues.00000.0.wav,66149,0.335406,0.091048,0.130405,0.003521,1773.065032,167541.630869,1972.744388,117335.771563,...,39.687145,-3.241280,36.488243,0.722209,38.099152,-5.050335,33.618073,-0.243027,43.771767,blues
1,blues.00000.1.wav,66149,0.343065,0.086147,0.112699,0.001450,1816.693777,90525.690866,2010.051501,65671.875673,...,64.748276,-6.055294,40.677654,0.159015,51.264091,-2.837699,97.030830,5.784063,59.943081,blues
2,blues.00000.2.wav,66149,0.346815,0.092243,0.132003,0.004620,1788.539719,111407.437613,2084.565132,75124.921716,...,67.336563,-1.768610,28.348579,2.378768,45.717648,-1.938424,53.050835,2.517375,33.105122,blues
3,blues.00000.3.wav,66149,0.363639,0.086856,0.132565,0.002448,1655.289045,111952.284517,1960.039988,82913.639269,...,47.739452,-3.841155,28.337118,1.218588,34.770935,-3.580352,50.836224,3.630866,32.023678,blues
4,blues.00000.4.wav,66149,0.335579,0.088129,0.143289,0.001701,1630.656199,79667.267654,1948.503884,60204.020268,...,30.336359,0.664582,45.880913,1.689446,51.363583,-3.392489,26.738789,0.536961,29.146694,blues
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9985,rock.00099.5.wav,66149,0.349126,0.080515,0.050019,0.000097,1499.083005,164266.886443,1718.707215,85931.574523,...,42.485981,-9.094270,38.326839,-4.246976,31.049839,-5.625813,48.804092,1.818823,38.966969,rock
9986,rock.00099.6.wav,66149,0.372564,0.082626,0.057897,0.000088,1847.965128,281054.935973,1906.468492,99727.037054,...,32.415203,-12.375726,66.418587,-3.081278,54.414265,-11.960546,63.452255,0.428857,18.697033,rock
9987,rock.00099.7.wav,66149,0.347481,0.089019,0.052403,0.000701,1346.157659,662956.246325,1561.859087,138762.841945,...,78.228149,-2.524483,21.778994,4.809936,25.980829,1.775686,48.582378,-0.299545,41.586990,rock
9988,rock.00099.8.wav,66149,0.387527,0.084815,0.066430,0.000320,2084.515327,203891.039161,2018.366254,22860.992562,...,28.323744,-5.363541,17.209942,6.462601,21.442928,2.354765,24.843613,0.675824,12.787750,rock


In [5]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

SR = 22050

N_MELS = 128
N_FFT = 2048
HOP = 512

SEG_SEC = 3.0
SEG_SAMPLES = int(SR * SEG_SEC)

ROOT_OF_GENRES = "/kaggle/input/gtzan-dataset-music-genre-classification/Data/genres_original"

In [6]:
# Индексация датасета (аудиофайлы + метки)
def list_gztan_files(root_dir):
    items = [] # (path, label_id, label_name),
    bad = []
    genres = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
    g2i = {g:i for i,g in enumerate(genres)}

    for g in genres:
        folder = os.path.join(root_dir, g)
        for fn in sorted(os.listdir(folder)):
            if not fn.lower().endswith(".wav"):
                continue

            path = os.path.join(folder, fn)

            try:
                sf.info(path)
                items.append((path, g2i[g], g))
            except Exception:
                bad.append(path)

    if bad:
        print(f"[WARN] skipped bad audio files: {len(bad)}")
        print("\n".join(bad[:10]))

    return items, g2i

In [7]:
#Split по трекам
def split_by_track(items, test_size=0.2):
    paths = [p for p, y, g in items]
    labels = [y for p, y, g in items]

    train_idx, val_idx = train_test_split(
        np.arange(len(items)),
        test_size=test_size,
        random_state=SEED,
        stratify=labels
    )
    train_items = [items[i] for i in train_idx]
    val_items = [items[i] for i in val_idx]
    return train_items, val_items


In [10]:
class GTZANDataset(Dataset):
    def __init__(self, items, augment=False):
        self.items = items
        self.augment = augment

        # (path, label, offset_sample)
        self.segments = []
        for path, y, g in items:
            # 30 сек превращаем в 10 сегментов по 3 секунды
            num_seg = 10
            for k in range(num_seg):
                offset = k * SEG_SAMPLES
                self.segments.append((path, y, offset))

    def __len__(self):
        return len(self.segments)

    def _load_segment(self, path, offset):
        wav, _ = librosa.load(path, sr=SR, mono=True)
        if len(wav) < SEG_SAMPLES:
          wav = np.pad(wav, (0, SEG_SAMPLES - len(wav)))

        wav = wav[offset:offset + SEG_SAMPLES]
        # Мы решили, что бесполезно, т.к. уже выше удостоверились в 3 секундах
        # if len(wav) < SEG_SAMPLES:
        #     wav = np.pad(wav, (0, SEG_SAMPLES - len(wav)))
        return wav

    def _wav_to_mel(self, wav):
        mel = librosa.feature.melspectrogram(
            y=wav, sr=SR, n_fft=N_FFT, hop_length=HOP, n_mels=N_MELS, power=2.0
        )
        mel_db = librosa.power_to_db(mel, ref=np.max)
        # нормализация
        mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + 1e-6)
        return mel_db.astype(np.float32)

    def __getitem__(self, idx):
        path, y, offset = self.segments[idx]
        wav = self._load_segment(path, offset)

        # аугментация
        if self.augment:
            if random.random() < 0.3:
                wav = wav + 0.005 * np.random.randn(len(wav))
            if random.random() < 0.3:
                rate = random.uniform(0.9, 1.1)
                wav = librosa.effects.time_stretch(wav, rate=rate)
                wav = librosa.util.fix_length(wav, size=SEG_SAMPLES)

        mel = self._wav_to_mel(wav)
        x = torch.from_numpy(mel).unsqueeze(0)  # [1, n_mels, time]
        y = torch.tensor(y, dtype=torch.long)
        return x, y


In [11]:
class TemporalAttention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.proj = nn.Linear(hidden_size, 1)

    def forward(self, h): # h: [B, T, H]
        scores = self.proj(h).squeeze(-1) # [B, T]
        alpha = torch.softmax(scores, dim=1) # [B, T]
        context = (h * alpha.unsqueeze(-1)).sum(dim=1) # [B, H]
        return context, alpha

In [12]:
class CNNRNNClassifier(nn.Module):
    def __init__(self, n_classes=10, use_attention=False):
        super().__init__()
        self.use_attention = use_attention

        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.MaxPool2d((2,2)),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d((2,2)),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d((2,2)),
        )

        # вычисляем размер фич для RNN автоматически
        with torch.no_grad():
            dummy = torch.zeros(1, 1, N_MELS, math.ceil(SEG_SAMPLES / HOP) + 1)
            z = self.cnn(dummy)
            _, C, Fp, Tp = z.shape
            rnn_in = C * Fp

        self.rnn = nn.GRU(
            input_size=rnn_in,
            hidden_size=128,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )

        rnn_out = 256

        self.dropout = nn.Dropout(0.3)

        if use_attention:
            self.attn = TemporalAttention(rnn_out)
            self.fc = nn.Linear(rnn_out, n_classes)
        else:
            self.fc = nn.Linear(rnn_out, n_classes)

    def forward(self, x):
        z = self.cnn(x) # [B, C, F', T']
        B, C, Fp, Tp = z.shape
        z = z.permute(0, 3, 1, 2).contiguous().view(B, Tp, C*Fp) # [B, T', C*F']
        h, _ = self.rnn(z) # [B, T', 256]

        if self.use_attention:
            ctx, alpha = self.attn(h)
            return self.fc(self.dropout(ctx)), alpha
        else:
            last = h[:, -1, :]
            return self.fc(self.dropout(last)), None


In [13]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss, total_correct, total = 0.0, 0, 0
    for x, y in tqdm(loader, leave=False):
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()

        logits, _ = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == y).sum().item()
        total += x.size(0)

    return total_loss/total, total_correct/total

@torch.no_grad()
def eval_model(model, loader, criterion):
    model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []
    for x, y in tqdm(loader, leave=False):
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits, _ = model(x)
        loss = criterion(logits, y)

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == y).sum().item()
        total += x.size(0)

        all_preds.append(preds.cpu().numpy())
        all_labels.append(y.cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    return total_loss/total, total_correct/total, all_preds, all_labels


In [14]:
def run_experiment(train_items, val_items, use_attention=False, epochs=12, batch_size=32, lr=1e-3):
    train_ds = GTZANDataset(train_items, augment=True)
    val_ds = GTZANDataset(val_items, augment=False)

    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    model = CNNRNNClassifier(n_classes=10, use_attention=use_attention).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    best_val_acc = 0.0
    best_state = None

    for ep in range(1, epochs+1):
        tr_loss, tr_acc = train_one_epoch(model, train_dl, optimizer, criterion)
        va_loss, va_acc, _, _ = eval_model(model, val_dl, criterion)

        if va_acc > best_val_acc:
            best_val_acc = va_acc
            best_state = {k: v.detach().cpu().clone() for k,v in model.state_dict().items()}

        print(f"epoch {ep:02d} | train loss {tr_loss:.4f} acc {tr_acc:.4f} | val loss {va_loss:.4f} acc {va_acc:.4f}")

    model.load_state_dict(best_state)
    va_loss, va_acc, preds, labels = eval_model(model, val_dl, criterion)
    print("\nVAL ACC:", va_acc)
    print(classification_report(labels, preds, digits=4))
    return model, va_acc


In [15]:
items, g2i = list_gztan_files(ROOT_OF_GENRES)
train_items, val_items = split_by_track(items, test_size=0.2)

print("tracks:", len(items), "train:", len(train_items), "val:", len(val_items))
print("genres:", list(g2i.keys()))

model_noattn, acc_noattn = run_experiment(train_items, val_items, use_attention=False, epochs=12)
model_attn, acc_attn = run_experiment(train_items, val_items, use_attention=True, epochs=12)

print("\n=== RESULT ===")
print("no attention:", acc_noattn)
print("with attention:", acc_attn)

[WARN] skipped bad audio files: 1
/kaggle/input/gtzan-dataset-music-genre-classification/Data/genres_original/jazz/jazz.00054.wav
tracks: 999 train: 799 val: 200
genres: ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']




epoch 01 | train loss 1.7015 acc 0.3892 | val loss 1.3922 acc 0.5215




epoch 02 | train loss 1.2751 acc 0.5411 | val loss 1.1960 acc 0.5795




epoch 03 | train loss 1.0904 acc 0.6223 | val loss 1.1056 acc 0.6235




epoch 04 | train loss 0.9673 acc 0.6712 | val loss 1.0372 acc 0.6420




epoch 05 | train loss 0.8501 acc 0.7140 | val loss 0.9031 acc 0.6885




epoch 06 | train loss 0.7428 acc 0.7489 | val loss 0.8736 acc 0.6980




epoch 07 | train loss 0.6525 acc 0.7809 | val loss 0.8314 acc 0.7225




epoch 08 | train loss 0.5980 acc 0.8043 | val loss 0.9571 acc 0.6680




epoch 09 | train loss 0.5227 acc 0.8260 | val loss 0.7910 acc 0.7485




epoch 10 | train loss 0.4602 acc 0.8484 | val loss 0.7773 acc 0.7425




epoch 11 | train loss 0.4209 acc 0.8637 | val loss 0.7727 acc 0.7550




epoch 12 | train loss 0.3762 acc 0.8776 | val loss 0.8033 acc 0.7470





VAL ACC: 0.755
              precision    recall  f1-score   support

           0     0.7304    0.8400    0.7814       200
           1     0.8981    0.9700    0.9327       200
           2     0.7135    0.6350    0.6720       200
           3     0.6422    0.7000    0.6699       200
           4     0.7745    0.7900    0.7822       200
           5     0.9176    0.8350    0.8743       200
           6     0.9202    0.8650    0.8918       200
           7     0.6250    0.8500    0.7203       200
           8     0.6597    0.7850    0.7169       200
           9     0.7568    0.2800    0.4088       200

    accuracy                         0.7550      2000
   macro avg     0.7638    0.7550    0.7450      2000
weighted avg     0.7638    0.7550    0.7450      2000





epoch 01 | train loss 1.3919 acc 0.5018 | val loss 1.0406 acc 0.6210




epoch 02 | train loss 0.9258 acc 0.6822 | val loss 0.8216 acc 0.7245




epoch 03 | train loss 0.7336 acc 0.7483 | val loss 0.8284 acc 0.7375




epoch 04 | train loss 0.5974 acc 0.8039 | val loss 0.8828 acc 0.7325




epoch 05 | train loss 0.5089 acc 0.8309 | val loss 0.7332 acc 0.7835




epoch 06 | train loss 0.4469 acc 0.8511 | val loss 0.7785 acc 0.7475




epoch 07 | train loss 0.3822 acc 0.8768 | val loss 0.7226 acc 0.7740




epoch 08 | train loss 0.3332 acc 0.8887 | val loss 0.8062 acc 0.7650




epoch 09 | train loss 0.3122 acc 0.8979 | val loss 1.0630 acc 0.6880




epoch 10 | train loss 0.2675 acc 0.9101 | val loss 0.8274 acc 0.7655




epoch 11 | train loss 0.2479 acc 0.9175 | val loss 0.8670 acc 0.7690




epoch 12 | train loss 0.2002 acc 0.9339 | val loss 0.7900 acc 0.7940


                                               


VAL ACC: 0.794
              precision    recall  f1-score   support

           0     0.9040    0.8950    0.8995       200
           1     0.9037    0.9850    0.9426       200
           2     0.6295    0.8750    0.7322       200
           3     0.9036    0.3750    0.5300       200
           4     0.7739    0.8900    0.8279       200
           5     0.8428    0.9650    0.8998       200
           6     0.9827    0.8500    0.9115       200
           7     0.7048    0.8000    0.7494       200
           8     0.7574    0.7650    0.7612       200
           9     0.6667    0.5400    0.5967       200

    accuracy                         0.7940      2000
   macro avg     0.8069    0.7940    0.7851      2000
weighted avg     0.8069    0.7940    0.7851      2000


=== RESULT ===
no attention: 0.755
with attention: 0.794


