# Reqs

In [2]:
# !pip install comet_ml

In [3]:
import pandas as pd
import numpy as np
import time
import gc
import pickle

from collections import Counter
from tqdm import tqdm
from sklearn.metrics import f1_score, confusion_matrix
from sklearn.model_selection import train_test_split
from typing import Any, Dict, Union, List, Optional, Tuple
from IPython.display import Audio

import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers import Wav2Vec2Model, Wav2Vec2Config

from comet_ml import Experiment, init
from comet_ml.integration.pytorch import log_model, watch

from utils import seed_everything, empty_cache

device = torch.device('cuda' if torch.cuda.is_available()  else 'cpu')

In [4]:
# !gdown https://drive.google.com/drive/folders/1eGRZFSyvzbLLiY_zSmZJNuxwbzOWVZPe -O data --folder

In [5]:
base_dir = ".."
path_to_data_folder = f'{base_dir}/data_processed/words'
weights_folder = f"{base_dir}/weights"

train = pd.read_parquet(f'{path_to_data_folder}/train.parquet')
val = pd.read_parquet(f'{path_to_data_folder}/val.parquet')
test = pd.read_parquet(f'{path_to_data_folder}/test.parquet')

train

Unnamed: 0,file,text,start,end,confidence,label,data_type,subset,р_count,г_count,path
0,611b27e7-0019-4fc6-9622-21d9647c45f0.mp3,нет,0.00,0.60,0.504,1.0,first_stage_train,train,0,0,data_wav/train/611b27e7-0019-4fc6-9622-21d9647...
1,611b27e7-0019-4fc6-9622-21d9647c45f0.mp3,такого,0.80,1.04,0.990,1.0,first_stage_train,train,0,1,data_wav/train/611b27e7-0019-4fc6-9622-21d9647...
2,611b27e7-0019-4fc6-9622-21d9647c45f0.mp3,планируется,1.32,1.88,0.971,1.0,first_stage_train,train,1,0,data_wav/train/611b27e7-0019-4fc6-9622-21d9647...
3,67465147-b88c-4acd-bb91-a78340a9bde7.mp3,посмотреть,0.00,1.06,0.728,0.0,first_stage_train,train,1,0,data_wav/train/67465147-b88c-4acd-bb91-a78340a...
4,67465147-b88c-4acd-bb91-a78340a9bde7.mp3,пока,1.06,1.24,0.644,0.0,first_stage_train,train,0,0,data_wav/train/67465147-b88c-4acd-bb91-a78340a...
...,...,...,...,...,...,...,...,...,...,...,...
111405,67ee2e50-8c8a-49d8-b8f8-75cdda336930.mp3,потом,11.12,12.40,0.994,1.0,final_train,train,0,0,data_wav/train/67ee2e50-8c8a-49d8-b8f8-75cdda3...
111406,67ee2e50-8c8a-49d8-b8f8-75cdda336930.mp3,оплату,12.40,13.18,0.944,1.0,final_train,train,0,0,data_wav/train/67ee2e50-8c8a-49d8-b8f8-75cdda3...
111407,67ee2e50-8c8a-49d8-b8f8-75cdda336930.mp3,выгодную,13.32,13.84,0.677,1.0,final_train,train,0,1,data_wav/train/67ee2e50-8c8a-49d8-b8f8-75cdda3...
111408,67ee2e50-8c8a-49d8-b8f8-75cdda336930.mp3,работу,13.84,14.08,0.999,1.0,final_train,train,1,0,data_wav/train/67ee2e50-8c8a-49d8-b8f8-75cdda3...


In [6]:
with open(f"{base_dir}/data_processed/target_letters.pkl", "rb") as f:
    target_letters = pickle.load(f)

In [7]:
%%time

# with open('data/train_arrays.pkl', 'rb') as f1, open('data/val_arrays.pkl', 'rb') as f2, open('data/final_test_arrays.pkl', 'rb') as f3:
with open(f'{path_to_data_folder}/train_arrays.pkl', 'rb') as f1, open(f'{path_to_data_folder}/val_arrays.pkl', 'rb') as f2,  \
                                                                    open(f'{path_to_data_folder}/test_arrays.pkl', 'rb') as f3:
    train_arrays = pickle.load(f1)
    val_arrays = pickle.load(f2)
    test_arrays = pickle.load(f3)

CPU times: total: 7.36 s
Wall time: 7.4 s


In [8]:
def compute_class_weights_sqrt(y, degree=0.5):
    n_classes = y.nunique()
    
    weights = len(y) / (n_classes * np.bincount(y).astype(np.float64))
    weights = weights ** degree
    
    return weights

In [9]:
train['label'].value_counts(normalize=True)

label
0.0    0.633530
1.0    0.335964
2.0    0.027213
3.0    0.003293
Name: proportion, dtype: float64

In [10]:
class cfg:

    model_type = "wav2vec"
    model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-russian"
    weights_folder = weights_folder

    batch_size = 32 if torch.cuda.is_available() else 2

    max_length = 16000 * 5

    target_letters = target_letters
    letter_count_weights = {}
    letters_num_classes = {}

    disorders_class_weights = torch.tensor(compute_class_weights_sqrt(train['label'].dropna()),
                                 device=device, dtype=torch.float32)

    label_smoothing_pretrain = 0.0
    label_smoothing_train = 0.0

    linear_probing_frac = 0.1 # Часть первой эпохи, в течение которой все веса, кроме головы, замораживаются
    zero_epoch_evaluation_frac = 0.1 # На какой части данных оценивать модель перед обучением (в начале 0-й эпохи)

    head_dim = 256
    dropout = 0.25
    lr_pretrain = 1e-4
    lr_train = 1e-4
    num_epochs_pretrain = 10
    num_epochs_train = 10
    metric_computation_times_per_epoch_train = 4
    metric_computation_times_per_epoch_val = 1

    early_stopping_pretrain = 1
    early_stopping_train = 3

cfg.save_model_name = f'{cfg.model_type}'
cfg.save_model_name

'wav2vec'

# Dataset

In [12]:
def get_rare_classes(data, target_letters):
    rare_borders = {}
    for letter in target_letters:
        count_sizes = data[f"{letter}_count"].value_counts(normalize=True).to_dict()
        rare_classes = [i for i, size in count_sizes.items() if size < 0.01]
        if rare_classes:
            rare_borders[letter] = min(rare_classes) - 1
        else:
            rare_borders[letter] = max(count_sizes.keys())

    return rare_borders

In [13]:
rare_borders = get_rare_classes(train, target_letters=target_letters)
rare_borders

{'р': 2, 'г': 1}

In [14]:
letter_count_weights = {}
letters_num_classes = {}

for letter in target_letters:
    train[f"{letter}_count"] = train[f"{letter}_count"].apply(lambda x: min(x, rare_borders[letter]))
    val[f"{letter}_count"] = val[f"{letter}_count"].apply(lambda x: min(x, rare_borders[letter]))
    test[f"{letter}_count"] = test[f"{letter}_count"].apply(lambda x: min(x, rare_borders[letter]))

    letter_count_weights[letter] = torch.tensor(compute_class_weights_sqrt(train[f"{letter}_count"]),
                             device=device, dtype=torch.float32)
    letters_num_classes[letter] = train[f"{letter}_count"].nunique()

cfg.letter_count_weights = letter_count_weights
cfg.letters_num_classes = letters_num_classes

In [15]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, arrays, target_letters=target_letters):

        # self.scale_factor = 2**15 # На что делим при переводе аудио из int во float
        df = df.reset_index(drop=True)

        self.files = df['file'].values
        self.arrays = arrays # [array/self.scale_factor for array in arrays]
        self.texts = df['text'].values
        self.target_letters = target_letters
        for letter in target_letters:
            setattr(self, f"{letter}_count", df[f"{letter}_count"])
            
        self.labels = df['label'].values

        self.text_lengths = [len(text) for text in df['text']]


    def __len__(self):
        return len(self.arrays)

    def __getitem__(self, idx):


        file = self.files[idx]
        array = self.arrays[idx]
        text = self.texts[idx]
        label = self.labels[idx]# if self.labels is not None else None


        batch = {
            'file': file,
            'input_values': array,
            'text': text,
            'label': label,
            'text_length': self.text_lengths[idx],
        }

        for letter in target_letters:
            batch[f"{letter}_counts"] = getattr(self, f"{letter}_count")[idx]

        return batch

train.fillna({"label": -100}, inplace=True)

dataset_train = CustomDataset(train, train_arrays, target_letters=target_letters)
dataset_val = CustomDataset(val, val_arrays, target_letters=target_letters)

In [16]:
class DataCollator:
    def __init__(self, cfg):
        self.cfg = cfg
        self.max_length = cfg.max_length

    def pad_arrays(self, arrays):
        max_batch_length = max(len(array) for array in arrays)
    
        arrays = torch.stack([torch.cat([array, torch.zeros(max_batch_length - len(array))]) for array in arrays])
        return arrays

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:

        arrays = [torch.tensor(feature["input_values"][:self.max_length], dtype=torch.float32) for feature in features]
        arrays = self.pad_arrays(arrays)

        labels = [feature["label"] for feature in features]

        batch = {'input_values': arrays}

        batch["labels"] = torch.tensor(labels, dtype=torch.long) if not np.isnan(labels[0]) else None

        for letter in cfg.target_letters:
            batch[f"{letter}_counts"] = torch.tensor([feature[f"{letter}_counts"] for feature in features], dtype=torch.long)
            
        batch["files"] = [feature["file"] for feature in features]

        return batch


data_collator = DataCollator(cfg=cfg)

In [17]:
dataloader_train = DataLoader(dataset_train, batch_size=cfg.batch_size, collate_fn=data_collator, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=cfg.batch_size*2, collate_fn=data_collator, shuffle=False)

In [18]:
for b in tqdm(dataloader_train):
    break
b

  0%|                                                                                         | 0/3468 [00:00<?, ?it/s]


{'input_values': tensor([[-3.7964e-02, -3.3569e-02, -2.7802e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 1.8127e-02,  1.2024e-02,  4.1809e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 1.1200e-02,  1.3336e-02,  1.0254e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [-9.1553e-04, -2.6245e-03, -3.3569e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-3.0518e-04, -9.1553e-05,  2.1362e-04,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 7.8247e-02,  7.9315e-02,  7.7087e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]),
 'labels': tensor([-100,    1,    0, -100, -100, -100,    0,    0,    1,    0,    1,    1,
            0, -100, -100,    0, -100,    0,    0,    0,    0,    0, -100,    0,
         -100,    0, -100,    1,    1, -100,    1, -100]),
 'р_counts': tensor([1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0,
         0, 0,

# Model

In [20]:
class DisordersDetector(nn.Module):

    def __init__(self, cfg, stage):
        super().__init__()
        self.cfg = cfg
        if stage == "pretrain":
            self.backbone = Wav2Vec2Model.from_pretrained(cfg.model_name)
        else:
            # Не подгружаем модель с HF, ведь мы подгрузим претрейн модель
            model_cfg = Wav2Vec2Config.from_pretrained(cfg.model_name)
            self.backbone = Wav2Vec2Model(model_cfg)
            
        self.stage = stage

        dropout = cfg.dropout
        hidden_dim = 1024
        head_dim = cfg.head_dim

        if stage == "pretrain":
            self.letter_count_heads = nn.ModuleDict({
                f"{letter}_count_head": nn.Sequential(
                    nn.Linear(hidden_dim, head_dim),
                    nn.Dropout(dropout),
                    nn.Linear(head_dim, cfg.letters_num_classes[letter])
                )
                for letter in cfg.target_letters
            })

            
            # for letter in cfg.target_letters:
            #     num_classes = cfg.letters_num_classes[letter]
            #     setattr(self, f"{letter}_count_head", nn.Sequential(nn.Linear(hidden_size, head_dim),
            #                                     nn.Dropout(dropout),
            #                                     nn.Linear(head_dim, num_classes)))
        else:
            self.disorders_head = nn.Sequential(nn.Linear(hidden_dim, head_dim),
                                            nn.Dropout(dropout),
                                            nn.Linear(head_dim, len(cfg.disorders_class_weights)))

    def forward(self, x):

        hidden_state = self.backbone(x).last_hidden_state
        pooled_output = torch.mean(hidden_state, dim=1)

        output = {}

        if self.stage == 'pretrain':
            # for letter in self.cfg.target_letters:
            #     head = getattr(self, f"{letter}_count_head")
            #     output[f'{letter}_count_output'] = self.g_count_head(pooled_output)
            for letter_head, head in self.letter_count_heads.items():
                output[letter_head[0]] = head(pooled_output)

        else:

            output['disorders'] = self.disorders_head(pooled_output)

        return output

    def freeze_feature_extractor(self):
        self.backbone.feature_extractor._freeze_parameters()

    def freeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = False

    def unfreeze_backbone(self):
            for param in self.backbone.parameters():
                param.requires_grad = True

            if self.cfg.model_type == "wav2vec":
                self.freeze_feature_extractor()


# Functions for train

In [103]:
def get_metric_pretrain(all_predictions, all_targets, is_val=False):

    metrics = []
    
    for letter in all_predictions.keys():
        # print(all_predictions, letter)
        if len(all_predictions[letter]) == 0: continue
        predictions = np.array(all_predictions[letter]).argmax(axis=-1)
        targets = np.array(all_targets[letter])

        metric = f1_score(targets, predictions, average='macro')
        metrics.append(metric)
        # metrics['letter'] = letter
        print(f"{letter} metric\n", metric)
        if is_val:
            print(confusion_matrix(targets, predictions))

    return metrics, np.mean(metrics)


def get_metric_train(predictions, targets):
    predictions = np.array(predictions['disorders']).argmax(axis=-1)
    targets = np.array(targets['disorders'])

    metric = f1_score(targets, predictions, average='macro')

    print(confusion_matrix(targets, predictions))

    return metric

In [104]:
def model_step(model, stage, batch, cfg, predictions, targets, all_files, criterions):
    loss = 0

    if stage == 'pretrain':

        output = model(batch['input_values'].to(model.backbone.device))

        for letter in cfg.target_letters:
            loss += criterions[f"{letter}_count_head"](output[letter], batch[f'{letter}_counts'].to(model.backbone.device))

            predictions[f'{letter}_count'].extend(F.softmax(output[letter], dim=-1).detach().cpu().numpy())
            targets[f'{letter}_count'].extend(batch[f'{letter}_counts'].cpu().numpy().flatten())

        all_files.extend(np.array(batch['files']))

    else:

        batch_targets = batch['labels'].to(model.backbone.device)

        output = model(batch['input_values'].to(model.backbone.device))

        loss += criterions['disorders'](output['disorders'], batch_targets)

        predictions['disorders'].extend(F.softmax(output['disorders'], dim=-1).detach().cpu().numpy())
        targets['disorders'].extend(batch_targets.cpu().numpy().flatten())

        all_files.extend(np.array(batch['files']))


    return loss

In [105]:
def train_model(model, cfg, dataloader_train, dataloader_val, stage='pretrain'):

    if stage == 'pretrain':
        optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr_pretrain)
        num_epochs = cfg.num_epochs_pretrain
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr_train)
        num_epochs = cfg.num_epochs_train

    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)


    best_metric = -1
    best_epoch = -1

    if cfg.linear_probing_frac > 0.0:
        # Замораживаем тело в начале обучения
        model.freeze_backbone()

    num_batches = len(dataloader_train)
    unfreeze_backbone_batch = int(num_batches * min(cfg.linear_probing_frac, 1.0))

    criterions = {}
    if stage == 'pretrain':
        save_weights_name = f"{cfg.save_model_name}-pretrain"
        for letter in cfg.target_letters:
            criterions[f"{letter}_count_head"] = torch.nn.CrossEntropyLoss(weight=cfg.letter_count_weights[letter],
                                                      label_smoothing=cfg.label_smoothing_pretrain)

    elif stage == 'train':

        save_weights_name = f"{cfg.save_model_name}-train"
        
        criterions['disorders'] = torch.nn.CrossEntropyLoss(weight=cfg.disorders_class_weights,
                                                   label_smoothing=cfg.label_smoothing_train, reduction='mean')
        
    for epoch in range(num_epochs):
        if stage == "pretrain" and epoch > best_epoch + cfg.early_stopping_pretrain: # Метрика не улучшается долго
            break
        elif stage == "train" and epoch > best_epoch + cfg.early_stopping_train: # Метрика не улучшается долго
            break

        print('*'*50)
        print('EPOCH', epoch)
        print('*'*50)

        model.train()
        
        predictions = {'disorders': []}
        targets = {'disorders': []}

        for letter in cfg.target_letters:
            predictions[f'{letter}_count'] = []
            targets[f'{letter}_count'] = []
            

        all_files = []

        for batch_idx, batch in enumerate(tqdm(dataloader_train, total=len(dataloader_train), desc='Training')):


            if sanity_checking and batch_idx  >= 2: break
            if epoch == 0 and cfg.linear_probing_frac > 0.0 and batch_idx == unfreeze_backbone_batch:
                model.unfreeze_backbone()

            optimizer.zero_grad()

            loss = model_step(model, stage, batch, cfg, predictions,
                                              targets, all_files,
                                              criterions)

            experiment.log_metric("loss", loss, step=batch_idx + (num_batches * epoch))

            del batch
            torch.cuda.empty_cache()

            loss.backward()
            optimizer.step()


            if batch_idx > 0 and batch_idx % max(1, len(dataloader_train) // cfg.metric_computation_times_per_epoch_train) == 0:
                
                if stage == 'pretrain':
                    metrics, mean_metric = get_metric_pretrain(predictions, targets, is_val=False)
                    for letter, metric in zip(cfg.target_letters, metrics):
                        experiment.log_metric(f"f1_{letter}_count", metric, step=batch_idx + (num_batches * epoch))
                        
                    experiment.log_metric("f1_mean", mean_metric, step=batch_idx + (num_batches * epoch))
                    metric = mean_metric
                else:
                    metric = get_metric_train(predictions, targets)
                    experiment.log_metric("f1_disorders", metric, step=batch_idx + (num_batches * epoch))

                print(metric)


            
            if (batch_idx == 0 and epoch == 0):
                metric = evaluate(model, cfg, dataloader_val, criterions, stage=stage, is_beggining=True, step_idx=batch_idx + (num_batches*epoch))
                model.train()

            elif (batch_idx + 1) % (len(dataloader_train) // cfg.metric_computation_times_per_epoch_val) == 0:
                metric = evaluate(model, cfg, dataloader_val, criterions, stage=stage, step_idx=batch_idx + (num_batches*epoch))
                model.train()
                
                if metric > best_metric:
                    best_metric = metric
                    best_epoch = epoch
                    torch.save(model.state_dict(), f'{cfg.weights_folder}/{save_weights_name}.pt')


        if sanity_checking:
            break

        scheduler.step()


In [106]:
def evaluate(model, cfg, dataloader_val, criterions, is_beggining=False, limit_num_batches=-1, stage='pretrain', step_idx=0):



    num_batches = len(dataloader_val)
    zero_epoch_evaluation_batches = np.ceil(num_batches * min(cfg.zero_epoch_evaluation_frac, 1.0))

    model.eval()

    predictions = {'disorders': []}
    targets = {'disorders': []}

    for letter in cfg.target_letters:
        predictions[f'{letter}_count'] = []
        targets[f'{letter}_count'] = []

    all_files = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader_val):
            if sanity_checking and batch_idx  == 5: break
            if batch_idx == limit_num_batches: break # Оценивал по 50 батчам, т.к. иначе времени бы много занимало, можно исправить и оценивать по фулл даталоадеру, но раз в эпоху
            if is_beggining and batch_idx == zero_epoch_evaluation_batches: break

            loss = model_step(model, stage, batch, cfg, 
                              predictions, targets, 
                              all_files, criterions)


            del batch
            torch.cuda.empty_cache()

    print('\n VALIDATION')
    print('=*'*50)
    
    if stage == 'pretrain':
        metrics, mean_metric = get_metric_pretrain(predictions, targets, is_val=True)
        
        for letter, metric in zip(cfg.target_letters, metrics):
            experiment.log_metric(f"val_f1_{letter}_count", metric, step=step_idx)
            
        experiment.log_metric("val_f1_mean", mean_metric, step=step_idx)

        metric = mean_metric
    else:
        metric = get_metric_train(predictions, targets)
        experiment.log_metric("val_f1_disorders", metric, step=step_idx)

    print(metric)
    print('=*'*50)

    return metric

In [107]:
sanity_checking = False

# PreTraining

In [28]:
experiment = Experiment(
  api_key="rpI0PuxxYkKMtiy42g1oIfLI1",
  project_name="aiijc-final-pretrain",
  workspace="ugryumnik"
)

experiment.log_parameters(dict(vars(cfg)))

[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/ugryumnik/aiijc-final-pretrain/312cb9acf6a942fd810d67ca5b2acf82

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in 'C:\\Users\\User\\jupyter_lab\\aiijc_tbank\\train_notebooks' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.


In [29]:
empty_cache()

seed_everything(42)

model = DisordersDetector(cfg=cfg, stage='pretrain')
model.to(device)
if cfg.model_type == "wav2vec":
    model.freeze_feature_extractor()

In [30]:
len(dataset_train)

110974

In [31]:
train_model(model, cfg, dataloader_train, dataloader_val, stage='pretrain')

**************************************************
EPOCH 0
**************************************************


Training:   0%|                                                                     | 2/3468 [00:02<1:04:00,  1.11s/it]


 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
р_count metric
 0.30920009822963407
[[236   6   0]
 [ 69   3   0]
 [  6   0   0]]
г_count metric
 0.477124183006536
[[292   0]
 [ 28   0]]
0.393162140618085
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


Training:  25%|█████████████████▎                                                   | 868/3468 [04:08<20:47,  2.08it/s]

р_count metric
 0.4062452101035658
г_count metric
 0.5063225053557862
0.45628385772967595


Training:  50%|██████████████████████████████████                                  | 1735/3468 [09:23<13:53,  2.08it/s]

р_count metric
 0.506941099814998
г_count metric
 0.5962561732652625
0.5515986365401302


Training:  75%|███████████████████████████████████████████████████                 | 2602/3468 [14:38<06:13,  2.32it/s]

р_count metric
 0.5545223317861101
г_count metric
 0.6324155262431461
0.5934689290146281


Training: 100%|███████████████████████████████████████████████████████████████████▉| 3467/3468 [19:55<00:00,  3.13it/s]


 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
р_count metric
 0.7660659491158719
[[1898  126   18]
 [ 176  582   17]
 [  10   10   41]]
г_count metric
 0.7647133298023441
[[2444  180]
 [  77  177]]
0.765389639459108
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


Training: 100%|████████████████████████████████████████████████████████████████████| 3468/3468 [20:14<00:00,  2.86it/s]


**************************************************
EPOCH 1
**************************************************


Training:  25%|█████████████████▎                                                   | 868/3468 [05:07<19:21,  2.24it/s]

р_count metric
 0.6851089952851647
г_count metric
 0.7273139108021558
0.7062114530436603


Training:  50%|██████████████████████████████████                                  | 1735/3468 [10:18<10:03,  2.87it/s]

р_count metric
 0.691678778226183
г_count metric
 0.7348482957246811
0.7132635369754321


Training:  75%|███████████████████████████████████████████████████                 | 2602/3468 [15:37<06:15,  2.31it/s]

р_count metric
 0.6959897015022348
г_count metric
 0.7398648965058336
0.7179272990040342


Training: 100%|███████████████████████████████████████████████████████████████████▉| 3467/3468 [21:01<00:00,  2.84it/s]


 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
р_count metric
 0.8075735779660694
[[1947   88    7]
 [ 147  598   30]
 [   8    4   49]]
г_count metric
 0.8012781336310748
[[2524  100]
 [  88  166]]
0.8044258557985722
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


Training: 100%|████████████████████████████████████████████████████████████████████| 3468/3468 [21:20<00:00,  2.71it/s]


**************************************************
EPOCH 2
**************************************************


Training:  25%|█████████████████▎                                                   | 868/3468 [05:13<13:31,  3.20it/s]

р_count metric
 0.7480940529018283
г_count metric
 0.7869540300670861
0.7675240414844572


Training:  50%|██████████████████████████████████                                  | 1735/3468 [10:16<14:32,  1.99it/s]

р_count metric
 0.7421788901303765
г_count metric
 0.7828874944755628
0.7625331923029697


Training:  75%|███████████████████████████████████████████████████                 | 2602/3468 [15:38<04:59,  2.89it/s]

р_count metric
 0.7425995247793922
г_count metric
 0.7833372952703059
0.762968410024849


Training: 100%|███████████████████████████████████████████████████████████████████▉| 3467/3468 [21:12<00:00,  2.49it/s]


 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
р_count metric
 0.84571784109392
[[1900  140    2]
 [ 108  658    9]
 [   5   11   45]]
г_count metric
 0.8287225468268122
[[2562   62]
 [  89  165]]
0.8372201939603661
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


Training: 100%|████████████████████████████████████████████████████████████████████| 3468/3468 [21:32<00:00,  2.68it/s]


**************************************************
EPOCH 3
**************************************************


Training:  25%|█████████████████▎                                                   | 868/3468 [05:36<15:16,  2.84it/s]

р_count metric
 0.7811660694532504
г_count metric
 0.8159817587953467
0.7985739141242986


Training:  50%|██████████████████████████████████                                  | 1735/3468 [11:04<11:45,  2.46it/s]

р_count metric
 0.7782818660486367
г_count metric
 0.8125232452555559
0.7954025556520963


Training:  75%|███████████████████████████████████████████████████                 | 2602/3468 [16:44<05:53,  2.45it/s]

р_count metric
 0.7795350722791078
г_count metric
 0.808774153448008
0.7941546128635579


Training: 100%|███████████████████████████████████████████████████████████████████▉| 3467/3468 [22:25<00:00,  2.26it/s]


 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
р_count metric
 0.8484981445921239
[[1902  139    1]
 [ 100  665   10]
 [   6   10   45]]
г_count metric
 0.8285940196407633
[[2535   89]
 [  74  180]]
0.8385460821164437
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


Training: 100%|████████████████████████████████████████████████████████████████████| 3468/3468 [22:44<00:00,  2.54it/s]


**************************************************
EPOCH 4
**************************************************


Training:  25%|█████████████████▎                                                   | 868/3468 [05:28<18:57,  2.29it/s]

р_count metric
 0.8107262261233833
г_count metric
 0.8303043667698173
0.8205152964466003


Training:  50%|██████████████████████████████████                                  | 1735/3468 [11:06<16:38,  1.74it/s]

р_count metric
 0.8019005456650415
г_count metric
 0.8342780442239921
0.8180892949445169


Training:  75%|███████████████████████████████████████████████████                 | 2602/3468 [16:42<07:18,  1.98it/s]

р_count metric
 0.8028603798906789
г_count metric
 0.8346591990638679
0.8187597894772733


Training: 100%|████████████████████████████████████████████████████████████████████| 3468/3468 [23:01<00:00,  2.51it/s]


 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
р_count metric
 0.8173635275414481
[[1897  134   11]
 [  85  661   29]
 [   5    5   51]]
г_count metric
 0.8339517131191194
[[2542   82]
 [  74  180]]
0.8256576203302838
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*





In [32]:
log_model(experiment, model, "pretrain-stage")
experiment.end()

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : worthy_mammal_343
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/ugryumnik/aiijc-final-pretrain/312cb9acf6a942fd810d67ca5b2acf82
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     f1_mean [15]       : (0.45628385772967595, 0.8205152964466003)
[1;38;5;39mCOMET INFO:[0m     f1_г_count [15]    : (0.5063225053557862, 0.8346591990638679)
[1;38;5;39mCOMET INFO:[0m     f1_р_count [15]    : (0.4062452101035658, 0.8107262261233833)
[1;38;5;39mCOMET INFO:[0m     loss [17340]       : (0.10222

# Training

## Data Cleaning
Немного очистим от ненужных для обучения данных

In [35]:
train = train.reset_index(drop=True)
train.shape

(110974, 11)

In [36]:
train = train[train['label'] != -100] # Удаляем данные, для которых нет таргета
train.shape

(82607, 11)

In [37]:
classes_count = train['label'].value_counts(normalize=True)
rare_classes = classes_count[classes_count < 0.05].index.values
rare_classes

array([2., 3.])

In [38]:
target_letters_count = train[[f"{letter}_count" for letter in target_letters]].sum(axis=1)


ids_to_drop = train[(target_letters_count == 0) & (~train['label'].isin(rare_classes))].index
ids_to_drop, _ = train_test_split(ids_to_drop, test_size=0.05, random_state=42)


train.drop(index=ids_to_drop, inplace=True)
train.shape

(31575, 11)

In [39]:
train_arrays = [train_arrays[i] for i in train.index]
train.reset_index(drop=True, inplace=True)
assert train.shape[0] == len(train_arrays)

In [40]:
dataset_train = CustomDataset(train, train_arrays, target_letters=target_letters)
dataset_val = CustomDataset(val, val_arrays, target_letters=target_letters)
dataset_test = CustomDataset(test, test_arrays, target_letters=target_letters)


dataloader_train = DataLoader(dataset_train, batch_size=cfg.batch_size, collate_fn=data_collator, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=cfg.batch_size*2, collate_fn=data_collator, shuffle=False)
dataloader_test = DataLoader(dataset_test, batch_size=cfg.batch_size*2, collate_fn=data_collator, shuffle=False)

## Теперь обучим модель

In [42]:
experiment = Experiment(
  api_key="rpI0PuxxYkKMtiy42g1oIfLI1",
  project_name="aiijc-final-train",
  workspace="ugryumnik"
)

experiment.log_parameters(dict(vars(cfg)))

[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/ugryumnik/aiijc-final-train/c8ed27196b344832adb5dc7c854b785f

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in 'C:\\Users\\User\\jupyter_lab\\aiijc_tbank\\train_notebooks' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.


In [130]:
empty_cache()

seed_everything(42)

dataloader_train = DataLoader(dataset_train, batch_size=cfg.batch_size, collate_fn=data_collator, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=cfg.batch_size*2, collate_fn=data_collator, shuffle=False)

model = DisordersDetector(cfg=cfg, stage='train')
model.to(device)

current_model_dict = model.state_dict()
loaded_state_dict = torch.load(f"{cfg.weights_folder}/{cfg.save_model_name}-pretrain.pt", map_location=device, weights_only=True)

for k in current_model_dict.keys():
    if k in loaded_state_dict.keys():
        current_model_dict[k] = loaded_state_dict[k]

del loaded_state_dict
empty_cache()

model.load_state_dict(current_model_dict, strict=False)

if cfg.model_type == "wav2vec":
    model.freeze_feature_extractor()

In [131]:
train_model(model, cfg, dataloader_train, dataloader_val, stage='train')

**************************************************
EPOCH 0
**************************************************


Training:   0%|                                                                        | 1/987 [00:02<38:12,  2.33s/it]


 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
[[155  18   0  20]
 [ 87  10   0  11]
 [ 13   3   0   3]
 [  0   0   0   0]]
0.20896229445015416
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


Training:  25%|█████████████████▌                                                    | 247/987 [01:15<04:48,  2.57it/s]

[[4198   77  409   57]
 [2249   39  200   36]
 [ 358    4  181    6]
 [  53    3   34    0]]
0.25424476338013746


Training:  50%|██████████████████████████████████▉                                   | 493/987 [02:53<03:32,  2.32it/s]

[[8207  342  899   62]
 [4302  202  491   38]
 [ 655   30  384   10]
 [  73    5   74    2]]
0.26775013606302556


Training:  75%|████████████████████████████████████████████████████▍                 | 739/987 [04:32<01:42,  2.43it/s]

[[11740   908  1523    63]
 [ 5919   686   902    38]
 [  893    85   659    10]
 [  102     9   109     2]]
0.286697496232548


Training: 100%|█████████████████████████████████████████████████████████████████████▊| 985/987 [06:07<00:00,  2.93it/s]

[[15044  1723  2132    63]
 [ 7208  1489  1309    38]
 [ 1086   181   965    10]
 [  116    17   137     2]]
0.3063855942451737


Training: 100%|█████████████████████████████████████████████████████████████████████▉| 986/987 [06:07<00:00,  2.82it/s]


 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
[[894 195 670   0]
 [338 181 358   0]
 [ 63  18 109   0]
 [ 16   9  27   0]]
0.25655683869184803
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


Training: 100%|██████████████████████████████████████████████████████████████████████| 987/987 [06:26<00:00,  2.56it/s]


**************************************************
EPOCH 1
**************************************************


Training:  25%|█████████████████▌                                                    | 247/987 [01:36<05:02,  2.44it/s]

[[3394  848  543   10]
 [1207  943  325   11]
 [ 188   74  278   13]
 [   8    9   48    5]]
0.3875612153898521


Training:  50%|██████████████████████████████████▉                                   | 493/987 [03:19<03:40,  2.24it/s]

[[6600 1822 1052   21]
 [2375 1992  624   24]
 [ 354  157  603   18]
 [  19   21   89    5]]
0.38423110276506867


Training:  75%|████████████████████████████████████████████████████▍                 | 739/987 [04:59<01:49,  2.26it/s]

[[9966 2680 1539   44]
 [3605 3011  888   39]
 [ 531  226  889   28]
 [  38   30  117   17]]
0.40004563801016435


Training: 100%|█████████████████████████████████████████████████████████████████████▊| 985/987 [06:59<00:00,  2.44it/s]

[[13320  3534  2031    80]
 [ 4773  4000  1204    64]
 [  710   289  1203    40]
 [   49    51   147    25]]
0.401843140139002


Training: 100%|█████████████████████████████████████████████████████████████████████▉| 986/987 [06:59<00:00,  2.77it/s]


 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
[[930 180 639  10]
 [348 221 301   7]
 [ 60  18 111   1]
 [ 12  16  21   3]]
0.2983140959534336
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


Training: 100%|██████████████████████████████████████████████████████████████████████| 987/987 [07:19<00:00,  2.25it/s]


**************************************************
EPOCH 2
**************************************************


Training:  25%|█████████████████▌                                                    | 247/987 [01:38<05:43,  2.15it/s]

[[3349  895  463   38]
 [1078 1194  261   36]
 [ 129   56  318   19]
 [   7   10   31   20]]
0.4612101033461853


Training:  50%|██████████████████████████████████▉                                   | 493/987 [03:17<03:21,  2.45it/s]

[[6650 1796  925   95]
 [2119 2393  502   69]
 [ 277  113  648   52]
 [  15   15   62   45]]
0.4636905520501708


Training:  75%|████████████████████████████████████████████████████▍                 | 739/987 [04:53<01:36,  2.56it/s]

[[10041  2656  1401   126]
 [ 3211  3508   743    97]
 [  451   155   976    79]
 [   20    21    93    70]]
0.4667267314026739


Training: 100%|█████████████████████████████████████████████████████████████████████▊| 985/987 [06:33<00:00,  2.23it/s]

[[13442  3427  1927   164]
 [ 4367  4556  1000   125]
 [  582   201  1365    92]
 [   28    25   118   101]]
0.4741886089480144


Training: 100%|██████████████████████████████████████████████████████████████████████| 987/987 [06:51<00:00,  2.40it/s]



 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
[[708 428 522 101]
 [228 323 232  94]
 [ 47  33  93  17]
 [  7  13  17  15]]
0.2962489749273696
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
**************************************************
EPOCH 3
**************************************************


Training:  25%|█████████████████▌                                                    | 247/987 [01:38<05:45,  2.14it/s]

[[3445  875  396   47]
 [ 996 1291  201   23]
 [ 115   28  391   22]
 [   4    7   16   47]]
0.557397938524423


Training:  50%|██████████████████████████████████▉                                   | 493/987 [03:14<03:28,  2.36it/s]

[[6830 1756  795   85]
 [2038 2596  390   48]
 [ 221   78  766   39]
 [   9   11   32   82]]
0.5501418417676339


Training:  75%|████████████████████████████████████████████████████▍                 | 739/987 [04:53<01:39,  2.48it/s]

[[10276  2605  1181   140]
 [ 3065  3827   589    80]
 [  342   107  1173    59]
 [   16    16    50   122]]
0.5460143364987211


Training: 100%|█████████████████████████████████████████████████████████████████████▊| 985/987 [06:31<00:00,  2.21it/s]

[[13797  3410  1598   164]
 [ 4109  5054   772   104]
 [  465   144  1553    79]
 [   21    18    66   166]]
0.5505376946727866


Training: 100%|█████████████████████████████████████████████████████████████████████▉| 986/987 [06:32<00:00,  1.80it/s]


 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
[[746 520 444  49]
 [242 399 193  43]
 [ 42  41  91  16]
 [ 11  21  11   9]]
0.31619241961152206
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


Training: 100%|██████████████████████████████████████████████████████████████████████| 987/987 [06:50<00:00,  2.40it/s]


**************************************************
EPOCH 4
**************************************************


Training:  25%|█████████████████▌                                                    | 247/987 [01:36<04:16,  2.88it/s]

[[3631  805  308   26]
 [ 979 1402  134   17]
 [  79   29  409   13]
 [   1    3   14   54]]
0.6337893145824182


Training:  50%|██████████████████████████████████▉                                   | 493/987 [03:15<03:18,  2.49it/s]

[[7180 1656  620   54]
 [1953 2732  280   32]
 [ 176   57  865   30]
 [   3    4   23  111]]
0.6344805562180286


Training:  75%|████████████████████████████████████████████████████▍                 | 739/987 [04:52<01:37,  2.54it/s]

[[10744  2487   949    79]
 [ 2966  4056   427    56]
 [  274    82  1278    49]
 [    9     7    34   151]]
0.6206481917257269


Training: 100%|█████████████████████████████████████████████████████████████████████▊| 985/987 [06:33<00:01,  1.89it/s]

[[14250  3340  1268   105]
 [ 3914  5467   585    76]
 [  345   109  1723    67]
 [   13     8    43   207]]
0.6237027726392196


Training: 100%|█████████████████████████████████████████████████████████████████████▉| 986/987 [06:34<00:00,  1.96it/s]


 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
[[1057  268  410   24]
 [ 368  266  229   14]
 [  54   27  105    4]
 [  17   12   17    6]]
0.33929477664749597
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


Training: 100%|██████████████████████████████████████████████████████████████████████| 987/987 [06:52<00:00,  2.39it/s]


**************************************************
EPOCH 5
**************************************************


Training:  25%|█████████████████▌                                                    | 247/987 [01:39<04:28,  2.76it/s]

[[3570  860  298   22]
 [ 890 1500  137   19]
 [  64   20  453    4]
 [   1    2    8   56]]
0.6689310970505578


Training:  50%|██████████████████████████████████▉                                   | 493/987 [03:19<03:46,  2.18it/s]

[[7116 1690  583   48]
 [1808 3005  246   35]
 [ 134   31  930   15]
 [   4    3   14  114]]
0.6708475180659781


Training:  75%|████████████████████████████████████████████████████▍                 | 739/987 [04:53<01:26,  2.88it/s]

[[10921  2436   845    76]
 [ 2720  4366   358    50]
 [  204    54  1391    21]
 [    7     4    19   176]]
0.6749374292652399


Training: 100%|█████████████████████████████████████████████████████████████████████▊| 985/987 [06:30<00:00,  2.80it/s]

[[14477  3237  1145   102]
 [ 3611  5878   491    65]
 [  275    81  1855    32]
 [    9     5    26   231]]
0.6723861075547751


Training: 100%|██████████████████████████████████████████████████████████████████████| 987/987 [06:48<00:00,  2.42it/s]



 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
[[927 267 550  15]
 [334 275 256  12]
 [ 51  27 109   3]
 [ 19  14  16   3]]
0.31044244699211665
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
**************************************************
EPOCH 6
**************************************************


Training:  25%|█████████████████▌                                                    | 247/987 [01:36<04:11,  2.94it/s]

[[3754  717  214   17]
 [ 898 1547   97   12]
 [  57   13  505    2]
 [   0    2    4   65]]
0.7348482063275527


Training:  50%|██████████████████████████████████▉                                   | 493/987 [03:13<03:43,  2.21it/s]

[[7411 1491  489   38]
 [1751 3066  203   27]
 [ 122   34  985   11]
 [   2    3    6  137]]
0.7204766942116169


Training:  75%|████████████████████████████████████████████████████▍                 | 739/987 [04:52<01:32,  2.67it/s]

[[11294  2196   719    58]
 [ 2596  4536   288    40]
 [  174    45  1473    15]
 [    6     4    11   193]]
0.7195561813217024


Training: 100%|█████████████████████████████████████████████████████████████████████▊| 985/987 [06:31<00:00,  2.71it/s]

[[14916  3029   935    79]
 [ 3449  6165   381    52]
 [  227    63  1931    23]
 [    8     7    15   240]]
0.7142686861316726


Training: 100%|██████████████████████████████████████████████████████████████████████| 987/987 [06:49<00:00,  2.41it/s]



 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
[[1051  343  347   18]
 [ 377  317  174    9]
 [  70   31   85    4]
 [  18   16   14    4]]
0.33597878812990245
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
**************************************************
EPOCH 7
**************************************************


Training:  25%|█████████████████▌                                                    | 247/987 [01:39<05:24,  2.28it/s]

[[3670  811  180   24]
 [ 801 1712   68    9]
 [  49   14  494    5]
 [   3    3    3   58]]
0.7330413432905468


Training:  50%|██████████████████████████████████▉                                   | 493/987 [03:21<03:08,  2.62it/s]

[[7466 1554  352   34]
 [1576 3340  151   19]
 [  87   31 1021    7]
 [   3    3    4  128]]
0.7553681945111712


Training:  75%|████████████████████████████████████████████████████▍                 | 739/987 [04:58<01:49,  2.26it/s]

[[11345  2276   525    46]
 [ 2377  4933   227    27]
 [  137    42  1508    12]
 [    4     4     8   177]]
0.7532074088404356


Training: 100%|█████████████████████████████████████████████████████████████████████▊| 985/987 [06:32<00:00,  2.41it/s]

[[15167  2997   726    69]
 [ 3171  6526   311    39]
 [  194    52  1982    16]
 [    8     7    11   244]]
0.747591183948346


Training: 100%|█████████████████████████████████████████████████████████████████████▉| 986/987 [06:32<00:00,  2.24it/s]


 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
[[1186  269  285   19]
 [ 447  263  153   14]
 [  75   32   79    4]
 [  22   15   10    5]]
0.3419650742959768
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


Training: 100%|██████████████████████████████████████████████████████████████████████| 987/987 [06:51<00:00,  2.40it/s]


**************************************************
EPOCH 8
**************************************************


Training:  25%|█████████████████▌                                                    | 247/987 [01:37<04:14,  2.91it/s]

[[3838  683  159   16]
 [ 733 1719   69   14]
 [  37   18  542    8]
 [   1    0    4   63]]
0.7626041076740562


Training:  50%|██████████████████████████████████▉                                   | 493/987 [03:14<04:39,  1.76it/s]

[[7668 1391  307   31]
 [1472 3468  124   23]
 [  78   23 1038   11]
 [   3    1    5  133]]
0.7732293801440168


Training:  75%|████████████████████████████████████████████████████▍                 | 739/987 [04:50<01:40,  2.46it/s]

[[11536  2110   462    46]
 [ 2205  5178   195    31]
 [  121    33  1504    16]
 [    7     2     8   194]]
0.769676515647681


Training: 100%|█████████████████████████████████████████████████████████████████████▊| 985/987 [06:26<00:00,  2.88it/s]

[[15528  2769   598    65]
 [ 2983  6773   247    40]
 [  159    41  2026    19]
 [    9     2    11   250]]
0.769889424387192


Training: 100%|██████████████████████████████████████████████████████████████████████| 987/987 [06:44<00:00,  2.44it/s]



 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
[[937 489 325   8]
 [323 376 172   6]
 [ 57  47  85   1]
 [ 21  16  12   3]]
0.33106830758916755
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
**************************************************
EPOCH 9
**************************************************


Training:  25%|█████████████████▌                                                    | 247/987 [01:36<04:26,  2.77it/s]

[[3919  674  153   18]
 [ 695 1746   58   11]
 [  38    6  515    5]
 [   4    1    2   59]]
0.7696808919820314


Training:  50%|██████████████████████████████████▉                                   | 493/987 [03:16<03:27,  2.38it/s]

[[7829 1325  291   28]
 [1390 3504  109   26]
 [  76   14 1060    6]
 [   6    1    3  108]]
0.7771386111652563


Training:  75%|████████████████████████████████████████████████████▍                 | 739/987 [05:04<01:34,  2.64it/s]

[[11723  1979   450    34]
 [ 2079  5292   165    32]
 [  111    25  1559     8]
 [    7     3     3   178]]
0.7883328444820222


Training: 100%|█████████████████████████████████████████████████████████████████████▊| 985/987 [06:38<00:00,  2.60it/s]

[[15618  2690   613    44]
 [ 2812  6971   220    38]
 [  145    35  2052    10]
 [   10     4     5   253]]
0.7894705562887242


Training: 100%|█████████████████████████████████████████████████████████████████████▉| 986/987 [06:38<00:00,  2.81it/s]


 VALIDATION
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
[[1083  359  301   16]
 [ 384  309  177    7]
 [  64   34   90    2]
 [  21   19    8    4]]
0.3436672050788729
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


Training: 100%|██████████████████████████████████████████████████████████████████████| 987/987 [06:56<00:00,  2.37it/s]


In [142]:
log_model(experiment, model, "train-stage")
experiment.end()

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : distant_barracuda_917
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/ugryumnik/aiijc-final-train/c8ed27196b344832adb5dc7c854b785f
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     f1_disorders [40]     : (0.25424476338013746, 0.7894705562887242)
[1;38;5;39mCOMET INFO:[0m     loss [10611]          : (0.09200619906187057, 2.7596871852874756)
[1;38;5;39mCOMET INFO:[0m     val_f1_disorders [14] : (0.20896229445015416, 0.3436672050788729)
[1;38;5;39mCOMET INFO:[0m   Others:
[1;38;5;39

In [143]:
cfg_dict = {attr: getattr(cfg, attr) for attr in dir(cfg) if not callable(getattr(cfg, attr)) and not attr.startswith("__")}

with open(f"{base_dir}/data_processed/cfg.pkl", 'wb') as f:
    pickle.dump(cfg_dict, f)

# Inference

In [144]:
empty_cache()

model = DisordersDetector(cfg=cfg, stage='train')
model.to(device)

model.load_state_dict(torch.load(f"{cfg.weights_folder}/{cfg.save_model_name}-train.pt", map_location=device, weights_only=True), strict=False)
model = model.eval()

In [145]:
def get_predictions(model, dataloader):
    model.eval()
    all_predictions = []

    with torch.no_grad():
        for batch_idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)):

            output = model(batch['input_values'].to(device))

            all_predictions.extend(F.softmax(output['disorders'], dim=-1).cpu().numpy())
            
            del batch
            torch.cuda.empty_cache()

    return all_predictions

In [146]:
all_predictions_val = get_predictions(model, dataloader_val)

with open(f"{base_dir}/preds/all_predictions_val.pkl", "wb") as f:
    pickle.dump(all_predictions_val, f)

100%|██████████████████████████████████████████████████████████████████████████████████| 45/45 [00:17<00:00,  2.63it/s]


In [147]:
all_predictions_test = get_predictions(model, dataloader_test)

with open(f"{base_dir}/preds/all_predictions_test.pkl", "wb") as f:
    pickle.dump(all_predictions_test, f)

100%|██████████████████████████████████████████████████████████████████████████████████| 72/72 [00:27<00:00,  2.59it/s]
