In [None]:
pip install --upgrade kagglehub



In [None]:
import kagglehub

path = kagglehub.dataset_download("andradaolteanu/gtzan-dataset-music-genre-classification")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'gtzan-dataset-music-genre-classification' dataset.
Path to dataset files: /kaggle/input/gtzan-dataset-music-genre-classification


In [None]:
import os, random, math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import librosa
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

In [None]:
import soundfile as sf

In [None]:
sounds_df = pd.read_csv('/kaggle/input/gtzan-dataset-music-genre-classification/Data/features_3_sec.csv')
sounds_df

Unnamed: 0,filename,length,chroma_stft_mean,chroma_stft_var,rms_mean,rms_var,spectral_centroid_mean,spectral_centroid_var,spectral_bandwidth_mean,spectral_bandwidth_var,...,mfcc16_var,mfcc17_mean,mfcc17_var,mfcc18_mean,mfcc18_var,mfcc19_mean,mfcc19_var,mfcc20_mean,mfcc20_var,label
0,blues.00000.0.wav,66149,0.335406,0.091048,0.130405,0.003521,1773.065032,167541.630869,1972.744388,117335.771563,...,39.687145,-3.241280,36.488243,0.722209,38.099152,-5.050335,33.618073,-0.243027,43.771767,blues
1,blues.00000.1.wav,66149,0.343065,0.086147,0.112699,0.001450,1816.693777,90525.690866,2010.051501,65671.875673,...,64.748276,-6.055294,40.677654,0.159015,51.264091,-2.837699,97.030830,5.784063,59.943081,blues
2,blues.00000.2.wav,66149,0.346815,0.092243,0.132003,0.004620,1788.539719,111407.437613,2084.565132,75124.921716,...,67.336563,-1.768610,28.348579,2.378768,45.717648,-1.938424,53.050835,2.517375,33.105122,blues
3,blues.00000.3.wav,66149,0.363639,0.086856,0.132565,0.002448,1655.289045,111952.284517,1960.039988,82913.639269,...,47.739452,-3.841155,28.337118,1.218588,34.770935,-3.580352,50.836224,3.630866,32.023678,blues
4,blues.00000.4.wav,66149,0.335579,0.088129,0.143289,0.001701,1630.656199,79667.267654,1948.503884,60204.020268,...,30.336359,0.664582,45.880913,1.689446,51.363583,-3.392489,26.738789,0.536961,29.146694,blues
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9985,rock.00099.5.wav,66149,0.349126,0.080515,0.050019,0.000097,1499.083005,164266.886443,1718.707215,85931.574523,...,42.485981,-9.094270,38.326839,-4.246976,31.049839,-5.625813,48.804092,1.818823,38.966969,rock
9986,rock.00099.6.wav,66149,0.372564,0.082626,0.057897,0.000088,1847.965128,281054.935973,1906.468492,99727.037054,...,32.415203,-12.375726,66.418587,-3.081278,54.414265,-11.960546,63.452255,0.428857,18.697033,rock
9987,rock.00099.7.wav,66149,0.347481,0.089019,0.052403,0.000701,1346.157659,662956.246325,1561.859087,138762.841945,...,78.228149,-2.524483,21.778994,4.809936,25.980829,1.775686,48.582378,-0.299545,41.586990,rock
9988,rock.00099.8.wav,66149,0.387527,0.084815,0.066430,0.000320,2084.515327,203891.039161,2018.366254,22860.992562,...,28.323744,-5.363541,17.209942,6.462601,21.442928,2.354765,24.843613,0.675824,12.787750,rock


In [None]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

SR = 22050

N_MELS = 128
N_FFT = 2048
HOP = 512

SEG_SEC = 3.0
SEG_SAMPLES = int(SR * SEG_SEC)

ROOT_OF_GENRES = "/kaggle/input/gtzan-dataset-music-genre-classification/Data/genres_original"

In [None]:
# Индексация датасета (аудиофайлы + метки)
def list_gztan_files(root_dir):
    items = [] # (path, label_id, label_name),
    bad = []
    genres = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
    g2i = {g:i for i,g in enumerate(genres)}

    for g in genres:
        folder = os.path.join(root_dir, g)
        for fn in sorted(os.listdir(folder)):
            if not fn.lower().endswith(".wav"):
                continue

            path = os.path.join(folder, fn)

            try:
                sf.info(path)
                items.append((path, g2i[g], g))
            except Exception:
                bad.append(path)

    if bad:
        print(f"[WARN] skipped bad audio files: {len(bad)}")
        print("\n".join(bad[:10]))

    return items, g2i

In [None]:
#Split по трекам
def split_by_track(items, test_size=0.2):
    paths = [p for p, y, g in items]
    labels = [y for p, y, g in items]

    train_idx, val_idx = train_test_split(
        np.arange(len(items)),
        test_size=test_size,
        random_state=SEED,
        stratify=labels
    )
    train_items = [items[i] for i in train_idx]
    val_items = [items[i] for i in val_idx]
    return train_items, val_items


In [None]:
class GTZANDataset(Dataset):
    def __init__(self, items, augment=False):
        self.items = items
        self.augment = augment

        self.target_frames = 1 + SEG_SAMPLES // HOP  # ✅ фиксируем ширину

        self.segments = []
        for path, y, g in items:
            for k in range(10):
                offset = k * SEG_SAMPLES
                self.segments.append((path, y, offset))

    def _wav_to_mel(self, wav):
        mel = librosa.feature.melspectrogram(
            y=wav, sr=SR, n_fft=N_FFT, hop_length=HOP, n_mels=N_MELS, power=2.0
        )
        mel_db = librosa.power_to_db(mel, ref=np.max)

        mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + 1e-6)

        mel_db = librosa.util.fix_length(mel_db, size=self.target_frames, axis=1)

        if self.augment:
            mel_db = self._spec_augment(mel_db)

        return mel_db.astype(np.float32)

    def __len__(self):
        return len(self.segments)

    def _load_segment(self, path, offset):
        wav, _ = librosa.load(path, sr=SR, mono=True)
        if len(wav) < SEG_SAMPLES:
          wav = np.pad(wav, (0, SEG_SAMPLES - len(wav)))

        wav = wav[offset:offset + SEG_SAMPLES]
        # Мы решили, что бесполезно, т.к. уже выше удостоверились в 3 секундах
        # if len(wav) < SEG_SAMPLES:
        #     wav = np.pad(wav, (0, SEG_SAMPLES - len(wav)))
        return wav

    def _spec_augment(
        self,
        mel_db: np.ndarray,
        time_mask_param: int = 24,
        freq_mask_param: int = 8,
        num_time_masks: int = 2,
        num_freq_masks: int = 2
    ) -> np.ndarray:
        """
        mel_db: [n_mels, time]
        """
        m = mel_db.copy()
        n_mels, t = m.shape

        # frequency masking
        for _ in range(num_freq_masks):
            f = random.randint(0, freq_mask_param)
            if f == 0:
                continue
            f0 = random.randint(0, max(0, n_mels - f))
            m[f0:f0 + f, :] = 0.0

        # time masking
        for _ in range(num_time_masks):
            tt = random.randint(0, time_mask_param)
            if tt == 0:
                continue
            t0 = random.randint(0, max(0, t - tt))
            m[:, t0:t0 + tt] = 0.0

        return m

    # def _wav_to_mel(self, wav):
    #     mel = librosa.feature.melspectrogram(
    #         y=wav, sr=SR, n_fft=N_FFT, hop_length=HOP, n_mels=N_MELS, power=2.0
    #     )
    #     mel_db = librosa.power_to_db(mel, ref=np.max)
    #     # нормализация
    #     mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + 1e-6)

    #     # SpecAugment на спектрограмме (работает быстрее и стабильнее, чем «умные» аугментации на волне)
    #     # Делаем только на train (augment=True)
    #     if self.augment:
    #         mel_db = self._spec_augment(mel_db)

    #     return mel_db.astype(np.float32)

    def __getitem__(self, idx):
        path, y, offset = self.segments[idx]
        wav = self._load_segment(path, offset)

        # аугментация
        if self.augment:
            if random.random() < 0.3:
                wav = wav + 0.005 * np.random.randn(len(wav))
            if random.random() < 0.3:
                rate = random.uniform(0.9, 1.1)
                wav = librosa.effects.time_stretch(wav, rate=rate)
                wav = librosa.util.fix_length(wav, size=SEG_SAMPLES)

        mel = self._wav_to_mel(wav)
        x = torch.from_numpy(mel).unsqueeze(0)  # [1, n_mels, time]
        y = torch.tensor(y, dtype=torch.long)
        return x, y


### Модели
**только CNN**, **CNN+RNN**, **CNN+RNN+Attention**

In [None]:
class CNNBackbone(nn.Module):
    """Общий CNN-блок для всех подходов."""
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.MaxPool2d((2,2)),

            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d((2,2)),

            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d((2,2)),

            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x) # [B, C, F', T']


class CNNClassifier(nn.Module):
    """Только CNN:"""
    def __init__(self, n_classes=10, dropout=0.3):
        super().__init__()
        self.cnn = CNNBackbone()

        # mean+max pooling
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(256 * 2, n_classes)

    def forward(self, x):
        z = self.cnn(x) # [B, 256, F', T']
        mean = z.mean(dim=(2,3)) # [B, 256]
        mx = z.amax(dim=(2,3)) # [B, 256]
        feats = torch.cat([mean, mx], dim=1) # [B, 512]
        return self.fc(self.dropout(feats)), None


class TemporalAttention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.proj = nn.Linear(hidden_size, 1)

    def forward(self, h): # h: [B, T, H]
        scores = self.proj(h).squeeze(-1) # [B, T]
        alpha = torch.softmax(scores, dim=1) # [B, T]
        context = (h * alpha.unsqueeze(-1)).sum(dim=1) # [B, H]
        return context, alpha


class CNNRNNClassifier(nn.Module):
    """CNN -> последовательность по времени -> BiGRU -> (pool/attention) -> классификатор."""
    def __init__(self, n_classes=10, use_attention=False, rnn_hidden=160, rnn_layers=2, dropout=0.3):
        super().__init__()
        self.use_attention = use_attention

        self.cnn = CNNBackbone()

        with torch.no_grad():
            dummy = torch.zeros(1, 1, N_MELS, math.ceil(SEG_SAMPLES / HOP) + 1)
            z = self.cnn(dummy)
            _, C, Fp, Tp = z.shape
            rnn_in = C * Fp

        self.rnn = nn.GRU(
            input_size=rnn_in,
            hidden_size=rnn_hidden,
            num_layers=rnn_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.2 if rnn_layers > 1 else 0.0
        )

        rnn_out = 2 * rnn_hidden
        self.dropout = nn.Dropout(dropout)

        if use_attention:
            self.attn = TemporalAttention(rnn_out)
            self.fc = nn.Linear(rnn_out, n_classes)
        else:
            self.fc = nn.Linear(rnn_out * 2, n_classes)

    def forward(self, x):
        z = self.cnn(x) # [B, C, F', T']
        B, C, Fp, Tp = z.shape
        z = z.permute(0, 3, 1, 2).contiguous().view(B, Tp, C*Fp) # [B, T', C*F']
        h, _ = self.rnn(z) # [B, T', 2H]

        if self.use_attention:
            ctx, alpha = self.attn(h)
            return self.fc(self.dropout(ctx)), alpha
        else:
            h_mean = h.mean(dim=1)
            h_max = h.amax(dim=1)
            feats = torch.cat([h_mean, h_max], dim=1)
            return self.fc(self.dropout(feats)), None


In [None]:
def train_one_epoch(model, loader, optimizer, criterion, max_grad_norm=1.0):
    model.train()
    total_loss, total_correct, total = 0.0, 0, 0

    for x, y in tqdm(loader, leave=False):
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad(set_to_none=True)

        logits, _ = model(x)
        loss = criterion(logits, y)
        loss.backward()

        if max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == y).sum().item()
        total += x.size(0)

    return total_loss/total, total_correct/total


@torch.no_grad()
def eval_model(model, loader, criterion):
    model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []

    for x, y in tqdm(loader, leave=False):
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits, _ = model(x)
        loss = criterion(logits, y)

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == y).sum().item()
        total += x.size(0)

        all_preds.append(preds.cpu().numpy())
        all_labels.append(y.cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    return total_loss/total, total_correct/total, all_preds, all_labels


In [None]:
def run_experiment(train_items, val_items, model_kind="cnn", epochs=18, batch_size=64, lr=3e-4, weight_decay=1e-4):
    train_ds = GTZANDataset(train_items, augment=True)
    val_ds = GTZANDataset(val_items, augment=False)

    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, persistent_workers=True, prefetch_factor=2, pin_memory=True)
    val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, persistent_workers=True, prefetch_factor=2, pin_memory=True)

    if model_kind == "cnn":
        model = CNNClassifier(n_classes=10).to(DEVICE)
    elif model_kind == "cnn_rnn":
        model = CNNRNNClassifier(n_classes=10, use_attention=False).to(DEVICE)
    elif model_kind == "cnn_rnn_attn":
        model = CNNRNNClassifier(n_classes=10, use_attention=True).to(DEVICE)
    else:
        raise ValueError(f"Unknown model_kind: {model_kind}")

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.05)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_val_acc = 0.0
    best_state = None

    for ep in range(1, epochs+1):
        tr_loss, tr_acc = train_one_epoch(model, train_dl, optimizer, criterion, max_grad_norm=1.0)
        va_loss, va_acc, _, _ = eval_model(model, val_dl, criterion)
        scheduler.step()

        if va_acc > best_val_acc:
            best_val_acc = va_acc
            best_state = {k: v.detach().cpu().clone() for k,v in model.state_dict().items()}

        print(f"epoch {ep:02d} | train loss {tr_loss:.4f} acc {tr_acc:.4f} | val loss {va_loss:.4f} acc {va_acc:.4f} | lr {scheduler.get_last_lr()[0]:.2e}")

    model.load_state_dict(best_state)
    va_loss, va_acc, preds, labels = eval_model(model, val_dl, criterion)
    print("\nVAL ACC:", va_acc)
    print(classification_report(labels, preds, digits=4))
    return model, va_acc


In [None]:
items, g2i = list_gztan_files(ROOT_OF_GENRES)
train_items, val_items = split_by_track(items, test_size=0.2)

[WARN] skipped bad audio files: 1
/kaggle/input/gtzan-dataset-music-genre-classification/Data/genres_original/jazz/jazz.00054.wav


In [None]:
print("tracks:", len(items), "train:", len(train_items), "val:", len(val_items))
print("genres:", list(g2i.keys()))

# Только CNN
model_cnn, acc_cnn = run_experiment(train_items, val_items, model_kind="cnn", epochs=18)

# CNN + RNN
model_cnn_rnn, acc_cnn_rnn = run_experiment(train_items, val_items, model_kind="cnn_rnn", epochs=18)


[WARN] skipped bad audio files: 1
/kaggle/input/gtzan-dataset-music-genre-classification/Data/genres_original/jazz/jazz.00054.wav
tracks: 999 train: 799 val: 200
genres: ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']




epoch 01 | train loss 1.9198 acc 0.3212 | val loss 1.3894 acc 0.5450 | lr 2.98e-04




epoch 02 | train loss 1.5246 acc 0.4879 | val loss 1.3965 acc 0.5200 | lr 2.91e-04




epoch 03 | train loss 1.3545 acc 0.5631 | val loss 1.4737 acc 0.5590 | lr 2.80e-04




epoch 04 | train loss 1.2494 acc 0.6176 | val loss 1.1689 acc 0.6950 | lr 2.65e-04




epoch 05 | train loss 1.1773 acc 0.6537 | val loss 1.1612 acc 0.6845 | lr 2.46e-04




epoch 06 | train loss 1.1288 acc 0.6703 | val loss 1.1604 acc 0.6895 | lr 2.25e-04




epoch 07 | train loss 1.0777 acc 0.6889 | val loss 1.0688 acc 0.7140 | lr 2.01e-04




epoch 08 | train loss 1.0288 acc 0.7220 | val loss 1.0867 acc 0.7380 | lr 1.76e-04




epoch 09 | train loss 1.0087 acc 0.7300 | val loss 1.0369 acc 0.7500 | lr 1.50e-04




epoch 10 | train loss 0.9859 acc 0.7335 | val loss 0.9911 acc 0.7765 | lr 1.24e-04




epoch 11 | train loss 0.9482 acc 0.7557 | val loss 0.9927 acc 0.7745 | lr 9.87e-05




epoch 12 | train loss 0.9367 acc 0.7587 | val loss 0.9739 acc 0.7690 | lr 7.50e-05




epoch 13 | train loss 0.9198 acc 0.7623 | val loss 1.0391 acc 0.7545 | lr 5.36e-05




epoch 14 | train loss 0.9070 acc 0.7705 | val loss 1.0024 acc 0.7730 | lr 3.51e-05




epoch 15 | train loss 0.8800 acc 0.7846 | val loss 0.9524 acc 0.7870 | lr 2.01e-05




epoch 16 | train loss 0.8719 acc 0.7911 | val loss 0.9258 acc 0.7940 | lr 9.05e-06




epoch 17 | train loss 0.8795 acc 0.7850 | val loss 0.9322 acc 0.7970 | lr 2.28e-06




epoch 18 | train loss 0.8586 acc 0.7914 | val loss 0.9351 acc 0.7960 | lr 0.00e+00





VAL ACC: 0.797
              precision    recall  f1-score   support

           0     0.7979    0.7700    0.7837       200
           1     0.8899    0.9700    0.9282       200
           2     0.7293    0.8350    0.7786       200
           3     0.7746    0.6700    0.7185       200
           4     0.7797    0.9200    0.8440       200
           5     0.9064    0.9200    0.9132       200
           6     0.9153    0.8650    0.8895       200
           7     0.6943    0.7950    0.7413       200
           8     0.8103    0.7050    0.7540       200
           9     0.6667    0.5200    0.5843       200

    accuracy                         0.7970      2000
   macro avg     0.7964    0.7970    0.7935      2000
weighted avg     0.7964    0.7970    0.7935      2000





epoch 01 | train loss 1.5789 acc 0.4706 | val loss 1.2315 acc 0.6195 | lr 2.98e-04




epoch 02 | train loss 1.0982 acc 0.6852 | val loss 1.2039 acc 0.6440 | lr 2.91e-04




epoch 03 | train loss 0.9404 acc 0.7493 | val loss 1.0124 acc 0.7305 | lr 2.80e-04




epoch 04 | train loss 0.8111 acc 0.8014 | val loss 0.9658 acc 0.7585 | lr 2.65e-04




epoch 05 | train loss 0.7208 acc 0.8370 | val loss 1.0151 acc 0.7265 | lr 2.46e-04




epoch 06 | train loss 0.6479 acc 0.8708 | val loss 1.0615 acc 0.7310 | lr 2.25e-04




epoch 07 | train loss 0.5944 acc 0.8875 | val loss 0.9968 acc 0.7605 | lr 2.01e-04




epoch 08 | train loss 0.5466 acc 0.9075 | val loss 1.0043 acc 0.7675 | lr 1.76e-04




epoch 09 | train loss 0.4914 acc 0.9330 | val loss 0.8746 acc 0.8015 | lr 1.50e-04




epoch 10 | train loss 0.4602 acc 0.9419 | val loss 0.9423 acc 0.7890 | lr 1.24e-04




epoch 11 | train loss 0.4259 acc 0.9601 | val loss 0.9420 acc 0.7975 | lr 9.87e-05




epoch 12 | train loss 0.4075 acc 0.9651 | val loss 0.9112 acc 0.8005 | lr 7.50e-05




epoch 13 | train loss 0.3893 acc 0.9732 | val loss 0.9023 acc 0.8015 | lr 5.36e-05




epoch 14 | train loss 0.3699 acc 0.9793 | val loss 0.9743 acc 0.7915 | lr 3.51e-05




epoch 15 | train loss 0.3607 acc 0.9834 | val loss 0.9367 acc 0.7990 | lr 2.01e-05




epoch 16 | train loss 0.3540 acc 0.9854 | val loss 0.8974 acc 0.8010 | lr 9.05e-06




epoch 17 | train loss 0.3503 acc 0.9861 | val loss 0.8844 acc 0.8105 | lr 2.28e-06




epoch 18 | train loss 0.3465 acc 0.9881 | val loss 0.8837 acc 0.8105 | lr 0.00e+00





VAL ACC: 0.8105
              precision    recall  f1-score   support

           0     0.8964    0.8650    0.8804       200
           1     0.8899    0.9700    0.9282       200
           2     0.6992    0.8250    0.7569       200
           3     0.7647    0.5850    0.6629       200
           4     0.8599    0.8900    0.8747       200
           5     0.8762    0.9200    0.8976       200
           6     0.9634    0.9200    0.9412       200
           7     0.6654    0.8450    0.7445       200
           8     0.9167    0.7150    0.8034       200
           9     0.6264    0.5700    0.5969       200

    accuracy                         0.8105      2000
   macro avg     0.8158    0.8105    0.8087      2000
weighted avg     0.8158    0.8105    0.8087      2000





epoch 01 | train loss 1.5518 acc 0.4791 | val loss 1.1737 acc 0.6430 | lr 2.98e-04




epoch 02 | train loss 1.0822 acc 0.6901 | val loss 1.1905 acc 0.6715 | lr 2.91e-04




epoch 03 | train loss 0.8982 acc 0.7665 | val loss 1.2056 acc 0.6765 | lr 2.80e-04




epoch 04 | train loss 0.7902 acc 0.8135 | val loss 1.1274 acc 0.6760 | lr 2.65e-04




epoch 05 | train loss 0.7114 acc 0.8423 | val loss 1.1632 acc 0.6920 | lr 2.46e-04




epoch 06 | train loss 0.6565 acc 0.8623 | val loss 0.9823 acc 0.7560 | lr 2.25e-04




epoch 07 | train loss 0.5821 acc 0.8917 | val loss 0.9314 acc 0.7755 | lr 2.01e-04


 71%|███████   | 89/125 [03:25<01:22,  2.30s/it]

In [None]:
print("tracks:", len(items), "train:", len(train_items), "val:", len(val_items))
print("genres:", list(g2i.keys()))

# CNN + RNN + Attention
model_attn, acc_attn = run_experiment(train_items, val_items, model_kind="cnn_rnn_attn", epochs=18)

tracks: 999 train: 799 val: 200
genres: ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']




epoch 01 | train loss 1.5115 acc 0.4991 | val loss 1.2255 acc 0.6305 | lr 2.98e-04




epoch 02 | train loss 1.1019 acc 0.6886 | val loss 1.3138 acc 0.6185 | lr 2.91e-04




epoch 03 | train loss 0.9236 acc 0.7557 | val loss 1.1695 acc 0.6805 | lr 2.80e-04




epoch 04 | train loss 0.8178 acc 0.8008 | val loss 1.2044 acc 0.6670 | lr 2.65e-04




epoch 05 | train loss 0.7298 acc 0.8317 | val loss 0.9907 acc 0.7575 | lr 2.46e-04




epoch 06 | train loss 0.6527 acc 0.8651 | val loss 1.0691 acc 0.7325 | lr 2.25e-04




epoch 07 | train loss 0.5839 acc 0.8945 | val loss 0.9472 acc 0.7715 | lr 2.01e-04




epoch 08 | train loss 0.5512 acc 0.9053 | val loss 0.9577 acc 0.7660 | lr 1.76e-04




epoch 09 | train loss 0.4901 acc 0.9310 | val loss 0.9939 acc 0.7795 | lr 1.50e-04




epoch 10 | train loss 0.4658 acc 0.9423 | val loss 0.9022 acc 0.7980 | lr 1.24e-04




epoch 11 | train loss 0.4301 acc 0.9562 | val loss 0.8898 acc 0.8050 | lr 9.87e-05




epoch 12 | train loss 0.4095 acc 0.9623 | val loss 0.9137 acc 0.8030 | lr 7.50e-05




epoch 13 | train loss 0.3903 acc 0.9722 | val loss 0.8835 acc 0.8185 | lr 5.36e-05




epoch 14 | train loss 0.3742 acc 0.9786 | val loss 0.9030 acc 0.8125 | lr 3.51e-05




epoch 15 | train loss 0.3623 acc 0.9825 | val loss 0.8807 acc 0.8150 | lr 2.01e-05




epoch 16 | train loss 0.3529 acc 0.9859 | val loss 0.8600 acc 0.8215 | lr 9.05e-06




epoch 17 | train loss 0.3458 acc 0.9880 | val loss 0.8579 acc 0.8160 | lr 2.28e-06




epoch 18 | train loss 0.3435 acc 0.9889 | val loss 0.8650 acc 0.8175 | lr 0.00e+00


                                               


VAL ACC: 0.8215
              precision    recall  f1-score   support

           0     0.9556    0.8600    0.9053       200
           1     0.8909    0.9800    0.9333       200
           2     0.7149    0.8400    0.7724       200
           3     0.7128    0.6700    0.6907       200
           4     0.8873    0.9050    0.8960       200
           5     0.9000    0.9000    0.9000       200
           6     0.9490    0.9300    0.9394       200
           7     0.6942    0.8400    0.7602       200
           8     0.8862    0.7400    0.8065       200
           9     0.6548    0.5500    0.5978       200

    accuracy                         0.8215      2000
   macro avg     0.8246    0.8215    0.8202      2000
weighted avg     0.8246    0.8215    0.8202      2000


=== SUMMARY ===




NameError: name 'acc_cnn' is not defined

In [None]:
print("\n=== SUMMARY ===")
print("CNN only: ", acc_cnn)
print("CNN + RNN: ", acc_cnn_rnn)
print("CNN + RNN + Attention:", acc_attn)


=== SUMMARY ===
CNN + RNN + Attention: 0.8215


## Когда лучше CNN / CNN+RNN / CNN+RNN+Attention (по смыслу звука)

**Только CNN (по спектрограмме)** лучше всего, когда класс определяется в основном **локальными «рисунками»** на спектрограмме и порядок событий внутри окна не критичен:
- короткие команды/слова (keyword spotting), отдельные звуки/удары, тип «тембра» инструмента;
- жанр/сцена в аудио, где достаточно статистики текстур (гармоники, шум, «зерно» спектра);
- когда нужен **максимальный speed** и простая модель (реал‑тайм, мобилка).

**CNN + RNN** полезно, когда важна **последовательность** и длительные зависимости:
- речь с более длинными фразами, фонемная/слоговая динамика, интонация;
- события типа «A потом B» (например, шаги → дверь → тишина);
- данные, где один и тот же «рисунок» может быть в разном порядке, и это меняет класс.

**CNN + RNN + Attention** выигрывает, когда сигнал содержит **много «мусора»/тишины** и нужно *выделять* информативные моменты:
- длинные записи, где полезное событие редкое и короткое (звонок, сирена, кашель);
- ситуации с неоднородностью: начало/конец записи менее информативны, важны отдельные фрагменты;
- когда в классе есть несколько под‑паттернов, и модель должна «переключаться» между ними.

In [None]:
# === CONFUSION MATRIX ===
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

def plot_confusion(y_true, y_pred, title):
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot()
    plt.title(title)
    plt.show()

# plot_confusion(test_labels_cnn, test_preds_cnn, 'CNN Confusion Matrix')
# plot_confusion(test_labels_rnn, test_preds_rnn, 'CNN+RNN Confusion Matrix')
# plot_confusion(test_labels_att, test_preds_att, 'CNN+RNN+Attention Confusion Matrix')
