In [1]:
import warnings
for warn in [UserWarning, FutureWarning]: warnings.filterwarnings("ignore", category = warn)

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import hub
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchaudio

import numpy as np
import pandas as pd

from sklearn.metrics import f1_score, recall_score, precision_score, balanced_accuracy_score, accuracy_score, classification_report
from sklearn.utils import shuffle

import scipy

from tqdm import tqdm

from datasets import load_dataset, Dataset, Audio
import librosa
from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification

from models.basic_transformer import BasicTransformer

from src.utils import AphasiaDatasetMFCC, AphasiaDatasetSpectrogram, AphasiaDatasetWaveform

from collections import Counter
from models.wav2vecClassifier import Wav2vecClassifier
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift, Gain

2025-04-15 22:41:12.859446: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-15 22:41:12.975621: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744746073.040935   12758 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744746073.055814   12758 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-15 22:41:13.167011: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
AUDIO_LENGTH = 6_000
SEQUENCE_LENGTH = 31
MFCC = 128
print(f"It's {DEVICE} time!!!")

It's cuda time!!!


In [3]:
DATA_DIR = os.path.join('..', 'data')
VOICES_DIR = os.path.join(DATA_DIR, 'Voices_wav')
APHASIA_DIR = os.path.join(VOICES_DIR, 'Aphasia')
NORM_DIR = os.path.join(VOICES_DIR, 'Norm')
RIR_DIR = os.path.join(DATA_DIR, 'RIRs')
NOISE_DIR = os.path.join(DATA_DIR, 'noise')

In [4]:
augmentations = audio_augment = Compose([
                AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.02, p=0.7),
                TimeStretch(min_rate=0.4, max_rate=1.5, p=0.5),
                PitchShift(min_semitones=-8, max_semitones=8, p=0.5),
                Shift(min_shift=-0.8, max_shift=0.8, p=0.5),
                Gain(min_gain_db=-14, max_gain_db=14, p=0.5),
            ])

In [5]:
train_dataset = AphasiaDatasetSpectrogram(os.path.join(DATA_DIR, "train_filenames_mc_1.csv"), VOICES_DIR, snr=(10, 20), rirs_dir=RIR_DIR, target_sample_rate=8_000, file_format="wav", transforms=augmentations)

KeyboardInterrupt: 

In [None]:
test_dataset = AphasiaDatasetSpectrogram(os.path.join(DATA_DIR, "val_filenames_mc_1.csv"), VOICES_DIR, target_sample_rate=8_000, file_format="wav")

In [None]:
val_dataset = AphasiaDatasetSpectrogram(os.path.join(DATA_DIR, "test_filenames_mc_1.csv"), VOICES_DIR, target_sample_rate=8_000, file_format="wav")

In [None]:
# Балансировка классов для train
train_labels = [label for _, label in train_dataset.data]
class_counts = Counter(train_labels)
if len(class_counts) < 4:
    raise ValueError("Один из классов отсутствует в тренировочном наборе")

class_weights = {label: 1.0 / count for label, count in class_counts.items()}
weights = [class_weights[label] for _, label in train_dataset.data]
train_sampler = WeightedRandomSampler(weights, num_samples=len(train_dataset), replacement=True)

In [None]:
MAX_LEN = 120_000

In [None]:
def pad_sequence(batch):
    if not batch:
        return torch.zeros(0), torch.zeros(0)
    
    spectrograms, labels = zip(*batch)
    max_len = max(s.shape[1] for s in spectrograms)
    freq_bins = spectrograms[0].shape[2]
    
    padded = torch.zeros(len(spectrograms), 1, max_len, freq_bins)
    for i, s in enumerate(spectrograms):
        padded[i, :, :s.shape[1], :] = s
    
    return padded, torch.stack(labels) 

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler, collate_fn=pad_sequence, drop_last=True, num_workers=6)
# test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=pad_sequence, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=pad_sequence, drop_last=True, num_workers=6)

In [None]:
def train_model(model, dl_train, dl_val, epochs=1, lr=0.001, device="cpu"):
      
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(weight=torch.tensor(list(class_weights.values()))).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)
        
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-7, last_epoch=-1) # torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=3, threshold=1e-3)
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=3, threshold=1e-3)
        
    train_loss_list = []
    val_loss_list = []
    train_acc_list = []
    val_acc_list = []
    for epoch in tqdm(range(epochs), desc="Training model"):
        model.train()
        total_train_loss = 0
        total_val_loss = 0
        
        train_acc = []
        val_acc = []
        for features, target in dl_train:
            features, target = features.to(device), target.to(device)
            optimizer.zero_grad()
            
            output = model(features).squeeze()

            preds = torch.argmax(output, dim=1).cpu().detach().numpy()
            train_acc.append(accuracy_score(target.cpu().detach().numpy(), preds))
            
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.detach().item()
        
        avg_train_acc = np.stack(train_acc, axis=0).mean()

        avg_train_loss = total_train_loss / len(train_dataloader)
        train_loss_list.append(avg_train_loss)
        train_acc_list.append(avg_train_acc)
                
        model.eval()
        
        with torch.no_grad():
            for features, target in dl_val:
                features, target = features.to(device), target.to(device)
                
                output = model(features).squeeze()
                
                preds = torch.argmax(output, dim=1).cpu().detach().numpy()
                val_acc.append(accuracy_score(target.cpu().detach().numpy(), preds))
                loss = criterion(output, target)
                total_val_loss += loss.detach().item()
        
        avg_val_acc = np.stack(val_acc, axis=0).mean()
        avg_val_loss = total_val_loss / len(dl_val)
        val_loss_list.append(avg_val_loss)
        val_acc_list.append(avg_val_acc)
        
        if scheduler:
            try:
                scheduler.step()
            except:
                scheduler.step(avg_val_loss)
        
        # if epoch % 10 == 0:
        tqdm.write(f"Epoch {epoch}: train loss: {avg_train_loss:.3f}, train balanced acc: {avg_train_acc:.2f}, test loss: {avg_val_loss:.3f}, test balanced acc: {avg_val_acc:.2f}, lr: {optimizer.param_groups[0]['lr']}")

            
    return model, train_loss_list, val_loss_list, train_acc_list, val_acc_list

In [13]:
from models.cnn import MobileNet

cnn = MobileNet(num_classes=4)

In [14]:
cnn, train_l, val_l, train_accuracy, val_accuracy = train_model(cnn, train_dataloader, val_dataloader, epochs=130, lr=1e-3, device=DEVICE)

Training model:   1%|          | 1/130 [00:13<28:26, 13.23s/it]

Epoch 0: train loss: 1.360, train balanced acc: 0.26, test loss: 1.346, test balanced acc: 0.30, lr: 0.001


Training model:   2%|▏         | 2/130 [00:26<27:55, 13.09s/it]

Epoch 1: train loss: 1.314, train balanced acc: 0.28, test loss: 1.346, test balanced acc: 0.30, lr: 0.001


Training model:   2%|▏         | 3/130 [00:39<27:36, 13.04s/it]

Epoch 2: train loss: 1.268, train balanced acc: 0.33, test loss: 1.333, test balanced acc: 0.30, lr: 0.001


Training model:   3%|▎         | 4/130 [00:52<27:19, 13.01s/it]

Epoch 3: train loss: 1.229, train balanced acc: 0.37, test loss: 1.317, test balanced acc: 0.30, lr: 0.001


Training model:   4%|▍         | 5/130 [01:05<27:03, 12.99s/it]

Epoch 4: train loss: 1.211, train balanced acc: 0.39, test loss: 1.309, test balanced acc: 0.30, lr: 0.001


Training model:   5%|▍         | 6/130 [01:18<26:52, 13.01s/it]

Epoch 5: train loss: 1.156, train balanced acc: 0.40, test loss: 1.318, test balanced acc: 0.30, lr: 0.001


Training model:   5%|▌         | 7/130 [01:31<26:41, 13.02s/it]

Epoch 6: train loss: 1.122, train balanced acc: 0.41, test loss: 1.320, test balanced acc: 0.30, lr: 0.001


Training model:   6%|▌         | 8/130 [01:44<26:27, 13.01s/it]

Epoch 7: train loss: 1.085, train balanced acc: 0.43, test loss: 1.313, test balanced acc: 0.30, lr: 0.001


Training model:   7%|▋         | 9/130 [01:57<26:14, 13.01s/it]

Epoch 8: train loss: 1.083, train balanced acc: 0.43, test loss: 1.322, test balanced acc: 0.30, lr: 0.0005


Training model:   8%|▊         | 10/130 [02:10<26:00, 13.00s/it]

Epoch 9: train loss: 0.974, train balanced acc: 0.49, test loss: 1.398, test balanced acc: 0.19, lr: 0.0005


Training model:   8%|▊         | 11/130 [02:23<25:47, 13.00s/it]

Epoch 10: train loss: 0.897, train balanced acc: 0.56, test loss: 1.463, test balanced acc: 0.19, lr: 0.0005


Training model:   9%|▉         | 12/130 [02:36<25:33, 13.00s/it]

Epoch 11: train loss: 0.861, train balanced acc: 0.57, test loss: 1.598, test balanced acc: 0.19, lr: 0.0005


Training model:  10%|█         | 13/130 [02:49<25:22, 13.01s/it]

Epoch 12: train loss: 0.832, train balanced acc: 0.59, test loss: 1.640, test balanced acc: 0.19, lr: 0.00025


Training model:  11%|█         | 14/130 [03:02<25:09, 13.01s/it]

Epoch 13: train loss: 0.708, train balanced acc: 0.64, test loss: 1.592, test balanced acc: 0.25, lr: 0.00025


Training model:  12%|█▏        | 15/130 [03:15<24:56, 13.01s/it]

Epoch 14: train loss: 0.606, train balanced acc: 0.69, test loss: 1.738, test balanced acc: 0.36, lr: 0.00025


Training model:  12%|█▏        | 16/130 [03:28<24:44, 13.03s/it]

Epoch 15: train loss: 0.573, train balanced acc: 0.71, test loss: 1.884, test balanced acc: 0.30, lr: 0.00025


Training model:  13%|█▎        | 17/130 [03:41<24:31, 13.03s/it]

Epoch 16: train loss: 0.521, train balanced acc: 0.72, test loss: 1.888, test balanced acc: 0.31, lr: 0.000125


Training model:  14%|█▍        | 18/130 [03:54<24:26, 13.09s/it]

Epoch 17: train loss: 0.473, train balanced acc: 0.75, test loss: 2.020, test balanced acc: 0.31, lr: 0.000125


Training model:  15%|█▍        | 19/130 [04:07<24:17, 13.13s/it]

Epoch 18: train loss: 0.448, train balanced acc: 0.76, test loss: 2.056, test balanced acc: 0.36, lr: 0.000125


Training model:  15%|█▌        | 20/130 [04:20<24:04, 13.13s/it]

Epoch 19: train loss: 0.418, train balanced acc: 0.77, test loss: 2.284, test balanced acc: 0.33, lr: 0.000125


Training model:  16%|█▌        | 21/130 [04:34<23:55, 13.17s/it]

Epoch 20: train loss: 0.403, train balanced acc: 0.79, test loss: 2.180, test balanced acc: 0.37, lr: 6.25e-05


Training model:  17%|█▋        | 22/130 [04:47<23:43, 13.18s/it]

Epoch 21: train loss: 0.365, train balanced acc: 0.80, test loss: 2.358, test balanced acc: 0.36, lr: 6.25e-05


Training model:  18%|█▊        | 23/130 [05:00<23:33, 13.21s/it]

Epoch 22: train loss: 0.378, train balanced acc: 0.80, test loss: 2.460, test balanced acc: 0.36, lr: 6.25e-05


Training model:  18%|█▊        | 24/130 [05:13<23:21, 13.22s/it]

Epoch 23: train loss: 0.342, train balanced acc: 0.83, test loss: 2.386, test balanced acc: 0.35, lr: 6.25e-05


Training model:  19%|█▉        | 25/130 [05:27<23:15, 13.29s/it]

Epoch 24: train loss: 0.323, train balanced acc: 0.83, test loss: 2.542, test balanced acc: 0.37, lr: 3.125e-05


Training model:  20%|██        | 26/130 [05:40<22:55, 13.23s/it]

Epoch 25: train loss: 0.324, train balanced acc: 0.84, test loss: 2.614, test balanced acc: 0.38, lr: 3.125e-05


Training model:  21%|██        | 27/130 [05:53<22:41, 13.22s/it]

Epoch 26: train loss: 0.327, train balanced acc: 0.84, test loss: 2.644, test balanced acc: 0.36, lr: 3.125e-05


Training model:  22%|██▏       | 28/130 [06:07<22:38, 13.32s/it]

Epoch 27: train loss: 0.330, train balanced acc: 0.84, test loss: 2.578, test balanced acc: 0.35, lr: 3.125e-05


Training model:  22%|██▏       | 29/130 [06:20<22:22, 13.30s/it]

Epoch 28: train loss: 0.317, train balanced acc: 0.84, test loss: 2.575, test balanced acc: 0.36, lr: 1.5625e-05


Training model:  23%|██▎       | 30/130 [06:33<22:11, 13.31s/it]

Epoch 29: train loss: 0.319, train balanced acc: 0.84, test loss: 2.674, test balanced acc: 0.38, lr: 1.5625e-05


Training model:  24%|██▍       | 31/130 [06:47<21:55, 13.29s/it]

Epoch 30: train loss: 0.301, train balanced acc: 0.85, test loss: 2.664, test balanced acc: 0.38, lr: 1.5625e-05


Training model:  25%|██▍       | 32/130 [07:00<21:42, 13.29s/it]

Epoch 31: train loss: 0.338, train balanced acc: 0.83, test loss: 2.674, test balanced acc: 0.38, lr: 1.5625e-05


Training model:  25%|██▌       | 33/130 [07:13<21:29, 13.29s/it]

Epoch 32: train loss: 0.312, train balanced acc: 0.84, test loss: 2.714, test balanced acc: 0.39, lr: 7.8125e-06


Training model:  26%|██▌       | 34/130 [07:26<21:16, 13.30s/it]

Epoch 33: train loss: 0.314, train balanced acc: 0.83, test loss: 2.723, test balanced acc: 0.39, lr: 7.8125e-06


Training model:  27%|██▋       | 35/130 [07:40<21:07, 13.34s/it]

Epoch 34: train loss: 0.292, train balanced acc: 0.85, test loss: 2.728, test balanced acc: 0.38, lr: 7.8125e-06


Training model:  28%|██▊       | 36/130 [07:53<20:54, 13.35s/it]

Epoch 35: train loss: 0.302, train balanced acc: 0.86, test loss: 2.730, test balanced acc: 0.38, lr: 7.8125e-06


Training model:  28%|██▊       | 37/130 [08:07<20:45, 13.39s/it]

Epoch 36: train loss: 0.327, train balanced acc: 0.85, test loss: 2.784, test balanced acc: 0.38, lr: 3.90625e-06


Training model:  29%|██▉       | 38/130 [08:20<20:27, 13.34s/it]

Epoch 37: train loss: 0.274, train balanced acc: 0.87, test loss: 2.780, test balanced acc: 0.37, lr: 3.90625e-06


Training model:  30%|███       | 39/130 [08:33<20:12, 13.33s/it]

Epoch 38: train loss: 0.271, train balanced acc: 0.87, test loss: 2.767, test balanced acc: 0.37, lr: 3.90625e-06


Training model:  31%|███       | 40/130 [08:47<20:03, 13.37s/it]

Epoch 39: train loss: 0.278, train balanced acc: 0.88, test loss: 2.768, test balanced acc: 0.37, lr: 3.90625e-06


Training model:  32%|███▏      | 41/130 [09:00<19:41, 13.27s/it]

Epoch 40: train loss: 0.308, train balanced acc: 0.85, test loss: 2.760, test balanced acc: 0.38, lr: 1.953125e-06


Training model:  32%|███▏      | 42/130 [09:13<19:29, 13.29s/it]

Epoch 41: train loss: 0.285, train balanced acc: 0.86, test loss: 2.786, test balanced acc: 0.37, lr: 1.953125e-06


Training model:  33%|███▎      | 43/130 [09:26<19:13, 13.26s/it]

Epoch 42: train loss: 0.279, train balanced acc: 0.87, test loss: 2.777, test balanced acc: 0.37, lr: 1.953125e-06


Training model:  34%|███▍      | 44/130 [09:40<19:02, 13.28s/it]

Epoch 43: train loss: 0.264, train balanced acc: 0.88, test loss: 2.776, test balanced acc: 0.38, lr: 1.953125e-06


Training model:  35%|███▍      | 45/130 [09:53<18:53, 13.33s/it]

Epoch 44: train loss: 0.263, train balanced acc: 0.88, test loss: 2.779, test balanced acc: 0.38, lr: 9.765625e-07


Training model:  35%|███▌      | 46/130 [10:06<18:36, 13.29s/it]

Epoch 45: train loss: 0.292, train balanced acc: 0.87, test loss: 2.778, test balanced acc: 0.38, lr: 9.765625e-07


Training model:  36%|███▌      | 47/130 [10:20<18:22, 13.29s/it]

Epoch 46: train loss: 0.301, train balanced acc: 0.85, test loss: 2.771, test balanced acc: 0.38, lr: 9.765625e-07


Training model:  37%|███▋      | 48/130 [10:33<18:14, 13.34s/it]

Epoch 47: train loss: 0.291, train balanced acc: 0.86, test loss: 2.760, test balanced acc: 0.38, lr: 9.765625e-07


Training model:  38%|███▊      | 49/130 [10:46<18:00, 13.33s/it]

Epoch 48: train loss: 0.306, train balanced acc: 0.85, test loss: 2.776, test balanced acc: 0.38, lr: 4.8828125e-07


Training model:  38%|███▊      | 50/130 [11:00<17:44, 13.30s/it]

Epoch 49: train loss: 0.270, train balanced acc: 0.88, test loss: 2.781, test balanced acc: 0.37, lr: 4.8828125e-07


Training model:  39%|███▉      | 51/130 [11:13<17:32, 13.33s/it]

Epoch 50: train loss: 0.287, train balanced acc: 0.86, test loss: 2.791, test balanced acc: 0.38, lr: 4.8828125e-07


Training model:  40%|████      | 52/130 [11:27<17:25, 13.41s/it]

Epoch 51: train loss: 0.293, train balanced acc: 0.86, test loss: 2.793, test balanced acc: 0.37, lr: 4.8828125e-07


Training model:  41%|████      | 53/130 [11:40<17:09, 13.37s/it]

Epoch 52: train loss: 0.307, train balanced acc: 0.85, test loss: 2.802, test balanced acc: 0.38, lr: 2.44140625e-07


Training model:  41%|████      | 53/130 [11:48<17:09, 13.37s/it]


KeyboardInterrupt: 