In [3]:
!pip install wfdb

Collecting wfdb
  Downloading wfdb-4.2.0-py3-none-any.whl.metadata (3.7 kB)
Downloading wfdb-4.2.0-py3-none-any.whl (162 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.3/162.3 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: wfdb
Successfully installed wfdb-4.2.0


In [4]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys, os
import scipy.io
import scipy.signal as signal
import pickle as dill
from tqdm import tqdm
from time import localtime, strftime
import random

from shutil import copyfile

from sklearn.metrics import log_loss
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, confusion_matrix
from sklearn.model_selection import train_test_split


from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils.rnn import pack_padded_sequence

from tensorflow.keras.preprocessing.sequence import pad_sequences  # Для выравнивания длин сигналов
import wfdb
import dill
from glob import glob
import csv

from collections import OrderedDict, Counter

import scipy.io
from scipy.signal import butter, lfilter, periodogram

In [5]:
df = pd.read_csv("/kaggle/input/dataset-mina/CPSC/labels.csv")
df.head(20)

Unnamed: 0,patient_id,SNR,AF,IAVB,LBBB,RBBB,PAC,PVC,STD,STE,fold
0,A0001,0,0,0,0,1,0,0,0,0,5
1,A0002,1,0,0,0,0,0,0,0,0,3
2,A0003,0,1,0,0,0,0,0,0,0,1
3,A0004,0,1,0,0,0,0,0,0,0,7
4,A0005,0,0,0,0,0,0,1,0,0,1
5,A0006,0,0,0,0,1,0,0,0,0,10
6,A0007,0,1,0,0,0,0,0,0,0,3
7,A0008,0,0,0,0,0,0,0,1,0,6
8,A0009,0,1,0,0,0,0,0,0,0,2
9,A0010,0,0,0,0,1,0,0,0,0,10


In [14]:
def preprocess_physionet(data_path, output_path='/kaggle/working/', max_length=9000):
    """
    Обрабатывает PhysioNet ECG Dataset и сохраняет в .pkl
    :param data_path: путь к данным
    :param max_length: длина для padding/truncating сигналов
    """

    # Читаем метки из REFERENCE-v3.csv
    label_df = pd.read_csv(os.path.join(data_path, 'labels.csv'))
    
    # Удаляем 'patient_id', оставляем бинарные метки
    labels = label_df.drop(columns=['patient_id']).values
    print(f"Загружено {labels.shape[0]} меток, {labels.shape[1]} классов.")

    # Читаем список файлов
    labels = label_df.iloc[:, 1].values
    filenames = label_df.iloc[:, 0].values
    print(f"Файлы: {filenames[:5]}")  # Вывод первых 5 файлов

    all_data = []
    for filename in tqdm(filenames, desc="Чтение .mat файлов"):
        mat = scipy.io.loadmat(os.path.join(data_path, f'{filename}.mat'))
        mat = np.array(mat['val'])[0]  # Берем только первый канал
        all_data.append(mat)

    # Приведение всех последовательностей к одинаковой длине
    all_data = pad_sequences(all_data, maxlen=max_length, padding='post', truncating='post')

    # Сохраняем данные и метки
    res = {'data': all_data, 'label': labels}
    with open(os.path.join(output_path, 'challenge2018.pkl'), 'wb') as fout:
        dill.dump(res, fout)

    print(f"Файл сохранен: {os.path.join(output_path, 'challenge2018.pkl')}")


In [None]:
'''def preprocess_physionet(data_path, output_path='/kaggle/working/'):
    """
    Перед обработкой данных скачайте их с https://physionet.org/content/challenge-2017/1.0.0/ 
    и поместите в data_path.
    """

    label_df = pd.read_csv(os.path.join(data_path, 'labels.csv'))
    
    # Удаляем столбец 'patient_id' и сохраняем метки в виде массива
    label = label_df.drop(columns=['patient_id']).values
    print(f"Метки загружены: {label.shape}")

    # Читаем список файлов
    labels = label_df.iloc[:, 1].values  # Categories: N, A, O, P
    filenames = label_df.iloc[:, 0].values
    print(f"Файлы: {filenames[:5]}")  # Вывод первых 5 файлов

    all_data = []
    for filename in tqdm(filenames, desc="Чтение .mat файлов"):
        mat = scipy.io.loadmat(os.path.join(data_path, f'{filename}.mat'))
        mat = np.array(mat['val'])[0]
        all_data.append(mat)

    all_data = np.array(all_data)

    # Сохраняем данные и метки
    res = {'data': all_data, 'label': label}
    with open(os.path.join(output_path, 'challenge2018.pkl'), 'wb') as fout:
        dill.dump(res, fout)

    print(f"Файл сохранен: {os.path.join(output_path, 'challenge2018.pkl')}")'''

In [19]:
def filter_channel(x):
    
    signal_freq = 300
    
    ### candidate channels for ECG
    P_wave = (0.67,5)
    QRS_complex = (10,50)
    T_wave = (1,7)
    muscle = (5,50)
    resp = (0.12,0.5)
    ECG_preprocessed = (0.5, 50)
    wander = (0.001, 0.5)
    noise = 50
    
    ### use low (wander), middle (ECG_preprocessed) and high (noise) for example
    bandpass_list = [wander, ECG_preprocessed]
    highpass_list = [noise]
    
    nyquist_freq = 0.5 * signal_freq
    filter_order = 1
    ### out including original x
    out_list = [x]
    
    for bandpass in bandpass_list:
        low = bandpass[0] / nyquist_freq
        high = bandpass[1] / nyquist_freq
        b, a = butter(filter_order, [low, high], btype="band")
        y = lfilter(b, a, x)
        out_list.append(y)
        
    for highpass in highpass_list:
        high = highpass / nyquist_freq
        b, a = butter(filter_order, high, btype="high")
        y = lfilter(b, a, x)
        out_list.append(y)
        
    out = np.array(out_list)
    
    return out

def slide_and_cut(X, Y, window_size, stride, output_pid=False):
    out_X = []
    out_Y = []
    out_pid = []
    n_sample = X.shape[0]
    mode = 0
    for i in range(n_sample):
        tmp_ts = X[i]
        tmp_Y = Y[i]
        if tmp_Y == 0:
            i_stride = stride
        elif tmp_Y == 1:
            i_stride = stride//10
        for j in range(0, len(tmp_ts)-window_size, i_stride):
            out_X.append(tmp_ts[j:j+window_size])
            out_Y.append(tmp_Y)
            out_pid.append(i)
    if output_pid:
        return np.array(out_X), np.array(out_Y), np.array(out_pid)
    else:
        return np.array(out_X), np.array(out_Y)

def compute_beat(X):
    out = np.zeros((X.shape[0], X.shape[1], X.shape[2]))
    for i in tqdm(range(out.shape[0]), desc="compute_beat"):
        for j in range(out.shape[1]):
            out[i, j] = np.concatenate([[0], np.abs(np.diff(X[i,j,:]))])
    return out

def compute_rhythm(X, n_split):
    cnt_split = int(X.shape[2]/n_split)
    out = np.zeros((X.shape[0], X.shape[1], cnt_split))
    for i in tqdm(range(out.shape[0]), desc="compute_rhythm"):
        for j in range(out.shape[1]):
            tmp_ts = X[i,j,:]
            tmp_ts_cut = np.split(tmp_ts, X.shape[2]/n_split)
            for k in range(cnt_split):
                out[i, j, k] = np.std(tmp_ts_cut[k])
    return out

def compute_freq(X):
    out = np.zeros((X.shape[0], X.shape[1], 1))
    fs = 300
    for i in tqdm(range(out.shape[0]), desc="compute_freq"):
        for j in range(out.shape[1]):
            _, Pxx_den = periodogram(X[i,j,:], fs)
            out[i, j, 0] = np.sum(Pxx_den)
    return out

def make_data_physionet(data_path, n_split=50, window_size=3000, stride=500, output_path='/kaggle/working/'):

    # read pkl
    with open(os.path.join(output_path, 'challenge2018.pkl'), 'rb') as fin:
        res = dill.load(fin)
    ## scale data
    all_data = res['data']
    for i in range(len(all_data)):
        tmp_data = all_data[i]
        tmp_std = np.std(tmp_data)
        tmp_mean = np.mean(tmp_data)
        all_data[i] = (tmp_data - tmp_mean) / tmp_std # normalize
    all_data = res['data']
    all_data = np.array(all_data)
    ## encode label
    all_label = []
    for i in res['label']:
        if i == 'A':
            all_label.append(1)
        else:
            all_label.append(0)
    all_label = np.array(all_label)

    # split train test
    n_sample = len(all_label)
    split_idx_1 = int(0.75 * n_sample)
    split_idx_2 = int(0.85 * n_sample)
    
    shuffle_idx = np.random.permutation(n_sample)
    all_data = all_data[shuffle_idx]
    all_label = all_label[shuffle_idx]
    
    X_train = all_data[:split_idx_1]
    X_val = all_data[split_idx_1:split_idx_2]
    X_test = all_data[split_idx_2:]
    Y_train = all_label[:split_idx_1]
    Y_val = all_label[split_idx_1:split_idx_2]
    Y_test = all_label[split_idx_2:]
    
    # slide and cut
    print(Counter(Y_train), Counter(Y_val), Counter(Y_test))
    X_train, Y_train = slide_and_cut(X_train, Y_train, window_size=window_size, stride=stride)
    X_val, Y_val = slide_and_cut(X_val, Y_val, window_size=window_size, stride=stride)
    X_test, Y_test, pid_test = slide_and_cut(X_test, Y_test, window_size=window_size, stride=stride, output_pid=True)
    print('after: ')
    print(Counter(Y_train), Counter(Y_val), Counter(Y_test))
    
    # shuffle train
    shuffle_pid = np.random.permutation(Y_train.shape[0])
    X_train = X_train[shuffle_pid]
    Y_train = Y_train[shuffle_pid]

    # multi-level
    X_train_ml = []
    X_val_ml = []
    X_test_ml = []
    for i in tqdm(X_train, desc="X_train_ml"):
        tmp = filter_channel(i)
        X_train_ml.append(tmp)
    X_train_ml = np.array(X_train_ml)
    for i in tqdm(X_val, desc="X_val_ml"):
        tmp = filter_channel(i)
        X_val_ml.append(tmp)
    X_val_ml = np.array(X_val_ml)
    for i in tqdm(X_test, desc="X_test_ml"):
        tmp = filter_channel(i)
        X_test_ml.append(tmp)
    X_test_ml = np.array(X_test_ml)
    print(X_train_ml.shape, X_val_ml.shape, X_test_ml.shape)

    # save
    res = {'Y_train': Y_train, 'Y_val': Y_val, 'Y_test': Y_test, 'pid_test': pid_test}
    with open(os.path.join(output_path, 'mina_info.pkl'), 'wb') as fout:
        dill.dump(res, fout)
        
    fout = open(os.path.join(output_path, 'mina_X_train.bin'), 'wb')
    np.save(fout, X_train_ml)
    fout.close()

    fout = open(os.path.join(output_path, 'mina_X_val.bin'), 'wb')
    np.save(fout, X_val_ml)
    fout.close()

    fout = open(os.path.join(output_path, 'mina_X_test.bin'), 'wb')
    np.save(fout, X_test_ml)
    fout.close()

def make_knowledge_physionet(data_path, n_split=50, output_path='/kaggle/working/'):

    # read
    fin = open(os.path.join(output_path, 'mina_X_train.bin'), 'rb')
    X_train = np.load(fin)
    fin.close()
    fin = open(os.path.join(output_path, 'mina_X_val.bin'), 'rb')
    X_val = np.load(fin)
    fin.close()
    fin = open(os.path.join(output_path, 'mina_X_test.bin'), 'rb')
    X_test = np.load(fin)
    fin.close()

    # compute knowledge
    K_train_beat = compute_beat(X_train)
    K_train_rhythm = compute_rhythm(X_train, n_split)
    K_train_freq = compute_freq(X_train)

    K_val_beat = compute_beat(X_val)
    K_val_rhythm = compute_rhythm(X_val, n_split)
    K_val_freq = compute_freq(X_val)

    K_test_beat = compute_beat(X_test)
    K_test_rhythm = compute_rhythm(X_test, n_split)
    K_test_freq = compute_freq(X_test)

    # save
    fout = open(os.path.join(output_path, 'mina_K_train_beat.bin'), 'wb')
    np.save(fout, K_train_beat)
    fout.close()
    fout = open(os.path.join(output_path, 'mina_K_val_beat.bin'), 'wb')
    np.save(fout, K_val_beat)
    fout.close()
    fout = open(os.path.join(output_path, 'mina_K_test_beat.bin'), 'wb')
    np.save(fout, K_test_beat)
    fout.close()

    res = {'K_train_rhythm': K_train_rhythm, 'K_train_freq': K_train_freq, 
    'K_val_rhythm': K_val_rhythm, 'K_val_freq': K_val_freq, 
    'K_test_rhythm': K_test_rhythm, 'K_test_freq': K_test_freq}
    with open(os.path.join(output_path, 'mina_knowledge.pkl'), 'wb') as fout:
        dill.dump(res, fout)

"""def evaluate(gt, pred):
    '''
    gt is (0, C-1)
    pred is list of probability
    '''

    pred_label = []
    for i in pred:
        pred_label.append(np.argmax(i))
    pred_label = np.array(pred_label)

    res = OrderedDict({})
    
    res['auroc'] = roc_auc_score(gt, pred[:,1])
    res['auprc'] = average_precision_score(gt, pred[:,1])
    res['f1'] = f1_score(gt, pred_label)
    
    res['\nmat'] = confusion_matrix(gt, pred_label)
    
    for k, v in res.items():
        print(k, ':', v, '|', end='')
    print()
    
    return list(res.values())"""

"def evaluate(gt, pred):\n    '''\n    gt is (0, C-1)\n    pred is list of probability\n    '''\n\n    pred_label = []\n    for i in pred:\n        pred_label.append(np.argmax(i))\n    pred_label = np.array(pred_label)\n\n    res = OrderedDict({})\n    \n    res['auroc'] = roc_auc_score(gt, pred[:,1])\n    res['auprc'] = average_precision_score(gt, pred[:,1])\n    res['f1'] = f1_score(gt, pred_label)\n    \n    res['\nmat'] = confusion_matrix(gt, pred_label)\n    \n    for k, v in res.items():\n        print(k, ':', v, '|', end='')\n    print()\n    \n    return list(res.values())"

In [16]:
def evaluate(gt, pred):
    res = OrderedDict({})

    # Проверяем, есть ли оба класса (0 и 1) в `gt`
    unique_classes = np.unique(gt)
    if len(unique_classes) < 2:
        print(f"Warning: Only one class {unique_classes} in y_true. ROC AUC cannot be computed.")
        res['auroc'] = None
        res['auprc'] = None
    else:
        res['auroc'] = roc_auc_score(gt, pred[:, 1])
        res['auprc'] = average_precision_score(gt, pred[:, 1])

    # Вычисляем F1-метрику независимо от наличия одного класса
    pred_label = (pred[:, 1] > 0.5).astype(int)  # Бинаризация предсказаний
    res['f1'] = f1_score(gt, pred_label) if len(unique_classes) > 1 else None

    return res

In [8]:
class Net(nn.Module):
    def __init__(self, n_channel, n_dim, n_split):
        super(Net, self).__init__()
        
        self.n_channel = n_channel
        self.n_dim = n_dim
        self.n_split = n_split
        self.n_class = 2
        
        self.base_net_0 = BaseNet(self.n_dim, self.n_split)
        self.base_net_1 = BaseNet(self.n_dim, self.n_split)
        self.base_net_2 = BaseNet(self.n_dim, self.n_split)
        self.base_net_3 = BaseNet(self.n_dim, self.n_split)
            
        ### attention
        self.out_size = 8
        self.att_channel_dim = 2
        self.W_att_channel = nn.Parameter(torch.randn(self.out_size+1, self.att_channel_dim))
        self.v_att_channel = nn.Parameter(torch.randn(self.att_channel_dim, 1))
        
        ### fc
        self.fc = nn.Linear(self.out_size, self.n_class)
        
    def forward(self, x_0, x_1, x_2, x_3, 
                k_beat_0, k_beat_1, k_beat_2, k_beat_3, 
                k_rhythm_0, k_rhythm_1, k_rhythm_2, k_rhythm_3, 
                k_freq):

        x_0, alpha_0, beta_0 = self.base_net_0(x_0, k_beat_0, k_rhythm_0)
        x_1, alpha_1, beta_1 = self.base_net_1(x_1, k_beat_1, k_rhythm_1)
        x_2, alpha_2, beta_2 = self.base_net_2(x_2, k_beat_2, k_rhythm_2)
        x_3, alpha_3, beta_3 = self.base_net_3(x_3, k_beat_3, k_rhythm_3)
        
        x = torch.stack([x_0, x_1, x_2, x_3], 1)

        # ############################################
        # ### attention on channel
        # ############################################
        k_freq = k_freq.permute(1, 0, 2)

        tmp_x = torch.cat((x, k_freq), dim=-1)
        e = torch.matmul(tmp_x, self.W_att_channel)
        e = torch.matmul(torch.tanh(e), self.v_att_channel)
        n1 = torch.exp(e)
        n2 = torch.sum(torch.exp(e), 1, keepdim=True)
        gama = torch.div(n1, n2)
        x = torch.sum(torch.mul(gama, x), 1)
        
        ############################################
        ### fc
        ############################################
        x = F.softmax(self.fc(x), 1)
        
        ############################################
        ### return 
        ############################################
        
        att_dic = {"alpha_0":alpha_0, "beta_0":beta_0, 
                  "alpha_1":alpha_1, "beta_1":beta_1, 
                  "alpha_2":alpha_2, "beta_2":beta_2, 
                  "alpha_3":alpha_3, "beta_3":beta_3, 
                  "gama":gama}
        
        return x, att_dic

In [9]:
class BaseNet(nn.Module):
    def __init__(self, n_dim, n_split):
        super(BaseNet, self).__init__()
        
        self.n_dim = n_dim
        self.n_split = n_split
        self.n_seg = int(n_dim/n_split)
        
        ### Input: (batch size, number of channels, length of signal sequence)
        self.conv_out_channels = 64
        self.conv_kernel_size = 32
        self.conv_stride = 2
        self.conv = nn.Conv1d(in_channels=1, 
                              out_channels=self.conv_out_channels, 
                              kernel_size=self.conv_kernel_size, 
                              stride=self.conv_stride)
        self.conv_k = nn.Conv1d(in_channels=1, 
                                out_channels=1, 
                                kernel_size=self.conv_kernel_size, 
                                stride=self.conv_stride)
        self.att_cnn_dim = 8
        self.W_att_cnn = nn.Parameter(torch.randn(self.conv_out_channels+1, self.att_cnn_dim))
        self.v_att_cnn = nn.Parameter(torch.randn(self.att_cnn_dim, 1))
        
        ### Input: (batch size, length of signal sequence, input_size)
        self.rnn_hidden_size = 32
        self.lstm = nn.LSTM(input_size=(self.conv_out_channels), 
                            hidden_size=self.rnn_hidden_size, 
                            num_layers=1, batch_first=True, bidirectional=True)
        self.att_rnn_dim = 8
        self.W_att_rnn = nn.Parameter(torch.randn(2*self.rnn_hidden_size+1, self.att_rnn_dim))
        self.v_att_rnn = nn.Parameter(torch.randn(self.att_rnn_dim, 1))
        
        ### fc
        self.do = nn.Dropout(p=0.5)
        self.out_size = 8
        self.fc = nn.Linear(2*self.rnn_hidden_size, self.out_size)
    
    def forward(self, x, k_beat, k_rhythm):
        
        self.batch_size = x.size()[0]

        ############################################
        ### reshape
        ############################################
        # print('orignial x:', x.size())
        x = x.view(-1, self.n_split)
        x = x.unsqueeze(1)
        
        k_beat = k_beat.view(-1, self.n_split)
        k_beat = k_beat.unsqueeze(1)
        
        ############################################
        ### conv
        ############################################
        x = F.relu(self.conv(x))
        
        k_beat = F.relu(self.conv_k(k_beat))
        
        ############################################
        ### attention conv
        ############################################
        x = x.permute(0, 2, 1)
        k_beat = k_beat.permute(0, 2, 1)
        tmp_x = torch.cat((x, k_beat), dim=-1)
        e = torch.matmul(tmp_x, self.W_att_cnn)
        e = torch.matmul(torch.tanh(e), self.v_att_cnn)
        n1 = torch.exp(e)
        n2 = torch.sum(torch.exp(e), 1, keepdim=True)
        alpha = torch.div(n1, n2)
        x = torch.sum(torch.mul(alpha, x), 1)
        
        ############################################
        ### reshape for rnn
        ############################################
        x = x.view(self.batch_size, self.n_seg, -1)
    
        ############################################
        ### rnn        
        ############################################
        
        k_rhythm = k_rhythm.unsqueeze(-1)
        o, (ht, ct) = self.lstm(x)
        tmp_o = torch.cat((o, k_rhythm), dim=-1)
        e = torch.matmul(tmp_o, self.W_att_rnn)
        e = torch.matmul(torch.tanh(e), self.v_att_rnn)
        n1 = torch.exp(e)
        n2 = torch.sum(torch.exp(e), 1, keepdim=True)
        beta = torch.div(n1, n2)
        x = torch.sum(torch.mul(beta, o), 1)
        
        ############################################
        ### fc
        ############################################
        x = F.relu(self.fc(x))
        x = self.do(x)
        
        return x, alpha, beta        

In [11]:
def train(model, optimizer, loss_func, epoch, batch_size, 
          X_train, Y_train, K_train_beat, K_train_rhythm, K_train_freq, 
          log_file):
    """
    X_train: (n_channel, n_sample, n_dim)
    Y_train: (n_sample,)
    
    K_train_beat: (n_channel, n_sample, n_dim)
    K_train_rhythm: (n_channel, n_sample, n_dim/n_split)
    K_train_freq: (n_channel, n_sample)
    """
    model.train()
    
    n_train = len(Y_train)
    
    pred_all = []
    batch_start_idx = 0
    batch_end_idx = 0
    loss_all = []
    for _ in tqdm(range(n_train//batch_size+1), desc="train"):
    # while batch_end_idx < n_train:
        # print('.', end="")
        batch_end_idx = batch_end_idx + batch_size
        if batch_end_idx >= n_train:
            batch_end_idx = n_train
            
        ### input data
        batch_input_0 = Variable(torch.FloatTensor(X_train[0, batch_start_idx: batch_end_idx, :])).cuda()
        batch_input_1 = Variable(torch.FloatTensor(X_train[1, batch_start_idx: batch_end_idx, :])).cuda()
        batch_input_2 = Variable(torch.FloatTensor(X_train[2, batch_start_idx: batch_end_idx, :])).cuda()
        batch_input_3 = Variable(torch.FloatTensor(X_train[3, batch_start_idx: batch_end_idx, :])).cuda()
        
        ### input K_beat
        batch_K_beat_0 = Variable(torch.FloatTensor(K_train_beat[0, batch_start_idx: batch_end_idx, :])).cuda()
        batch_K_beat_1 = Variable(torch.FloatTensor(K_train_beat[1, batch_start_idx: batch_end_idx, :])).cuda()
        batch_K_beat_2 = Variable(torch.FloatTensor(K_train_beat[2, batch_start_idx: batch_end_idx, :])).cuda()
        batch_K_beat_3 = Variable(torch.FloatTensor(K_train_beat[3, batch_start_idx: batch_end_idx, :])).cuda()

        ### input K_rhythm
        batch_K_rhythm_0 = Variable(torch.FloatTensor(K_train_rhythm[0, batch_start_idx: batch_end_idx, :])).cuda()
        batch_K_rhythm_1 = Variable(torch.FloatTensor(K_train_rhythm[1, batch_start_idx: batch_end_idx, :])).cuda()
        batch_K_rhythm_2 = Variable(torch.FloatTensor(K_train_rhythm[2, batch_start_idx: batch_end_idx, :])).cuda()
        batch_K_rhythm_3 = Variable(torch.FloatTensor(K_train_rhythm[3, batch_start_idx: batch_end_idx, :])).cuda()        
        
        ### input K_freq
        batch_K_freq = Variable(torch.FloatTensor(K_train_freq[:, batch_start_idx: batch_end_idx, :])).cuda()  
        
        ### gt
        batch_gt = Variable(torch.LongTensor(Y_train[batch_start_idx: batch_end_idx])).cuda()
        
        pred, _ = model(batch_input_0, batch_input_1, batch_input_2, batch_input_3, 
                        batch_K_beat_0, batch_K_beat_1, batch_K_beat_2, batch_K_beat_3, 
                        batch_K_rhythm_0, batch_K_rhythm_1, batch_K_rhythm_2, batch_K_rhythm_3, 
                        batch_K_freq)
        
        pred_all.append(pred.cpu().data.numpy())
        # print(pred, batch_gt)

        loss = loss_func(pred, batch_gt)
        loss_all.append(loss.cpu().data.numpy())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_start_idx = batch_start_idx + batch_size

    loss_res = np.mean(loss_all)
    print('epoch {0} '.format(epoch))
    print('loss ', np.mean(loss_all))
    print('train | ', end='')
    pred_all = np.concatenate(pred_all, axis=0)
    # print(Y_train.shape, pred_all.shape)
    res = evaluate(Y_train, pred_all)
    res['loss_res'] = loss_res
    res['pred_all'] = pred_all
    # res.append(loss_res)
    # res.append(pred_all)
    
    with open(log_file, 'a') as fout:
        print('epoch {0} '.format(epoch), 'train | ', res, file=fout)
        print('loss_all ', np.mean(loss_all), file=fout)
        
    return res
    

def test(model, batch_size, 
         X_test, Y_test, K_test_beat, K_test_rhythm, K_test_freq, 
         log_file):
    
    model.eval()
    
    n_test = len(Y_test)
    
    pred_all = []
    att_dic_all = []
    
    batch_start_idx = 0
    batch_end_idx = 0
    for _ in tqdm(range(n_test//batch_size+1), desc="test"):
    # while batch_end_idx < n_test:
        # print('.', end="")
        batch_end_idx = batch_end_idx + batch_size
        if batch_end_idx >= n_test:
            batch_end_idx = n_test
            
        ### input data
        batch_input_0 = Variable(torch.FloatTensor(X_test[0, batch_start_idx: batch_end_idx, :])).cuda()
        batch_input_1 = Variable(torch.FloatTensor(X_test[1, batch_start_idx: batch_end_idx, :])).cuda()
        batch_input_2 = Variable(torch.FloatTensor(X_test[2, batch_start_idx: batch_end_idx, :])).cuda()
        batch_input_3 = Variable(torch.FloatTensor(X_test[3, batch_start_idx: batch_end_idx, :])).cuda()
        
        ### input K_beat
        batch_K_beat_0 = Variable(torch.FloatTensor(K_test_beat[0, batch_start_idx: batch_end_idx, :])).cuda()
        batch_K_beat_1 = Variable(torch.FloatTensor(K_test_beat[1, batch_start_idx: batch_end_idx, :])).cuda()
        batch_K_beat_2 = Variable(torch.FloatTensor(K_test_beat[2, batch_start_idx: batch_end_idx, :])).cuda()
        batch_K_beat_3 = Variable(torch.FloatTensor(K_test_beat[3, batch_start_idx: batch_end_idx, :])).cuda()

        ### input K_rhythm
        batch_K_rhythm_0 = Variable(torch.FloatTensor(K_test_rhythm[0, batch_start_idx: batch_end_idx, :])).cuda()
        batch_K_rhythm_1 = Variable(torch.FloatTensor(K_test_rhythm[1, batch_start_idx: batch_end_idx, :])).cuda()
        batch_K_rhythm_2 = Variable(torch.FloatTensor(K_test_rhythm[2, batch_start_idx: batch_end_idx, :])).cuda()
        batch_K_rhythm_3 = Variable(torch.FloatTensor(K_test_rhythm[3, batch_start_idx: batch_end_idx, :])).cuda()
        
        ### input K_freq
        batch_K_freq = Variable(torch.FloatTensor(K_test_freq[:, batch_start_idx: batch_end_idx, :])).cuda()
        
        ### gt
        batch_gt = Variable(torch.LongTensor(Y_test[batch_start_idx: batch_end_idx])).cuda()

        pred, att_dic = model(batch_input_0, batch_input_1, batch_input_2, batch_input_3, 
                              batch_K_beat_0, batch_K_beat_1, batch_K_beat_2, batch_K_beat_3, 
                              batch_K_rhythm_0, batch_K_rhythm_1, batch_K_rhythm_2, batch_K_rhythm_3, 
                              batch_K_freq)
            
        for k, v in att_dic.items():
            att_dic[k] = v.cpu().data.numpy()
        att_dic_all.append(att_dic)
        pred_all.append(pred.cpu().data.numpy())

        batch_start_idx = batch_start_idx + batch_size

    print('test | ', end='')
    pred_all = np.concatenate(pred_all, axis=0)
    res = evaluate(Y_test, pred_all)
    res['pred_all'] = pred_all
    # res.append(pred_all)
    
    with open(log_file, 'a') as fout:
        print('test | ', res, file=fout)

    return res, att_dic_all

def run(data_path, output_path='/kaggle/working/'):

    n_epoch = 200
    lr = 0.003
    n_split = 50

    ##################################################################
    ### par
    ##################################################################
    run_id = 'mina_{0}'.format(strftime("%Y-%m-%d-%H-%M-%S", localtime()))
    directory = 'res/{0}'.format(run_id)
    try:
        os.stat('res/')
    except:
        os.mkdir('res/')    
    try:
        os.stat(directory)
    except:
        os.mkdir(directory)
    
    log_file = '{0}/log.txt'.format(directory)
    model_file = '/kaggle/input/dataset-mina/CPSC/mina.py'
    destination_file = os.path.join(directory, os.path.basename(model_file))  # Правильный путь

    copyfile(model_file, destination_file)  # Копируем корректно
    #copyfile(model_file, '{0}/{1}'.format(directory, model_file))

    n_dim = 3000
    batch_size = 128

    with open(log_file, 'a') as fout:
        print(run_id, file=fout)

    ##################################################################
    ### read data
    ##################################################################
    with open(os.path.join(output_path, 'mina_info.pkl'), 'rb') as fin:
        res = dill.load(fin)    
    Y_train = res['Y_train']
    Y_val = res['Y_val']
    Y_test = res['Y_test']
    print(Counter(Y_train), Counter(Y_val), Counter(Y_test))

    fin = open(os.path.join(output_path, 'mina_X_train.bin'), 'rb')
    X_train = np.load(fin)
    fin.close()
    fin = open(os.path.join(output_path, 'mina_X_val.bin'), 'rb')
    X_val = np.load(fin)
    fin.close()
    fin = open(os.path.join(output_path, 'mina_X_test.bin'), 'rb')
    X_test = np.load(fin)
    fin.close()
    X_train = np.swapaxes(X_train, 0, 1)
    X_val = np.swapaxes(X_val, 0, 1)
    X_test = np.swapaxes(X_test, 0, 1)
    print(X_train.shape, X_val.shape, X_test.shape)

    fin = open(os.path.join(output_path, 'mina_K_train_beat.bin'), 'rb')
    K_train_beat = np.load(fin)
    fin.close()
    fin = open(os.path.join(output_path, 'mina_K_val_beat.bin'), 'rb')
    K_val_beat = np.load(fin)
    fin.close()
    fin = open(os.path.join(output_path, 'mina_K_test_beat.bin'), 'rb')
    K_test_beat = np.load(fin)
    fin.close()
    with open(os.path.join(output_path, 'mina_knowledge.pkl'), 'rb') as fin:
        res = dill.load(fin)    
    K_train_rhythm = res['K_train_rhythm']
    K_train_freq = res['K_train_freq']
    K_val_rhythm = res['K_val_rhythm']
    K_val_freq = res['K_val_freq']
    K_test_rhythm = res['K_test_rhythm']
    K_test_freq = res['K_test_freq']
    K_train_beat = np.swapaxes(K_train_beat, 0, 1)
    K_train_rhythm = np.swapaxes(K_train_rhythm, 0, 1)
    K_train_freq = np.swapaxes(K_train_freq, 0, 1)
    K_val_beat = np.swapaxes(K_val_beat, 0, 1)
    K_val_rhythm = np.swapaxes(K_val_rhythm, 0, 1)
    K_val_freq = np.swapaxes(K_val_freq, 0, 1)
    K_test_beat = np.swapaxes(K_test_beat, 0, 1)
    K_test_rhythm = np.swapaxes(K_test_rhythm, 0, 1)
    K_test_freq = np.swapaxes(K_test_freq, 0, 1)
    print(K_train_beat.shape, K_train_rhythm.shape, K_train_freq.shape)
    print(K_val_beat.shape, K_val_rhythm.shape, K_val_freq.shape)
    print(K_test_beat.shape, K_test_rhythm.shape, K_test_freq.shape)

    print('load data done!')
    
    ##################################################################
    ### train
    ##################################################################

    n_channel = X_train.shape[0]
    print('n_channel:', n_channel)

    torch.cuda.manual_seed(0)

    model = Net(n_channel, n_dim, n_split)
    model.cuda()

    optimizer = optim.Adam(model.parameters(), lr=lr)
    # weight = Variable(torch.FloatTensor([n_train/cnter[0], n_train/cnter[1]])).cuda()
    loss_func = torch.nn.CrossEntropyLoss()

    train_res_list = []
    val_res_list = []
    test_res_list = []
    val_att_list = []
    test_att_list = []
    for epoch in range(n_epoch):
        tmp_train = train(model, optimizer, loss_func, epoch, batch_size, 
                          X_train, Y_train, K_train_beat, K_train_rhythm, K_train_freq, 
                          log_file)
        tmp_val, tmp_att_val = test(model, batch_size, 
                                    X_val, Y_val, K_val_beat, K_val_rhythm, K_val_freq, 
                                    log_file)
        tmp_test, tmp_att_test = test(model, batch_size, 
                                      X_test, Y_test, K_test_beat, K_test_rhythm, K_test_freq, 
                                      log_file)
        
        train_res_list.append(tmp_train)
        val_res_list.append(tmp_val)
        test_res_list.append(tmp_test)
        # val_att_list.append(tmp_att_val)
        test_att_list.append(tmp_att_test)
        torch.save(model, '{0}/model_{1}.pt'.format(directory, epoch))
    
    ##################################################################
    ### save results
    ##################################################################
    res_mat = []
    for i in range(n_epoch):
        train_res = train_res_list[i]
        val_res = val_res_list[i]
        test_res = test_res_list[i]
        res_mat.append([
            train_res[0], train_res[1], 
            val_res[0], val_res[1], 
            test_res[0], test_res[1]])
    res_mat = np.array(res_mat)

    res = {'train_res_list':train_res_list, 
           'val_res_list':val_res_list, 
           'test_res_list':test_res_list}
    with open('{0}/res.pkl'.format(directory), 'wb') as fout:
        dill.dump(res, fout)
    
    np.savetxt('{0}/res_mat.csv'.format(directory), res_mat, delimiter=',')
    
    try:
        res = {'test_att_list':test_att_list}
        with open('{0}/res_att.pkl'.format(directory), 'wb') as fout:
            dill.dump(res, fout)
    except:
        print('error in saving attention file')

In [1]:
# !rm -rf /kaggle/working/res/*

In [12]:
# prepare data
data_path = '/kaggle/input/dataset-mina/CPSC'
output_path = '/kaggle/working/'

In [20]:
preprocess_physionet(data_path)
make_data_physionet(data_path) 
make_knowledge_physionet(data_path)

Загружено 6877 меток, 10 классов.
Файлы: ['A0001' 'A0002' 'A0003' 'A0004' 'A0005']


Чтение .mat файлов: 100%|██████████| 6877/6877 [00:08<00:00, 850.75it/s]


Файл сохранен: /kaggle/working/challenge2018.pkl
Counter({0: 5157}) Counter({0: 688}) Counter({0: 1032})
after: 
Counter({0: 61884}) Counter({0: 8256}) Counter({0: 12384})


X_train_ml: 100%|██████████| 61884/61884 [00:47<00:00, 1309.13it/s]
X_val_ml: 100%|██████████| 8256/8256 [00:05<00:00, 1511.39it/s]
X_test_ml: 100%|██████████| 12384/12384 [00:08<00:00, 1503.73it/s]


(61884, 4, 3000) (8256, 4, 3000) (12384, 4, 3000)


compute_beat: 100%|██████████| 61884/61884 [00:07<00:00, 7747.56it/s]
compute_rhythm: 100%|██████████| 61884/61884 [05:16<00:00, 195.53it/s]
compute_freq: 100%|██████████| 61884/61884 [00:53<00:00, 1165.17it/s]
compute_beat: 100%|██████████| 8256/8256 [00:00<00:00, 15417.60it/s]
compute_rhythm: 100%|██████████| 8256/8256 [00:42<00:00, 192.57it/s]
compute_freq: 100%|██████████| 8256/8256 [00:07<00:00, 1160.61it/s]
compute_beat: 100%|██████████| 12384/12384 [00:00<00:00, 14889.24it/s]
compute_rhythm: 100%|██████████| 12384/12384 [01:04<00:00, 190.60it/s]
compute_freq: 100%|██████████| 12384/12384 [00:10<00:00, 1165.78it/s]


In [None]:
 # run
for i_run in range(1):
    run(data_path)

Counter({0: 61884}) Counter({0: 8256}) Counter({0: 12384})
(4, 61884, 3000) (4, 8256, 3000) (4, 12384, 3000)
(4, 61884, 3000) (4, 61884, 60) (4, 61884, 1)
(4, 8256, 3000) (4, 8256, 60) (4, 8256, 1)
(4, 12384, 3000) (4, 12384, 60) (4, 12384, 1)
load data done!
n_channel: 4


train: 100%|██████████| 484/484 [00:13<00:00, 34.61it/s]


epoch 0 
loss  0.32370442


test: 100%|██████████| 65/65 [00:00<00:00, 78.32it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.77it/s]




train: 100%|██████████| 484/484 [00:12<00:00, 37.38it/s]


epoch 1 
loss  0.31354338


test: 100%|██████████| 65/65 [00:00<00:00, 74.69it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 70.34it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 37.12it/s]


epoch 2 
loss  0.31342363


test: 100%|██████████| 65/65 [00:00<00:00, 76.47it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.69it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 37.01it/s]


epoch 3 
loss  0.31348503


test: 100%|██████████| 65/65 [00:00<00:00, 75.95it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.39it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.61it/s]


epoch 4 
loss  0.31339282


test: 100%|██████████| 65/65 [00:00<00:00, 74.16it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 77.13it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.64it/s]


epoch 5 
loss  0.31338903


test: 100%|██████████| 65/65 [00:00<00:00, 75.51it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.73it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.03it/s]


epoch 6 
loss  0.31336305


test: 100%|██████████| 65/65 [00:00<00:00, 75.73it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.28it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.28it/s]


epoch 7 
loss  0.3133765


test: 100%|██████████| 65/65 [00:00<00:00, 74.47it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.98it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.88it/s]


epoch 8 
loss  0.31335765


test: 100%|██████████| 65/65 [00:00<00:00, 75.51it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 77.14it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.27it/s]


epoch 9 
loss  0.31335917


test: 100%|██████████| 65/65 [00:00<00:00, 75.62it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.74it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.42it/s]


epoch 10 
loss  0.3133367


test: 100%|██████████| 65/65 [00:00<00:00, 76.06it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 77.04it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.54it/s]


epoch 11 
loss  0.31331518


test: 100%|██████████| 65/65 [00:00<00:00, 76.23it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.48it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.28it/s]


epoch 12 
loss  0.3133253


test: 100%|██████████| 65/65 [00:00<00:00, 76.34it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 78.13it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.40it/s]


epoch 13 
loss  0.3133171


test: 100%|██████████| 65/65 [00:00<00:00, 76.16it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.53it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.20it/s]


epoch 14 
loss  0.3133138


test: 100%|██████████| 65/65 [00:00<00:00, 75.88it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 77.00it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.33it/s]


epoch 15 
loss  0.31330368


test: 100%|██████████| 65/65 [00:00<00:00, 74.04it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.88it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.25it/s]


epoch 16 
loss  0.31329742


test: 100%|██████████| 65/65 [00:00<00:00, 75.82it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 77.58it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.33it/s]


epoch 17 
loss  0.31328914


test: 100%|██████████| 65/65 [00:00<00:00, 75.76it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 77.43it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.29it/s]


epoch 18 
loss  0.31328335


test: 100%|██████████| 65/65 [00:00<00:00, 75.59it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.77it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.45it/s]


epoch 19 
loss  0.31328058


test: 100%|██████████| 65/65 [00:00<00:00, 76.19it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.78it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.15it/s]


epoch 20 
loss  0.31327653


test: 100%|██████████| 65/65 [00:00<00:00, 75.18it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.76it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.48it/s]


epoch 21 
loss  0.31327227


test: 100%|██████████| 65/65 [00:00<00:00, 76.25it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 77.22it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.20it/s]


epoch 22 
loss  0.3132717


test: 100%|██████████| 65/65 [00:00<00:00, 75.84it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.08it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.20it/s]


epoch 23 
loss  0.31326804


test: 100%|██████████| 65/65 [00:00<00:00, 75.33it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 77.09it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.20it/s]


epoch 24 
loss  0.31326723


test: 100%|██████████| 65/65 [00:00<00:00, 73.29it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 77.31it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.43it/s]


epoch 25 
loss  0.31326786


test: 100%|██████████| 65/65 [00:00<00:00, 76.02it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 77.11it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.01it/s]


epoch 26 
loss  0.31326678


test: 100%|██████████| 65/65 [00:00<00:00, 75.80it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.14it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.47it/s]


epoch 27 
loss  0.31326547


test: 100%|██████████| 65/65 [00:00<00:00, 75.79it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.63it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.87it/s]


epoch 28 
loss  0.31326482


test: 100%|██████████| 65/65 [00:00<00:00, 73.52it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.31it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.46it/s]


epoch 29 
loss  0.313264


test: 100%|██████████| 65/65 [00:00<00:00, 72.54it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 74.69it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.28it/s]


epoch 30 
loss  0.3132635


test: 100%|██████████| 65/65 [00:00<00:00, 73.87it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.53it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.90it/s]


epoch 31 
loss  0.31326345


test: 100%|██████████| 65/65 [00:00<00:00, 74.20it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 73.46it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.76it/s]


epoch 32 
loss  0.31326333


test: 100%|██████████| 65/65 [00:00<00:00, 75.12it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.24it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.08it/s]


epoch 33 
loss  0.31326282


test: 100%|██████████| 65/65 [00:00<00:00, 75.54it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 77.04it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.99it/s]


epoch 34 
loss  0.3132623


test: 100%|██████████| 65/65 [00:00<00:00, 74.95it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.41it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.73it/s]


epoch 35 
loss  0.3132625


test: 100%|██████████| 65/65 [00:00<00:00, 74.66it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.34it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.77it/s]


epoch 36 
loss  0.3132623


test: 100%|██████████| 65/65 [00:00<00:00, 70.27it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.03it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.13it/s]


epoch 37 
loss  0.3132622


test: 100%|██████████| 65/65 [00:00<00:00, 75.03it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.77it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.19it/s]


epoch 38 
loss  0.31326225


test: 100%|██████████| 65/65 [00:00<00:00, 69.26it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 74.48it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.18it/s]


epoch 39 
loss  0.3132622


test: 100%|██████████| 65/65 [00:00<00:00, 75.05it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.83it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.12it/s]


epoch 40 
loss  0.31326216


test: 100%|██████████| 65/65 [00:00<00:00, 65.61it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.21it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.19it/s]


epoch 41 
loss  0.31326202


test: 100%|██████████| 65/65 [00:00<00:00, 76.13it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.87it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.24it/s]


epoch 42 
loss  0.31326202


test: 100%|██████████| 65/65 [00:00<00:00, 74.72it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 67.69it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.22it/s]


epoch 43 
loss  0.31326202


test: 100%|██████████| 65/65 [00:00<00:00, 75.97it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.65it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.46it/s]


epoch 44 
loss  0.31326202


test: 100%|██████████| 65/65 [00:00<00:00, 75.70it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 74.58it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.25it/s]


epoch 45 
loss  0.313262


test: 100%|██████████| 65/65 [00:00<00:00, 75.51it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.75it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.49it/s]


epoch 46 
loss  0.313262


test: 100%|██████████| 65/65 [00:00<00:00, 75.38it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.78it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.22it/s]


epoch 47 
loss  0.31326196


test: 100%|██████████| 65/65 [00:00<00:00, 75.38it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.46it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.50it/s]


epoch 48 
loss  0.31326196


test: 100%|██████████| 65/65 [00:00<00:00, 76.43it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 77.13it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.19it/s]


epoch 49 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 73.50it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.93it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.19it/s]


epoch 50 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 76.07it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.94it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.94it/s]


epoch 51 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 74.13it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.73it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.26it/s]


epoch 52 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 73.82it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.41it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.94it/s]


epoch 53 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 74.35it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.07it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.43it/s]


epoch 54 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 73.36it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.02it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.19it/s]


epoch 55 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 73.38it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.16it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.51it/s]


epoch 56 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 72.78it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 74.71it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.03it/s]


epoch 57 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 73.06it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.47it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.40it/s]


epoch 58 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 71.13it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.59it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.01it/s]


epoch 59 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 73.05it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.55it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.41it/s]


epoch 60 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 73.33it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.79it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.13it/s]


epoch 61 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 73.19it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.15it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.36it/s]


epoch 62 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 73.42it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.14it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.09it/s]


epoch 63 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 72.62it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.22it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.53it/s]


epoch 64 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 73.38it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.87it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.21it/s]


epoch 65 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 73.27it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 74.05it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.49it/s]


epoch 66 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 70.82it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.95it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.15it/s]


epoch 67 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 70.58it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.21it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.61it/s]


epoch 68 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 72.60it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.26it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.93it/s]


epoch 69 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 74.36it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.72it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.57it/s]


epoch 70 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 73.41it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 76.06it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.20it/s]


epoch 71 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 72.37it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.61it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.26it/s]


epoch 72 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 72.36it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.55it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.99it/s]


epoch 73 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 71.52it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 77.16it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.37it/s]


epoch 74 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 73.17it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 74.24it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.21it/s]


epoch 75 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 72.74it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.64it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.29it/s]


epoch 76 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 71.60it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.98it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.16it/s]


epoch 77 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 71.38it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 73.12it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.85it/s]


epoch 78 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 68.54it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 74.13it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.84it/s]


epoch 79 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 73.55it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.97it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.00it/s]


epoch 80 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 71.15it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 74.81it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.69it/s]


epoch 81 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 70.09it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 73.80it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.14it/s]


epoch 82 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 72.05it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.27it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.15it/s]


epoch 83 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 68.48it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.44it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.72it/s]


epoch 84 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 72.63it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.20it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.96it/s]


epoch 85 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 69.13it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.18it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 35.93it/s]


epoch 86 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 72.00it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.19it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.03it/s]


epoch 87 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 66.62it/s]




test: 100%|██████████| 97/97 [00:01<00:00, 75.04it/s]




train: 100%|██████████| 484/484 [00:13<00:00, 36.13it/s]


epoch 88 
loss  0.31326193


test: 100%|██████████| 65/65 [00:00<00:00, 65.24it/s]




test:  89%|████████▊ | 86/97 [06:17<10:17, 56.15s/it]