In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchaudio import datasets, transforms, info, load
from torch.utils.data import DataLoader, Dataset, random_split
import torch.nn.functional as F

In [None]:
import kagglehub

path = kagglehub.dataset_download("andradaolteanu/gtzan-dataset-music-genre-classification")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'gtzan-dataset-music-genre-classification' dataset.
Path to dataset files: /kaggle/input/gtzan-dataset-music-genre-classification


In [None]:
data_path = '/kaggle/input/gtzan-dataset-music-genre-classification/Data/genres_original'

In [None]:
import os
genres = sorted(os.listdir(data_path))
genres

['blues',
 'classical',
 'country',
 'disco',
 'hiphop',
 'jazz',
 'metal',
 'pop',
 'reggae',
 'rock']

In [None]:
len(genres)

10

In [None]:
label_to_index = {lab: ind for ind, lab in enumerate(genres)}
label_to_index

{'blues': 0,
 'classical': 1,
 'country': 2,
 'disco': 3,
 'hiphop': 4,
 'jazz': 5,
 'metal': 6,
 'pop': 7,
 'reggae': 8,
 'rock': 9}

In [None]:
transform = transforms.MelSpectrogram(
    sample_rate=22050,
    n_mels=64
)

In [None]:
max_len = 500

In [None]:
import os
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchaudio
from torchaudio import transforms

class GTZAN(Dataset):
    def __init__(self, root_path, transform, max_len):
        self.root_path = root_path
        self.transform = transform
        self.max_len = max_len
        self.audios = []

        self.genres = sorted(os.listdir(root_path))
        self.label_to_index = {genre: i for i, genre in enumerate(self.genres)}

        for genre in self.genres:
            genre_path = os.path.join(root_path, genre)
            for file in os.listdir(genre_path):
                if file.endswith('.wav'):
                    file_path = os.path.join(genre_path, file)
                    try:
                        torchaudio.info(file_path)
                        self.audios.append((file_path, genre))
                    except Exception as e:
                        print(f'Ошибка {e}')

    def __len__(self):
        return len(self.audios)

    def __getitem__(self, ind):
        file_path, genre = self.audios[ind]
        waveform, sr = torchaudio.load(file_path)

        if sr != 22050:
            resample = transforms.Resample(orig_freq=sr, new_freq=22050)
            waveform = resample(waveform)

        spec = self.transform(waveform).squeeze(0)

        if spec.shape[1] > self.max_len:
            spec = spec[:, :self.max_len]
        if spec.shape[1] < self.max_len:
            count_len = self.max_len - spec.shape[1]
            spec = F.pad(spec, (0, count_len))

        return spec, self.label_to_index[genre]


In [None]:
dataset = GTZAN(data_path, transform, max_len)
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size

train_data, test_data = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))

  torchaudio.info(file_path)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  return AudioMetaData(


Ошибка Failed to open the input "/kaggle/input/gtzan-dataset-music-genre-classification/Data/genres_original/jazz/jazz.00054.wav" (Invalid data found when processing input).


In [None]:
train = DataLoader(train_data, batch_size=32, shuffle=True)
test = DataLoader(test_data, batch_size=32, )

In [None]:
device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [56]:
class CheckMelodia(nn.Module):
  def __init__(self):
    super().__init__()
    self.first = nn.Sequential(
        nn.Conv2d(1, 16, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.AdaptiveAvgPool2d((8, 8))
    )
    self.second = nn.Sequential(
        nn.Flatten(),
        nn.Linear(16 * 8 * 8, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    )

  def forward(self, x):
    x = x.unsqueeze(1)
    x = self.first(x)
    x = self.second(x)
    return x

In [57]:
model = CheckMelodia().to(device)


In [58]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [67]:
for epoch in range(25):
  model.train
  total_loss = 0

  for x_batch, y_batch in train:
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)

    y_pred = model(x_batch)
    loss = loss_fn(y_pred, y_batch)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
print(f'Эпоха {epoch+1}, Потери: {total_loss}')

Эпоха 25, Потери: 0.4494849946349859


In [68]:
model.eval()
correct, total = 0, 0
with torch.no_grad():
  for x_batch, y_batch in test:
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)
    y_pred = model(x_batch)
    pred = torch.argmax(y_pred, dim=1)

    correct += (pred == y_batch).sum().item()
    total += y_batch.size(0)

accuracy = correct * 100 / total
print(f'Точность: {accuracy}%')

Точность: 54.0%


In [70]:
torch.save(model.state_dict(), 'model.pth')
torch.save(genres, 'labels.pth')