In [1]:
import warnings

from pyroomacoustics import room

for warn in [UserWarning, FutureWarning]: warnings.filterwarnings("ignore", category = warn)

import os
import torch
from random import uniform, choice

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import hub
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchaudio
from torchaudio.transforms import Resample

import numpy as np
import pandas as pd

from sklearn.metrics import f1_score, recall_score, precision_score, balanced_accuracy_score, accuracy_score, classification_report
from sklearn.utils import shuffle

import scipy

from tqdm import tqdm

from datasets import load_dataset, Dataset, Audio
import librosa
import pyroomacoustics as pra 

from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification

from models.basic_transformer import BasicTransformer

from src.utils import AphasiaDatasetMFCC, AphasiaDatasetSpectrogram, AphasiaDatasetWaveform

from collections import Counter

2025-04-09 17:21:29.513452: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-09 17:21:29.524751: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744208489.537358  114027 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744208489.541145  114027 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-09 17:21:29.553524: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
AUDIO_LENGTH = 6_000
SEQUENCE_LENGTH = 31
MFCC = 128
SR = 8_000
print(f"It's {DEVICE} time!!!")

It's cuda time!!!


In [3]:
DATA_DIR = os.path.join('..', 'data')
VOICES_DIR = os.path.join(DATA_DIR, 'Voices')
APHASIA_DIR = os.path.join(VOICES_DIR, 'Aphasia')
NORM_DIR = os.path.join(VOICES_DIR, 'Norm')
NOISE_DIR = os.path.join(DATA_DIR, 'noise')
RIR_DIR = os.path.join(DATA_DIR, 'RIRs')
CHKP_PATH = os.path.join("..", 'checkpoints', "wav2vec_chkp")

In [4]:
signal_rir, sample_rate = torchaudio.load("/home/zakhar/ems_dereverb/data/data_thchs30/train/A2_0.wav")
if sample_rate != SR:
    resampler = Resample(sample_rate, SR)
    signal_rir = resampler(signal_rir)

In [5]:
WALLS_KEYWORDS = ["hard_surface", "ceramic_tiles", "plasterboard", "wooden_lining", "glass_3mm"]
FLOOR_KEYWORDS = ["linoleum_on_concrete", "carpet_cotton"]
CEILING_KEYWORDS = ["ceiling_plasterboard", "ceiling_fissured_tile", "ceiling_metal_panel", ]

snr = (-5, 10)
room_square = (7., 14.)
room_height = (3., 4.)

def simulate_rir_shoebox(signal: torch.Tensor) -> torch.Tensor:
    square = uniform(*room_square)
    width = uniform(2.5, square * 0.75)
    length = square / width
    height = uniform(*room_height)

    rt60 = uniform(0.3, 1.25)   # Делать длинный ревёрб
    room_dim = [length, width, height]

    e_absorption, max_order = pra.inverse_sabine(rt60, room_dim)

    wall = pra.Material(choice(WALLS_KEYWORDS))
    ceil = pra.Material(choice(CEILING_KEYWORDS))
    floor = pra.Material(choice(FLOOR_KEYWORDS))

    material = {"east": wall, "west": wall, "north": wall, "south": wall, "ceiling": ceil, "floor": floor}

    room = pra.ShoeBox(room_dim, fs=SR, materials=material, max_order=max_order,
                       use_rand_ism=True, max_rand_disp=0.05, ray_tracing=False)

    source_locs = [uniform(0.01, length), uniform(0.01, width), uniform(1.0, 2.0)]
    mic_locs = np.array([x * 0.98 for x in source_locs])[:, None]

    room.add_source(source_locs, signal=signal.squeeze(), delay=0.5)

    room.add_microphone_array(mic_locs)
    room.compute_rir()
    room.simulate()     # Внутри есть параметр snr, возможно, он пригодится

    return room.rir[0][0]  

In [6]:
for i in range(0, 100):
    rir = simulate_rir_shoebox(signal_rir)
    # print(type(rir[None, :]))
    torchaudio.save(os.path.join(RIR_DIR, f'rir_{str(i)}'), src=torch.from_numpy(rir[None, :]), sample_rate=SR, format='wav')

In [7]:
test_data = pd.read_csv(os.path.join(DATA_DIR, 'test_filenames.csv'))

In [8]:
test_dataset = AphasiaDatasetWaveform(os.path.join(DATA_DIR, "test_filenames.csv"), VOICES_DIR, target_sample_rate=8_000, noise_dir=NOISE_DIR, add_noise=True, snr=(5, 20))

In [9]:
test_dataset[0][0]

torch.Size([1, 100688])

In [10]:
# from IPython.display import Audio

In [11]:
# Audio(test_dataset[0][0], rate=8_000)

### Wav2vec robustness test

In [12]:
from models.wav2vecClassifier import Wav2vecClassifier

wav2vec = Wav2vecClassifier(unfreeze=0.75)

model_sd = torch.load(os.path.join(CHKP_PATH, "wav2vec_0.75_0.9375.pt"), weights_only=False)

wav2vec.load_state_dict(model_sd["model_state_dict"])

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [64]:
def test_model(model, test_dataset, max_seq=420):
    model = model.to(DEVICE)
        
    model.eval()
    preds = []
    targets = []
    with torch.no_grad():
        for features, target in tqdm(test_dataset):
            features = features.to(DEVICE)
            target = target.to(DEVICE)
            if model.__class__.__name__ == "Wav2vecClassifier":
                label = model(features).logits.to("cpu").detach().numpy().squeeze()
            elif model.__class__.__name__ == "MobileNet":
                label = model(features[None, ...]).to("cpu").detach().numpy().squeeze()
            else:
                padded_features = torch.zeros(1, features.shape[0], max_seq, device=DEVICE)
                padded_features[0, :, :features.shape[-1]] = features[..., :max_seq]
                label = model(padded_features).to("cpu").detach().numpy().squeeze()
            preds.append(label.argmax(axis=-1))
            targets.append(target.to("cpu").item())

    preds = np.array(preds)
    # print(targets)
    print(classification_report(targets, preds))    

    return preds

In [14]:
test_model(wav2vec, test_dataset)

100%|██████████| 955/955 [00:15<00:00, 60.19it/s]

              precision    recall  f1-score   support

           0       0.75      0.74      0.75       136
           1       0.96      0.96      0.96       819

    accuracy                           0.93       955
   macro avg       0.85      0.85      0.85       955
weighted avg       0.93      0.93      0.93       955






array([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
       1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1,
       0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1,
       1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1,
       1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,
       1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,
       1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0,
       1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1,
       0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0,
       1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1,

In [15]:
MAX_LEN = 120_000
DATA_DIR = os.path.join('..', 'data')
VOICES_DIR = os.path.join(DATA_DIR, 'Voices_wav')
APHASIA_DIR = os.path.join(VOICES_DIR, 'Aphasia')
NORM_DIR = os.path.join(VOICES_DIR, 'Norm')

In [66]:
def test_model_for_each_participant(model, test_data, test_dataset, norm_dir, aphasia_dir, max_seq=420):
    model = model.to("cpu")
        
    model.eval()

    test_data["ID"] = test_data["file_name"].apply(
        lambda x: str(x).split("-")[0] + str(x).split("-")[1])
    test_data.head()
    IDs = test_data["ID"].unique()

    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for participant_id in tqdm(IDs):
            participant_samples = test_data[test_data["ID"] == participant_id]
            preds = []
            for ind, participant_sample in participant_samples.iterrows():

                sgnl_path = participant_sample["file_name"]

                if participant_sample['label'] == 0:
                    sgnl_path = os.path.join(norm_dir, sgnl_path)
                else:
                    sgnl_path = os.path.join(aphasia_dir, sgnl_path)
                    
                chunks = test_dataset.process_audio(sgnl_path)

                if model.__class__.__name__ == "Wav2vecClassifier":
                    padded = torch.zeros(len(chunks), MAX_LEN)
                    for i, s in enumerate(chunks):
                        padded[i, :s.shape[1]] = s[0, :MAX_LEN]
                    pred = model(torch.from_numpy(np.array(padded))).logits.detach().numpy().squeeze().argmax(axis=-1)
                elif model.__class__.__name__ == "MobileNet":
                    padded = torch.zeros(len(chunks), MFCC, SEQUENCE_LENGTH)
                    for i, s in enumerate(chunks):
                        padded[i, ..., :s.shape[-1]] = s[..., :SEQUENCE_LENGTH]
                    pred = model(torch.from_numpy(np.array(padded)).unsqueeze(1)).detach().numpy().squeeze().argmax(axis=-1)
                else:
                    padded = torch.zeros(len(chunks), MFCC, max_seq)
                    for i, s in enumerate(chunks):
                        padded[i, ..., :s.shape[-1]] = s[..., :max_seq]
                    pred = model(torch.from_numpy(np.array(padded))).detach().numpy().squeeze().argmax(axis=-1)
                # print(type(pred))
                if isinstance(pred, np.ndarray):
                    # print(pred)
                    preds.extend(pred.tolist())
                else:
                    preds.append(pred)
            labels = participant_samples["label"]
  
            # sgnls = torch.from_numpy(np.array(sgnls))
            # preds = model(sgnls).detach().numpy().squeeze().argmax(axis=-1)
            pred = scipy.stats.mode(np.array(preds))

            all_preds.append(pred.mode)
        
            all_labels.append(labels.values[0])
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    print(classification_report(all_labels, all_preds))
    
    return all_preds

In [18]:
test_model_for_each_participant(wav2vec, test_data, test_dataset, NORM_DIR, APHASIA_DIR)

100%|██████████| 72/72 [19:46<00:00, 16.48s/it]

              precision    recall  f1-score   support

           0       0.89      0.76      0.82        21
           1       0.91      0.96      0.93        51

    accuracy                           0.90        72
   macro avg       0.90      0.86      0.88        72
weighted avg       0.90      0.90      0.90        72






array([0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1])

In [73]:
from pydub import AudioSegment

def test_model_no_chunks(model, test_data, test_dataset, norm_dir, aphasia_dir, max_seq=420):
    model = model.to("cpu")
        
    model.eval()

    test_data["ID"] = test_data["file_name"].apply(
        lambda x: str(x).split("-")[0] + str(x).split("-")[1])
    test_data.head()
    IDs = test_data["ID"].unique()

    all_preds_aggr = []
    all_labels_aggr = []
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for participant_id in tqdm(IDs):
            participant_samples = test_data[test_data["ID"] == participant_id]
            preds_aggr = []
            labels_aggr = []
            for ind, participant_sample in participant_samples.iterrows():

                sgnl_path = participant_sample["file_name"]

                if participant_sample['label'] == 0:
                    sgnl_path = os.path.join(norm_dir, sgnl_path)
                else:
                    sgnl_path = os.path.join(aphasia_dir, sgnl_path)
                    
                # y, sr = librosa.load(sgnl_path, sr=8_000)
                audio = AudioSegment.from_file(sgnl_path, format="3gp")
                y = test_dataset.preprocess(audio)
                
                if model.__class__.__name__ == "Wav2vecClassifier":
                    pred = model(y).logits.detach().numpy().squeeze().argmax(axis=-1)
                elif model.__class__.__name__ == "MobileNet":
                    pred = model(y[None, None, ...]).detach().numpy().squeeze().argmax(axis=-1)
                else:
                    padded_y = torch.zeros(1, MFCC, max_seq)
                    padded_y[..., :y.shape[-1]] = y[..., :max_seq]
                    pred = model(padded_y).detach().numpy().squeeze().argmax(axis=-1)
                
                # all_preds.append(pred)
                labels_aggr.append(participant_sample['label'])
                preds_aggr.append(pred)
            # labels = participant_samples["label"]
  
            # sgnls = torch.from_numpy(np.array(sgnls))
            # preds = model(sgnls).detach().numpy().squeeze().argmax(axis=-1)
            pred = scipy.stats.mode(np.array(preds_aggr))

            # all_preds.append(pred.mode)
        
            # all_labels.append(labels.values[0])
            all_preds_aggr.append(pred.mode)
            all_labels_aggr.append(labels_aggr[0])
            
            all_preds.extend(preds_aggr)
            all_labels.extend(labels_aggr)
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    print("Without aggr")
    print(classification_report(all_labels, all_preds))
    
    print("With aggr")
    print(classification_report(all_labels_aggr, all_preds_aggr))
    
    return all_preds, all_preds_aggr

In [22]:
test_model_no_chunks(wav2vec, test_data, test_dataset, NORM_DIR, APHASIA_DIR)

100%|██████████| 72/72 [21:34<00:00, 17.99s/it]

Without aggr
              precision    recall  f1-score   support

           0       0.88      0.83      0.85        42
           1       0.95      0.96      0.95       130

    accuracy                           0.93       172
   macro avg       0.91      0.90      0.90       172
weighted avg       0.93      0.93      0.93       172

With aggr
              precision    recall  f1-score   support

           0       0.95      0.90      0.93        21
           1       0.96      0.98      0.97        51

    accuracy                           0.96        72
   macro avg       0.96      0.94      0.95        72
weighted avg       0.96      0.96      0.96        72






(array([0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  0,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1])

In [24]:
DATA_DIR = os.path.join('..', 'data')
VOICES_DIR = os.path.join(DATA_DIR, 'Voices')
APHASIA_DIR = os.path.join(VOICES_DIR, 'Aphasia')
NORM_DIR = os.path.join(VOICES_DIR, 'Norm')
test_dataset_mfcc = AphasiaDatasetMFCC(os.path.join(DATA_DIR, "test_filenames.csv"), VOICES_DIR, target_sample_rate=8_000, noise_dir=NOISE_DIR, add_noise=True, snr=(5, 20))

### MobileNet robustness test

In [25]:
from models.cnn import MobileNet

mobilenet = MobileNet()

CHKP_PATH = os.path.join("..", 'checkpoints', "mobilenet_chkp")
model_sd = torch.load(os.path.join(CHKP_PATH, "mobilenet_0.94.pt"), weights_only=False)

mobilenet.load_state_dict(model_sd["model_state_dict"])

<All keys matched successfully>

In [26]:
test_model(mobilenet, test_dataset_mfcc)

100%|██████████| 967/967 [00:02<00:00, 378.29it/s]

              precision    recall  f1-score   support

           0       0.90      0.56      0.69       138
           1       0.93      0.99      0.96       829

    accuracy                           0.93       967
   macro avg       0.91      0.77      0.82       967
weighted avg       0.93      0.93      0.92       967






array([1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0,
       1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
       0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,

In [27]:
DATA_DIR = os.path.join('..', 'data')
VOICES_DIR = os.path.join(DATA_DIR, 'Voices_wav')
APHASIA_DIR = os.path.join(VOICES_DIR, 'Aphasia')
NORM_DIR = os.path.join(VOICES_DIR, 'Norm')

In [32]:
test_model_for_each_participant(mobilenet, test_data, test_dataset_mfcc, NORM_DIR, APHASIA_DIR)

100%|██████████| 72/72 [02:36<00:00,  2.18s/it] 

              precision    recall  f1-score   support

           0       0.33      0.05      0.08        21
           1       0.71      0.96      0.82        51

    accuracy                           0.69        72
   macro avg       0.52      0.50      0.45        72
weighted avg       0.60      0.69      0.60        72






array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
       1, 1, 1, 1, 1, 1])

In [33]:
test_model_no_chunks(mobilenet, test_data, test_dataset_mfcc, NORM_DIR, APHASIA_DIR)

100%|██████████| 72/72 [00:24<00:00,  2.98it/s]

Without aggr
              precision    recall  f1-score   support

           0       0.94      0.40      0.57        42
           1       0.84      0.99      0.91       130

    accuracy                           0.85       172
   macro avg       0.89      0.70      0.74       172
weighted avg       0.86      0.85      0.82       172

With aggr
              precision    recall  f1-score   support

           0       1.00      0.48      0.65        21
           1       0.82      1.00      0.90        51

    accuracy                           0.85        72
   macro avg       0.91      0.74      0.77        72
weighted avg       0.87      0.85      0.83        72






(array([0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1,
        0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 [0,
  0,
  1,
  1,
  1,
  1,
  0,
  1,
  1,
  1,
  0,
  0,
  0,
  1,
  1,
  1,
  0,
  0,
  0,
  0,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1])

In [34]:
VOICES_DIR = os.path.join(DATA_DIR, 'Voices')
APHASIA_DIR = os.path.join(VOICES_DIR, 'Aphasia')
NORM_DIR = os.path.join(VOICES_DIR, 'Norm')

### SwishNet robustness test

In [58]:
from models.swishnet import SwishNet
SEQUENCE_LENGTH = 420
MFCC = 128
swishnet = SwishNet(MFCC, 2, input_size=SEQUENCE_LENGTH, dropout_rate=0.5)

CHKP_PATH = os.path.join("..", 'checkpoints', "swishnet_chkp")
model_sd = torch.load(os.path.join(CHKP_PATH, "swishnet_0.919_best.pt"), weights_only=False)
swishnet.load_state_dict(model_sd["model_state_dict"])
# swishnet = swishnet.cpu()

<All keys matched successfully>

In [65]:
test_model(swishnet, test_dataset_mfcc)

100%|██████████| 967/967 [00:00<00:00, 1063.63it/s]


              precision    recall  f1-score   support

           0       0.55      0.30      0.38       138
           1       0.89      0.96      0.92       829

    accuracy                           0.86       967
   macro avg       0.72      0.63      0.65       967
weighted avg       0.84      0.86      0.85       967



array([1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
       1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
       1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,

In [68]:
VOICES_DIR = os.path.join(DATA_DIR, 'Voices_wav')
APHASIA_DIR = os.path.join(VOICES_DIR, 'Aphasia')
NORM_DIR = os.path.join(VOICES_DIR, 'Norm')

In [69]:
test_model_for_each_participant(swishnet, test_data, test_dataset_mfcc, NORM_DIR, APHASIA_DIR)

100%|██████████| 72/72 [01:21<00:00,  1.13s/it]

              precision    recall  f1-score   support

           0       0.75      0.14      0.24        21
           1       0.74      0.98      0.84        51

    accuracy                           0.74        72
   macro avg       0.74      0.56      0.54        72
weighted avg       0.74      0.74      0.67        72






array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1])

In [74]:
test_model_no_chunks(swishnet, test_data, test_dataset_mfcc, NORM_DIR, APHASIA_DIR)

100%|██████████| 72/72 [00:16<00:00,  4.45it/s]

Without aggr
              precision    recall  f1-score   support

           0       0.68      0.31      0.43        42
           1       0.81      0.95      0.88       130

    accuracy                           0.80       172
   macro avg       0.75      0.63      0.65       172
weighted avg       0.78      0.80      0.77       172

With aggr
              precision    recall  f1-score   support

           0       0.82      0.43      0.56        21
           1       0.80      0.96      0.88        51

    accuracy                           0.81        72
   macro avg       0.81      0.69      0.72        72
weighted avg       0.81      0.81      0.78        72






(array([1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 [1,
  0,
  0,
  0,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  0,
  0,
  1,
  1,
  1,
  0,
  1,
  0,
  0,
  0,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  0,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  0,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1])