## Аудио → мел-спектрограмма → преобразование в log-mel и нормализация → positional encoding → audio Transformer (AST), предобученный на AudioSet → классификатор

In [None]:
import os
import torch
import torchaudio
import soundfile as sf
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

from transformers import ASTFeatureExtractor, ASTForAudioClassification
from tqdm.auto import tqdm

if torch.cuda.is_available():
    DEVICE = "cuda"  # GPU в Colab (бесплатно доступна T4 или K80)
    print(f"✅ Используется GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    DEVICE = "mps"  # Только для Mac с чипом Apple Silicon
    print("✅ Используется MPS (Apple Silicon)")
else:
    DEVICE = "cpu"
    print("⚠️ GPU не найден, используется CPU (будет очень медленно)")

print(f"Итоговое устройство: {DEVICE}")
SAMPLE_RATE = 16000
BATCH_SIZE = 8
EPOCHS = 10
NUM_CLASSES = 7

CLASSES = ['down', 'left', 'off', 'on', 'right', 'stop', 'up']

feature_extractor = ASTFeatureExtractor.from_pretrained(
    "MIT/ast-finetuned-audioset-10-10-0.4593"
)

✅ Используется GPU: Tesla T4
Итоговое устройство: cuda


In [None]:
class SpeechCommandsDataset(Dataset):
    def __init__(self, root_dir, classes):
        self.files = []
        self.labels = []
        self.classes = classes

        for label, cls in enumerate(classes):
            cls_dir = os.path.join(root_dir, cls)
            for fname in os.listdir(cls_dir):
                if fname.endswith(".wav"):
                    self.files.append(os.path.join(cls_dir, fname))
                    self.labels.append(label)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        wav_np, sr = sf.read(self.files[idx], dtype="float32")
        waveform = torch.from_numpy(wav_np)

        if waveform.ndim == 1:
            pass
        else:
            waveform = waveform.mean(dim=1)  # to mono

        inputs = feature_extractor(
            waveform,
            sampling_rate=SAMPLE_RATE,
            return_tensors="pt"
        )

        return {
            "input_values": inputs["input_values"].squeeze(0),
            "labels": torch.tensor(self.labels[idx])
        }

In [None]:
from sklearn.model_selection import train_test_split

DATA_ROOT = "speech_commands_data/speech_commands"

full_dataset = SpeechCommandsDataset(DATA_ROOT, CLASSES)

indices = list(range(len(full_dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)

train_ds = torch.utils.data.Subset(full_dataset, train_idx)
val_ds   = torch.utils.data.Subset(full_dataset, val_idx)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
model = ASTForAudioClassification.from_pretrained(
    "MIT/ast-finetuned-audioset-10-10-0.4593",
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
)

model.to(DEVICE)

for param in model.parameters():
    param.requires_grad = False

for layer in model.audio_spectrogram_transformer.encoder.layer[-2:]:
    for p in layer.parameters():
        p.requires_grad = True

for p in model.classifier.parameters():
    p.requires_grad = True

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4,
    weight_decay=1e-4
)

Loading weights:   0%|          | 0/203 [00:00<?, ?it/s]

ASTForAudioClassification LOAD REPORT from: MIT/ast-finetuned-audioset-10-10-0.4593
Key                     | Status   |                                                                                       
------------------------+----------+---------------------------------------------------------------------------------------
classifier.dense.bias   | MISMATCH | Reinit due to size mismatch ckpt: torch.Size([527]) vs model:torch.Size([7])          
classifier.dense.weight | MISMATCH | Reinit due to size mismatch ckpt: torch.Size([527, 768]) vs model:torch.Size([7, 768])

Notes:
- MISMATCH	:ckpt weights were loaded, but they did not match the original empty weight shapes.


In [None]:
best_acc = 0.0
best_epoch = 0

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for batch in tqdm(train_loader, desc=f"Train {epoch+1}/{EPOCHS}"):
        optimizer.zero_grad()

        input_values = batch["input_values"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        outputs = model(
            input_values=input_values,
            labels=labels
        )

        loss = outputs.loss
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        preds = torch.argmax(outputs.logits, dim=1)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)

    train_loss /= len(train_loader)
    train_acc = train_correct / train_total

    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Val   {epoch+1}/{EPOCHS}"):
            input_values = batch["input_values"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = model(
                input_values=input_values,
                labels=labels
            )

            loss = outputs.loss
            val_loss += loss.item()

            preds = torch.argmax(outputs.logits, dim=1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss /= len(val_loader)
    val_acc = val_correct / val_total

    print(
        f"[{epoch+1:2d}/{EPOCHS}] "
        f"train loss: {train_loss:.4f}  acc: {train_acc:.4f} | "
        f"val loss: {val_loss:.4f}  acc: {val_acc:.4f}"
    )

    if val_acc > best_acc:
        best_acc = val_acc
        best_epoch = epoch + 1
        torch.save(model.state_dict(), "best_ast_speech.pt")
        print(f"→ новая лучшая модель, сохранена (acc = {val_acc:.4f})")

print(f"\nЛучшая валидационная точность: {best_acc:.4f} на эпохе {best_epoch}")

Train 1/10:   0%|          | 0/1647 [00:00<?, ?it/s]

Val   1/10:   0%|          | 0/412 [00:00<?, ?it/s]

[ 1/10] train loss: 0.1238  acc: 0.9569 | val loss: 0.2056  acc: 0.9317
→ новая лучшая модель, сохранена (acc = 0.9317)


Train 2/10:   0%|          | 0/1647 [00:00<?, ?it/s]

Val   2/10:   0%|          | 0/412 [00:00<?, ?it/s]

[ 2/10] train loss: 0.0879  acc: 0.9706 | val loss: 0.1727  acc: 0.9411
→ новая лучшая модель, сохранена (acc = 0.9411)


Train 3/10:   0%|          | 0/1647 [00:00<?, ?it/s]

Val   3/10:   0%|          | 0/412 [00:00<?, ?it/s]

[ 3/10] train loss: 0.0645  acc: 0.9783 | val loss: 0.2032  acc: 0.9399


Train 4/10:   0%|          | 0/1647 [00:00<?, ?it/s]

Val   4/10:   0%|          | 0/412 [00:00<?, ?it/s]

[ 4/10] train loss: 0.0563  acc: 0.9799 | val loss: 0.1699  acc: 0.9569
→ новая лучшая модель, сохранена (acc = 0.9569)


Train 5/10:   0%|          | 0/1647 [00:00<?, ?it/s]

Val   5/10:   0%|          | 0/412 [00:00<?, ?it/s]

[ 5/10] train loss: 0.0431  acc: 0.9849 | val loss: 0.2501  acc: 0.9381


Train 6/10:   0%|          | 0/1647 [00:00<?, ?it/s]

Val   6/10:   0%|          | 0/412 [00:00<?, ?it/s]

[ 6/10] train loss: 0.0374  acc: 0.9867 | val loss: 0.1970  acc: 0.9511


Train 7/10:   0%|          | 0/1647 [00:00<?, ?it/s]

Val   7/10:   0%|          | 0/412 [00:00<?, ?it/s]

[ 7/10] train loss: 0.0360  acc: 0.9882 | val loss: 0.1854  acc: 0.9542


Train 8/10:   0%|          | 0/1647 [00:00<?, ?it/s]

Val   8/10:   0%|          | 0/412 [00:00<?, ?it/s]

[ 8/10] train loss: 0.0229  acc: 0.9928 | val loss: 0.1806  acc: 0.9605
→ новая лучшая модель, сохранена (acc = 0.9605)


Train 9/10:   0%|          | 0/1647 [00:00<?, ?it/s]

In [1]:
model.eval()

y_true = []
y_pred = []

with torch.no_grad():
    for batch in val_loader:
        input_values = batch["input_values"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        outputs = model(input_values=input_values)
        preds = torch.argmax(outputs.logits, dim=1)

        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

acc = accuracy_score(y_true, y_pred)
print(f"Validation accuracy: {acc:.4f}")

print(
    classification_report(
        y_true,
        y_pred,
        target_names=CLASSES,
        digits=4
    )
)

NameError: name 'model' is not defined

In [None]:
cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(8, 6))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    xticklabels=CLASSES,
    yticklabels=CLASSES,
    cmap="Blues"
)

plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix (AST)")
plt.tight_layout()
plt.show()