In [None]:
import os
import torch
import zipfile
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torchaudio import datasets, transforms, load

In [None]:
import kagglehub

path = kagglehub.dataset_download("andradaolteanu/gtzan-dataset-music-genre-classification")
print("Path to dataset files:", path)

Using Colab cache for faster access to the 'gtzan-dataset-music-genre-classification' dataset.
Path to dataset files: /kaggle/input/gtzan-dataset-music-genre-classification


In [None]:
root_path = '/kaggle/input/gtzan-dataset-music-genre-classification/Data/genres_original'

In [None]:
genres = sorted(os.listdir(root_path))
genres

['blues',
 'classical',
 'country',
 'disco',
 'hiphop',
 'jazz',
 'metal',
 'pop',
 'reggae',
 'rock']

In [None]:
len(genres)

10

In [None]:
label_to_index = {label:ind for ind, label in enumerate(genres)}
label_to_index

{'blues': 0,
 'classical': 1,
 'country': 2,
 'disco': 3,
 'hiphop': 4,
 'jazz': 5,
 'metal': 6,
 'pop': 7,
 'reggae': 8,
 'rock': 9}

In [None]:
transform = transforms.MelSpectrogram(
     sample_rate = 22050,
     n_mels = 64
 )

In [None]:
max_len = 500

In [None]:
class TrainTestSplitter(Dataset):
    def __init__(self, root_path, transform, max_len):
        self.root_path = root_path
        self.transform = transform
        self.max_len = max_len
        self.audios = []

        for genre in os.listdir(root_path):
            genre_path = os.path.join(root_path, genre)
            for audio in os.listdir(genre_path):
                if audio.endswith('.wav'):
                    audio_path = os.path.join(genre_path, audio)
                    try:
                        info(audio_path)
                        self.audios.append((audio_path, genre))
                    except Exception as e:
                        print(f'Ошибка при чтении файла ".wav": {e}')

    def __len__(self):
        return len(self.audios)

    def __getitem__(self, index):
        audio_path, genre = self.audios[index]
        waveform, sample_rate = load(audio_path)

        if sample_rate != 22050:
            resample = transforms.Resample(orig_freq=sample_rate, new_freq=22050)
            waveform = resample(waveform)

        input_spectrogram = self.transform(waveform).squeeze(0)

        if input_spectrogram.shape[1] > self.max_len:
            input_spectrogram = input_spectrogram[:, :self.max_len]
        elif input_spectrogram.shape[1] < self.max_len:
            pad_len = self.max_len - input_spectrogram.shape[1]
            input_spectrogram = F.pad(input_spectrogram, (0, pad_len))

        return input_spectrogram, label_to_index[genre]

In [None]:
dataset = TrainTestSplitter(root_path, transform, max_len)
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size

train_data, test_data = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))

  info(audio_path)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  return AudioMetaData(


Ошибка при чтении файла ".wav": Failed to open the input "/kaggle/input/gtzan-dataset-music-genre-classification/Data/genres_original/jazz/jazz.00054.wav" (Invalid data found when processing input).


In [None]:
train = DataLoader(train_data, batch_size=32, shuffle=True)
test = DataLoader(test_data, batch_size=32)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
class CheckAudio(nn.Module):
    def __init__(self):
        super().__init__()
        self.first = nn.Sequential(  # Input: (1, H, W)
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),          # Output: (16, H/2, W/2)
            nn.Conv2d(16, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((8, 8))
        )
        self.second = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, audio):
        audio = audio.unsqueeze(1)
        audio = self.first(audio)
        audio = self.second(audio)
        return audio

In [None]:
model = CheckAudio().to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
for epoch in range(50):
    model.train()
    total_loss = 0.0

    for x_batch, y_batch in train:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        y_pred = model(x_batch)
        loss = loss_fn(y_pred, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f'Эпоха {epoch + 1}, Потери: {total_loss:.2f}')

Эпоха 1, Потери: 1.73
Эпоха 2, Потери: 2.15
Эпоха 3, Потери: 7.25
Эпоха 4, Потери: 9.46
Эпоха 5, Потери: 5.57
Эпоха 6, Потери: 3.59
Эпоха 7, Потери: 2.57
Эпоха 8, Потери: 1.60
Эпоха 9, Потери: 1.53
Эпоха 10, Потери: 1.13
Эпоха 11, Потери: 1.09
Эпоха 12, Потери: 0.80
Эпоха 13, Потери: 0.87
Эпоха 14, Потери: 0.91
Эпоха 15, Потери: 0.74
Эпоха 16, Потери: 0.96
Эпоха 17, Потери: 0.68
Эпоха 18, Потери: 0.67
Эпоха 19, Потери: 0.66
Эпоха 20, Потери: 0.61
Эпоха 21, Потери: 0.56
Эпоха 22, Потери: 0.55
Эпоха 23, Потери: 0.67
Эпоха 24, Потери: 0.41
Эпоха 25, Потери: 0.50
Эпоха 26, Потери: 0.46
Эпоха 27, Потери: 0.65
Эпоха 28, Потери: 0.49
Эпоха 29, Потери: 0.67
Эпоха 30, Потери: 0.66
Эпоха 31, Потери: 1.20
Эпоха 32, Потери: 2.10
Эпоха 33, Потери: 6.16
Эпоха 34, Потери: 7.16
Эпоха 35, Потери: 4.01
Эпоха 36, Потери: 2.50
Эпоха 37, Потери: 1.46
Эпоха 38, Потери: 0.78
Эпоха 39, Потери: 0.51
Эпоха 40, Потери: 0.37
Эпоха 41, Потери: 0.44
Эпоха 42, Потери: 0.23
Эпоха 43, Потери: 0.24
Эпоха 44, Потери: 0.

In [None]:
model.eval()
correct, total = 0, 0

with torch.no_grad():
    for x_batch, y_batch in test:  # предполагается, что test — это DataLoader
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        y_pred = model(x_batch)
        pred = torch.argmax(y_pred, dim=1)

        correct += (pred == y_batch).sum().item()
        total += y_batch.size(0)

accuracy = correct * 100 / total
print(f'Точность предположения модели: {accuracy:.2f}%')

Точность предположения модели: 57.50%


In [None]:
torch.save(model.state_dict(), 'audio_model.pth')
torch.save(genres, 'label.pth')

In [None]:
from google.colab import files
files.download('audio_model.pth')
files.download('label.pth')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>