In [1]:
import warnings
for warn in [UserWarning, FutureWarning]: warnings.filterwarnings("ignore", category = warn)

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import hub
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchaudio

import numpy as np
import pandas as pd

from sklearn.metrics import f1_score, recall_score, precision_score, balanced_accuracy_score, accuracy_score, classification_report
from sklearn.utils import shuffle

import scipy

from tqdm import tqdm

from datasets import load_dataset, Dataset, Audio
import librosa
from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification

from models.basic_transformer import BasicTransformer

from src.utils import AphasiaDatasetMFCC, AphasiaDatasetSpectrogram, AphasiaDatasetWaveform

from collections import Counter
from models.wav2vecClassifier import Wav2vecClassifier
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift, Gain

2025-04-16 00:18:33.223557: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-16 00:18:33.313815: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744751913.351490   18618 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744751913.367236   18618 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-16 00:18:33.474794: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
AUDIO_LENGTH = 6_000
SEQUENCE_LENGTH = 31
MFCC = 128
print(f"It's {DEVICE} time!!!")

It's cuda time!!!


In [3]:
DATA_DIR = os.path.join('..', 'data')
VOICES_DIR = os.path.join(DATA_DIR, 'Voices_wav')
APHASIA_DIR = os.path.join(VOICES_DIR, 'Aphasia')
NORM_DIR = os.path.join(VOICES_DIR, 'Norm')
RIR_DIR = os.path.join(DATA_DIR, 'RIRs')
NOISE_DIR = os.path.join(DATA_DIR, 'noise')

In [4]:
augmentations = audio_augment = Compose([
                AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.02, p=0.7),
                TimeStretch(min_rate=0.8, max_rate=1.2, p=0.5),
                PitchShift(min_semitones=-3, max_semitones=3, p=0.5),
                Shift(min_shift=-0.3, max_shift=0.3, p=0.5),
                Gain(min_gain_db=-8, max_gain_db=8, p=0.5),
            ])

In [5]:
train_dataset = AphasiaDatasetWaveform(os.path.join(DATA_DIR, "train_filenames_mc_1.csv"), VOICES_DIR, snr=(10, 20), noise_dir=NORM_DIR, rirs_dir=RIR_DIR, target_sample_rate=8_000, file_format="wav", transforms=augmentations)

In [6]:
test_dataset = AphasiaDatasetWaveform(os.path.join(DATA_DIR, "val_filenames_mc_1.csv"), VOICES_DIR, target_sample_rate=8_000, file_format="wav")

In [7]:
val_dataset = AphasiaDatasetWaveform(os.path.join(DATA_DIR, "test_filenames_mc_1.csv"), VOICES_DIR, target_sample_rate=8_000, file_format="wav")

In [8]:
# Балансировка классов для train
train_labels = [label for _, label in train_dataset.data]
class_counts = Counter(train_labels)
if len(class_counts) < 4:
    raise ValueError("Один из классов отсутствует в тренировочном наборе")

class_weights = {label: 1.0 / count for label, count in class_counts.items()}
weights = [class_weights[label] for _, label in train_dataset.data]
train_sampler = WeightedRandomSampler(weights, num_samples=len(train_dataset), replacement=True)

In [9]:
MAX_LEN = 120_000

In [10]:
def pad_sequence(batch):
    if not batch:
        return torch.zeros(0), torch.zeros(0)
    
    seq, labels = zip(*batch)
    # print(seq[1], labels)
    max_len = max(s.shape[1] for s in seq)
    # print(seq[0].shape)

    # print(seq[0].shape)
    padded = torch.zeros(len(seq), max_len)
    for i, s in enumerate(seq):
        padded[i, :s.shape[1]] = s[0, :MAX_LEN]
    
    return padded, torch.stack(labels) 

In [11]:
train_dataloader = DataLoader(train_dataset, batch_size=4, sampler=train_sampler, collate_fn=pad_sequence, drop_last=True, num_workers=6)
# test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=pad_sequence, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=pad_sequence, drop_last=True, num_workers=6)

In [12]:
def train_model(model, dl_train, dl_val, epochs=1, lr=0.001, device="cpu"):
      
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(weight=torch.tensor(list(class_weights.values()))).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
        
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, threshold=1e-3)
    # scheduler = step_scheduler(optimizer, 30) 
        
    train_loss_list = []
    val_loss_list = []
    train_acc_list = []
    val_acc_list = []
    for epoch in tqdm(range(epochs), desc="Training model"):
        model.train()
        total_train_loss = 0
        total_val_loss = 0
        
        train_acc = []
        val_acc = []
        for features, target in dl_train:
            features, target = features.to(device), target.to(device)
            optimizer.zero_grad()
            
            output = model(features).logits
            # print(output)
            # train_acc.append(torch.tensor(torch.argmax(output, dim=1) == target).cpu().detach().numpy())
            preds = torch.argmax(output, dim=1).cpu().detach().numpy()
            train_acc.append(accuracy_score(target.cpu().detach().numpy(), preds))
            
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.detach().item()
        # print(f"LR before: {optimizer.param_groups[0]['lr']}")
        # scheduler.step(epoch=epoch)
        # print(f"LR after: {optimizer.param_groups[0]['lr']}")
        
        avg_train_acc = np.stack(train_acc, axis=0).mean()

        avg_train_loss = total_train_loss / len(train_dataloader)
        train_loss_list.append(avg_train_loss)
        train_acc_list.append(avg_train_acc)
                
        model.eval()
        
        with torch.no_grad():
            for features, target in dl_val:
                features, target = features.to(device), target.to(device)
                
                output = model(features).logits
                
                preds = torch.argmax(output, dim=1).cpu().detach().numpy()
                val_acc.append(accuracy_score(target.cpu().detach().numpy(), preds))
                loss = criterion(output, target)
                total_val_loss += loss.detach().item()
        
        avg_val_acc = np.stack(val_acc, axis=0).mean()
        avg_val_loss = total_val_loss / len(dl_val)
        val_loss_list.append(avg_val_loss)
        val_acc_list.append(avg_val_acc)
        
        if scheduler:
            try:
                scheduler.step()
            except:
                scheduler.step(avg_val_loss)
        
        # if epoch % 10 == 0:
        tqdm.write(f"Epoch {epoch}: train loss: {avg_train_loss:.3f}, train balanced acc: {avg_train_acc:.2f}, test loss: {avg_val_loss:.3f}, test balanced acc: {avg_val_acc:.2f}, lr: {optimizer.param_groups[0]['lr']}")

            
    return model, train_loss_list, val_loss_list, train_acc_list, val_acc_list

In [13]:
wav2vec = Wav2vecClassifier(num_labels=4, unfreeze=0.0)

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
from termcolor import colored
from collections import defaultdict

def beautiful_int(i):
    i = str(i)
    return ".".join(reversed([i[max(j, 0):j+3] for j in range(len(i) - 3, -3, -3)]))

# Считаем общее число параметров в нашей модели
def model_num_params(model, verbose_all=True, verbose_only_learnable=False):
    sum_params = 0
    sum_learnable_params = 0
    submodules = defaultdict(lambda : [0, 0])
    for name, param in model.named_parameters():
        num_params = np.prod(param.shape)
        if verbose_all or (verbose_only_learnable and param[1].requires_grad):
            print(
                colored(
                    '{: <42} ~  {: <9} params ~ grad: {}'.format(
                        name,
                        beautiful_int(num_params),
                        param.requires_grad,
                    ),
                    {True: "green", False: "red"}[param.requires_grad],
                )
            )
        sum_params += num_params
        sm = name.split(".")[0]
        submodules[sm][0] += num_params
        if param.requires_grad:
            sum_learnable_params += num_params
            submodules[sm][1] += num_params
    print(
        f'\nIn total:\n  - {beautiful_int(sum_params)} params\n  - {beautiful_int(sum_learnable_params)} learnable params'
    )

    for sm, v in submodules.items():
        print(
            f"\n . {sm}:\n .   - {beautiful_int(submodules[sm][0])} params\n .   - {beautiful_int(submodules[sm][1])} learnable params"
        )
    return sum_params, sum_learnable_params


sum_params, sum_learnable_params = model_num_params(wav2vec)

[32mwav2vec.wav2vec2.masked_spec_embed         ~  768       params ~ grad: True[0m
[32mwav2vec.wav2vec2.feature_extractor.conv_layers.0.conv.weight ~  5.120     params ~ grad: True[0m
[32mwav2vec.wav2vec2.feature_extractor.conv_layers.0.layer_norm.weight ~  512       params ~ grad: True[0m
[32mwav2vec.wav2vec2.feature_extractor.conv_layers.0.layer_norm.bias ~  512       params ~ grad: True[0m
[32mwav2vec.wav2vec2.feature_extractor.conv_layers.1.conv.weight ~  786.432   params ~ grad: True[0m
[32mwav2vec.wav2vec2.feature_extractor.conv_layers.2.conv.weight ~  786.432   params ~ grad: True[0m
[32mwav2vec.wav2vec2.feature_extractor.conv_layers.3.conv.weight ~  786.432   params ~ grad: True[0m
[32mwav2vec.wav2vec2.feature_extractor.conv_layers.4.conv.weight ~  786.432   params ~ grad: True[0m
[32mwav2vec.wav2vec2.feature_extractor.conv_layers.5.conv.weight ~  524.288   params ~ grad: True[0m
[32mwav2vec.wav2vec2.feature_extractor.conv_layers.6.conv.weight ~  524.288   pa

In [15]:
wav2vec, train_l, val_l, train_acc, val_acc = train_model(wav2vec, train_dataloader, val_dataloader, epochs=30, lr=0.001, device=DEVICE)

Training model:   3%|▎         | 1/30 [03:26<1:39:54, 206.72s/it]

Epoch 0: train loss: 1.370, train balanced acc: 0.26, test loss: 1.328, test balanced acc: 0.31, lr: 0.001


Training model:   7%|▋         | 2/30 [06:59<1:38:00, 210.01s/it]

Epoch 1: train loss: 1.362, train balanced acc: 0.26, test loss: 1.326, test balanced acc: 0.31, lr: 0.001


Training model:  10%|█         | 3/30 [10:31<1:35:04, 211.28s/it]

Epoch 2: train loss: 1.367, train balanced acc: 0.24, test loss: 1.344, test balanced acc: 0.31, lr: 0.001


Training model:  13%|█▎        | 4/30 [14:06<1:32:04, 212.49s/it]

Epoch 3: train loss: 1.362, train balanced acc: 0.25, test loss: 1.331, test balanced acc: 0.31, lr: 0.001


Training model:  17%|█▋        | 5/30 [17:41<1:28:56, 213.46s/it]

Epoch 4: train loss: 1.367, train balanced acc: 0.24, test loss: 1.318, test balanced acc: 0.31, lr: 0.001


Training model:  20%|██        | 6/30 [21:16<1:25:36, 214.04s/it]

Epoch 5: train loss: 1.367, train balanced acc: 0.24, test loss: 1.321, test balanced acc: 0.31, lr: 0.001


Training model:  23%|██▎       | 7/30 [24:50<1:22:05, 214.16s/it]

Epoch 6: train loss: 1.360, train balanced acc: 0.26, test loss: 1.302, test balanced acc: 0.31, lr: 0.001


Training model:  27%|██▋       | 8/30 [28:24<1:18:24, 213.82s/it]

Epoch 7: train loss: 1.366, train balanced acc: 0.24, test loss: 1.318, test balanced acc: 0.31, lr: 0.001


Training model:  30%|███       | 9/30 [31:56<1:14:40, 213.37s/it]

Epoch 8: train loss: 1.353, train balanced acc: 0.26, test loss: 1.312, test balanced acc: 0.31, lr: 0.001


Training model:  33%|███▎      | 10/30 [35:28<1:11:01, 213.09s/it]

Epoch 9: train loss: 1.354, train balanced acc: 0.25, test loss: 1.313, test balanced acc: 0.31, lr: 0.0005


Training model:  37%|███▋      | 11/30 [39:02<1:07:32, 213.27s/it]

Epoch 10: train loss: 1.361, train balanced acc: 0.25, test loss: 1.480, test balanced acc: 0.31, lr: 0.0005


Training model:  40%|████      | 12/30 [42:35<1:03:58, 213.24s/it]

Epoch 11: train loss: 1.359, train balanced acc: 0.24, test loss: 1.311, test balanced acc: 0.31, lr: 0.0005


Training model:  43%|████▎     | 13/30 [46:08<1:00:20, 212.96s/it]

Epoch 12: train loss: 1.362, train balanced acc: 0.24, test loss: 1.314, test balanced acc: 0.31, lr: 0.00025


Training model:  43%|████▎     | 13/30 [48:23<1:03:16, 223.33s/it]

KeyboardInterrupt



In [16]:
wav2vec, train_l, val_l, train_acc, val_acc = train_model(wav2vec, train_dataloader, val_dataloader, epochs=30, lr=0.001, device=DEVICE)

Training model:   3%|▎         | 1/30 [02:43<1:19:05, 163.63s/it]

Epoch 0: train loss: 1.349, train balanced acc: 0.25, test loss: 1.294, test balanced acc: 0.31, lr: 0.001


Training model:   7%|▋         | 2/30 [05:28<1:16:42, 164.38s/it]

Epoch 1: train loss: 1.351, train balanced acc: 0.24, test loss: 1.335, test balanced acc: 0.31, lr: 0.001


Training model:  10%|█         | 3/30 [08:14<1:14:21, 165.24s/it]

Epoch 2: train loss: 1.345, train balanced acc: 0.25, test loss: 1.311, test balanced acc: 0.31, lr: 0.001


Training model:  13%|█▎        | 4/30 [11:04<1:12:24, 167.11s/it]

Epoch 3: train loss: 1.334, train balanced acc: 0.26, test loss: 1.300, test balanced acc: 0.31, lr: 0.0005


Training model:  17%|█▋        | 5/30 [13:59<1:10:44, 169.79s/it]

Epoch 4: train loss: 1.347, train balanced acc: 0.24, test loss: 1.311, test balanced acc: 0.31, lr: 0.0005


Training model:  20%|██        | 6/30 [16:52<1:08:22, 170.92s/it]

Epoch 5: train loss: 1.331, train balanced acc: 0.26, test loss: 1.289, test balanced acc: 0.31, lr: 0.0005


Training model:  23%|██▎       | 7/30 [19:44<1:05:38, 171.23s/it]

Epoch 6: train loss: 1.324, train balanced acc: 0.27, test loss: 1.282, test balanced acc: 0.31, lr: 0.0005


Training model:  27%|██▋       | 8/30 [22:33<1:02:32, 170.58s/it]

Epoch 7: train loss: 1.339, train balanced acc: 0.25, test loss: 1.305, test balanced acc: 0.31, lr: 0.0005


Training model:  30%|███       | 9/30 [25:23<59:36, 170.29s/it]  

Epoch 8: train loss: 1.336, train balanced acc: 0.26, test loss: 1.285, test balanced acc: 0.31, lr: 0.0005


Training model:  33%|███▎      | 10/30 [28:14<56:53, 170.66s/it]

Epoch 9: train loss: 1.328, train balanced acc: 0.26, test loss: 1.283, test balanced acc: 0.31, lr: 0.00025


Training model:  37%|███▋      | 11/30 [31:07<54:12, 171.21s/it]

Epoch 10: train loss: 1.343, train balanced acc: 0.24, test loss: 1.291, test balanced acc: 0.31, lr: 0.00025


Training model:  40%|████      | 12/30 [33:57<51:17, 170.95s/it]

Epoch 11: train loss: 1.332, train balanced acc: 0.25, test loss: 1.281, test balanced acc: 0.31, lr: 0.00025


Training model:  43%|████▎     | 13/30 [36:44<48:06, 169.78s/it]

Epoch 12: train loss: 1.329, train balanced acc: 0.26, test loss: 1.282, test balanced acc: 0.31, lr: 0.00025


Training model:  47%|████▋     | 14/30 [39:31<45:01, 168.85s/it]

Epoch 13: train loss: 1.353, train balanced acc: 0.23, test loss: 1.301, test balanced acc: 0.31, lr: 0.00025


Training model:  50%|█████     | 15/30 [42:17<42:03, 168.21s/it]

Epoch 14: train loss: 1.339, train balanced acc: 0.25, test loss: 1.293, test balanced acc: 0.31, lr: 0.000125


Training model:  53%|█████▎    | 16/30 [45:04<39:06, 167.64s/it]

Epoch 15: train loss: 1.337, train balanced acc: 0.25, test loss: 1.290, test balanced acc: 0.31, lr: 0.000125


Training model:  57%|█████▋    | 17/30 [47:51<36:16, 167.39s/it]

Epoch 16: train loss: 1.345, train balanced acc: 0.24, test loss: 1.294, test balanced acc: 0.31, lr: 0.000125


Training model:  60%|██████    | 18/30 [50:40<33:37, 168.15s/it]

Epoch 17: train loss: 1.344, train balanced acc: 0.24, test loss: 1.294, test balanced acc: 0.31, lr: 6.25e-05


Training model:  63%|██████▎   | 19/30 [53:32<30:59, 169.06s/it]

Epoch 18: train loss: 1.344, train balanced acc: 0.24, test loss: 1.281, test balanced acc: 0.31, lr: 6.25e-05


Training model:  67%|██████▋   | 20/30 [56:23<28:18, 169.86s/it]

Epoch 19: train loss: 1.243, train balanced acc: 0.34, test loss: 1.037, test balanced acc: 0.48, lr: 6.25e-05


Training model:  70%|███████   | 21/30 [59:15<25:34, 170.48s/it]

Epoch 20: train loss: 1.149, train balanced acc: 0.40, test loss: 1.085, test balanced acc: 0.45, lr: 6.25e-05


Training model:  73%|███████▎  | 22/30 [1:02:07<22:46, 170.77s/it]

Epoch 21: train loss: 1.106, train balanced acc: 0.42, test loss: 1.077, test balanced acc: 0.46, lr: 6.25e-05


Training model:  77%|███████▋  | 23/30 [1:05:00<20:00, 171.52s/it]

Epoch 22: train loss: 1.091, train balanced acc: 0.43, test loss: 1.088, test balanced acc: 0.50, lr: 3.125e-05


Training model:  80%|████████  | 24/30 [1:07:52<17:10, 171.74s/it]

Epoch 23: train loss: 1.088, train balanced acc: 0.44, test loss: 1.022, test balanced acc: 0.47, lr: 3.125e-05


Training model:  83%|████████▎ | 25/30 [1:10:45<14:19, 171.99s/it]

Epoch 24: train loss: 1.056, train balanced acc: 0.45, test loss: 1.041, test balanced acc: 0.47, lr: 3.125e-05


Training model:  87%|████████▋ | 26/30 [1:13:38<11:28, 172.19s/it]

Epoch 25: train loss: 1.043, train balanced acc: 0.45, test loss: 1.079, test balanced acc: 0.47, lr: 3.125e-05


Training model:  90%|█████████ | 27/30 [1:16:31<08:37, 172.59s/it]

Epoch 26: train loss: 1.043, train balanced acc: 0.47, test loss: 0.999, test balanced acc: 0.48, lr: 3.125e-05


Training model:  93%|█████████▎| 28/30 [1:19:25<05:45, 172.89s/it]

Epoch 27: train loss: 1.060, train balanced acc: 0.45, test loss: 1.034, test balanced acc: 0.48, lr: 3.125e-05


Training model:  97%|█████████▋| 29/30 [1:22:18<02:52, 172.91s/it]

Epoch 28: train loss: 1.036, train balanced acc: 0.46, test loss: 1.065, test balanced acc: 0.50, lr: 3.125e-05


Training model: 100%|██████████| 30/30 [1:25:12<00:00, 170.41s/it]

Epoch 29: train loss: 1.029, train balanced acc: 0.47, test loss: 1.148, test balanced acc: 0.48, lr: 1.5625e-05





In [17]:
wav2vec, train_l, val_l, train_acc, val_acc = train_model(wav2vec, train_dataloader, val_dataloader, epochs=60, lr=1.5625e-05, device=DEVICE)

Training model:   2%|▏         | 1/60 [02:44<2:42:02, 164.78s/it]

Epoch 0: train loss: 1.011, train balanced acc: 0.49, test loss: 1.105, test balanced acc: 0.49, lr: 1.5625e-05


Training model:   3%|▎         | 2/60 [05:32<2:40:57, 166.50s/it]

Epoch 1: train loss: 1.017, train balanced acc: 0.47, test loss: 1.117, test balanced acc: 0.49, lr: 1.5625e-05


Training model:   5%|▌         | 3/60 [08:19<2:38:25, 166.77s/it]

Epoch 2: train loss: 0.994, train balanced acc: 0.50, test loss: 1.100, test balanced acc: 0.49, lr: 1.5625e-05


Training model:   7%|▋         | 4/60 [11:07<2:36:08, 167.30s/it]

Epoch 3: train loss: 0.975, train balanced acc: 0.51, test loss: 1.056, test balanced acc: 0.50, lr: 1.5625e-05


Training model:   8%|▊         | 5/60 [13:59<2:34:58, 169.06s/it]

Epoch 4: train loss: 0.991, train balanced acc: 0.50, test loss: 1.069, test balanced acc: 0.50, lr: 1.5625e-05


Training model:  10%|█         | 6/60 [16:50<2:32:30, 169.45s/it]

Epoch 5: train loss: 0.954, train balanced acc: 0.52, test loss: 1.062, test balanced acc: 0.49, lr: 1.5625e-05


Training model:  12%|█▏        | 7/60 [19:44<2:31:04, 171.03s/it]

Epoch 6: train loss: 0.965, train balanced acc: 0.51, test loss: 1.093, test balanced acc: 0.49, lr: 7.8125e-06


Training model:  13%|█▎        | 8/60 [22:33<2:27:50, 170.59s/it]

Epoch 7: train loss: 0.953, train balanced acc: 0.51, test loss: 1.094, test balanced acc: 0.50, lr: 7.8125e-06


Training model:  15%|█▌        | 9/60 [25:23<2:24:42, 170.25s/it]

Epoch 8: train loss: 0.943, train balanced acc: 0.52, test loss: 1.069, test balanced acc: 0.49, lr: 7.8125e-06


Training model:  17%|█▋        | 10/60 [28:13<2:21:45, 170.10s/it]

Epoch 9: train loss: 0.948, train balanced acc: 0.52, test loss: 1.073, test balanced acc: 0.50, lr: 3.90625e-06


Training model:  18%|█▊        | 11/60 [31:04<2:19:08, 170.37s/it]

Epoch 10: train loss: 0.966, train balanced acc: 0.51, test loss: 1.059, test balanced acc: 0.50, lr: 3.90625e-06


Training model:  20%|██        | 12/60 [33:53<2:16:06, 170.13s/it]

Epoch 11: train loss: 0.948, train balanced acc: 0.52, test loss: 1.056, test balanced acc: 0.50, lr: 3.90625e-06


Training model:  22%|██▏       | 13/60 [36:42<2:13:00, 169.79s/it]

Epoch 12: train loss: 0.929, train balanced acc: 0.52, test loss: 1.075, test balanced acc: 0.50, lr: 1.953125e-06


Training model:  23%|██▎       | 14/60 [39:32<2:10:12, 169.83s/it]

Epoch 13: train loss: 0.944, train balanced acc: 0.52, test loss: 1.068, test balanced acc: 0.50, lr: 1.953125e-06


Training model:  25%|██▌       | 15/60 [42:23<2:07:40, 170.24s/it]

Epoch 14: train loss: 0.920, train balanced acc: 0.54, test loss: 1.071, test balanced acc: 0.50, lr: 1.953125e-06


Training model:  27%|██▋       | 16/60 [45:13<2:04:43, 170.08s/it]

Epoch 15: train loss: 0.948, train balanced acc: 0.51, test loss: 1.059, test balanced acc: 0.51, lr: 9.765625e-07


Training model:  28%|██▊       | 17/60 [48:03<2:01:51, 170.04s/it]

Epoch 16: train loss: 0.918, train balanced acc: 0.53, test loss: 1.072, test balanced acc: 0.50, lr: 9.765625e-07


Training model:  30%|███       | 18/60 [50:53<1:59:02, 170.05s/it]

Epoch 17: train loss: 0.898, train balanced acc: 0.54, test loss: 1.068, test balanced acc: 0.50, lr: 9.765625e-07


Training model:  32%|███▏      | 19/60 [53:45<1:56:36, 170.65s/it]

Epoch 18: train loss: 0.918, train balanced acc: 0.53, test loss: 1.067, test balanced acc: 0.50, lr: 4.8828125e-07


Training model:  33%|███▎      | 20/60 [56:36<1:53:43, 170.59s/it]

Epoch 19: train loss: 0.907, train balanced acc: 0.56, test loss: 1.068, test balanced acc: 0.50, lr: 4.8828125e-07


Training model:  35%|███▌      | 21/60 [59:26<1:50:43, 170.36s/it]

Epoch 20: train loss: 0.919, train balanced acc: 0.53, test loss: 1.070, test balanced acc: 0.50, lr: 4.8828125e-07


Training model:  37%|███▋      | 22/60 [1:02:17<1:48:03, 170.63s/it]

Epoch 21: train loss: 0.931, train balanced acc: 0.54, test loss: 1.069, test balanced acc: 0.51, lr: 2.44140625e-07


Training model:  38%|███▊      | 23/60 [1:05:08<1:45:14, 170.66s/it]

Epoch 22: train loss: 0.921, train balanced acc: 0.54, test loss: 1.072, test balanced acc: 0.51, lr: 2.44140625e-07


Training model:  40%|████      | 24/60 [1:07:57<1:42:14, 170.40s/it]

Epoch 23: train loss: 0.922, train balanced acc: 0.54, test loss: 1.071, test balanced acc: 0.50, lr: 2.44140625e-07


Training model:  42%|████▏     | 25/60 [1:10:48<1:39:28, 170.54s/it]

Epoch 24: train loss: 0.915, train balanced acc: 0.54, test loss: 1.069, test balanced acc: 0.51, lr: 1.220703125e-07


Training model:  43%|████▎     | 26/60 [1:13:40<1:36:51, 170.92s/it]

Epoch 25: train loss: 0.931, train balanced acc: 0.54, test loss: 1.068, test balanced acc: 0.51, lr: 1.220703125e-07


Training model:  45%|████▌     | 27/60 [1:16:31<1:34:00, 170.93s/it]

Epoch 26: train loss: 0.916, train balanced acc: 0.54, test loss: 1.070, test balanced acc: 0.51, lr: 1.220703125e-07


Training model:  47%|████▋     | 28/60 [1:19:24<1:31:27, 171.47s/it]

Epoch 27: train loss: 0.933, train balanced acc: 0.53, test loss: 1.071, test balanced acc: 0.50, lr: 6.103515625e-08


Training model:  48%|████▊     | 29/60 [1:22:17<1:28:55, 172.10s/it]

Epoch 28: train loss: 0.926, train balanced acc: 0.53, test loss: 1.072, test balanced acc: 0.50, lr: 6.103515625e-08


Training model:  50%|█████     | 30/60 [1:25:08<1:25:50, 171.70s/it]

Epoch 29: train loss: 0.928, train balanced acc: 0.53, test loss: 1.071, test balanced acc: 0.50, lr: 6.103515625e-08


Training model:  52%|█████▏    | 31/60 [1:27:59<1:22:54, 171.54s/it]

Epoch 30: train loss: 0.922, train balanced acc: 0.54, test loss: 1.071, test balanced acc: 0.50, lr: 3.0517578125e-08


Training model:  53%|█████▎    | 32/60 [1:30:51<1:20:01, 171.50s/it]

Epoch 31: train loss: 0.920, train balanced acc: 0.54, test loss: 1.071, test balanced acc: 0.50, lr: 3.0517578125e-08


Training model:  55%|█████▌    | 33/60 [1:33:40<1:16:53, 170.88s/it]

Epoch 32: train loss: 0.917, train balanced acc: 0.53, test loss: 1.071, test balanced acc: 0.50, lr: 3.0517578125e-08


Training model:  57%|█████▋    | 34/60 [1:36:29<1:13:49, 170.38s/it]

Epoch 33: train loss: 0.913, train balanced acc: 0.54, test loss: 1.071, test balanced acc: 0.50, lr: 1.52587890625e-08


Training model:  58%|█████▊    | 35/60 [1:39:21<1:11:10, 170.83s/it]

Epoch 34: train loss: 0.929, train balanced acc: 0.54, test loss: 1.071, test balanced acc: 0.50, lr: 1.52587890625e-08


Training model:  60%|██████    | 36/60 [1:42:11<1:08:12, 170.54s/it]

Epoch 35: train loss: 0.929, train balanced acc: 0.53, test loss: 1.071, test balanced acc: 0.50, lr: 1.52587890625e-08


Training model:  62%|██████▏   | 37/60 [1:45:01<1:05:18, 170.38s/it]

Epoch 36: train loss: 0.911, train balanced acc: 0.53, test loss: 1.071, test balanced acc: 0.50, lr: 1.52587890625e-08


Training model:  63%|██████▎   | 38/60 [1:47:51<1:02:27, 170.33s/it]

Epoch 37: train loss: 0.932, train balanced acc: 0.52, test loss: 1.071, test balanced acc: 0.50, lr: 1.52587890625e-08


Training model:  65%|██████▌   | 39/60 [1:50:42<59:37, 170.35s/it]  

Epoch 38: train loss: 0.924, train balanced acc: 0.53, test loss: 1.071, test balanced acc: 0.50, lr: 1.52587890625e-08


Training model:  67%|██████▋   | 40/60 [1:53:33<56:54, 170.72s/it]

Epoch 39: train loss: 0.952, train balanced acc: 0.53, test loss: 1.071, test balanced acc: 0.50, lr: 1.52587890625e-08


Training model:  68%|██████▊   | 41/60 [1:56:22<53:54, 170.22s/it]

Epoch 40: train loss: 0.936, train balanced acc: 0.51, test loss: 1.071, test balanced acc: 0.50, lr: 1.52587890625e-08


Training model:  70%|███████   | 42/60 [1:59:11<50:54, 169.72s/it]

Epoch 41: train loss: 0.925, train balanced acc: 0.54, test loss: 1.071, test balanced acc: 0.50, lr: 1.52587890625e-08


Training model:  72%|███████▏  | 43/60 [2:02:00<48:05, 169.71s/it]

Epoch 42: train loss: 0.923, train balanced acc: 0.53, test loss: 1.071, test balanced acc: 0.50, lr: 1.52587890625e-08


Training model:  73%|███████▎  | 44/60 [2:04:55<45:39, 171.19s/it]

Epoch 43: train loss: 0.926, train balanced acc: 0.53, test loss: 1.071, test balanced acc: 0.50, lr: 1.52587890625e-08


Training model:  75%|███████▌  | 45/60 [2:07:48<42:53, 171.59s/it]

Epoch 44: train loss: 0.943, train balanced acc: 0.51, test loss: 1.071, test balanced acc: 0.50, lr: 1.52587890625e-08


Training model:  77%|███████▋  | 46/60 [2:10:41<40:11, 172.25s/it]

Epoch 45: train loss: 0.926, train balanced acc: 0.53, test loss: 1.071, test balanced acc: 0.51, lr: 1.52587890625e-08


Training model:  78%|███████▊  | 47/60 [2:13:39<37:39, 173.82s/it]

Epoch 46: train loss: 0.935, train balanced acc: 0.52, test loss: 1.071, test balanced acc: 0.51, lr: 1.52587890625e-08


Training model:  80%|████████  | 48/60 [2:16:34<34:51, 174.27s/it]

Epoch 47: train loss: 0.914, train balanced acc: 0.54, test loss: 1.071, test balanced acc: 0.51, lr: 1.52587890625e-08


Training model:  82%|████████▏ | 49/60 [2:19:29<31:59, 174.52s/it]

Epoch 48: train loss: 0.923, train balanced acc: 0.54, test loss: 1.071, test balanced acc: 0.51, lr: 1.52587890625e-08


Training model:  83%|████████▎ | 50/60 [2:22:25<29:07, 174.77s/it]

Epoch 49: train loss: 0.934, train balanced acc: 0.53, test loss: 1.071, test balanced acc: 0.51, lr: 1.52587890625e-08


Training model:  83%|████████▎ | 50/60 [2:23:44<28:44, 172.49s/it]

KeyboardInterrupt



In [18]:
def test_model(model, test_dataset, max_seq=420):
    model = model.to(DEVICE)
        
    model.eval()
    preds = []
    targets = []
    with torch.no_grad():
        for features, target in tqdm(test_dataset):
            features = features.to(DEVICE)
            target = target.to(DEVICE)
            if model.__class__.__name__ == "Wav2vecClassifier":
                label = model(features).logits.to("cpu").detach().numpy().squeeze()
            elif model.__class__.__name__ == "MobileNet":
                label = model(features[None, ...]).to("cpu").detach().numpy().squeeze()
            else:
                padded_features = torch.zeros(1, features.shape[0], max_seq, device=DEVICE)
                padded_features[0, :, :features.shape[-1]] = features[..., :max_seq]
                label = model(padded_features).to("cpu").detach().numpy().squeeze()
            preds.append(label.argmax(axis=-1))
            targets.append(target.to("cpu").item())

    preds = np.array(preds)
    # print(targets)
    print(classification_report(targets, preds))    

    return preds

In [19]:
test_model(wav2vec, test_dataset)

100%|██████████| 806/806 [00:14<00:00, 54.41it/s]


              precision    recall  f1-score   support

           0       0.80      0.89      0.85       104
           1       0.24      0.07      0.11       185
           2       0.41      0.73      0.52       268
           3       0.60      0.37      0.46       249

    accuracy                           0.49       806
   macro avg       0.51      0.52      0.48       806
weighted avg       0.48      0.49      0.45       806



array([0, 2, 3, 3, 2, 0, 2, 3, 2, 0, 2, 2, 3, 3, 2, 2, 3, 2, 3, 2, 2, 2,
       3, 2, 2, 2, 2, 0, 3, 3, 2, 2, 0, 2, 2, 3, 0, 2, 3, 3, 2, 2, 0, 2,
       2, 3, 2, 2, 2, 2, 3, 3, 1, 2, 2, 2, 2, 2, 2, 0, 2, 3, 3, 1, 0, 2,
       2, 2, 2, 3, 2, 2, 2, 2, 2, 0, 2, 2, 2, 3, 3, 1, 2, 2, 2, 2, 2, 2,
       2, 2, 1, 3, 3, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 3, 3, 3, 2, 2, 3, 2,
       2, 2, 2, 0, 2, 0, 0, 2, 0, 3, 2, 2, 0, 2, 2, 2, 3, 2, 0, 0, 0, 3,
       2, 0, 3, 2, 0, 3, 0, 2, 1, 2, 2, 2, 2, 2, 3, 3, 2, 2, 2, 0, 2, 3,
       2, 1, 1, 0, 2, 1, 1, 2, 0, 2, 2, 0, 0, 2, 2, 1, 2, 2, 2, 2, 2, 0,
       3, 3, 2, 0, 0, 2, 2, 1, 2, 2, 2, 0, 3, 3, 3, 0, 3, 0, 2, 2, 2, 3,
       2, 0, 3, 2, 2, 0, 2, 2, 2, 1, 2, 3, 2, 2, 3, 0, 2, 2, 2, 2, 3, 3,
       2, 3, 0, 2, 2, 2, 2, 2, 0, 0, 1, 2, 0, 2, 2, 2, 2, 2, 2, 0, 1, 2,
       2, 0, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 0, 2, 0, 1, 2, 2, 3, 1,
       2, 2, 2, 2, 2, 1, 2, 0, 3, 0, 3, 3, 2, 2, 2, 2, 0, 2, 0, 3, 0, 0,
       2, 3, 2, 0, 2, 3, 3, 0, 3, 3, 1, 2, 1, 2, 3,

In [45]:
augmentations = Compose([
                AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.02, p=0.5),
                TimeStretch(min_rate=0.3, max_rate=1.9, p=0.5),
                PitchShift(min_semitones=-6, max_semitones=6, p=0.5),
                Shift(min_shift=-0.6, max_shift=0.6, p=0.5),
                Gain(min_gain_db=-14, max_gain_db=14, p=0.5),
            ])

In [46]:
train_dataset = AphasiaDatasetMFCC(os.path.join(DATA_DIR, "train_filenames_mc_1.csv"), VOICES_DIR, target_sample_rate=8_000, mfcc=MFCC, n_mels=128, fft_size=512,
                 hop_length=256, win_length=512, min_duration=10, max_duration=15, file_format="wav", rirs_dir=RIR_DIR, transforms=augmentations)
test_dataset = AphasiaDatasetMFCC(os.path.join(DATA_DIR, "val_filenames_mc_1.csv"), VOICES_DIR, target_sample_rate=8_000, mfcc=MFCC, n_mels=128, fft_size=512,
                 hop_length=256, win_length=512, min_duration=10, max_duration=15, file_format="wav")
val_dataset = AphasiaDatasetMFCC(os.path.join(DATA_DIR, "test_filenames_mc_1.csv"), VOICES_DIR, target_sample_rate=8_000, mfcc=MFCC, n_mels=128, fft_size=512,
                 hop_length=256, win_length=512, min_duration=10, max_duration=15, file_format="wav")

# Балансировка классов для train
train_labels = [label for _, label in train_dataset.data]
class_counts = Counter(train_labels)
if len(class_counts) < 4:
    raise ValueError("Один из классов отсутствует в тренировочном наборе")

class_weights = {label: 1.0 / count for label, count in class_counts.items()}
weights = [class_weights[label] for _, label in train_dataset.data]
train_sampler = WeightedRandomSampler(weights, num_samples=len(train_dataset), replacement=True)

In [47]:
def pad_sequence(batch):
    if not batch:
        return torch.zeros(0), torch.zeros(0)
    
    seq, labels = zip(*batch)
    max_len = max(s.shape[1] for s in seq)

    padded = torch.zeros(len(seq), MFCC, SEQUENCE_LENGTH)
    for i, s in enumerate(seq):
        padded[i, ..., :s.shape[-1]] = s[..., :SEQUENCE_LENGTH]
    
    return padded, torch.stack(labels) 

In [48]:
train_dataloader = DataLoader(train_dataset, batch_size=16, sampler=train_sampler, collate_fn=pad_sequence, drop_last=True, num_workers=6)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=pad_sequence, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=pad_sequence, drop_last=True, num_workers=6)

In [49]:
def train_model(model, dl_train, dl_val, epochs=1, lr=0.001, device="cpu"):
      
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(weight=torch.tensor(list(class_weights.values()))).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)
        
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=3, threshold=1e-3)
        
    train_loss_list = []
    val_loss_list = []
    train_acc_list = []
    val_acc_list = []
    for epoch in tqdm(range(epochs), desc="Training model"):
        model.train()
        total_train_loss = 0
        total_val_loss = 0
        
        train_acc = []
        val_acc = []
        for features, target in dl_train:
            features, target = features.to(device), target.to(device)
            optimizer.zero_grad()
            
            output = model(features.unsqueeze(1)).squeeze()

            preds = torch.argmax(output, dim=1).cpu().detach().numpy()
            train_acc.append(accuracy_score(target.cpu().detach().numpy(), preds))
            
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.detach().item()
        
        avg_train_acc = np.stack(train_acc, axis=0).mean()

        avg_train_loss = total_train_loss / len(train_dataloader)
        train_loss_list.append(avg_train_loss)
        train_acc_list.append(avg_train_acc)
                
        model.eval()
        
        with torch.no_grad():
            for features, target in dl_val:
                features, target = features.to(device), target.to(device)
                
                output = model(features.unsqueeze(1)).squeeze()
                
                preds = torch.argmax(output, dim=1).cpu().detach().numpy()
                val_acc.append(accuracy_score(target.cpu().detach().numpy(), preds))
                loss = criterion(output, target)
                total_val_loss += loss.detach().item()
        
        avg_val_acc = np.stack(val_acc, axis=0).mean()
        avg_val_loss = total_val_loss / len(dl_val)
        val_loss_list.append(avg_val_loss)
        val_acc_list.append(avg_val_acc)
        
        if scheduler:
            try:
                scheduler.step()
            except:
                scheduler.step(avg_val_loss)
        
        # if epoch % 10 == 0:
        tqdm.write(f"Epoch {epoch}: train loss: {avg_train_loss:.3f}, train balanced acc: {avg_train_acc:.2f}, test loss: {avg_val_loss:.3f}, test balanced acc: {avg_val_acc:.2f}, lr: {optimizer.param_groups[0]['lr']}")

            
    return model, train_loss_list, val_loss_list, train_acc_list, val_acc_list

In [50]:
from models.cnn import MobileNet

cnn = MobileNet(num_classes=4)

In [51]:
cnn, train_l, val_l, train_accuracy, val_accuracy = train_model(cnn, train_dataloader, val_dataloader, epochs=10, lr=1e-4, device=DEVICE)

Training model:  10%|█         | 1/10 [00:02<00:21,  2.41s/it]

Epoch 0: train loss: 1.352, train balanced acc: 0.24, test loss: 1.366, test balanced acc: 0.31, lr: 0.0001


Training model:  20%|██        | 2/10 [00:04<00:19,  2.39s/it]

Epoch 1: train loss: 1.269, train balanced acc: 0.33, test loss: 1.331, test balanced acc: 0.31, lr: 0.0001


Training model:  30%|███       | 3/10 [00:07<00:16,  2.40s/it]

Epoch 2: train loss: 1.132, train balanced acc: 0.45, test loss: 1.306, test balanced acc: 0.31, lr: 0.0001


Training model:  40%|████      | 4/10 [00:09<00:14,  2.41s/it]

Epoch 3: train loss: 0.987, train balanced acc: 0.54, test loss: 1.277, test balanced acc: 0.31, lr: 0.0001


Training model:  50%|█████     | 5/10 [00:12<00:12,  2.42s/it]

Epoch 4: train loss: 0.851, train balanced acc: 0.63, test loss: 1.318, test balanced acc: 0.32, lr: 0.0001


Training model:  60%|██████    | 6/10 [00:14<00:09,  2.42s/it]

Epoch 5: train loss: 0.735, train balanced acc: 0.68, test loss: 1.612, test balanced acc: 0.31, lr: 0.0001


Training model:  70%|███████   | 7/10 [00:16<00:07,  2.43s/it]

Epoch 6: train loss: 0.664, train balanced acc: 0.72, test loss: 1.686, test balanced acc: 0.30, lr: 0.0001


Training model:  80%|████████  | 8/10 [00:19<00:04,  2.43s/it]

Epoch 7: train loss: 0.617, train balanced acc: 0.74, test loss: 1.815, test balanced acc: 0.31, lr: 5e-05


Training model:  90%|█████████ | 9/10 [00:21<00:02,  2.43s/it]

Epoch 8: train loss: 0.509, train balanced acc: 0.79, test loss: 1.879, test balanced acc: 0.31, lr: 5e-05


Training model: 100%|██████████| 10/10 [00:24<00:00,  2.42s/it]

Epoch 9: train loss: 0.393, train balanced acc: 0.84, test loss: 2.085, test balanced acc: 0.31, lr: 5e-05





In [69]:
test_model(cnn, test_dataset)

100%|██████████| 807/807 [00:01<00:00, 436.07it/s]

              precision    recall  f1-score   support

           0       0.00      0.00      0.00       106
           1       0.08      0.01      0.01       186
           2       0.00      0.00      0.00       270
           3       0.30      0.97      0.46       245

    accuracy                           0.29       807
   macro avg       0.09      0.24      0.12       807
weighted avg       0.11      0.29      0.14       807






array([3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,

In [5]:
train_dataset = AphasiaDatasetMFCC(os.path.join(DATA_DIR, "train_filenames_mc_1.csv"), VOICES_DIR, target_sample_rate=8_000, mfcc=MFCC, n_mels=128, fft_size=512,
                 hop_length=256, win_length=512, min_duration=10, max_duration=15, file_format="wav", transforms=augmentations)#  noise_dir=NORM_DIR, rirs_dir=RIR_DIR, snr=(-5, 20))
test_dataset = AphasiaDatasetMFCC(os.path.join(DATA_DIR, "val_filenames_mc_1.csv"), VOICES_DIR, target_sample_rate=8_000, mfcc=MFCC, n_mels=128, fft_size=512,
                 hop_length=256, win_length=512, min_duration=10, max_duration=15, file_format="wav")
val_dataset = AphasiaDatasetMFCC(os.path.join(DATA_DIR, "test_filenames_mc_1.csv"), VOICES_DIR, target_sample_rate=8_000, mfcc=MFCC, n_mels=128, fft_size=512,
                 hop_length=256, win_length=512, min_duration=10, max_duration=15, file_format="wav")
# Балансировка классов для train
train_labels = [label for _, label in train_dataset.data]
class_counts = Counter(train_labels)
if len(class_counts) < 2:
    raise ValueError("Один из классов отсутствует в тренировочном наборе")

class_weights = {label: 1.0 / count for label, count in class_counts.items()}
weights = [class_weights[label] for _, label in train_dataset.data]
train_sampler = WeightedRandomSampler(weights, num_samples=len(train_dataset), replacement=True)

In [6]:
def pad_sequence(batch):
    if not batch:
        return torch.zeros(0), torch.zeros(0)
    
    seq, labels = zip(*batch)
    # print(seq[1], labels)
    max_len = max(s.shape[1] for s in seq)
    # print(max_len)

    # print(seq[0].shape)
    padded = torch.zeros(len(seq), MFCC, SEQUENCE_LENGTH)
    for i, s in enumerate(seq):
        # print(s.shape)
        padded[i, ..., :s.shape[-1]] = s[..., :SEQUENCE_LENGTH]
    
    return padded, torch.stack(labels) 

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler, collate_fn=pad_sequence, drop_last=True, num_workers=6)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=pad_sequence, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=pad_sequence, drop_last=True, num_workers=6)

In [16]:
def train_model(model, dl_train, dl_val, epochs=1, lr=0.001, device="cpu"):
      
    model = model.to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)
        
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=3, threshold=1e-3)
    # scheduler = step_scheduler(optimizer, 30) 
        
    train_loss_list = []
    val_loss_list = []
    train_acc_list = []
    val_acc_list = []
    for epoch in tqdm(range(epochs), desc="Training model"):
        model.train()
        total_train_loss = 0
        total_val_loss = 0
        
        train_acc = []
        val_acc = []
        for features, target in dl_train:
            features, target = features.to(device), target.to(device)
            optimizer.zero_grad()
            
            output = model(features).squeeze()
            # print(output.shape)
            # train_acc.append(torch.tensor(torch.argmax(output, dim=1) == target).cpu().detach().numpy())
            preds = torch.argmax(output, dim=1).cpu().detach().numpy()
            train_acc.append(accuracy_score(target.cpu().detach().numpy(), preds))
            
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.detach().item()
        # print(f"LR before: {optimizer.param_groups[0]['lr']}")
        # scheduler.step(epoch=epoch)
        # print(f"LR after: {optimizer.param_groups[0]['lr']}")
        
        avg_train_acc = np.stack(train_acc, axis=0).mean()

        avg_train_loss = total_train_loss / len(train_dataloader)
        train_loss_list.append(avg_train_loss)
        train_acc_list.append(avg_train_acc)
                
        model.eval()
        
        with torch.no_grad():
            for features, target in dl_val:
                features, target = features.to(device), target.to(device)
                
                output = model(features).squeeze()
                
                preds = torch.argmax(output, dim=1).cpu().detach().numpy()
                val_acc.append(accuracy_score(target.cpu().detach().numpy(), preds))
                loss = criterion(output, target)
                total_val_loss += loss.detach().item()
        
        avg_val_acc = np.stack(val_acc, axis=0).mean()
        avg_val_loss = total_val_loss / len(dl_val)
        val_loss_list.append(avg_val_loss)
        val_acc_list.append(avg_val_acc)
        
        if scheduler:
            try:
                scheduler.step()
            except:
                scheduler.step(avg_val_loss)
        
        # if epoch % 10 == 0:
        tqdm.write(f"Epoch {epoch}: train loss: {avg_train_loss:.3f}, train balanced acc: {avg_train_acc:.2f}, test loss: {avg_val_loss:.3f}, test balanced acc: {avg_val_acc:.2f}, lr: {optimizer.param_groups[0]['lr']}")

            
    return model, train_loss_list, val_loss_list, train_acc_list, val_acc_list

In [17]:
from models.swishnet import SwishNet

swishnet = SwishNet(MFCC, 4, input_size=SEQUENCE_LENGTH, dropout_rate=0.55)

In [18]:
swishnet, train_l, val_l, train_accuracy, val_accuracy = train_model(swishnet, train_dataloader, val_dataloader, epochs=100, lr=1e-4, device=DEVICE)

Training model:   1%|          | 1/100 [00:00<01:08,  1.45it/s]

Epoch 0: train loss: 1.384, train balanced acc: 0.28, test loss: 1.384, test balanced acc: 0.31, lr: 0.0001


Training model:   2%|▏         | 2/100 [00:01<01:13,  1.33it/s]

Epoch 1: train loss: 1.383, train balanced acc: 0.28, test loss: 1.383, test balanced acc: 0.32, lr: 0.0001


Training model:   3%|▎         | 3/100 [00:02<01:10,  1.37it/s]

Epoch 2: train loss: 1.379, train balanced acc: 0.30, test loss: 1.382, test balanced acc: 0.32, lr: 0.0001


Training model:   4%|▍         | 4/100 [00:02<01:09,  1.38it/s]

Epoch 3: train loss: 1.376, train balanced acc: 0.31, test loss: 1.381, test balanced acc: 0.32, lr: 0.0001


Training model:   5%|▌         | 5/100 [00:03<01:08,  1.39it/s]

Epoch 4: train loss: 1.370, train balanced acc: 0.32, test loss: 1.379, test balanced acc: 0.33, lr: 0.0001


Training model:   6%|▌         | 6/100 [00:04<01:05,  1.43it/s]

Epoch 5: train loss: 1.367, train balanced acc: 0.33, test loss: 1.377, test balanced acc: 0.33, lr: 0.0001


Training model:   7%|▋         | 7/100 [00:04<01:04,  1.45it/s]

Epoch 6: train loss: 1.361, train balanced acc: 0.33, test loss: 1.377, test balanced acc: 0.33, lr: 0.0001


Training model:   8%|▊         | 8/100 [00:05<01:03,  1.46it/s]

Epoch 7: train loss: 1.358, train balanced acc: 0.34, test loss: 1.374, test balanced acc: 0.33, lr: 0.0001


Training model:   9%|▉         | 9/100 [00:06<01:01,  1.48it/s]

Epoch 8: train loss: 1.352, train balanced acc: 0.34, test loss: 1.372, test balanced acc: 0.34, lr: 0.0001


Training model:  10%|█         | 10/100 [00:07<01:02,  1.44it/s]

Epoch 9: train loss: 1.347, train balanced acc: 0.34, test loss: 1.372, test balanced acc: 0.33, lr: 0.0001


Training model:  11%|█         | 11/100 [00:07<01:01,  1.45it/s]

Epoch 10: train loss: 1.351, train balanced acc: 0.34, test loss: 1.369, test balanced acc: 0.34, lr: 0.0001


Training model:  12%|█▏        | 12/100 [00:08<00:59,  1.48it/s]

Epoch 11: train loss: 1.339, train balanced acc: 0.35, test loss: 1.369, test balanced acc: 0.33, lr: 0.0001


Training model:  13%|█▎        | 13/100 [00:08<00:56,  1.53it/s]

Epoch 12: train loss: 1.338, train balanced acc: 0.37, test loss: 1.367, test balanced acc: 0.34, lr: 0.0001


Training model:  14%|█▍        | 14/100 [00:09<00:55,  1.55it/s]

Epoch 13: train loss: 1.331, train balanced acc: 0.38, test loss: 1.368, test balanced acc: 0.33, lr: 0.0001


Training model:  15%|█▌        | 15/100 [00:10<00:54,  1.55it/s]

Epoch 14: train loss: 1.329, train balanced acc: 0.37, test loss: 1.364, test balanced acc: 0.34, lr: 0.0001


Training model:  16%|█▌        | 16/100 [00:10<00:53,  1.58it/s]

Epoch 15: train loss: 1.318, train balanced acc: 0.38, test loss: 1.368, test balanced acc: 0.33, lr: 0.0001


Training model:  17%|█▋        | 17/100 [00:11<00:52,  1.59it/s]

Epoch 16: train loss: 1.321, train balanced acc: 0.38, test loss: 1.362, test balanced acc: 0.34, lr: 0.0001


Training model:  18%|█▊        | 18/100 [00:12<00:50,  1.61it/s]

Epoch 17: train loss: 1.317, train balanced acc: 0.37, test loss: 1.359, test balanced acc: 0.34, lr: 0.0001


Training model:  19%|█▉        | 19/100 [00:12<00:50,  1.61it/s]

Epoch 18: train loss: 1.313, train balanced acc: 0.38, test loss: 1.364, test balanced acc: 0.33, lr: 0.0001


Training model:  20%|██        | 20/100 [00:13<00:49,  1.62it/s]

Epoch 19: train loss: 1.300, train balanced acc: 0.39, test loss: 1.363, test balanced acc: 0.33, lr: 0.0001


Training model:  21%|██        | 21/100 [00:13<00:48,  1.63it/s]

Epoch 20: train loss: 1.303, train balanced acc: 0.39, test loss: 1.365, test balanced acc: 0.33, lr: 0.0001


Training model:  22%|██▏       | 22/100 [00:14<00:47,  1.65it/s]

Epoch 21: train loss: 1.290, train balanced acc: 0.41, test loss: 1.360, test balanced acc: 0.33, lr: 5e-05


Training model:  23%|██▎       | 23/100 [00:15<00:47,  1.64it/s]

Epoch 22: train loss: 1.292, train balanced acc: 0.40, test loss: 1.362, test balanced acc: 0.33, lr: 5e-05


Training model:  24%|██▍       | 24/100 [00:15<00:45,  1.65it/s]

Epoch 23: train loss: 1.283, train balanced acc: 0.41, test loss: 1.362, test balanced acc: 0.33, lr: 5e-05


Training model:  25%|██▌       | 25/100 [00:16<00:45,  1.64it/s]

Epoch 24: train loss: 1.287, train balanced acc: 0.40, test loss: 1.360, test balanced acc: 0.33, lr: 5e-05


Training model:  26%|██▌       | 26/100 [00:16<00:44,  1.65it/s]

Epoch 25: train loss: 1.290, train balanced acc: 0.40, test loss: 1.362, test balanced acc: 0.33, lr: 2.5e-05


Training model:  27%|██▋       | 27/100 [00:17<00:45,  1.62it/s]

Epoch 26: train loss: 1.288, train balanced acc: 0.41, test loss: 1.359, test balanced acc: 0.33, lr: 2.5e-05


Training model:  28%|██▊       | 28/100 [00:18<00:44,  1.63it/s]

Epoch 27: train loss: 1.294, train balanced acc: 0.39, test loss: 1.363, test balanced acc: 0.33, lr: 2.5e-05


Training model:  29%|██▉       | 29/100 [00:18<00:44,  1.61it/s]

Epoch 28: train loss: 1.281, train balanced acc: 0.41, test loss: 1.362, test balanced acc: 0.33, lr: 2.5e-05


Training model:  30%|███       | 30/100 [00:19<00:44,  1.59it/s]

Epoch 29: train loss: 1.276, train balanced acc: 0.42, test loss: 1.360, test balanced acc: 0.33, lr: 1.25e-05


Training model:  31%|███       | 31/100 [00:20<00:43,  1.58it/s]

Epoch 30: train loss: 1.288, train balanced acc: 0.40, test loss: 1.359, test balanced acc: 0.33, lr: 1.25e-05


Training model:  32%|███▏      | 32/100 [00:20<00:43,  1.57it/s]

Epoch 31: train loss: 1.289, train balanced acc: 0.39, test loss: 1.361, test balanced acc: 0.33, lr: 1.25e-05


Training model:  33%|███▎      | 33/100 [00:21<00:42,  1.57it/s]

Epoch 32: train loss: 1.276, train balanced acc: 0.43, test loss: 1.362, test balanced acc: 0.33, lr: 1.25e-05


Training model:  34%|███▍      | 34/100 [00:21<00:41,  1.59it/s]

Epoch 33: train loss: 1.274, train balanced acc: 0.43, test loss: 1.362, test balanced acc: 0.33, lr: 6.25e-06


Training model:  35%|███▌      | 35/100 [00:22<00:40,  1.62it/s]

Epoch 34: train loss: 1.281, train balanced acc: 0.41, test loss: 1.362, test balanced acc: 0.33, lr: 6.25e-06


Training model:  36%|███▌      | 36/100 [00:23<00:39,  1.63it/s]

Epoch 35: train loss: 1.271, train balanced acc: 0.42, test loss: 1.361, test balanced acc: 0.33, lr: 6.25e-06


Training model:  37%|███▋      | 37/100 [00:23<00:38,  1.64it/s]

Epoch 36: train loss: 1.284, train balanced acc: 0.41, test loss: 1.361, test balanced acc: 0.33, lr: 6.25e-06


Training model:  38%|███▊      | 38/100 [00:24<00:37,  1.64it/s]

Epoch 37: train loss: 1.273, train balanced acc: 0.43, test loss: 1.362, test balanced acc: 0.33, lr: 3.125e-06


Training model:  39%|███▉      | 39/100 [00:24<00:36,  1.65it/s]

Epoch 38: train loss: 1.279, train balanced acc: 0.41, test loss: 1.362, test balanced acc: 0.33, lr: 3.125e-06


Training model:  40%|████      | 40/100 [00:25<00:36,  1.66it/s]

Epoch 39: train loss: 1.273, train balanced acc: 0.41, test loss: 1.361, test balanced acc: 0.33, lr: 3.125e-06


Training model:  41%|████      | 41/100 [00:26<00:36,  1.60it/s]

Epoch 40: train loss: 1.277, train balanced acc: 0.43, test loss: 1.361, test balanced acc: 0.33, lr: 3.125e-06


Training model:  42%|████▏     | 42/100 [00:26<00:35,  1.61it/s]

Epoch 41: train loss: 1.284, train balanced acc: 0.40, test loss: 1.361, test balanced acc: 0.33, lr: 1.5625e-06


Training model:  43%|████▎     | 43/100 [00:27<00:35,  1.61it/s]

Epoch 42: train loss: 1.275, train balanced acc: 0.42, test loss: 1.361, test balanced acc: 0.33, lr: 1.5625e-06


Training model:  44%|████▍     | 44/100 [00:28<00:34,  1.61it/s]

Epoch 43: train loss: 1.282, train balanced acc: 0.41, test loss: 1.361, test balanced acc: 0.33, lr: 1.5625e-06


Training model:  45%|████▌     | 45/100 [00:28<00:34,  1.61it/s]

Epoch 44: train loss: 1.278, train balanced acc: 0.41, test loss: 1.362, test balanced acc: 0.33, lr: 1.5625e-06


Training model:  46%|████▌     | 46/100 [00:29<00:33,  1.62it/s]

Epoch 45: train loss: 1.276, train balanced acc: 0.42, test loss: 1.361, test balanced acc: 0.33, lr: 7.8125e-07


Training model:  47%|████▋     | 47/100 [00:29<00:32,  1.62it/s]

Epoch 46: train loss: 1.267, train balanced acc: 0.44, test loss: 1.362, test balanced acc: 0.33, lr: 7.8125e-07


Training model:  48%|████▊     | 48/100 [00:30<00:32,  1.62it/s]

Epoch 47: train loss: 1.278, train balanced acc: 0.41, test loss: 1.361, test balanced acc: 0.33, lr: 7.8125e-07


Training model:  49%|████▉     | 49/100 [00:31<00:31,  1.63it/s]

Epoch 48: train loss: 1.277, train balanced acc: 0.41, test loss: 1.361, test balanced acc: 0.33, lr: 7.8125e-07


Training model:  50%|█████     | 50/100 [00:31<00:30,  1.63it/s]

Epoch 49: train loss: 1.284, train balanced acc: 0.40, test loss: 1.359, test balanced acc: 0.33, lr: 3.90625e-07


Training model:  51%|█████     | 51/100 [00:32<00:31,  1.57it/s]

Epoch 50: train loss: 1.274, train balanced acc: 0.42, test loss: 1.361, test balanced acc: 0.33, lr: 3.90625e-07


Training model:  52%|█████▏    | 52/100 [00:33<00:30,  1.58it/s]

Epoch 51: train loss: 1.289, train balanced acc: 0.39, test loss: 1.359, test balanced acc: 0.33, lr: 3.90625e-07


Training model:  53%|█████▎    | 53/100 [00:33<00:29,  1.59it/s]

Epoch 52: train loss: 1.267, train balanced acc: 0.43, test loss: 1.362, test balanced acc: 0.33, lr: 3.90625e-07


Training model:  54%|█████▍    | 54/100 [00:34<00:28,  1.59it/s]

Epoch 53: train loss: 1.273, train balanced acc: 0.43, test loss: 1.360, test balanced acc: 0.33, lr: 1.953125e-07


Training model:  55%|█████▌    | 55/100 [00:34<00:28,  1.56it/s]

Epoch 54: train loss: 1.277, train balanced acc: 0.40, test loss: 1.361, test balanced acc: 0.33, lr: 1.953125e-07


Training model:  56%|█████▌    | 56/100 [00:35<00:27,  1.57it/s]

Epoch 55: train loss: 1.271, train balanced acc: 0.42, test loss: 1.362, test balanced acc: 0.33, lr: 1.953125e-07


Training model:  57%|█████▋    | 57/100 [00:36<00:27,  1.58it/s]

Epoch 56: train loss: 1.266, train balanced acc: 0.44, test loss: 1.362, test balanced acc: 0.33, lr: 1.953125e-07


Training model:  58%|█████▊    | 58/100 [00:36<00:26,  1.58it/s]

Epoch 57: train loss: 1.278, train balanced acc: 0.41, test loss: 1.361, test balanced acc: 0.33, lr: 9.765625e-08


Training model:  59%|█████▉    | 59/100 [00:37<00:26,  1.56it/s]

Epoch 58: train loss: 1.276, train balanced acc: 0.41, test loss: 1.360, test balanced acc: 0.33, lr: 9.765625e-08


Training model:  60%|██████    | 60/100 [00:38<00:25,  1.57it/s]

Epoch 59: train loss: 1.277, train balanced acc: 0.42, test loss: 1.361, test balanced acc: 0.33, lr: 9.765625e-08


Training model:  61%|██████    | 61/100 [00:38<00:24,  1.59it/s]

Epoch 60: train loss: 1.280, train balanced acc: 0.40, test loss: 1.362, test balanced acc: 0.33, lr: 9.765625e-08


Training model:  62%|██████▏   | 62/100 [00:39<00:23,  1.62it/s]

Epoch 61: train loss: 1.276, train balanced acc: 0.42, test loss: 1.361, test balanced acc: 0.33, lr: 4.8828125e-08


Training model:  63%|██████▎   | 63/100 [00:39<00:22,  1.62it/s]

Epoch 62: train loss: 1.272, train balanced acc: 0.43, test loss: 1.360, test balanced acc: 0.33, lr: 4.8828125e-08


Training model:  64%|██████▍   | 64/100 [00:40<00:22,  1.61it/s]

Epoch 63: train loss: 1.273, train balanced acc: 0.43, test loss: 1.360, test balanced acc: 0.33, lr: 4.8828125e-08


Training model:  65%|██████▌   | 65/100 [00:41<00:21,  1.63it/s]

Epoch 64: train loss: 1.273, train balanced acc: 0.41, test loss: 1.361, test balanced acc: 0.33, lr: 4.8828125e-08


Training model:  66%|██████▌   | 66/100 [00:41<00:20,  1.65it/s]

Epoch 65: train loss: 1.277, train balanced acc: 0.41, test loss: 1.360, test balanced acc: 0.33, lr: 2.44140625e-08


Training model:  67%|██████▋   | 67/100 [00:42<00:20,  1.62it/s]

Epoch 66: train loss: 1.273, train balanced acc: 0.40, test loss: 1.360, test balanced acc: 0.33, lr: 2.44140625e-08


Training model:  68%|██████▊   | 68/100 [00:43<00:20,  1.58it/s]

Epoch 67: train loss: 1.277, train balanced acc: 0.42, test loss: 1.360, test balanced acc: 0.33, lr: 2.44140625e-08


Training model:  69%|██████▉   | 69/100 [00:43<00:19,  1.60it/s]

Epoch 68: train loss: 1.271, train balanced acc: 0.41, test loss: 1.362, test balanced acc: 0.33, lr: 2.44140625e-08


Training model:  70%|███████   | 70/100 [00:44<00:18,  1.59it/s]

Epoch 69: train loss: 1.282, train balanced acc: 0.41, test loss: 1.361, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  71%|███████   | 71/100 [00:44<00:18,  1.61it/s]

Epoch 70: train loss: 1.266, train balanced acc: 0.43, test loss: 1.362, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  72%|███████▏  | 72/100 [00:45<00:17,  1.60it/s]

Epoch 71: train loss: 1.276, train balanced acc: 0.42, test loss: 1.362, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  73%|███████▎  | 73/100 [00:46<00:18,  1.49it/s]

Epoch 72: train loss: 1.281, train balanced acc: 0.41, test loss: 1.359, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  74%|███████▍  | 74/100 [00:47<00:17,  1.49it/s]

Epoch 73: train loss: 1.275, train balanced acc: 0.41, test loss: 1.361, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  75%|███████▌  | 75/100 [00:47<00:16,  1.52it/s]

Epoch 74: train loss: 1.271, train balanced acc: 0.43, test loss: 1.361, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  76%|███████▌  | 76/100 [00:48<00:15,  1.54it/s]

Epoch 75: train loss: 1.278, train balanced acc: 0.41, test loss: 1.361, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  77%|███████▋  | 77/100 [00:48<00:14,  1.57it/s]

Epoch 76: train loss: 1.278, train balanced acc: 0.41, test loss: 1.360, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  78%|███████▊  | 78/100 [00:49<00:14,  1.56it/s]

Epoch 77: train loss: 1.278, train balanced acc: 0.41, test loss: 1.361, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  79%|███████▉  | 79/100 [00:50<00:13,  1.52it/s]

Epoch 78: train loss: 1.278, train balanced acc: 0.41, test loss: 1.360, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  80%|████████  | 80/100 [00:50<00:12,  1.55it/s]

Epoch 79: train loss: 1.280, train balanced acc: 0.41, test loss: 1.360, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  81%|████████  | 81/100 [00:51<00:12,  1.56it/s]

Epoch 80: train loss: 1.269, train balanced acc: 0.41, test loss: 1.362, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  82%|████████▏ | 82/100 [00:52<00:11,  1.59it/s]

Epoch 81: train loss: 1.277, train balanced acc: 0.41, test loss: 1.360, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  83%|████████▎ | 83/100 [00:52<00:10,  1.58it/s]

Epoch 82: train loss: 1.271, train balanced acc: 0.43, test loss: 1.362, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  84%|████████▍ | 84/100 [00:53<00:10,  1.58it/s]

Epoch 83: train loss: 1.272, train balanced acc: 0.42, test loss: 1.360, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  85%|████████▌ | 85/100 [00:53<00:09,  1.59it/s]

Epoch 84: train loss: 1.277, train balanced acc: 0.43, test loss: 1.362, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  86%|████████▌ | 86/100 [00:54<00:08,  1.61it/s]

Epoch 85: train loss: 1.267, train balanced acc: 0.43, test loss: 1.361, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  87%|████████▋ | 87/100 [00:55<00:08,  1.61it/s]

Epoch 86: train loss: 1.279, train balanced acc: 0.40, test loss: 1.361, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  88%|████████▊ | 88/100 [00:55<00:07,  1.62it/s]

Epoch 87: train loss: 1.278, train balanced acc: 0.40, test loss: 1.361, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  89%|████████▉ | 89/100 [00:56<00:06,  1.62it/s]

Epoch 88: train loss: 1.278, train balanced acc: 0.40, test loss: 1.362, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  90%|█████████ | 90/100 [00:57<00:06,  1.62it/s]

Epoch 89: train loss: 1.271, train balanced acc: 0.42, test loss: 1.362, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  91%|█████████ | 91/100 [00:57<00:05,  1.60it/s]

Epoch 90: train loss: 1.274, train balanced acc: 0.42, test loss: 1.361, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  92%|█████████▏| 92/100 [00:58<00:04,  1.60it/s]

Epoch 91: train loss: 1.276, train balanced acc: 0.42, test loss: 1.360, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  93%|█████████▎| 93/100 [00:58<00:04,  1.62it/s]

Epoch 92: train loss: 1.270, train balanced acc: 0.43, test loss: 1.360, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  94%|█████████▍| 94/100 [00:59<00:03,  1.60it/s]

Epoch 93: train loss: 1.278, train balanced acc: 0.42, test loss: 1.362, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  95%|█████████▌| 95/100 [01:00<00:03,  1.56it/s]

Epoch 94: train loss: 1.270, train balanced acc: 0.41, test loss: 1.361, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  96%|█████████▌| 96/100 [01:00<00:02,  1.58it/s]

Epoch 95: train loss: 1.273, train balanced acc: 0.40, test loss: 1.360, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  97%|█████████▋| 97/100 [01:01<00:01,  1.59it/s]

Epoch 96: train loss: 1.270, train balanced acc: 0.42, test loss: 1.361, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  98%|█████████▊| 98/100 [01:02<00:01,  1.58it/s]

Epoch 97: train loss: 1.280, train balanced acc: 0.41, test loss: 1.360, test balanced acc: 0.33, lr: 1.220703125e-08


Training model:  99%|█████████▉| 99/100 [01:02<00:00,  1.58it/s]

Epoch 98: train loss: 1.275, train balanced acc: 0.40, test loss: 1.361, test balanced acc: 0.33, lr: 1.220703125e-08


Training model: 100%|██████████| 100/100 [01:03<00:00,  1.58it/s]

Epoch 99: train loss: 1.278, train balanced acc: 0.41, test loss: 1.362, test balanced acc: 0.33, lr: 1.220703125e-08



