In [1]:
import librosa
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import models
from datasets import load_dataset
import os

In [23]:
# ====== Параметры ======
SR = 16000  # Частота дискретизации
BATCH_SIZE = 32
EPOCHS = 10
N_MELS = 128
MAX_TIME_FRAMES = 3000  # Максимальная длина по времени для padding
CHECKPOINT_PATH = "checkpoint_vggish.pth"  # Путь для сохранения чекпоинта

In [3]:
def compute_mel_spectrogram(audio, sr=SR, n_fft=1024, hop_length=512, n_mels=N_MELS):
    """Вычисляет Mel-спектрограмму аудиосигнала."""
    mel_spec = librosa.feature.melspectrogram(
        y=audio, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
    )
    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
    return mel_spec_db

# ====== 2. Функция для выравнивания спектрограммы ======
def pad_spectrogram(mel_spec, max_frames=MAX_TIME_FRAMES):
    """Добавляет padding к Mel-спектрограмме, чтобы выровнять ее длину."""
    if mel_spec.shape[1] < max_frames:
        pad_width = max_frames - mel_spec.shape[1]
        mel_spec = np.pad(mel_spec, ((0, 0), (0, pad_width)), mode='constant')
    else:
        mel_spec = mel_spec[:, :max_frames]
    return mel_spec

In [4]:
class AudioDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        # Создаем словарь для преобразования жанров в числовые метки
        self.genre_to_id = {genre: idx for idx, genre in enumerate(set(dataset["genre"]))}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        audio = self.dataset[idx]["audio"]["array"]
        genre = self.dataset[idx]["genre"]
        label = self.genre_to_id[genre]
        mel_spec = compute_mel_spectrogram(audio)
        mel_spec = pad_spectrogram(mel_spec)  # Добавляем padding
        mel_spec = np.expand_dims(mel_spec, axis=0)
        mel_spec = np.repeat(mel_spec, 3, axis=0)  # Повторяем канал 3 раза для совместимости с VGG16  # Добавляем канал
        return torch.tensor(mel_spec, dtype=torch.float32), torch.tensor(label, dtype=torch.long)


In [5]:
dataset = load_dataset("lewtun/music_genres", split="train")

README.md:   0%|          | 0.00/545 [00:00<?, ?B/s]

Downloading data:   0%|          | 0/16 [00:00<?, ?files/s]

(…)-00000-of-00016-6b5481c76a3d2702.parquet:   0%|          | 0.00/489M [00:00<?, ?B/s]

(…)-00001-of-00016-438bb9cb7b06002c.parquet:   0%|          | 0.00/488M [00:00<?, ?B/s]

(…)-00002-of-00016-c1a616564aeae4b0.parquet:   0%|          | 0.00/490M [00:00<?, ?B/s]

(…)-00003-of-00016-73a29e154975c452.parquet:   0%|          | 0.00/488M [00:00<?, ?B/s]

(…)-00004-of-00016-db37b9fc5526f40b.parquet:   0%|          | 0.00/489M [00:00<?, ?B/s]

(…)-00005-of-00016-93716f278704089e.parquet:   0%|          | 0.00/488M [00:00<?, ?B/s]

(…)-00006-of-00016-5d90eeed316ceb16.parquet:   0%|          | 0.00/487M [00:00<?, ?B/s]

(…)-00007-of-00016-92ae3797361c8db8.parquet:   0%|          | 0.00/492M [00:00<?, ?B/s]

(…)-00008-of-00016-26222f0024734427.parquet:   0%|          | 0.00/484M [00:00<?, ?B/s]

(…)-00009-of-00016-54aecc0dd7ee2005.parquet:   0%|          | 0.00/488M [00:00<?, ?B/s]

(…)-00010-of-00016-3828cc45b664a4c8.parquet:   0%|          | 0.00/486M [00:00<?, ?B/s]

(…)-00011-of-00016-6e3dc52ec46ea765.parquet:   0%|          | 0.00/490M [00:00<?, ?B/s]

(…)-00012-of-00016-610de7a10e23d537.parquet:   0%|          | 0.00/487M [00:00<?, ?B/s]

(…)-00013-of-00016-b0af0a9e4b167ba8.parquet:   0%|          | 0.00/488M [00:00<?, ?B/s]

(…)-00014-of-00016-f224b1ff8d7d444e.parquet:   0%|          | 0.00/486M [00:00<?, ?B/s]

(…)-00015-of-00016-f1c3cbe5e2cccd73.parquet:   0%|          | 0.00/488M [00:00<?, ?B/s]

(…)-00000-of-00004-2e0b9f634f1aec7a.parquet:   0%|          | 0.00/495M [00:00<?, ?B/s]

(…)-00001-of-00004-f199d70ca2b53305.parquet:   0%|          | 0.00/497M [00:00<?, ?B/s]

(…)-00002-of-00004-f34d5fd400a9f24b.parquet:   0%|          | 0.00/498M [00:00<?, ?B/s]

(…)-00003-of-00004-00213d3ef9894abe.parquet:   0%|          | 0.00/494M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/19909 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5076 [00:00<?, ? examples/s]

In [6]:
dataset

Dataset({
    features: ['audio', 'song_id', 'genre_id', 'genre'],
    num_rows: 19909
})

In [7]:
NUM_CLASSES = len(set(dataset['genre']))
NUM_CLASSES

19

In [30]:
class VGGishClassifier(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(VGGishClassifier, self).__init__()
        self.vggish = models.vgg16(pretrained=True)
        for param in self.vggish.features.parameters():
            param.requires_grad = False  # Замораживаем сверточные слои
        self.embeddings_layer = nn.Identity()  # Слой для извлечения эмбеддингов
        self.vggish.classifier[6] = nn.Linear(4096, num_classes)  # Заменяем последний слой для классификации жанров

    def forward(self, x, return_embeddings=False):
        x = self.vggish.features(x)
        x = self.vggish.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.vggish.classifier[:-1](x)  # Пропускаем через все слои, кроме последнего
        if return_embeddings:
            return x  # Возвращаем эмбеддинги
        x = self.vggish.classifier[-1](x)
        return x

In [14]:
audio_dataset = AudioDataset(dataset)

In [16]:
train_size = int(0.8 * len(audio_dataset))
test_size = len(audio_dataset) - train_size

In [17]:
train_size, test_size

(15927, 3982)

In [18]:
train_dataset, test_dataset = random_split(audio_dataset, [train_size, test_size])


In [24]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [31]:
model = VGGishClassifier(num_classes=NUM_CLASSES).to('cuda')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [32]:
# Загрузка чекпоинта, если он существует
if os.path.exists(CHECKPOINT_PATH):
    model.load_state_dict(torch.load(CHECKPOINT_PATH))
    print("Checkpoint loaded.")

  model.load_state_dict(torch.load(CHECKPOINT_PATH))


Checkpoint loaded.


In [33]:
model.train()
for epoch in range(EPOCHS):
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch [{epoch + 1}/{EPOCHS}], Loss: {running_loss / len(train_loader):.4f}")

    # Сохранение чекпоинта
    torch.save(model.state_dict(), CHECKPOINT_PATH)
    print(f"Checkpoint saved at epoch {epoch + 1}.")

Epoch [1/10], Loss: 2.4132
Checkpoint saved at epoch 1.


In [34]:
torch.save(model, "vggish_classifier_full.pth")