In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchaudio import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [None]:
train_data = datasets.SPEECHCOMMANDS(root='./data', download=True, subset='training')
test_data = datasets.SPEECHCOMMANDS(root='./data', download=True, subset='testing')
valid_data = datasets.SPEECHCOMMANDS(root='./data', download=True, subset='validation')

100%|██████████| 2.26G/2.26G [00:20<00:00, 120MB/s]


In [None]:
label = list(set([i[2] for i in train_data]))
# label_to_index = {label: i for i, label in enumerate(label)}
# index_to_label = {i: label for i, label in enumerate(label)}

# print(f'Всего классов: {len(label)}')
# print(f'Примеры меток:, {label[:10]}')

  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


In [None]:
label_to_index = {label: i for i, label in enumerate(label)}
label_to_index

{'bird': 0,
 'four': 1,
 'stop': 2,
 'go': 3,
 'on': 4,
 'tree': 5,
 'five': 6,
 'down': 7,
 'two': 8,
 'backward': 9,
 'bed': 10,
 'zero': 11,
 'house': 12,
 'one': 13,
 'off': 14,
 'marvin': 15,
 'forward': 16,
 'six': 17,
 'follow': 18,
 'visual': 19,
 'three': 20,
 'eight': 21,
 'up': 22,
 'yes': 23,
 'sheila': 24,
 'learn': 25,
 'dog': 26,
 'no': 27,
 'left': 28,
 'cat': 29,
 'wow': 30,
 'happy': 31,
 'seven': 32,
 'nine': 33,
 'right': 34}

In [None]:
transform = transforms.MelSpectrogram(
    sample_rate=16000,
    n_mels = 64
)

In [None]:
import torch.nn.functional as F
max_len = 100

def collate_fn(batch):
    spectrograms, targets =[], []
    for waveform, sample_rate, label, *_ in batch:
        spec = transform(waveform).squeeze()
        if spec.shape[1] > max_len:
            spec = spec[:, :max_len]
        if spec.shape[1] < max_len:
            pad_amount = max_len - spec.shape[1]
            spec = F.pad(spec, (0, pad_amount))
        spectrograms.append(spec)
        targets.append(label_to_index[label])

    spectrograms = torch.stack(spectrograms)
    targets = torch.tensor(targets)
    return spectrograms, targets

In [None]:
len(label)

35

In [None]:
train = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=collate_fn)
test = DataLoader(test_data, batch_size=64, collate_fn=collate_fn)


In [None]:
class CheckAudio(nn.Module):
    def __init__(self, num_classes=35):
        super().__init__()
        self.first = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((8, 8))
        )

        self.second = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 8 * 8, 128),
            nn.ReLU(),
            nn.Linear(128, 35),
        )

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.first(x)
        x = self.second(x)
        return x


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
model = CheckAudio().to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

In [None]:
for epoch in range(10):
  model.train()
  total_loss = 0
  for x_batch, y_batch in train:
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)

    y_pred = model(x_batch)
    loss = loss_fn(y_pred, y_batch)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

  scheduler.step()
  print(f"Эпоха {epoch + 1} - Потери: {round(total_loss, 2)}")

Эпоха 1 - Потери: 551.04
Эпоха 2 - Потери: 549.0
Эпоха 3 - Потери: 542.11
Эпоха 4 - Потери: 540.87
Эпоха 5 - Потери: 540.36
Эпоха 6 - Потери: 537.18
Эпоха 7 - Потери: 536.84
Эпоха 8 - Потери: 536.5
Эпоха 9 - Потери: 535.33


In [None]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
  for x_batch, y_batch in test:
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)
    y_pred = model(x_batch)
    predicted = torch.argmax(y_pred, dim=1)
    total += y_batch.size(0)
    correct += (predicted == y_batch).sum().item()

accuracy = 100 * correct / total
print(f"Точность: {round(accuracy, 2)}%")

Точность: 77.99%


In [None]:
torch.save(label_to_index.state_dict(), 'label.pth')

AttributeError: 'list' object has no attribute 'state_dict'

In [None]:
torch.save(model.state_dict(), 'torch_audio.pth')