In [1]:
import numpy as np
import pandas as pd
import os 
import pickle
import sys
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
from tqdm import tqdm
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix
import warnings

torch.set_printoptions(profile="full")
np.set_printoptions(threshold=sys.maxsize)
warnings.filterwarnings("ignore")

In [2]:
from src.attn import FixedPositionalEncoding, LearnablePositionalEncoding, TemporalEmbedding, MultiheadAttention, Decoder
from src.loss import ContrastiveLoss, FocalLoss

In [3]:
from datetime import datetime

In [4]:
seed = 777
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(seed)

In [5]:
path = '/data/notebook/shared/MIMIC-IV'

In [6]:
with open(os.path.join(path, 'dict_types_nomedi_mimic_240408_clinic_3_years.pkl'), 'rb') as f:
    dtype_dict = pickle.load(f)
f.close()

with open(os.path.join('./data', 'preprocessed_nomedi_240423_clinic_3_years.pkl'), 'rb') as f:
    data_dict_d = pickle.load(f)
f.close()

In [7]:
remain_year_clinical = pd.read_csv(os.path.join(path, "240408_remain_year_clinical_diag_seq_4.csv"))
del remain_year_clinical['clinical_label']

In [8]:
len(dtype_dict)

12702

In [9]:
labels = []
code_labels = []
length_list = []
clinical_labels = []
code_length_list = []
for sample_id, visits in tqdm(data_dict_d.items()):
    # 레이블 추가
    label = visits['label']
    code_label = visits['code_label']
    clinical_label = visits['clinical_label']
    labels.append(label)
    code_labels.append(code_label)
    clinical_labels.append(clinical_labels)

100%|██████████| 8037/8037 [00:00<00:00, 702225.26it/s]


In [10]:
train_indices, test_indices, train_y, test_y = train_test_split(list(data_dict_d.keys()), labels, test_size=0.1, random_state=777, stratify=labels)
train_indices, valid_indices, valid_y, valid_y = train_test_split(train_indices, train_y, test_size=(len(test_indices)/len(train_indices)), random_state=777, stratify=train_y) 

In [11]:
train_data = {}
valid_data = {}
test_data = {}
for sample in tqdm(train_indices):
    train_data[sample] = data_dict_d[sample]

for sample in tqdm(valid_indices):
    valid_data[sample] = data_dict_d[sample]

for sample in tqdm(test_indices):
    test_data[sample] = data_dict_d[sample]

100%|██████████| 6429/6429 [00:00<00:00, 2235177.42it/s]
100%|██████████| 804/804 [00:00<00:00, 1726687.36it/s]
100%|██████████| 804/804 [00:00<00:00, 1722278.05it/s]


In [12]:
train_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(train_indices)].reset_index(drop=True)
valid_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(valid_indices)].reset_index(drop=True)
test_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(test_indices)].reset_index(drop=True)

In [16]:
class CustomDataset(Dataset):
    def __init__(self, data, clinical_df, scaler, mode='train'):       
        self.keys = list(data.keys())  # 딕셔너리의 키 목록 저장
        self.data = data  # 딕셔너리에서 데이터만 추출하여 저장
        self.scaler = scaler
        if mode == 'train':
            scaled_data = self.scaler.fit_transform(clinical_df.iloc[:, 2:])
            scaled_clinical_df = pd.DataFrame(scaled_data, columns = clinical_df.iloc[:, 2:].columns)
            self.scaled_clinical_df = pd.concat([clinical_df.iloc[:, :2], scaled_clinical_df], axis=1)
        else:
            scaled_data = self.scaler.transform(clinical_df.iloc[:, 2:])
            scaled_clinical_df = pd.DataFrame(scaled_data, columns = clinical_df.iloc[:, 2:].columns)
            self.scaled_clinical_df = pd.concat([clinical_df.iloc[:, :2], scaled_clinical_df], axis=1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # x는 현재 방문까지의 모든 방문 데이터, y는 다음 방문 데이터
        padding_temp = torch.zeros((1, len(self.data[self.keys[idx]]['code_index'][0])), dtype=torch.long)
        
        origin_visit = torch.tensor(self.data[self.keys[idx]]['code_index'], dtype=torch.long)
        origin_mask = torch.tensor(self.data[self.keys[idx]]['seq_mask'], dtype=torch.long)
        origin_mask_final = torch.tensor(self.data[self.keys[idx]]['seq_mask_final'], dtype=torch.long)
        origin_mask_code = torch.tensor(self.data[self.keys[idx]]['seq_mask_code'], dtype=torch.long)
        
        next_visit = torch.cat((origin_visit[1:], padding_temp), dim=0)
        next_mask =torch.cat((origin_mask[1:], torch.tensor([0], dtype=torch.long)), dim=0)
        next_mask_code = torch.cat((origin_mask_code[1:], padding_temp), dim=0)
        
        clinical_data = self.scaled_clinical_df.loc[self.scaled_clinical_df['subject_id'] == self.keys[idx], self.scaled_clinical_df.columns[2:]].values
        clinical_data = torch.tensor(clinical_data, dtype=torch.float)
        visit_index = torch.tensor(self.data[self.keys[idx]]['year_onehot'], dtype=torch.long)
        last_visit_index = torch.tensor(self.data[self.keys[idx]]['last_year_onehot'], dtype=torch.float)
        time_feature = torch.tensor(self.data[self.keys[idx]]['time_feature'], dtype=torch.long)
        label_per_sample = self.data[self.keys[idx]]['label']
        key_per_sample = self.keys[idx]  # 해당 샘플의 키
        # 키 값도 함께 반환
        return {'sample_id': key_per_sample, 'origin_visit': origin_visit, 'next_visit': next_visit,\
                'origin_mask': origin_mask, 'origin_mask_final': origin_mask_final, 'next_mask': next_mask, \
                'origin_mask_code': origin_mask_code, 'next_mask_code': next_mask_code, 'time_feature': time_feature, \
                'clinical_data': clinical_data, 'visit_index': visit_index, 'last_visit_index': last_visit_index, \
                'label': label_per_sample}

In [17]:
scaler = StandardScaler()
train_dataset = CustomDataset(train_data, train_clinical, scaler, mode='train')
valid_dataset = CustomDataset(valid_data, valid_clinical, scaler, mode='valid')
test_dataset = CustomDataset(test_data, test_clinical, scaler, mode='test')

In [18]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=512, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

In [30]:
class CustomTransformerModel(nn.Module):
    def __init__(self, code_size, ninp, nhead, nhid, nlayers, dropout=0.5, device='cuda:0', pe='fixed'):
        super(CustomTransformerModel, self).__init__()
        self.ninp = ninp  # 축소된 차원 및 Transformer 입력 차원
        self.device = device
        # 차원 축소를 위한 선형 레이어
        self.pre_embedding = nn.Embedding(code_size, self.ninp)
        self.pe = pe
        if self.pe == 'fixed':
            self.pos_encoder = FixedPositionalEncoding(ninp, dropout)
        elif self.pe == 'time_feature':
            self.pos_encoder = TemporalEmbedding(ninp, embed_type='embed', dropout=dropout)
        else:
            self.pos_encoder = LearnablePositionalEncoding(ninp, dropout)
            
        self.transformer_decoder = Decoder(ninp, nhead, nhid, nlayers, dropout)
        self.decoder = nn.Linear(ninp, ninp, bias=False)  # 최종 출력 차원을 설정 (여기서는 ninp로 설정)
        self.classification_layer = nn.Linear(ninp*2, 2)
        self.clinical_transform = nn.Linear(18, ninp) 
        self.cross_attn = MultiheadAttention(ninp, nhead, dropout=0.3)
        self.init_weights()
        
    def init_weights(self):
        nn.init.xavier_uniform_(self.pre_embedding.weight)
        nn.init.xavier_uniform_(self.decoder.weight)
        nn.init.xavier_uniform_(self.classification_layer.weight)
    
    def forward(self, batch_data):
        # 차원 축소
        origin_visit = batch_data['origin_visit'].to(self.device)
        next_visit = batch_data['next_visit'].to(self.device)
        clinical_tensor = batch_data['clinical_data'].to(self.device)
        mask_code = batch_data['origin_mask_code'].unsqueeze(3).to(self.device)
        next_mask_code = batch_data['next_mask_code'].unsqueeze(3).to(self.device)
        mask_final = batch_data['origin_mask_final'].unsqueeze(2).to(self.device)
        mask_final_year = batch_data['last_visit_index'].to(self.device)
                
        origin_visit_emb = (self.pre_embedding(origin_visit) * mask_code).sum(dim=2)
        next_visit_emb = (self.pre_embedding(next_visit) * next_mask_code).sum(dim=2)
        # 위치 인코딩 및 Transformer 디코더 적용
        if self.pe == 'time_feature':
            time_feature = batch_data['time_feature'].to(self.device)
            src = self.pos_encoder(origin_visit_emb, time_feature)
        else:
            src = self.pos_encoder(origin_visit_emb)
        
        seq_len = src.shape[1]
        src_mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=self.device), diagonal=1)
        output = self.transformer_decoder(src, attention_mask=src_mask)
        output = self.decoder(output)
        
        output_batch, output_visit_num, output_dim = output.size()
        next_visit_output_batch, next_visit_output_num, next_visit_output_dim = next_visit_emb.size()
        
        # print(output[:, :torch.,:])
        temp_output = output.reshape(-1, output_dim)
        temp_next_visit = next_visit_emb.reshape(-1, next_visit_output_dim)
        final_visit = (output * mask_final).sum(dim=1)
        year_emb = torch.bmm(output.transpose(1,2), mask_final_year).transpose(1,2)[:,:2,:]
        
        transformed_clinical = self.clinical_transform(clinical_tensor)
        mixed_output, mixed_cross_attn = self.cross_attn(year_emb, transformed_clinical, transformed_clinical)
        mixed_output = mixed_output[:,-1,:]
        mixed_final_emb = torch.cat((final_visit, mixed_output), dim=-1)
        classification_result = self.classification_layer(mixed_final_emb)    
        return temp_output, temp_next_visit, classification_result, final_visit

In [31]:
lr = 0.001
ninp = 64
nhid = 256
nlayer = 6
gamma = 0.5
model_name = 'no_center_total_label'
pe = 'time_feature'

In [32]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_1 = CustomTransformerModel(len(dtype_dict), ninp=ninp, nhead=8, nhid=nhid, nlayers=nlayer, dropout=0.2, device=device, pe=pe).to(device)
# model_2 = CustomTransformerModel(len(dtype_dict), ninp=ninp, nhead=8, nhid=nhid, nlayers=nlayer, dropout=0.6, device=device, pe='fixed').to(device)

optimizer = torch.optim.Adam(model_1.parameters(), lr=lr, weight_decay=1e-5)

# criterion = nn.BCEWithLogitsLoss()
criterion = FocalLoss(2, gamma=gamma)
# cosine_loss = CosineSimilarityLoss()
cosine_embedding_loss = nn.CosineEmbeddingLoss()
# const_loss = ContrastiveLoss(temperature=0.05)

In [33]:
# 얼리 스타핑 설정
num_epochs = 200
patience = 30
best_loss = float("inf")
counter = 0
epoch_temp = 0

best_f1 = 0.0
best_combined_score = 0.0
best_epoch = 0

view_total_loss = []
view_cos_loss = []
view_classi_loss = []

cos_lambda = 1
classi_lambda = 1


for epoch in tqdm(range(num_epochs)):
    model_1.train()

    total_train_loss = 0
    total_cos_loss = 0
    total_classi_loss = 0
    total_const_loss = 0
    # key_per_sample, origin_visit, origin_mask, origin_mask_code, origin_mask_final, next_visit, next_mask, next_mask_code, next_mask_final, label_per_sample
    for batch_data in train_loader:
        tr_labels = batch_data['label'].to(device)
        optimizer.zero_grad()
        output_1, next_visit_output_1, final_visit_classification_1, final_visit_1 = model_1(batch_data)
        y = torch.ones(output_1.size(0), dtype=torch.float, device=device)
        y = y.to(device)   
        cosine_loss_mean_1 = cosine_embedding_loss(output_1, next_visit_output_1, y)
        classification_loss_1 = criterion(final_visit_classification_1.squeeze(), tr_labels.long())
        loss = (cos_lambda * cosine_loss_mean_1) + (classi_lambda * classification_loss_1)

        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        total_cos_loss += (cosine_loss_mean_1.item())
        total_classi_loss += (classification_loss_1.item())
        # total_const_loss += train_const_loss.item()

        view_total_loss.append(total_train_loss)
        view_cos_loss.append(total_cos_loss)
        view_classi_loss.append(total_classi_loss)

    # 평균 손실 계산
    avg_train_loss = total_train_loss / len(train_loader)
    avg_cos_loss = total_cos_loss / len(train_loader)
    avg_classi_loss = total_classi_loss / len(train_loader)

    print("Epoch:", epoch+1, ", Total loss:", round(avg_train_loss, 4), ", cos_loss : ",\
          round(avg_cos_loss, 4), ", classification_loss :",round(avg_classi_loss, 4))

    model_1.eval()
    total_val_loss = 0
    total_val_cos_loss = 0
    total_val_classi_loss = 0
    total_val_const_loss = 0

    val_labels_list = []
    val_predictions_list = []
    val_probabilities_list = []


    with torch.no_grad():
        for batch_data in valid_loader:
            val_labels = batch_data['label'].to(device)
            val_output_1, val_next_visit_output_1, val_final_visit_classification_1, val_final_visit_1 = model_1(batch_data)
            
            y_val = torch.ones(val_output_1.size(0), dtype=torch.float, device=device) 
            val_cosine_loss_mean_1 = cosine_embedding_loss(val_output_1, val_next_visit_output_1, y_val)
            
            classification_loss_val_1 = criterion(val_final_visit_classification_1.squeeze(), val_labels.long())            

            val_loss = (cos_lambda * val_cosine_loss_mean_1) + (classi_lambda * classification_loss_val_1) 
            
            total_val_loss += val_loss.item()
            total_val_cos_loss += (val_cosine_loss_mean_1.item())
            total_val_classi_loss += (classification_loss_val_1.item())
            # total_val_const_loss += val_const_loss.item()

            val_probs = F.softmax(val_final_visit_classification_1)
            val_predictions = torch.max(val_probs, 1)[1].view((len(val_labels),))

            val_labels_list.extend(val_labels.view(-1).cpu().numpy())
            val_predictions_list.extend(val_predictions.cpu().numpy())
            
            ########
            val_probabilities_list.extend(val_probs[:,1].cpu().numpy())
            ######## 
            
            
        avg_val_loss = total_val_loss / len(valid_loader)
        avg_val_cos_loss = total_val_cos_loss / len(valid_loader)
        avg_val_classi_loss = total_val_classi_loss / len(valid_loader)

    # 성능 지표 계산
    accuracy = accuracy_score(val_labels_list, val_predictions_list)
    auc = roc_auc_score(val_labels_list, val_probabilities_list)
    f1 = f1_score(val_labels_list, val_predictions_list)
    precision = precision_score(val_labels_list,val_predictions_list)
    recall = recall_score(val_labels_list, val_predictions_list)

    print("Epoch:", epoch+1, ", valid loss:", round(avg_val_loss, 4), ", cos_loss : ",\
          round(avg_val_cos_loss, 4), ", classification_loss :",round(avg_val_classi_loss, 4))
    print("Accuracy:", round(accuracy,4), ", AUC: ", round(auc,4), ", F1: ", round(f1,4), ", Precision: ", round(precision,4), ", recall: ", round(recall,4))

    current_auc = roc_auc_score(val_labels_list, val_probabilities_list)
    current_f1 = f1_score(val_labels_list, val_predictions_list)
    current_combined_score = current_f1

    # 모델 저장 경로 설정
    model_save_path = f'results'
    date_dir = datetime.today().strftime("%Y%m%d")
    model_time =  datetime.today().strftime("%H%M%S")
    # 학습률과 분류 가중치를 파일명에 포함시키기 위한 문자열 포맷
    model_filename_format = f'model_lr{lr}_classi{classi_lambda}_dim{ninp}_hid{nhid}_layer{nlayer}_epoch{{epoch}}_{{model}}_{{pe}}_{{gamma}}_{model_time}.pth'

    # 모델 저장 폴더가 없으면 생성
    os.makedirs(os.path.join(model_save_path, date_dir), exist_ok=True)

    if current_combined_score > best_combined_score:
        best_combined_score = current_combined_score
        best_epoch = epoch
        counter = 0
        # 모델 저장 경로와 파일명을 결합하여 전체 파일 경로 생성
        model_1_save_path = os.path.join(model_save_path, date_dir, model_filename_format.format(epoch=best_epoch, model=model_name, pe=pe, gamma=gamma))
        # 모델 저장
        torch.save(model_1.state_dict(), f'{model_1_save_path}')
    else:
        counter += 1

    if counter >= patience:
        print("Early stopping triggered at epoch {}".format(best_epoch))
        print(f"Best Combined Score (AUC): {best_combined_score:.4f}")
        break

  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1 , Total loss: 1.3366 , cos_loss :  0.9664 , classification_loss : 0.3702


  0%|          | 1/200 [00:03<12:13,  3.69s/it]

Epoch: 1 , valid loss: 1.1399 , cos_loss :  0.9341 , classification_loss : 0.2059
Accuracy: 0.9129 , AUC:  0.6121 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 2 , Total loss: 1.1452 , cos_loss :  0.927 , classification_loss : 0.2182


  1%|          | 2/200 [00:07<11:33,  3.50s/it]

Epoch: 2 , valid loss: 1.122 , cos_loss :  0.9178 , classification_loss : 0.2042
Accuracy: 0.9129 , AUC:  0.5535 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 3 , Total loss: 1.132 , cos_loss :  0.9186 , classification_loss : 0.2134


  2%|▏         | 3/200 [00:10<11:31,  3.51s/it]

Epoch: 3 , valid loss: 1.1164 , cos_loss :  0.9143 , classification_loss : 0.2021
Accuracy: 0.9129 , AUC:  0.5755 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 4 , Total loss: 1.1251 , cos_loss :  0.9162 , classification_loss : 0.2088


  2%|▏         | 4/200 [00:13<11:18,  3.46s/it]

Epoch: 4 , valid loss: 1.1158 , cos_loss :  0.9131 , classification_loss : 0.2027
Accuracy: 0.9129 , AUC:  0.5776 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 5 , Total loss: 1.1261 , cos_loss :  0.9154 , classification_loss : 0.2106


  2%|▎         | 5/200 [00:17<11:19,  3.49s/it]

Epoch: 5 , valid loss: 1.1249 , cos_loss :  0.9125 , classification_loss : 0.2123
Accuracy: 0.9129 , AUC:  0.583 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 6 , Total loss: 1.1229 , cos_loss :  0.9152 , classification_loss : 0.2077


  3%|▎         | 6/200 [00:20<11:10,  3.46s/it]

Epoch: 6 , valid loss: 1.111 , cos_loss :  0.9122 , classification_loss : 0.1988
Accuracy: 0.9129 , AUC:  0.591 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 7 , Total loss: 1.1188 , cos_loss :  0.915 , classification_loss : 0.2038


  4%|▎         | 7/200 [00:24<11:03,  3.44s/it]

Epoch: 7 , valid loss: 1.11 , cos_loss :  0.9119 , classification_loss : 0.1981
Accuracy: 0.9129 , AUC:  0.5973 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 8 , Total loss: 1.119 , cos_loss :  0.9145 , classification_loss : 0.2045


  4%|▍         | 8/200 [00:27<11:04,  3.46s/it]

Epoch: 8 , valid loss: 1.1102 , cos_loss :  0.9118 , classification_loss : 0.1984
Accuracy: 0.9129 , AUC:  0.5952 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 9 , Total loss: 1.1147 , cos_loss :  0.9147 , classification_loss : 0.1999


  4%|▍         | 9/200 [00:31<10:54,  3.42s/it]

Epoch: 9 , valid loss: 1.1085 , cos_loss :  0.9117 , classification_loss : 0.1968
Accuracy: 0.9129 , AUC:  0.6127 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 10 , Total loss: 1.1144 , cos_loss :  0.9145 , classification_loss : 0.1998


  5%|▌         | 10/200 [00:34<10:45,  3.40s/it]

Epoch: 10 , valid loss: 1.1098 , cos_loss :  0.9117 , classification_loss : 0.1981
Accuracy: 0.9129 , AUC:  0.6405 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 11 , Total loss: 1.1111 , cos_loss :  0.9143 , classification_loss : 0.1969


  6%|▌         | 11/200 [00:37<10:45,  3.41s/it]

Epoch: 11 , valid loss: 1.1017 , cos_loss :  0.9118 , classification_loss : 0.19
Accuracy: 0.9129 , AUC:  0.6918 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 12 , Total loss: 1.0851 , cos_loss :  0.9156 , classification_loss : 0.1694


  6%|▌         | 12/200 [00:41<10:39,  3.40s/it]

Epoch: 12 , valid loss: 1.0643 , cos_loss :  0.9155 , classification_loss : 0.1487
Accuracy: 0.9415 , AUC:  0.7091 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 13 , Total loss: 1.0627 , cos_loss :  0.9171 , classification_loss : 0.1455


  6%|▋         | 13/200 [00:44<10:32,  3.38s/it]

Epoch: 13 , valid loss: 1.0611 , cos_loss :  0.9145 , classification_loss : 0.1467
Accuracy: 0.9415 , AUC:  0.7288 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 14 , Total loss: 1.0578 , cos_loss :  0.9165 , classification_loss : 0.1413


  7%|▋         | 14/200 [00:48<10:34,  3.41s/it]

Epoch: 14 , valid loss: 1.0628 , cos_loss :  0.9141 , classification_loss : 0.1488
Accuracy: 0.9415 , AUC:  0.7351 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 15 , Total loss: 1.0555 , cos_loss :  0.9159 , classification_loss : 0.1396


  8%|▊         | 15/200 [00:51<10:28,  3.40s/it]

Epoch: 15 , valid loss: 1.0607 , cos_loss :  0.9142 , classification_loss : 0.1465
Accuracy: 0.9415 , AUC:  0.7463 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 16 , Total loss: 1.0542 , cos_loss :  0.9159 , classification_loss : 0.1384


  8%|▊         | 16/200 [00:55<10:31,  3.43s/it]

Epoch: 16 , valid loss: 1.0628 , cos_loss :  0.9138 , classification_loss : 0.149
Accuracy: 0.9415 , AUC:  0.7511 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 17 , Total loss: 1.0539 , cos_loss :  0.9156 , classification_loss : 0.1383


  8%|▊         | 17/200 [00:58<10:24,  3.41s/it]

Epoch: 17 , valid loss: 1.0603 , cos_loss :  0.9141 , classification_loss : 0.1462
Accuracy: 0.9415 , AUC:  0.7559 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 18 , Total loss: 1.0519 , cos_loss :  0.9156 , classification_loss : 0.1362


  9%|▉         | 18/200 [01:01<10:18,  3.40s/it]

Epoch: 18 , valid loss: 1.0611 , cos_loss :  0.9139 , classification_loss : 0.1472
Accuracy: 0.9415 , AUC:  0.7587 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 19 , Total loss: 1.0506 , cos_loss :  0.9155 , classification_loss : 0.1351


 10%|▉         | 19/200 [01:05<10:18,  3.42s/it]

Epoch: 19 , valid loss: 1.0606 , cos_loss :  0.9136 , classification_loss : 0.147
Accuracy: 0.9415 , AUC:  0.7627 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 20 , Total loss: 1.0568 , cos_loss :  0.9153 , classification_loss : 0.1414


 10%|█         | 20/200 [01:08<10:11,  3.40s/it]

Epoch: 20 , valid loss: 1.0587 , cos_loss :  0.9139 , classification_loss : 0.1448
Accuracy: 0.9415 , AUC:  0.7706 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 21 , Total loss: 1.0507 , cos_loss :  0.9158 , classification_loss : 0.1349


 10%|█         | 21/200 [01:11<10:07,  3.39s/it]

Epoch: 21 , valid loss: 1.0599 , cos_loss :  0.9134 , classification_loss : 0.1464
Accuracy: 0.9415 , AUC:  0.7608 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 22 , Total loss: 1.0514 , cos_loss :  0.9152 , classification_loss : 0.1362


 11%|█         | 22/200 [01:15<10:09,  3.43s/it]

Epoch: 22 , valid loss: 1.0602 , cos_loss :  0.9136 , classification_loss : 0.1465
Accuracy: 0.9415 , AUC:  0.7615 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 23 , Total loss: 1.0485 , cos_loss :  0.9155 , classification_loss : 0.133


 12%|█▏        | 23/200 [01:19<10:16,  3.49s/it]

Epoch: 23 , valid loss: 1.0593 , cos_loss :  0.9138 , classification_loss : 0.1455
Accuracy: 0.9415 , AUC:  0.7686 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 24 , Total loss: 1.0457 , cos_loss :  0.9155 , classification_loss : 0.1302


 12%|█▏        | 24/200 [01:22<10:08,  3.45s/it]

Epoch: 24 , valid loss: 1.0638 , cos_loss :  0.9138 , classification_loss : 0.15
Accuracy: 0.9415 , AUC:  0.7618 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 25 , Total loss: 1.0495 , cos_loss :  0.9154 , classification_loss : 0.1341


 12%|█▎        | 25/200 [01:25<10:07,  3.47s/it]

Epoch: 25 , valid loss: 1.0595 , cos_loss :  0.9134 , classification_loss : 0.146
Accuracy: 0.9415 , AUC:  0.77 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 26 , Total loss: 1.0473 , cos_loss :  0.9154 , classification_loss : 0.1319


 13%|█▎        | 26/200 [01:29<09:59,  3.45s/it]

Epoch: 26 , valid loss: 1.0607 , cos_loss :  0.913 , classification_loss : 0.1477
Accuracy: 0.9415 , AUC:  0.7654 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 27 , Total loss: 1.0455 , cos_loss :  0.9156 , classification_loss : 0.1299


 14%|█▎        | 27/200 [01:32<09:52,  3.42s/it]

Epoch: 27 , valid loss: 1.0639 , cos_loss :  0.914 , classification_loss : 0.1499
Accuracy: 0.9415 , AUC:  0.7791 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 28 , Total loss: 1.0464 , cos_loss :  0.9156 , classification_loss : 0.1307


 14%|█▍        | 28/200 [01:36<09:53,  3.45s/it]

Epoch: 28 , valid loss: 1.0626 , cos_loss :  0.9134 , classification_loss : 0.1492
Accuracy: 0.9428 , AUC:  0.7646 , F1:  0.5106 , Precision:  1.0 , recall:  0.3429
Epoch: 29 , Total loss: 1.0483 , cos_loss :  0.9155 , classification_loss : 0.1328


 14%|█▍        | 29/200 [01:39<09:49,  3.45s/it]

Epoch: 29 , valid loss: 1.0709 , cos_loss :  0.9137 , classification_loss : 0.1572
Accuracy: 0.9415 , AUC:  0.7723 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 30 , Total loss: 1.0448 , cos_loss :  0.9156 , classification_loss : 0.1292


 15%|█▌        | 30/200 [01:43<10:05,  3.56s/it]

Epoch: 30 , valid loss: 1.0598 , cos_loss :  0.9134 , classification_loss : 0.1464
Accuracy: 0.9428 , AUC:  0.7605 , F1:  0.5106 , Precision:  1.0 , recall:  0.3429
Epoch: 31 , Total loss: 1.0412 , cos_loss :  0.9155 , classification_loss : 0.1257


 16%|█▌        | 31/200 [01:47<10:02,  3.56s/it]

Epoch: 31 , valid loss: 1.0622 , cos_loss :  0.9139 , classification_loss : 0.1483
Accuracy: 0.9415 , AUC:  0.7659 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 32 , Total loss: 1.042 , cos_loss :  0.9155 , classification_loss : 0.1266


 16%|█▌        | 32/200 [01:50<09:50,  3.52s/it]

Epoch: 32 , valid loss: 1.0597 , cos_loss :  0.9134 , classification_loss : 0.1463
Accuracy: 0.9415 , AUC:  0.7729 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 33 , Total loss: 1.0381 , cos_loss :  0.9155 , classification_loss : 0.1226


 16%|█▋        | 33/200 [01:53<09:46,  3.51s/it]

Epoch: 33 , valid loss: 1.0601 , cos_loss :  0.9137 , classification_loss : 0.1464
Accuracy: 0.9428 , AUC:  0.7649 , F1:  0.5106 , Precision:  1.0 , recall:  0.3429
Epoch: 34 , Total loss: 1.0372 , cos_loss :  0.9158 , classification_loss : 0.1214


 17%|█▋        | 34/200 [01:57<09:36,  3.47s/it]

Epoch: 34 , valid loss: 1.0569 , cos_loss :  0.9138 , classification_loss : 0.1431
Accuracy: 0.944 , AUC:  0.767 , F1:  0.5263 , Precision:  1.0 , recall:  0.3571
Epoch: 35 , Total loss: 1.0341 , cos_loss :  0.9156 , classification_loss : 0.1185


 18%|█▊        | 35/200 [02:00<09:29,  3.45s/it]

Epoch: 35 , valid loss: 1.0562 , cos_loss :  0.9139 , classification_loss : 0.1423
Accuracy: 0.944 , AUC:  0.7697 , F1:  0.5263 , Precision:  1.0 , recall:  0.3571
Epoch: 36 , Total loss: 1.0345 , cos_loss :  0.9157 , classification_loss : 0.1188


 18%|█▊        | 36/200 [02:04<09:28,  3.47s/it]

Epoch: 36 , valid loss: 1.0641 , cos_loss :  0.914 , classification_loss : 0.1501
Accuracy: 0.9428 , AUC:  0.7766 , F1:  0.5106 , Precision:  1.0 , recall:  0.3429
Epoch: 37 , Total loss: 1.038 , cos_loss :  0.9156 , classification_loss : 0.1224


 18%|█▊        | 37/200 [02:07<09:20,  3.44s/it]

Epoch: 37 , valid loss: 1.0644 , cos_loss :  0.9141 , classification_loss : 0.1503
Accuracy: 0.9428 , AUC:  0.7591 , F1:  0.5106 , Precision:  1.0 , recall:  0.3429
Epoch: 38 , Total loss: 1.0313 , cos_loss :  0.9157 , classification_loss : 0.1156


 19%|█▉        | 38/200 [02:10<09:13,  3.42s/it]

Epoch: 38 , valid loss: 1.0561 , cos_loss :  0.9138 , classification_loss : 0.1423
Accuracy: 0.949 , AUC:  0.7712 , F1:  0.5859 , Precision:  1.0 , recall:  0.4143
Epoch: 39 , Total loss: 1.0295 , cos_loss :  0.9159 , classification_loss : 0.1136


 20%|█▉        | 39/200 [02:14<09:13,  3.44s/it]

Epoch: 39 , valid loss: 1.0705 , cos_loss :  0.9141 , classification_loss : 0.1565
Accuracy: 0.944 , AUC:  0.7655 , F1:  0.5263 , Precision:  1.0 , recall:  0.3571
Epoch: 40 , Total loss: 1.0362 , cos_loss :  0.9159 , classification_loss : 0.1204


 20%|██        | 40/200 [02:17<09:07,  3.42s/it]

Epoch: 40 , valid loss: 1.0628 , cos_loss :  0.9149 , classification_loss : 0.1479
Accuracy: 0.944 , AUC:  0.7772 , F1:  0.5263 , Precision:  1.0 , recall:  0.3571
Epoch: 41 , Total loss: 1.0342 , cos_loss :  0.9164 , classification_loss : 0.1178


 20%|██        | 41/200 [02:21<09:01,  3.41s/it]

Epoch: 41 , valid loss: 1.0577 , cos_loss :  0.9141 , classification_loss : 0.1437
Accuracy: 0.9478 , AUC:  0.7748 , F1:  0.58 , Precision:  0.9667 , recall:  0.4143
Epoch: 42 , Total loss: 1.0345 , cos_loss :  0.9156 , classification_loss : 0.1189


 21%|██        | 42/200 [02:24<09:03,  3.44s/it]

Epoch: 42 , valid loss: 1.0582 , cos_loss :  0.9144 , classification_loss : 0.1438
Accuracy: 0.9478 , AUC:  0.7697 , F1:  0.5714 , Precision:  1.0 , recall:  0.4
Epoch: 43 , Total loss: 1.0269 , cos_loss :  0.9159 , classification_loss : 0.111


 22%|██▏       | 43/200 [02:28<08:57,  3.43s/it]

Epoch: 43 , valid loss: 1.06 , cos_loss :  0.9147 , classification_loss : 0.1453
Accuracy: 0.949 , AUC:  0.7698 , F1:  0.5859 , Precision:  1.0 , recall:  0.4143
Epoch: 44 , Total loss: 1.0281 , cos_loss :  0.9159 , classification_loss : 0.1121


 22%|██▏       | 44/200 [02:31<08:54,  3.43s/it]

Epoch: 44 , valid loss: 1.0574 , cos_loss :  0.9142 , classification_loss : 0.1432
Accuracy: 0.9465 , AUC:  0.7856 , F1:  0.5657 , Precision:  0.9655 , recall:  0.4
Epoch: 45 , Total loss: 1.0249 , cos_loss :  0.9158 , classification_loss : 0.1091


 22%|██▎       | 45/200 [02:35<08:55,  3.46s/it]

Epoch: 45 , valid loss: 1.0613 , cos_loss :  0.9138 , classification_loss : 0.1475
Accuracy: 0.949 , AUC:  0.7687 , F1:  0.5941 , Precision:  0.9677 , recall:  0.4286
Epoch: 46 , Total loss: 1.0316 , cos_loss :  0.9157 , classification_loss : 0.1159


 23%|██▎       | 46/200 [02:38<08:48,  3.43s/it]

Epoch: 46 , valid loss: 1.0625 , cos_loss :  0.9141 , classification_loss : 0.1484
Accuracy: 0.9478 , AUC:  0.7787 , F1:  0.5962 , Precision:  0.9118 , recall:  0.4429
Epoch: 47 , Total loss: 1.0283 , cos_loss :  0.9155 , classification_loss : 0.1128


 24%|██▎       | 47/200 [02:41<08:41,  3.41s/it]

Epoch: 47 , valid loss: 1.0607 , cos_loss :  0.9139 , classification_loss : 0.1469
Accuracy: 0.9465 , AUC:  0.7742 , F1:  0.5657 , Precision:  0.9655 , recall:  0.4
Epoch: 48 , Total loss: 1.0278 , cos_loss :  0.9155 , classification_loss : 0.1123


 24%|██▍       | 48/200 [02:45<08:41,  3.43s/it]

Epoch: 48 , valid loss: 1.0635 , cos_loss :  0.9143 , classification_loss : 0.1492
Accuracy: 0.9478 , AUC:  0.774 , F1:  0.5714 , Precision:  1.0 , recall:  0.4
Epoch: 49 , Total loss: 1.0294 , cos_loss :  0.9157 , classification_loss : 0.1138


 24%|██▍       | 49/200 [02:48<08:34,  3.41s/it]

Epoch: 49 , valid loss: 1.0631 , cos_loss :  0.914 , classification_loss : 0.1491
Accuracy: 0.9478 , AUC:  0.777 , F1:  0.5714 , Precision:  1.0 , recall:  0.4
Epoch: 50 , Total loss: 1.0248 , cos_loss :  0.9156 , classification_loss : 0.1092


 25%|██▌       | 50/200 [02:52<08:28,  3.39s/it]

Epoch: 50 , valid loss: 1.064 , cos_loss :  0.9137 , classification_loss : 0.1503
Accuracy: 0.9478 , AUC:  0.7729 , F1:  0.58 , Precision:  0.9667 , recall:  0.4143
Epoch: 51 , Total loss: 1.0278 , cos_loss :  0.9157 , classification_loss : 0.1121


 26%|██▌       | 51/200 [02:55<08:29,  3.42s/it]

Epoch: 51 , valid loss: 1.0623 , cos_loss :  0.9145 , classification_loss : 0.1478
Accuracy: 0.9453 , AUC:  0.7835 , F1:  0.551 , Precision:  0.9643 , recall:  0.3857
Epoch: 52 , Total loss: 1.0235 , cos_loss :  0.9159 , classification_loss : 0.1076


 26%|██▌       | 52/200 [02:58<08:23,  3.40s/it]

Epoch: 52 , valid loss: 1.0677 , cos_loss :  0.9142 , classification_loss : 0.1535
Accuracy: 0.9478 , AUC:  0.7547 , F1:  0.5882 , Precision:  0.9375 , recall:  0.4286
Epoch: 53 , Total loss: 1.0263 , cos_loss :  0.9156 , classification_loss : 0.1106


 26%|██▋       | 53/200 [03:02<08:17,  3.39s/it]

Epoch: 53 , valid loss: 1.0589 , cos_loss :  0.9143 , classification_loss : 0.1446
Accuracy: 0.9478 , AUC:  0.7861 , F1:  0.5882 , Precision:  0.9375 , recall:  0.4286
Epoch: 54 , Total loss: 1.0251 , cos_loss :  0.9158 , classification_loss : 0.1093


 27%|██▋       | 54/200 [03:05<08:18,  3.42s/it]

Epoch: 54 , valid loss: 1.0638 , cos_loss :  0.914 , classification_loss : 0.1498
Accuracy: 0.9502 , AUC:  0.7746 , F1:  0.6 , Precision:  1.0 , recall:  0.4286
Epoch: 55 , Total loss: 1.0232 , cos_loss :  0.9156 , classification_loss : 0.1076


 28%|██▊       | 55/200 [03:09<08:12,  3.40s/it]

Epoch: 55 , valid loss: 1.0622 , cos_loss :  0.9144 , classification_loss : 0.1478
Accuracy: 0.9478 , AUC:  0.7857 , F1:  0.5882 , Precision:  0.9375 , recall:  0.4286
Epoch: 56 , Total loss: 1.0213 , cos_loss :  0.9154 , classification_loss : 0.1059


 28%|██▊       | 56/200 [03:12<08:12,  3.42s/it]

Epoch: 56 , valid loss: 1.0652 , cos_loss :  0.9141 , classification_loss : 0.1511
Accuracy: 0.9465 , AUC:  0.771 , F1:  0.5825 , Precision:  0.9091 , recall:  0.4286
Epoch: 57 , Total loss: 1.0199 , cos_loss :  0.9155 , classification_loss : 0.1044


 28%|██▊       | 57/200 [03:15<08:06,  3.40s/it]

Epoch: 57 , valid loss: 1.0709 , cos_loss :  0.9139 , classification_loss : 0.157
Accuracy: 0.9478 , AUC:  0.7829 , F1:  0.5714 , Precision:  1.0 , recall:  0.4
Epoch: 58 , Total loss: 1.0283 , cos_loss :  0.9158 , classification_loss : 0.1125


 29%|██▉       | 58/200 [03:19<08:01,  3.39s/it]

Epoch: 58 , valid loss: 1.0626 , cos_loss :  0.9139 , classification_loss : 0.1487
Accuracy: 0.949 , AUC:  0.7796 , F1:  0.5941 , Precision:  0.9677 , recall:  0.4286
Epoch: 59 , Total loss: 1.0247 , cos_loss :  0.9156 , classification_loss : 0.1091


 30%|██▉       | 59/200 [03:22<08:02,  3.42s/it]

Epoch: 59 , valid loss: 1.0595 , cos_loss :  0.914 , classification_loss : 0.1455
Accuracy: 0.949 , AUC:  0.7845 , F1:  0.6019 , Precision:  0.9394 , recall:  0.4429
Epoch: 60 , Total loss: 1.0222 , cos_loss :  0.9156 , classification_loss : 0.1066


 30%|███       | 60/200 [03:26<07:57,  3.41s/it]

Epoch: 60 , valid loss: 1.0663 , cos_loss :  0.9142 , classification_loss : 0.1521
Accuracy: 0.949 , AUC:  0.7744 , F1:  0.5941 , Precision:  0.9677 , recall:  0.4286
Epoch: 61 , Total loss: 1.0227 , cos_loss :  0.9158 , classification_loss : 0.1069


 30%|███       | 61/200 [03:29<07:51,  3.39s/it]

Epoch: 61 , valid loss: 1.0632 , cos_loss :  0.914 , classification_loss : 0.1492
Accuracy: 0.949 , AUC:  0.7802 , F1:  0.5941 , Precision:  0.9677 , recall:  0.4286
Epoch: 62 , Total loss: 1.0187 , cos_loss :  0.9156 , classification_loss : 0.1031


 31%|███       | 62/200 [03:33<07:53,  3.43s/it]

Epoch: 62 , valid loss: 1.0618 , cos_loss :  0.9139 , classification_loss : 0.1479
Accuracy: 0.949 , AUC:  0.7805 , F1:  0.5941 , Precision:  0.9677 , recall:  0.4286
Epoch: 63 , Total loss: 1.0234 , cos_loss :  0.9155 , classification_loss : 0.108


 32%|███▏      | 63/200 [03:36<07:48,  3.42s/it]

Epoch: 63 , valid loss: 1.0673 , cos_loss :  0.9139 , classification_loss : 0.1534
Accuracy: 0.9502 , AUC:  0.7791 , F1:  0.6 , Precision:  1.0 , recall:  0.4286
Epoch: 64 , Total loss: 1.0219 , cos_loss :  0.9156 , classification_loss : 0.1063


 32%|███▏      | 64/200 [03:39<07:41,  3.39s/it]

Epoch: 64 , valid loss: 1.061 , cos_loss :  0.9139 , classification_loss : 0.1471
Accuracy: 0.949 , AUC:  0.7861 , F1:  0.5941 , Precision:  0.9677 , recall:  0.4286
Epoch: 65 , Total loss: 1.0185 , cos_loss :  0.9157 , classification_loss : 0.1027


 32%|███▎      | 65/200 [03:43<07:42,  3.42s/it]

Epoch: 65 , valid loss: 1.0651 , cos_loss :  0.9139 , classification_loss : 0.1512
Accuracy: 0.949 , AUC:  0.7909 , F1:  0.5941 , Precision:  0.9677 , recall:  0.4286
Epoch: 66 , Total loss: 1.0174 , cos_loss :  0.9156 , classification_loss : 0.1018


 33%|███▎      | 66/200 [03:46<07:35,  3.40s/it]

Epoch: 66 , valid loss: 1.0631 , cos_loss :  0.9139 , classification_loss : 0.1493
Accuracy: 0.9502 , AUC:  0.7917 , F1:  0.6078 , Precision:  0.9688 , recall:  0.4429
Epoch: 67 , Total loss: 1.02 , cos_loss :  0.9156 , classification_loss : 0.1044


 34%|███▎      | 67/200 [03:49<07:29,  3.38s/it]

Epoch: 67 , valid loss: 1.0649 , cos_loss :  0.9138 , classification_loss : 0.1511
Accuracy: 0.949 , AUC:  0.7847 , F1:  0.5941 , Precision:  0.9677 , recall:  0.4286
Epoch: 68 , Total loss: 1.0215 , cos_loss :  0.9153 , classification_loss : 0.1062


 34%|███▍      | 68/200 [03:53<07:29,  3.40s/it]

Epoch: 68 , valid loss: 1.0678 , cos_loss :  0.9141 , classification_loss : 0.1538
Accuracy: 0.9502 , AUC:  0.7856 , F1:  0.6 , Precision:  1.0 , recall:  0.4286
Epoch: 69 , Total loss: 1.0214 , cos_loss :  0.916 , classification_loss : 0.1054


 34%|███▍      | 69/200 [03:56<07:23,  3.39s/it]

Epoch: 69 , valid loss: 1.0625 , cos_loss :  0.9141 , classification_loss : 0.1484
Accuracy: 0.9465 , AUC:  0.7903 , F1:  0.5905 , Precision:  0.8857 , recall:  0.4429
Epoch: 70 , Total loss: 1.0211 , cos_loss :  0.9158 , classification_loss : 0.1053


 35%|███▌      | 70/200 [04:00<07:19,  3.38s/it]

Epoch: 70 , valid loss: 1.0623 , cos_loss :  0.9138 , classification_loss : 0.1485
Accuracy: 0.949 , AUC:  0.7881 , F1:  0.6019 , Precision:  0.9394 , recall:  0.4429
Epoch: 71 , Total loss: 1.0174 , cos_loss :  0.9156 , classification_loss : 0.1019


 36%|███▌      | 71/200 [04:03<07:18,  3.40s/it]

Epoch: 71 , valid loss: 1.0652 , cos_loss :  0.914 , classification_loss : 0.1512
Accuracy: 0.949 , AUC:  0.7815 , F1:  0.6019 , Precision:  0.9394 , recall:  0.4429
Epoch: 72 , Total loss: 1.0177 , cos_loss :  0.9156 , classification_loss : 0.1021


 36%|███▌      | 72/200 [04:06<07:13,  3.39s/it]

Epoch: 72 , valid loss: 1.061 , cos_loss :  0.9137 , classification_loss : 0.1473
Accuracy: 0.949 , AUC:  0.7975 , F1:  0.5941 , Precision:  0.9677 , recall:  0.4286
Epoch: 73 , Total loss: 1.0151 , cos_loss :  0.9156 , classification_loss : 0.0995


 36%|███▋      | 73/200 [04:10<07:08,  3.38s/it]

Epoch: 73 , valid loss: 1.0592 , cos_loss :  0.9139 , classification_loss : 0.1454
Accuracy: 0.949 , AUC:  0.7931 , F1:  0.6019 , Precision:  0.9394 , recall:  0.4429
Epoch: 74 , Total loss: 1.0183 , cos_loss :  0.9155 , classification_loss : 0.1027


 37%|███▋      | 74/200 [04:13<07:08,  3.40s/it]

Epoch: 74 , valid loss: 1.0641 , cos_loss :  0.9137 , classification_loss : 0.1504
Accuracy: 0.9502 , AUC:  0.7854 , F1:  0.6078 , Precision:  0.9688 , recall:  0.4429
Epoch: 75 , Total loss: 1.0197 , cos_loss :  0.9153 , classification_loss : 0.1043


 38%|███▊      | 75/200 [04:17<07:08,  3.43s/it]

Epoch: 75 , valid loss: 1.0652 , cos_loss :  0.9138 , classification_loss : 0.1514
Accuracy: 0.949 , AUC:  0.7919 , F1:  0.5941 , Precision:  0.9677 , recall:  0.4286
Epoch: 76 , Total loss: 1.0134 , cos_loss :  0.9156 , classification_loss : 0.0977


 38%|███▊      | 76/200 [04:20<07:06,  3.44s/it]

Epoch: 76 , valid loss: 1.0624 , cos_loss :  0.9138 , classification_loss : 0.1486
Accuracy: 0.949 , AUC:  0.7899 , F1:  0.5941 , Precision:  0.9677 , recall:  0.4286
Epoch: 77 , Total loss: 1.0162 , cos_loss :  0.9153 , classification_loss : 0.1009


 38%|███▊      | 77/200 [04:24<07:00,  3.42s/it]

Epoch: 77 , valid loss: 1.0604 , cos_loss :  0.9138 , classification_loss : 0.1466
Accuracy: 0.949 , AUC:  0.798 , F1:  0.5941 , Precision:  0.9677 , recall:  0.4286
Epoch: 78 , Total loss: 1.015 , cos_loss :  0.9156 , classification_loss : 0.0994


 39%|███▉      | 78/200 [04:27<06:56,  3.42s/it]

Epoch: 78 , valid loss: 1.0612 , cos_loss :  0.9138 , classification_loss : 0.1474
Accuracy: 0.949 , AUC:  0.7987 , F1:  0.6019 , Precision:  0.9394 , recall:  0.4429
Epoch: 79 , Total loss: 1.0155 , cos_loss :  0.9157 , classification_loss : 0.0998


 40%|███▉      | 79/200 [04:31<07:01,  3.48s/it]

Epoch: 79 , valid loss: 1.0645 , cos_loss :  0.914 , classification_loss : 0.1505
Accuracy: 0.9465 , AUC:  0.7978 , F1:  0.5743 , Precision:  0.9355 , recall:  0.4143
Epoch: 80 , Total loss: 1.0178 , cos_loss :  0.9156 , classification_loss : 0.1022


 40%|████      | 80/200 [04:34<07:07,  3.56s/it]

Epoch: 80 , valid loss: 1.0627 , cos_loss :  0.9136 , classification_loss : 0.149
Accuracy: 0.949 , AUC:  0.793 , F1:  0.6019 , Precision:  0.9394 , recall:  0.4429
Epoch: 81 , Total loss: 1.0176 , cos_loss :  0.9154 , classification_loss : 0.1022


 40%|████      | 81/200 [04:38<06:56,  3.50s/it]

Epoch: 81 , valid loss: 1.0688 , cos_loss :  0.9138 , classification_loss : 0.155
Accuracy: 0.9415 , AUC:  0.7896 , F1:  0.5688 , Precision:  0.7949 , recall:  0.4429
Epoch: 82 , Total loss: 1.0189 , cos_loss :  0.9156 , classification_loss : 0.1033


 41%|████      | 82/200 [04:41<06:52,  3.50s/it]

Epoch: 82 , valid loss: 1.0649 , cos_loss :  0.9137 , classification_loss : 0.1511
Accuracy: 0.949 , AUC:  0.7947 , F1:  0.5941 , Precision:  0.9677 , recall:  0.4286
Epoch: 83 , Total loss: 1.0146 , cos_loss :  0.9156 , classification_loss : 0.099


 42%|████▏     | 83/200 [04:45<06:45,  3.46s/it]

Epoch: 83 , valid loss: 1.0676 , cos_loss :  0.9136 , classification_loss : 0.154
Accuracy: 0.9478 , AUC:  0.782 , F1:  0.5962 , Precision:  0.9118 , recall:  0.4429
Epoch: 84 , Total loss: 1.0244 , cos_loss :  0.9151 , classification_loss : 0.1093


 42%|████▏     | 84/200 [04:48<06:39,  3.44s/it]

Epoch: 84 , valid loss: 1.0646 , cos_loss :  0.9131 , classification_loss : 0.1515
Accuracy: 0.9465 , AUC:  0.7934 , F1:  0.5743 , Precision:  0.9355 , recall:  0.4143
Epoch: 85 , Total loss: 1.0174 , cos_loss :  0.9154 , classification_loss : 0.1019


 42%|████▎     | 85/200 [04:51<06:37,  3.46s/it]

Epoch: 85 , valid loss: 1.0645 , cos_loss :  0.9139 , classification_loss : 0.1506
Accuracy: 0.949 , AUC:  0.7882 , F1:  0.6019 , Precision:  0.9394 , recall:  0.4429
Epoch: 86 , Total loss: 1.0207 , cos_loss :  0.9156 , classification_loss : 0.1052


 43%|████▎     | 86/200 [04:55<06:30,  3.42s/it]

Epoch: 86 , valid loss: 1.0637 , cos_loss :  0.9137 , classification_loss : 0.15
Accuracy: 0.9465 , AUC:  0.7829 , F1:  0.5825 , Precision:  0.9091 , recall:  0.4286
Epoch: 87 , Total loss: 1.0185 , cos_loss :  0.9153 , classification_loss : 0.1032


 44%|████▎     | 87/200 [04:58<06:28,  3.44s/it]

Epoch: 87 , valid loss: 1.0655 , cos_loss :  0.9136 , classification_loss : 0.1519
Accuracy: 0.9428 , AUC:  0.79 , F1:  0.5741 , Precision:  0.8158 , recall:  0.4429
Epoch: 88 , Total loss: 1.0137 , cos_loss :  0.9153 , classification_loss : 0.0984


 44%|████▍     | 88/200 [05:02<06:22,  3.42s/it]

Epoch: 88 , valid loss: 1.061 , cos_loss :  0.9137 , classification_loss : 0.1473
Accuracy: 0.9453 , AUC:  0.7977 , F1:  0.5769 , Precision:  0.8824 , recall:  0.4286
Epoch: 89 , Total loss: 1.0167 , cos_loss :  0.9154 , classification_loss : 0.1013


 44%|████▍     | 89/200 [05:05<06:17,  3.40s/it]

Epoch: 89 , valid loss: 1.0635 , cos_loss :  0.9133 , classification_loss : 0.1502
Accuracy: 0.9465 , AUC:  0.7976 , F1:  0.5825 , Precision:  0.9091 , recall:  0.4286
Epoch: 90 , Total loss: 1.0181 , cos_loss :  0.9153 , classification_loss : 0.1027


 45%|████▌     | 90/200 [05:08<06:16,  3.43s/it]

Epoch: 90 , valid loss: 1.0685 , cos_loss :  0.9138 , classification_loss : 0.1547
Accuracy: 0.949 , AUC:  0.7851 , F1:  0.5941 , Precision:  0.9677 , recall:  0.4286
Epoch: 91 , Total loss: 1.0128 , cos_loss :  0.9156 , classification_loss : 0.0972


 46%|████▌     | 91/200 [05:12<06:10,  3.40s/it]

Epoch: 91 , valid loss: 1.0649 , cos_loss :  0.9136 , classification_loss : 0.1513
Accuracy: 0.949 , AUC:  0.7905 , F1:  0.6019 , Precision:  0.9394 , recall:  0.4429
Epoch: 92 , Total loss: 1.0119 , cos_loss :  0.9155 , classification_loss : 0.0965


 46%|████▌     | 92/200 [05:15<06:05,  3.39s/it]

Epoch: 92 , valid loss: 1.0655 , cos_loss :  0.9137 , classification_loss : 0.1518
Accuracy: 0.949 , AUC:  0.7889 , F1:  0.6019 , Precision:  0.9394 , recall:  0.4429
Epoch: 93 , Total loss: 1.011 , cos_loss :  0.9153 , classification_loss : 0.0956


 46%|████▋     | 93/200 [05:19<06:05,  3.42s/it]

Epoch: 93 , valid loss: 1.0684 , cos_loss :  0.9136 , classification_loss : 0.1548
Accuracy: 0.9465 , AUC:  0.7842 , F1:  0.5905 , Precision:  0.8857 , recall:  0.4429
Epoch: 94 , Total loss: 1.012 , cos_loss :  0.9153 , classification_loss : 0.0967


 47%|████▋     | 94/200 [05:22<06:00,  3.40s/it]

Epoch: 94 , valid loss: 1.0691 , cos_loss :  0.9138 , classification_loss : 0.1553
Accuracy: 0.949 , AUC:  0.7879 , F1:  0.6019 , Precision:  0.9394 , recall:  0.4429
Epoch: 95 , Total loss: 1.0145 , cos_loss :  0.9154 , classification_loss : 0.0991


 48%|████▊     | 95/200 [05:25<05:55,  3.39s/it]

Epoch: 95 , valid loss: 1.0649 , cos_loss :  0.9137 , classification_loss : 0.1513
Accuracy: 0.944 , AUC:  0.7927 , F1:  0.5714 , Precision:  0.8571 , recall:  0.4286
Epoch: 96 , Total loss: 1.0128 , cos_loss :  0.9154 , classification_loss : 0.0975


 48%|████▊     | 95/200 [05:29<06:04,  3.47s/it]

Epoch: 96 , valid loss: 1.0699 , cos_loss :  0.9134 , classification_loss : 0.1565
Accuracy: 0.9453 , AUC:  0.7846 , F1:  0.5849 , Precision:  0.8611 , recall:  0.4429
Early stopping triggered at epoch 65
Best Combined Score (AUC): 0.6078





In [34]:
# model_1_save_path = './results/20240423/model_lr0.001_classi1_dim64_hid256_layer5_epoch35_no_center_total_label_time_feature_0.5_180651.pth'

In [35]:
total_te_loss = 0
total_te_cos_loss = 0
total_te_classi_loss = 0
total_te_const_loss = 0

te_labels_list = []
te_predictions_list = []
te_probabilities_list = []

model_1_path = model_1_save_path

model_1.load_state_dict(torch.load(model_1_path))

model_1.eval()

with torch.no_grad():
    for batch_data in tqdm(test_loader):
        te_labels = batch_data['label'].to(device)
        te_output_1, te_next_visit_output_1, te_final_visit_classification_1, te_final_visit_1 = model_1(batch_data)
        y_te = torch.ones(te_output_1.size(0), dtype=torch.float, device=device)

        te_cosine_loss_mean_1 = cosine_embedding_loss(te_output_1, te_next_visit_output_1, y_te)
        classification_loss_te_1 = criterion(te_final_visit_classification_1.squeeze(), te_labels.long())
        te_loss = (cos_lambda * te_cosine_loss_mean_1) + (classi_lambda * classification_loss_te_1)
        
        if batch_idx == 0:
            test_visit_embedding = te_final_visit_1.detach().cpu().numpy()
            test_label_numpy = te_labels.detach().cpu().numpy()
        else:
            add_visit_embedding = te_final_visit_1.detach().cpu().numpy()
            add_label_numpy = te_labels.detach().cpu().numpy()
            test_visit_embedding = np.concatenate((test_visit_embedding, add_visit_embedding), axis=0)
            test_label_numpy = np.concatenate((test_label_numpy, add_label_numpy))
        

        total_te_loss += te_loss.item()
        total_te_cos_loss += (te_cosine_loss_mean_1.item())
        total_te_classi_loss += (classification_loss_te_1.item())

        te_probs = F.softmax(te_final_visit_classification_1)
        te_predictions = torch.max(te_probs, 1)[1].view((len(te_labels),))
        
        te_labels_list.extend(te_labels.view(-1).cpu().numpy())
        te_predictions_list.extend(te_predictions.cpu().numpy())
        te_probabilities_list.extend(te_probs[:,1].cpu().numpy())

    avg_te_loss = total_te_loss / len(test_loader)
    avg_te_cos_loss = total_te_cos_loss / len(test_loader)
    avg_te_classi_loss = total_te_classi_loss / len(test_loader)

  0%|          | 0/2 [00:00<?, ?it/s]


NameError: name 'batch_idx' is not defined

In [None]:
# 성능 지표 계산
accuracy = accuracy_score(te_labels_list, te_predictions_list)
auc = roc_auc_score(te_labels_list, te_probabilities_list)
f1 = f1_score(te_labels_list, te_predictions_list)
precision = precision_score(te_labels_list,te_predictions_list)
recall = recall_score(te_labels_list, te_predictions_list)

print("test loss:", round(avg_te_loss, 4), ", cos_loss : ", round(avg_te_cos_loss, 4),
      ", classification_loss :",round(avg_te_classi_loss, 4), 
      # ", contrastive_loss : ", round(avg_te_const_loss,4)
     )
print("Accuracy:", round(accuracy,4), ", AUC: ", round(auc,4), ", F1: ", round(f1,4), ", Precision: ", round(precision,4), ", recall: ", round(recall,4))

np.unique(te_labels_list,  return_counts=True)

In [None]:
confusion_matrix(te_labels_list, te_predictions_list)

In [81]:
print(model_1_path)

results/20240424/model_lr0.001_classi1_dim64_hid256_layer5_epoch35_no_center_total_label_time_feature_0.5_140528.pth


In [85]:
total_te_loss = 0
total_te_cos_loss = 0
total_te_classi_loss = 0
total_te_const_loss = 0

te_labels_list = []
te_predictions_list = []
te_probabilities_list = []

model_1_path = model_1_save_path

model_1.load_state_dict(torch.load(model_1_path))

model_1.eval()

with torch.no_grad():
    for batch_data in tqdm(test_loader):
        te_labels = batch_data['label'].to(device)
        te_output_1, te_next_visit_output_1, te_final_visit_classification_1, te_final_visit_1 = model_1(batch_data)
        y_te = torch.ones(te_output_1.size(0), dtype=torch.float, device=device)

        te_cosine_loss_mean_1 = cosine_embedding_loss(te_output_1, te_next_visit_output_1, y_te)
        classification_loss_te_1 = criterion(te_final_visit_classification_1.squeeze(), te_labels.long())
        te_loss = (cos_lambda * te_cosine_loss_mean_1) + (classi_lambda * classification_loss_te_1) 

        total_te_loss += te_loss.item()
        total_te_cos_loss += (te_cosine_loss_mean_1.item())
        total_te_classi_loss += (classification_loss_te_1.item())

        te_probs = F.softmax(te_final_visit_classification_1)
        te_predictions = torch.max(te_probs, 1)[1].view((len(te_labels),))
        
        te_labels_list.extend(te_labels.view(-1).cpu().numpy())
        te_predictions_list.extend(te_predictions.cpu().numpy())
        te_probabilities_list.extend(te_probs[:,1].cpu().numpy())

    avg_te_loss = total_te_loss / len(test_loader)
    avg_te_cos_loss = total_te_cos_loss / len(test_loader)
    avg_te_classi_loss = total_te_classi_loss / len(test_loader)

100%|██████████| 2/2 [00:00<00:00,  5.32it/s]


In [86]:
# 성능 지표 계산
accuracy = accuracy_score(te_labels_list, te_predictions_list)
auc = roc_auc_score(te_labels_list, te_probabilities_list)
f1 = f1_score(te_labels_list, te_predictions_list)
precision = precision_score(te_labels_list,te_predictions_list)
recall = recall_score(te_labels_list, te_predictions_list)

print("test loss:", round(avg_te_loss, 4), ", cos_loss : ", round(avg_te_cos_loss, 4),
      ", classification_loss :",round(avg_te_classi_loss, 4), 
      # ", contrastive_loss : ", round(avg_te_const_loss,4)
     )
print("Accuracy:", round(accuracy,4), ", AUC: ", round(auc,4), ", F1: ", round(f1,4), ", Precision: ", round(precision,4), ", recall: ", round(recall,4))

np.unique(te_labels_list,  return_counts=True)

test loss: 1.0313 , cos_loss :  0.9141 , classification_loss : 0.1172
Accuracy: 0.9614 , AUC:  0.8807 , F1:  0.7156 , Precision:  1.0 , recall:  0.5571


(array([0, 1]), array([734,  70]))

In [87]:
log_path = './logs'
logs = f'model_lr{lr}_classi{classi_lambda}_dim{ninp}_hid{nhid}_layer{nlayer}_epoch{{epoch}}_{{model}}_{{pe}}_{model_time}.txt'
os.makedirs(os.path.join(log_path, date_dir), exist_ok=True)

results = [accuracy, auc, f1, precision, recall]
cm = confusion_matrix(te_labels_list, te_predictions_list)
with open(os.path.join(log_path, date_dir, logs.format(epoch=best_epoch, model=model_name, pe=pe)), 'w') as f:
    f.write(logs.format(epoch=best_epoch, model=model_name, pe=pe))
    f.write('\n')
    f.write(str(results))
    f.write('\n')
    f.write(str(cm))

In [88]:
confusion_matrix(te_labels_list, te_predictions_list)

array([[734,   0],
       [ 31,  39]])