In [1]:
import warnings
for warn in [UserWarning, FutureWarning]: warnings.filterwarnings("ignore", category = warn)

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import hub
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchaudio

import numpy as np
import pandas as pd

from sklearn.metrics import f1_score, recall_score, precision_score, balanced_accuracy_score, accuracy_score, classification_report
from sklearn.utils import shuffle

import scipy

from tqdm import tqdm

from datasets import load_dataset, Dataset, Audio
import librosa
from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification

from models.basic_transformer import BasicTransformer

from src.utils import AphasiaDatasetMFCC, AphasiaDatasetSpectrogram, AphasiaDatasetWaveform

from collections import Counter

2025-03-27 19:30:06.896791: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-27 19:30:06.980427: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743093007.029269   79148 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743093007.041129   79148 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-27 19:30:07.149152: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
AUDIO_LENGTH = 6_000
SEQUENCE_LENGTH = 31
MFCC = 128
print(f"It's {DEVICE} time!!!")

It's cuda time!!!


In [3]:
DATA_DIR = os.path.join('..', 'data')
VOICES_DIR = os.path.join(DATA_DIR, 'Voices')
APHASIA_DIR = os.path.join(VOICES_DIR, 'Aphasia')
NORM_DIR = os.path.join(VOICES_DIR, 'Norm')

In [4]:
train_dataset = AphasiaDatasetWaveform(os.path.join(DATA_DIR, "train_filenames.csv"), VOICES_DIR, target_sample_rate=8_000)
test_dataset = AphasiaDatasetWaveform(os.path.join(DATA_DIR, "val_filenames.csv"), VOICES_DIR, target_sample_rate=8_000)
val_dataset = AphasiaDatasetWaveform(os.path.join(DATA_DIR, "test_filenames.csv"), VOICES_DIR, target_sample_rate=8_000)

# Балансировка классов для train
train_labels = [label for _, label in train_dataset.data]
class_counts = Counter(train_labels)
if len(class_counts) < 2:
    raise ValueError("Один из классов отсутствует в тренировочном наборе")

class_weights = {label: 1.0 / count for label, count in class_counts.items()}
weights = [class_weights[label] for _, label in train_dataset.data]
train_sampler = WeightedRandomSampler(weights, num_samples=len(train_dataset), replacement=True)

In [5]:
MAX_LEN = 120_000

In [6]:
def pad_sequence(batch):
    if not batch:
        return torch.zeros(0), torch.zeros(0)
    
    seq, labels = zip(*batch)
    # print(seq[1], labels)
    max_len = max(s.shape[1] for s in seq)
    # print(seq[0].shape)

    # print(seq[0].shape)
    padded = torch.zeros(len(seq), MAX_LEN)
    for i, s in enumerate(seq):
        padded[i, :s.shape[1]] = s[0, :MAX_LEN]
    
    return padded, torch.stack(labels) 

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=16, sampler=train_sampler, collate_fn=pad_sequence, drop_last=True, num_workers=6)
# test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=pad_sequence, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=pad_sequence, drop_last=True, num_workers=6)

In [8]:
CHKP_PATH = os.path.join("..", 'checkpoints', "wav2vec_chkp")

In [9]:
from models.wav2vecClassifier import Wav2vecClassifier

wav2vec = Wav2vecClassifier(unfreeze=0.75)

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
model_sd = torch.load(os.path.join(CHKP_PATH, "wav2vec_0.75_0.9375.pt"), weights_only=False)

wav2vec.load_state_dict(model_sd["model_state_dict"])

<All keys matched successfully>

In [11]:
DATA_DIR = os.path.join('..', 'data')
VOICES_DIR = os.path.join(DATA_DIR, 'Voices_wav')
APHASIA_DIR = os.path.join(VOICES_DIR, 'Aphasia')
NORM_DIR = os.path.join(VOICES_DIR, 'Norm')

In [12]:
test_data = pd.read_csv(os.path.join(DATA_DIR, 'test_filenames.csv'))

In [78]:
def test_model_for_each_participant(model, test_data):
    model = model.to("cpu")
        
    model.eval()

    test_data["ID"] = test_data["file_name"].apply(
        lambda x: str(x).split("-")[0] + str(x).split("-")[1])
    test_data.head()
    IDs = test_data["ID"].unique()

    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for participant_id in tqdm(IDs):
            participant_samples = test_data[test_data["ID"] == participant_id]
            preds = []
            for ind, participant_sample in participant_samples.iterrows():

                sgnl_path = participant_sample["file_name"]

                if participant_sample['label'] == 0:
                    sgnl_path = os.path.join(NORM_DIR, sgnl_path)
                else:
                    sgnl_path = os.path.join(APHASIA_DIR, sgnl_path)
                    
                chunks = train_dataset.process_audio(sgnl_path)

                padded = torch.zeros(len(chunks), MAX_LEN)
                for i, s in enumerate(chunks):
                    padded[i, :s.shape[1]] = s[0, :MAX_LEN]
                pred = model(torch.from_numpy(np.array(padded))).logits.detach().numpy().squeeze().argmax(axis=-1)
                # print(type(pred))
                if isinstance(pred, np.ndarray):
                    # print(pred)
                    preds.extend(pred.tolist())
                else:
                    preds.append(pred)
            labels = participant_samples["label"]
  
            # sgnls = torch.from_numpy(np.array(sgnls))
            # preds = model(sgnls).detach().numpy().squeeze().argmax(axis=-1)
            pred = scipy.stats.mode(np.array(preds))

            all_preds.append(pred.mode)
        
            all_labels.append(labels.values[0])
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    print(classification_report(all_labels, all_preds))
    
    return all_preds

In [79]:
test_model_for_each_participant(wav2vec, test_data)

100%|██████████| 72/72 [19:05<00:00, 15.90s/it]

              precision    recall  f1-score   support

           0       1.00      0.86      0.92        21
           1       0.94      1.00      0.97        51

    accuracy                           0.96        72
   macro avg       0.97      0.93      0.95        72
weighted avg       0.96      0.96      0.96        72






array([0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1])

In [25]:
def test_model_no_chunks(model, test_data):
    model = model.to("cpu")
        
    model.eval()

    test_data["ID"] = test_data["file_name"].apply(
        lambda x: str(x).split("-")[0] + str(x).split("-")[1])
    test_data.head()
    IDs = test_data["ID"].unique()

    all_preds_aggr = []
    all_labels_aggr = []
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for participant_id in tqdm(IDs):
            participant_samples = test_data[test_data["ID"] == participant_id]
            preds_aggr = []
            labels_aggr = []
            for ind, participant_sample in participant_samples.iterrows():

                sgnl_path = participant_sample["file_name"]

                if participant_sample['label'] == 0:
                    sgnl_path = os.path.join(NORM_DIR, sgnl_path)
                else:
                    sgnl_path = os.path.join(APHASIA_DIR, sgnl_path)
                    
                y, sr = librosa.load(sgnl_path, sr=8_000)

                pred = model(torch.from_numpy(y)[None, :]).logits.detach().numpy().squeeze().argmax(axis=-1)
                
                # all_preds.append(pred)
                labels_aggr.append(participant_sample['label'])
                preds_aggr.append(pred)
            # labels = participant_samples["label"]
  
            # sgnls = torch.from_numpy(np.array(sgnls))
            # preds = model(sgnls).detach().numpy().squeeze().argmax(axis=-1)
            pred = scipy.stats.mode(np.array(preds_aggr))

            # all_preds.append(pred.mode)
        
            # all_labels.append(labels.values[0])
            all_preds_aggr.append(pred.mode)
            all_labels_aggr.append(labels_aggr[0])
            
            all_preds.extend(preds_aggr)
            all_labels.extend(labels_aggr)
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    print("Without aggr")
    print(classification_report(all_labels, all_preds))
    
    print("With aggr")
    print(classification_report(all_labels_aggr, all_preds_aggr))
    
    return all_preds, all_preds_aggr

In [26]:
_, _ = test_model_no_chunks(wav2vec, test_data)

100%|██████████| 72/72 [21:09<00:00, 17.63s/it]

Without aggr
              precision    recall  f1-score   support

           0       0.84      0.88      0.86        42
           1       0.96      0.95      0.95       130

    accuracy                           0.93       172
   macro avg       0.90      0.91      0.91       172
weighted avg       0.93      0.93      0.93       172

With aggr
              precision    recall  f1-score   support

           0       0.86      0.90      0.88        21
           1       0.96      0.94      0.95        51

    accuracy                           0.93        72
   macro avg       0.91      0.92      0.92        72
weighted avg       0.93      0.93      0.93        72




