In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchaudio import datasets, transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch.nn.functional as F
import zipfile
import os
from torch.utils.data import Dataset
from torch.utils.data import random_split

In [None]:
zipfile.ZipFile("../../Downloads/Data.zip").extractall("../../Downloads/Data")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [None]:
class GTZANDataset(Dataset):
    def __init__(self, root_dir, transform=None, target_transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.target_transform = target_transform

        self.classes = sorted(os.listdir(root_dir))
        self.files = []
        for label, genre in enumerate(self.classes):
            genre_dir = os.path.join(root_dir, genre)
            for file in os.listdir(genre_dir):
                if file.endswith(".wav"):
                    self.files.append((os.path.join(genre_dir, file), label))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path, label = self.files[idx]
        try:
            waveform, sample_rate = torchaudio.load(file_path)
        except Exception as e:
            print(f"Ошибка загрузки {file_path}: {e}")

            return torch.zeros(1, 22050), label

        if self.transform:
            waveform = self.transform(waveform)

        if self.target_transform:
            label = self.target_transform(label)

        return waveform, label


In [None]:
transform = transforms.MelSpectrogram(
    sample_rate=22050,
    n_mels=128,
    n_fft=2048,
    hop_length=512
)
# dataset = GTZANDataset("/content/data/genres_original", transform=transform)
dataset = GTZANDataset("/content/data/Data/genres_original")


In [None]:
import torchaudio

# Проверка и удаление битых файлов
bad_files = []
for path, label in dataset.files:
    try:
        torchaudio.load(path)
    except Exception as e:
        bad_files.append(path)
        print(f"Ошибка загрузки {path}: {e}")

# Удаляем битые файлы из списка датасета
dataset.files = [(p, l) for p, l in dataset.files if p not in bad_files]

print(f"Всего удалено битых файлов: {len(bad_files)}")


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Ошибка загрузки /content/data/Data/genres_original/jazz/jazz.00054.wav: Failed to open the input "/content/data/Data/genres_original/jazz/jazz.00054.wav" (Invalid data found when processing input).
Всего удалено битых файлов: 1


In [None]:
train_size = int(0.7 * len(dataset))
val_size   = int(0.15 * len(dataset))
test_size  = len(dataset) - train_size - val_size

train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size])

In [None]:
labels = sorted(list(set([i[1] for i in train_data])))
label_to_index = {lab: ind for ind, lab in enumerate(labels)}

In [None]:
max_len = 660

def collate_fn(batch):
    spectrograms, targets = [], []
    for waveform, label, *_ in batch:
      spec = transform(waveform).squeeze(0)

      if spec.shape[1] > max_len:
        spec = spec[:, :max_len]

      if spec.shape[1] < max_len:
        pad_amount = max_len - spec.shape[1]
        spec = F.pad(spec, (0, pad_amount))
      spectrograms.append(spec)
      targets.append(label_to_index[label])

    spectrograms = torch.stack(spectrograms)
    targets = torch.tensor(targets)

    return spectrograms, targets

In [None]:
train = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=collate_fn)
valid = DataLoader(val_data, batch_size=64, collate_fn=collate_fn)
test = DataLoader(test_data, batch_size=64, collate_fn=collate_fn)

In [None]:
labels = dataset.classes
labels

['blues',
 'classical',
 'country',
 'disco',
 'hiphop',
 'jazz',
 'metal',
 'pop',
 'reggae',
 'rock']

In [None]:
num_classes = len(labels)

In [None]:
class CheckAudio(nn.Module):
  def __init__(self):
    super().__init__()
    self.first = nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Conv2d(32, 64, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.AdaptiveAvgPool2d((4, 4))
    )
    self.second = nn.Sequential(
        nn.Flatten(),
        nn.Linear(64 * 4 * 4, 128),
        nn.ReLU(),
        nn.Linear(128, num_classes)
    )
  def forward(self, x):
    x = x.unsqueeze(1)
    x = self.first(x)
    x = self.second(x)
    return x

In [None]:
model = CheckAudio().to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
for epoch in range(50):
  model.train
  total_loss = 0
  for x_batch, y_batch in train:
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)

    y_pred_train = model(x_batch)
    loss = loss_fn(y_pred_train, y_batch)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
  print(f'Эпоха {epoch+1}, Потери: {total_loss}')

  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Эпоха 1, Потери: 586.1523303985596
Эпоха 2, Потери: 130.43727445602417
Эпоха 3, Потери: 38.75723624229431
Эпоха 4, Потери: 27.510332822799683
Эпоха 5, Потери: 23.98690915107727
Эпоха 6, Потери: 22.45406150817871
Эпоха 7, Потери: 21.83602750301361
Эпоха 8, Потери: 21.264204621315002
Эпоха 9, Потери: 20.447927474975586
Эпоха 10, Потери: 19.817460536956787
Эпоха 11, Потери: 19.118231534957886
Эпоха 12, Потери: 18.379493713378906
Эпоха 13, Потери: 17.889331698417664
Эпоха 14, Потери: 17.178139328956604
Эпоха 15, Потери: 16.62355077266693
Эпоха 16, Потери: 16.230235695838928
Эпоха 17, Потери: 15.426336288452148
Эпоха 18, Потери: 14.882123470306396
Эпоха 19, Потери: 14.493664979934692
Эпоха 20, Потери: 13.855992555618286
Эпоха 21, Потери: 13.40796971321106
Эпоха 22, Потери: 12.50597733259201
Эпоха 23, Потери: 12.155968725681305
Эпоха 24, Потери: 11.245635867118835
Эпоха 25, Потери: 10.873333215713501
Эпоха 26, Потери: 10.357626795768738
Эпоха 27, Потери: 10.420064747333527
Эпоха 28, Потери: 

In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
  for x_batch, y_batch in test:
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)
    y_pred_test = model(x_batch)
    predicted = torch.argmax(y_pred_test, dim=1)

    total += y_batch.size(0)
    correct += (predicted == y_batch).sum().item()

accuracy = 100 * correct / total
print(f'toch models is test datasets:  {accuracy:.2f}%')