In [1]:
import numpy as np
import pandas as pd
import os 
import pickle
import sys
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
from tqdm import tqdm
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix
import warnings

torch.set_printoptions(profile="full")
np.set_printoptions(threshold=sys.maxsize)
warnings.filterwarnings("ignore")

In [2]:
from src.attn import FixedPositionalEncoding, LearnablePositionalEncoding, TemporalEmbedding, MultiheadAttention, Decoder
from src.loss import ContrastiveLoss, FocalLoss

In [3]:
from datetime import datetime

In [4]:
seed = 777
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(seed)

In [5]:
path = '/data/notebook/shared/MIMIC-IV'

In [6]:
with open(os.path.join(path, 'dict_types_nomedi_mimic_240408_clinic_3_years.pkl'), 'rb') as f:
    dtype_dict = pickle.load(f)
f.close()

with open(os.path.join('./data', 'preprocessed_nomedi_240423_clinic_3_years.pkl'), 'rb') as f:
    data_dict_d = pickle.load(f)
f.close()

In [7]:
remain_year_clinical = pd.read_csv(os.path.join(path, "240408_remain_year_clinical_diag_seq_4.csv"))
del remain_year_clinical['clinical_label']

In [8]:
len(dtype_dict)

12702

In [9]:
labels = []
code_labels = []
length_list = []
clinical_labels = []
code_length_list = []
for sample_id, visits in tqdm(data_dict_d.items()):
    # 레이블 추가
    label = visits['label']
    code_label = visits['code_label']
    clinical_label = visits['clinical_label']
    labels.append(label)
    code_labels.append(code_label)
    clinical_labels.append(clinical_labels)

100%|██████████| 8037/8037 [00:00<00:00, 571620.79it/s]


In [10]:
train_indices, test_indices, train_y, test_y = train_test_split(list(data_dict_d.keys()), labels, test_size=0.1, random_state=777, stratify=labels)
train_indices, valid_indices, valid_y, valid_y = train_test_split(train_indices, train_y, test_size=(len(test_indices)/len(train_indices)), random_state=777, stratify=train_y) 

In [11]:
train_data = {}
valid_data = {}
test_data = {}
for sample in tqdm(train_indices):
    train_data[sample] = data_dict_d[sample]

for sample in tqdm(valid_indices):
    valid_data[sample] = data_dict_d[sample]

for sample in tqdm(test_indices):
    test_data[sample] = data_dict_d[sample]

100%|██████████| 6429/6429 [00:00<00:00, 1319429.49it/s]
100%|██████████| 804/804 [00:00<00:00, 1091333.47it/s]
100%|██████████| 804/804 [00:00<00:00, 1262531.04it/s]


In [12]:
train_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(train_indices)].reset_index(drop=True)
valid_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(valid_indices)].reset_index(drop=True)
test_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(test_indices)].reset_index(drop=True)

In [36]:
class CustomDataset(Dataset):
    def __init__(self, data, clinical_df, scaler, mode='train'):       
        self.keys = list(data.keys())  # 딕셔너리의 키 목록 저장
        self.data = data  # 딕셔너리에서 데이터만 추출하여 저장
        self.scaler = scaler
        if mode == 'train':
            scaled_data = self.scaler.fit_transform(clinical_df.iloc[:, 2:])
            scaled_clinical_df = pd.DataFrame(scaled_data, columns = clinical_df.iloc[:, 2:].columns)
            self.scaled_clinical_df = pd.concat([clinical_df.iloc[:, :2], scaled_clinical_df], axis=1)
        else:
            scaled_data = self.scaler.transform(clinical_df.iloc[:, 2:])
            scaled_clinical_df = pd.DataFrame(scaled_data, columns = clinical_df.iloc[:, 2:].columns)
            self.scaled_clinical_df = pd.concat([clinical_df.iloc[:, :2], scaled_clinical_df], axis=1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # x는 현재 방문까지의 모든 방문 데이터, y는 다음 방문 데이터
        padding_temp = torch.zeros((1, len(self.data[self.keys[idx]]['code_index'][0])), dtype=torch.long)
        
        origin_visit = torch.tensor(self.data[self.keys[idx]]['code_index'], dtype=torch.long)
        origin_mask = torch.tensor(self.data[self.keys[idx]]['seq_mask'], dtype=torch.long)
        origin_mask_final = torch.tensor(self.data[self.keys[idx]]['seq_mask_final'], dtype=torch.long)
        origin_mask_code = torch.tensor(self.data[self.keys[idx]]['seq_mask_code'], dtype=torch.long)
        
        next_visit = torch.cat((origin_visit[1:], padding_temp), dim=0)
        next_mask =torch.cat((origin_mask[1:], torch.tensor([0], dtype=torch.long)), dim=0)
        next_mask_code = torch.cat((origin_mask_code[1:], padding_temp), dim=0)
        
        clinical_data = self.scaled_clinical_df.loc[self.scaled_clinical_df['subject_id'] == self.keys[idx], self.scaled_clinical_df.columns[2:]].values
        clinical_data = torch.tensor(clinical_data, dtype=torch.float)
        visit_index = torch.tensor(self.data[self.keys[idx]]['year_onehot'], dtype=torch.long)
        last_visit_index = torch.tensor(self.data[self.keys[idx]]['last_year_onehot'], dtype=torch.float)
        time_feature = torch.tensor(self.data[self.keys[idx]]['time_feature'], dtype=torch.long)
        label_per_sample = self.data[self.keys[idx]]['code_label']
        key_per_sample = self.keys[idx]  # 해당 샘플의 키
        # 키 값도 함께 반환
        return {'sample_id': key_per_sample, 'origin_visit': origin_visit, 'next_visit': next_visit,\
                'origin_mask': origin_mask, 'origin_mask_final': origin_mask_final, 'next_mask': next_mask, \
                'origin_mask_code': origin_mask_code, 'next_mask_code': next_mask_code, 'time_feature': time_feature, \
                'clinical_data': clinical_data, 'visit_index': visit_index, 'last_visit_index': last_visit_index, \
                'label': label_per_sample}

In [37]:
scaler = StandardScaler()
train_dataset = CustomDataset(train_data, train_clinical, scaler, mode='train')
valid_dataset = CustomDataset(valid_data, valid_clinical, scaler, mode='valid')
test_dataset = CustomDataset(test_data, test_clinical, scaler, mode='test')

In [38]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=512, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

In [39]:
class CustomTransformerModel(nn.Module):
    def __init__(self, code_size, ninp, nhead, nhid, nlayers, dropout=0.5, device='cuda:0', pe='fixed'):
        super(CustomTransformerModel, self).__init__()
        self.ninp = ninp  # 축소된 차원 및 Transformer 입력 차원
        self.device = device
        # 차원 축소를 위한 선형 레이어
        self.pre_embedding = nn.Embedding(code_size, self.ninp)
        self.pe = pe
        if self.pe == 'fixed':
            self.pos_encoder = FixedPositionalEncoding(ninp, dropout)
        elif self.pe == 'time_feature':
            self.pos_encoder = TemporalEmbedding(ninp, embed_type='embed', dropout=dropout)
        else:
            self.pos_encoder = LearnablePositionalEncoding(ninp, dropout)
            
        self.transformer_decoder = Decoder(ninp, nhead, nhid, nlayers, dropout)
        self.decoder = nn.Linear(ninp, ninp, bias=False)  # 최종 출력 차원을 설정 (여기서는 ninp로 설정)
        self.classification_layer = nn.Linear(ninp*2, 2)
        self.clinical_transform = nn.Linear(18, ninp) 
        self.cross_attn = MultiheadAttention(ninp, nhead, dropout=0.3)
        self.init_weights()
        
    def init_weights(self):
        nn.init.xavier_uniform_(self.pre_embedding.weight)
        nn.init.xavier_uniform_(self.decoder.weight)
        nn.init.xavier_uniform_(self.classification_layer.weight)
    
    def forward(self, batch_data):
        # 차원 축소
        origin_visit = batch_data['origin_visit'].to(self.device)
        next_visit = batch_data['next_visit'].to(self.device)
        clinical_tensor = batch_data['clinical_data'].to(self.device)
        mask_code = batch_data['origin_mask_code'].unsqueeze(3).to(self.device)
        next_mask_code = batch_data['next_mask_code'].unsqueeze(3).to(self.device)
        mask_final = batch_data['origin_mask_final'].unsqueeze(2).to(self.device)
        mask_final_year = batch_data['last_visit_index'].to(self.device)
                
        origin_visit_emb = (self.pre_embedding(origin_visit) * mask_code).sum(dim=2)
        next_visit_emb = (self.pre_embedding(next_visit) * next_mask_code).sum(dim=2)
        # 위치 인코딩 및 Transformer 디코더 적용
        if self.pe == 'time_feature':
            time_feature = batch_data['time_feature'].to(self.device)
            src = self.pos_encoder(origin_visit_emb, time_feature)
        else:
            src = self.pos_encoder(origin_visit_emb)
        
        seq_len = src.shape[1]
        src_mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=self.device), diagonal=1)
        output = self.transformer_decoder(src, attention_mask=src_mask)
        output = self.decoder(output)
        
        output_batch, output_visit_num, output_dim = output.size()
        next_visit_output_batch, next_visit_output_num, next_visit_output_dim = next_visit_emb.size()
        
        # print(output[:, :torch.,:])
        temp_output = output.reshape(-1, output_dim)
        temp_next_visit = next_visit_emb.reshape(-1, next_visit_output_dim)
        final_visit = (output * mask_final).sum(dim=1)
        year_emb = torch.bmm(output.transpose(1,2), mask_final_year).transpose(1,2)[:,:2,:]
        
        transformed_clinical = self.clinical_transform(clinical_tensor)
        mixed_output, mixed_cross_attn = self.cross_attn(year_emb, transformed_clinical, transformed_clinical)
        mixed_output = mixed_output[:,-1,:]
        mixed_final_emb = torch.cat((final_visit, mixed_output), dim=-1)
        classification_result = self.classification_layer(mixed_final_emb)    
        return temp_output, temp_next_visit, classification_result, final_visit

In [40]:
lr = 0.001
ninp = 64
nhid = 256
nlayer = 6
gamma = 0.5
model_name = 'no_center_total_label'
pe = 'time_feature'

In [41]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_1 = CustomTransformerModel(len(dtype_dict), ninp=ninp, nhead=8, nhid=nhid, nlayers=nlayer, dropout=0.2, device=device, pe=pe).to(device)
# model_2 = CustomTransformerModel(len(dtype_dict), ninp=ninp, nhead=8, nhid=nhid, nlayers=nlayer, dropout=0.6, device=device, pe='fixed').to(device)

optimizer = torch.optim.Adam(model_1.parameters(), lr=lr, weight_decay=1e-5)

# criterion = nn.BCEWithLogitsLoss()
criterion = FocalLoss(2, gamma=gamma)
# cosine_loss = CosineSimilarityLoss()
cosine_embedding_loss = nn.CosineEmbeddingLoss()
# const_loss = ContrastiveLoss(temperature=0.05)

In [42]:
# 얼리 스타핑 설정
num_epochs = 200
patience = 30
best_loss = float("inf")
counter = 0
epoch_temp = 0

best_f1 = 0.0
best_combined_score = 0.0
best_epoch = 0

view_total_loss = []
view_cos_loss = []
view_classi_loss = []

cos_lambda = 1
classi_lambda = 1


for epoch in tqdm(range(num_epochs)):
    model_1.train()

    total_train_loss = 0
    total_cos_loss = 0
    total_classi_loss = 0
    total_const_loss = 0
    # key_per_sample, origin_visit, origin_mask, origin_mask_code, origin_mask_final, next_visit, next_mask, next_mask_code, next_mask_final, label_per_sample
    for batch_data in train_loader:
        tr_labels = batch_data['label'].to(device)
        optimizer.zero_grad()
        output_1, next_visit_output_1, final_visit_classification_1, final_visit_1 = model_1(batch_data)
        y = torch.ones(output_1.size(0), dtype=torch.float, device=device)
        y = y.to(device)   
        cosine_loss_mean_1 = cosine_embedding_loss(output_1, next_visit_output_1, y)
        classification_loss_1 = criterion(final_visit_classification_1.squeeze(), tr_labels.long())
        loss = (cos_lambda * cosine_loss_mean_1) + (classi_lambda * classification_loss_1)

        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        total_cos_loss += (cosine_loss_mean_1.item())
        total_classi_loss += (classification_loss_1.item())
        # total_const_loss += train_const_loss.item()

        view_total_loss.append(total_train_loss)
        view_cos_loss.append(total_cos_loss)
        view_classi_loss.append(total_classi_loss)

    # 평균 손실 계산
    avg_train_loss = total_train_loss / len(train_loader)
    avg_cos_loss = total_cos_loss / len(train_loader)
    avg_classi_loss = total_classi_loss / len(train_loader)

    print("Epoch:", epoch+1, ", Total loss:", round(avg_train_loss, 4), ", cos_loss : ",\
          round(avg_cos_loss, 4), ", classification_loss :",round(avg_classi_loss, 4))

    model_1.eval()
    total_val_loss = 0
    total_val_cos_loss = 0
    total_val_classi_loss = 0
    total_val_const_loss = 0

    val_labels_list = []
    val_predictions_list = []
    val_probabilities_list = []


    with torch.no_grad():
        for batch_data in valid_loader:
            val_labels = batch_data['label'].to(device)
            val_output_1, val_next_visit_output_1, val_final_visit_classification_1, val_final_visit_1 = model_1(batch_data)
            
            y_val = torch.ones(val_output_1.size(0), dtype=torch.float, device=device) 
            val_cosine_loss_mean_1 = cosine_embedding_loss(val_output_1, val_next_visit_output_1, y_val)
            
            classification_loss_val_1 = criterion(val_final_visit_classification_1.squeeze(), val_labels.long())            

            val_loss = (cos_lambda * val_cosine_loss_mean_1) + (classi_lambda * classification_loss_val_1) 
            
            total_val_loss += val_loss.item()
            total_val_cos_loss += (val_cosine_loss_mean_1.item())
            total_val_classi_loss += (classification_loss_val_1.item())
            # total_val_const_loss += val_const_loss.item()

            val_probs = F.softmax(val_final_visit_classification_1)
            val_predictions = torch.max(val_probs, 1)[1].view((len(val_labels),))

            val_labels_list.extend(val_labels.view(-1).cpu().numpy())
            val_predictions_list.extend(val_predictions.cpu().numpy())
            
            ########
            val_probabilities_list.extend(val_probs[:,1].cpu().numpy())
            ######## 
            
            
        avg_val_loss = total_val_loss / len(valid_loader)
        avg_val_cos_loss = total_val_cos_loss / len(valid_loader)
        avg_val_classi_loss = total_val_classi_loss / len(valid_loader)

    # 성능 지표 계산
    accuracy = accuracy_score(val_labels_list, val_predictions_list)
    auc = roc_auc_score(val_labels_list, val_probabilities_list)
    f1 = f1_score(val_labels_list, val_predictions_list)
    precision = precision_score(val_labels_list,val_predictions_list)
    recall = recall_score(val_labels_list, val_predictions_list)

    print("Epoch:", epoch+1, ", valid loss:", round(avg_val_loss, 4), ", cos_loss : ",\
          round(avg_val_cos_loss, 4), ", classification_loss :",round(avg_val_classi_loss, 4))
    print("Accuracy:", round(accuracy,4), ", AUC: ", round(auc,4), ", F1: ", round(f1,4), ", Precision: ", round(precision,4), ", recall: ", round(recall,4))

    current_auc = roc_auc_score(val_labels_list, val_probabilities_list)
    current_f1 = f1_score(val_labels_list, val_predictions_list)
    current_combined_score = current_f1

    # 모델 저장 경로 설정
    model_save_path = f'results'
    date_dir = datetime.today().strftime("%Y%m%d")
    model_time =  datetime.today().strftime("%H%M%S")
    # 학습률과 분류 가중치를 파일명에 포함시키기 위한 문자열 포맷
    model_filename_format = f'model_lr{lr}_classi{classi_lambda}_dim{ninp}_hid{nhid}_layer{nlayer}_epoch{{epoch}}_{{model}}_{{pe}}_{{gamma}}_{model_time}.pth'

    # 모델 저장 폴더가 없으면 생성
    os.makedirs(os.path.join(model_save_path, date_dir), exist_ok=True)

    if current_combined_score > best_combined_score:
        best_combined_score = current_combined_score
        best_epoch = epoch
        counter = 0
        # 모델 저장 경로와 파일명을 결합하여 전체 파일 경로 생성
        model_1_save_path = os.path.join(model_save_path, date_dir, model_filename_format.format(epoch=best_epoch, model=model_name, pe=pe, gamma=gamma))
        # 모델 저장
        torch.save(model_1.state_dict(), f'{model_1_save_path}')
    else:
        counter += 1

    if counter >= patience:
        print("Early stopping triggered at epoch {}".format(best_epoch))
        print(f"Best Combined Score (AUC): {best_combined_score:.4f}")
        break

  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1 , Total loss: 1.4943 , cos_loss :  0.971 , classification_loss : 0.5233


  0%|          | 1/200 [00:04<13:46,  4.15s/it]

Epoch: 1 , valid loss: 1.2802 , cos_loss :  0.9389 , classification_loss : 0.3413
Accuracy: 0.9254 , AUC:  0.6459 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 2 , Total loss: 1.1926 , cos_loss :  0.9322 , classification_loss : 0.2604


  1%|          | 2/200 [00:08<13:13,  4.01s/it]

Epoch: 2 , valid loss: 1.1259 , cos_loss :  0.921 , classification_loss : 0.2049
Accuracy: 0.9254 , AUC:  0.5899 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 3 , Total loss: 1.1332 , cos_loss :  0.9208 , classification_loss : 0.2124


  2%|▏         | 3/200 [00:11<12:38,  3.85s/it]

Epoch: 3 , valid loss: 1.0956 , cos_loss :  0.9157 , classification_loss : 0.1799
Accuracy: 0.9254 , AUC:  0.6071 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 4 , Total loss: 1.1124 , cos_loss :  0.9173 , classification_loss : 0.1951


  2%|▏         | 4/200 [00:15<12:53,  3.95s/it]

Epoch: 4 , valid loss: 1.095 , cos_loss :  0.9137 , classification_loss : 0.1813
Accuracy: 0.9254 , AUC:  0.5872 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 5 , Total loss: 1.1114 , cos_loss :  0.9159 , classification_loss : 0.1955


  2%|▎         | 5/200 [00:20<13:13,  4.07s/it]

Epoch: 5 , valid loss: 1.0921 , cos_loss :  0.9129 , classification_loss : 0.1792
Accuracy: 0.9254 , AUC:  0.613 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 6 , Total loss: 1.109 , cos_loss :  0.9154 , classification_loss : 0.1936


  3%|▎         | 6/200 [00:24<13:14,  4.10s/it]

Epoch: 6 , valid loss: 1.0863 , cos_loss :  0.9124 , classification_loss : 0.1738
Accuracy: 0.9254 , AUC:  0.6405 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 7 , Total loss: 1.1076 , cos_loss :  0.915 , classification_loss : 0.1926


  4%|▎         | 7/200 [00:28<13:22,  4.16s/it]

Epoch: 7 , valid loss: 1.0895 , cos_loss :  0.9121 , classification_loss : 0.1774
Accuracy: 0.9254 , AUC:  0.6134 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 8 , Total loss: 1.1056 , cos_loss :  0.9147 , classification_loss : 0.1908


  4%|▍         | 8/200 [00:32<13:11,  4.12s/it]

Epoch: 8 , valid loss: 1.0849 , cos_loss :  0.9119 , classification_loss : 0.173
Accuracy: 0.9254 , AUC:  0.6641 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 9 , Total loss: 1.1049 , cos_loss :  0.9147 , classification_loss : 0.1903


  4%|▍         | 9/200 [00:36<12:42,  3.99s/it]

Epoch: 9 , valid loss: 1.0936 , cos_loss :  0.9118 , classification_loss : 0.1819
Accuracy: 0.9254 , AUC:  0.6046 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 10 , Total loss: 1.1022 , cos_loss :  0.9144 , classification_loss : 0.1877


  5%|▌         | 10/200 [00:40<12:33,  3.97s/it]

Epoch: 10 , valid loss: 1.0887 , cos_loss :  0.9117 , classification_loss : 0.177
Accuracy: 0.9254 , AUC:  0.6291 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 11 , Total loss: 1.1033 , cos_loss :  0.9146 , classification_loss : 0.1886


  6%|▌         | 11/200 [00:44<12:39,  4.02s/it]

Epoch: 11 , valid loss: 1.0878 , cos_loss :  0.9116 , classification_loss : 0.1763
Accuracy: 0.9254 , AUC:  0.6462 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 12 , Total loss: 1.1075 , cos_loss :  0.9146 , classification_loss : 0.1929


  6%|▌         | 12/200 [00:48<12:36,  4.03s/it]

Epoch: 12 , valid loss: 1.088 , cos_loss :  0.9115 , classification_loss : 0.1765
Accuracy: 0.9254 , AUC:  0.6718 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 13 , Total loss: 1.1064 , cos_loss :  0.9143 , classification_loss : 0.1921


  6%|▋         | 13/200 [00:52<12:31,  4.02s/it]

Epoch: 13 , valid loss: 1.0956 , cos_loss :  0.9115 , classification_loss : 0.1842
Accuracy: 0.9254 , AUC:  0.6031 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 14 , Total loss: 1.1013 , cos_loss :  0.9142 , classification_loss : 0.1871


  7%|▋         | 14/200 [00:56<12:40,  4.09s/it]

Epoch: 14 , valid loss: 1.0852 , cos_loss :  0.9114 , classification_loss : 0.1737
Accuracy: 0.9254 , AUC:  0.6879 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 15 , Total loss: 1.1033 , cos_loss :  0.9142 , classification_loss : 0.189


  8%|▊         | 15/200 [01:00<12:36,  4.09s/it]

Epoch: 15 , valid loss: 1.0858 , cos_loss :  0.9114 , classification_loss : 0.1744
Accuracy: 0.9254 , AUC:  0.6757 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 16 , Total loss: 1.1032 , cos_loss :  0.9145 , classification_loss : 0.1888


  8%|▊         | 16/200 [01:04<12:30,  4.08s/it]

Epoch: 16 , valid loss: 1.0893 , cos_loss :  0.9114 , classification_loss : 0.178
Accuracy: 0.9254 , AUC:  0.6353 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 17 , Total loss: 1.1003 , cos_loss :  0.9145 , classification_loss : 0.1858


  8%|▊         | 17/200 [01:08<12:17,  4.03s/it]

Epoch: 17 , valid loss: 1.0857 , cos_loss :  0.9113 , classification_loss : 0.1743
Accuracy: 0.9241 , AUC:  0.6749 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 18 , Total loss: 1.094 , cos_loss :  0.9144 , classification_loss : 0.1796


  9%|▉         | 18/200 [01:12<12:21,  4.08s/it]

Epoch: 18 , valid loss: 1.0697 , cos_loss :  0.9115 , classification_loss : 0.1582
Accuracy: 0.9254 , AUC:  0.7296 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 19 , Total loss: 1.0591 , cos_loss :  0.9156 , classification_loss : 0.1435


 10%|▉         | 19/200 [01:16<12:00,  3.98s/it]

Epoch: 19 , valid loss: 1.0359 , cos_loss :  0.9152 , classification_loss : 0.1207
Accuracy: 0.954 , AUC:  0.7827 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 20 , Total loss: 1.0457 , cos_loss :  0.9171 , classification_loss : 0.1286


 10%|█         | 20/200 [01:20<12:00,  4.00s/it]

Epoch: 20 , valid loss: 1.0356 , cos_loss :  0.9143 , classification_loss : 0.1214
Accuracy: 0.954 , AUC:  0.7919 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 21 , Total loss: 1.0407 , cos_loss :  0.9162 , classification_loss : 0.1245


 10%|█         | 21/200 [01:25<12:14,  4.10s/it]

Epoch: 21 , valid loss: 1.0363 , cos_loss :  0.9138 , classification_loss : 0.1225
Accuracy: 0.954 , AUC:  0.7878 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 22 , Total loss: 1.0403 , cos_loss :  0.9159 , classification_loss : 0.1244


 11%|█         | 22/200 [01:28<11:49,  3.98s/it]

Epoch: 22 , valid loss: 1.0329 , cos_loss :  0.9137 , classification_loss : 0.1192
Accuracy: 0.954 , AUC:  0.795 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 23 , Total loss: 1.0425 , cos_loss :  0.9159 , classification_loss : 0.1266


 12%|█▏        | 23/200 [01:32<11:29,  3.89s/it]

Epoch: 23 , valid loss: 1.0418 , cos_loss :  0.9142 , classification_loss : 0.1276
Accuracy: 0.954 , AUC:  0.7946 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 24 , Total loss: 1.0403 , cos_loss :  0.9155 , classification_loss : 0.1248


 12%|█▏        | 24/200 [01:36<11:44,  4.00s/it]

Epoch: 24 , valid loss: 1.0341 , cos_loss :  0.9139 , classification_loss : 0.1202
Accuracy: 0.954 , AUC:  0.7962 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 25 , Total loss: 1.0349 , cos_loss :  0.9156 , classification_loss : 0.1193


 12%|█▎        | 25/200 [01:40<11:24,  3.91s/it]

Epoch: 25 , valid loss: 1.0349 , cos_loss :  0.9143 , classification_loss : 0.1207
Accuracy: 0.954 , AUC:  0.8163 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 26 , Total loss: 1.0366 , cos_loss :  0.9159 , classification_loss : 0.1207


 13%|█▎        | 26/200 [01:44<11:07,  3.84s/it]

Epoch: 26 , valid loss: 1.0301 , cos_loss :  0.9139 , classification_loss : 0.1162
Accuracy: 0.9527 , AUC:  0.83 , F1:  0.5476 , Precision:  0.9583 , recall:  0.3833
Epoch: 27 , Total loss: 1.0334 , cos_loss :  0.9156 , classification_loss : 0.1178


 14%|█▎        | 27/200 [01:48<11:15,  3.90s/it]

Epoch: 27 , valid loss: 1.0279 , cos_loss :  0.9138 , classification_loss : 0.1141
Accuracy: 0.954 , AUC:  0.8388 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 28 , Total loss: 1.0318 , cos_loss :  0.9158 , classification_loss : 0.116


 14%|█▍        | 28/200 [01:51<11:07,  3.88s/it]

Epoch: 28 , valid loss: 1.0273 , cos_loss :  0.9138 , classification_loss : 0.1136
Accuracy: 0.954 , AUC:  0.8378 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 29 , Total loss: 1.0347 , cos_loss :  0.9155 , classification_loss : 0.1192


 14%|█▍        | 29/200 [01:55<10:54,  3.82s/it]

Epoch: 29 , valid loss: 1.0257 , cos_loss :  0.9139 , classification_loss : 0.1118
Accuracy: 0.954 , AUC:  0.8492 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 30 , Total loss: 1.0303 , cos_loss :  0.9158 , classification_loss : 0.1145


 15%|█▌        | 30/200 [01:59<11:03,  3.90s/it]

Epoch: 30 , valid loss: 1.0258 , cos_loss :  0.9141 , classification_loss : 0.1118
Accuracy: 0.954 , AUC:  0.8416 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 31 , Total loss: 1.0368 , cos_loss :  0.9156 , classification_loss : 0.1212


 16%|█▌        | 31/200 [02:03<11:15,  4.00s/it]

Epoch: 31 , valid loss: 1.0336 , cos_loss :  0.9134 , classification_loss : 0.1202
Accuracy: 0.9527 , AUC:  0.8391 , F1:  0.5581 , Precision:  0.9231 , recall:  0.4
Epoch: 32 , Total loss: 1.0393 , cos_loss :  0.9152 , classification_loss : 0.1241


 16%|█▌        | 32/200 [02:08<11:19,  4.05s/it]

Epoch: 32 , valid loss: 1.0302 , cos_loss :  0.9136 , classification_loss : 0.1166
Accuracy: 0.954 , AUC:  0.8259 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 33 , Total loss: 1.0331 , cos_loss :  0.9156 , classification_loss : 0.1176


 16%|█▋        | 33/200 [02:12<11:20,  4.08s/it]

Epoch: 33 , valid loss: 1.0302 , cos_loss :  0.9141 , classification_loss : 0.1161
Accuracy: 0.954 , AUC:  0.8317 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 34 , Total loss: 1.03 , cos_loss :  0.9156 , classification_loss : 0.1145


 17%|█▋        | 34/200 [02:16<11:24,  4.12s/it]

Epoch: 34 , valid loss: 1.029 , cos_loss :  0.9137 , classification_loss : 0.1153
Accuracy: 0.9527 , AUC:  0.8317 , F1:  0.5476 , Precision:  0.9583 , recall:  0.3833
Epoch: 35 , Total loss: 1.0287 , cos_loss :  0.9154 , classification_loss : 0.1133


 18%|█▊        | 35/200 [02:20<11:21,  4.13s/it]

Epoch: 35 , valid loss: 1.0291 , cos_loss :  0.9138 , classification_loss : 0.1153
Accuracy: 0.9552 , AUC:  0.8277 , F1:  0.5714 , Precision:  1.0 , recall:  0.4
Epoch: 36 , Total loss: 1.0278 , cos_loss :  0.9155 , classification_loss : 0.1123


 18%|█▊        | 36/200 [02:24<11:03,  4.04s/it]

Epoch: 36 , valid loss: 1.0343 , cos_loss :  0.9136 , classification_loss : 0.1208
Accuracy: 0.954 , AUC:  0.833 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 37 , Total loss: 1.0271 , cos_loss :  0.9156 , classification_loss : 0.1115


 18%|█▊        | 37/200 [02:28<10:40,  3.93s/it]

Epoch: 37 , valid loss: 1.0298 , cos_loss :  0.9134 , classification_loss : 0.1164
Accuracy: 0.9552 , AUC:  0.8366 , F1:  0.5714 , Precision:  1.0 , recall:  0.4
Epoch: 38 , Total loss: 1.029 , cos_loss :  0.9153 , classification_loss : 0.1136


 19%|█▉        | 38/200 [02:31<10:24,  3.85s/it]

Epoch: 38 , valid loss: 1.0294 , cos_loss :  0.9137 , classification_loss : 0.1157
Accuracy: 0.9552 , AUC:  0.8372 , F1:  0.5714 , Precision:  1.0 , recall:  0.4
Epoch: 39 , Total loss: 1.0303 , cos_loss :  0.9156 , classification_loss : 0.1147


 20%|█▉        | 39/200 [02:35<10:34,  3.94s/it]

Epoch: 39 , valid loss: 1.0275 , cos_loss :  0.9136 , classification_loss : 0.1138
Accuracy: 0.9552 , AUC:  0.8402 , F1:  0.5714 , Precision:  1.0 , recall:  0.4
Epoch: 40 , Total loss: 1.0297 , cos_loss :  0.9154 , classification_loss : 0.1142


 20%|██        | 40/200 [02:39<10:18,  3.87s/it]

Epoch: 40 , valid loss: 1.034 , cos_loss :  0.9138 , classification_loss : 0.1201
Accuracy: 0.954 , AUC:  0.8383 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 41 , Total loss: 1.0265 , cos_loss :  0.9158 , classification_loss : 0.1106


 20%|██        | 41/200 [02:43<10:11,  3.85s/it]

Epoch: 41 , valid loss: 1.0296 , cos_loss :  0.9136 , classification_loss : 0.1161
Accuracy: 0.9552 , AUC:  0.8338 , F1:  0.5714 , Precision:  1.0 , recall:  0.4
Epoch: 42 , Total loss: 1.0255 , cos_loss :  0.9157 , classification_loss : 0.1097


 21%|██        | 42/200 [02:47<10:22,  3.94s/it]

Epoch: 42 , valid loss: 1.0275 , cos_loss :  0.9137 , classification_loss : 0.1138
Accuracy: 0.9577 , AUC:  0.8444 , F1:  0.6136 , Precision:  0.9643 , recall:  0.45
Epoch: 43 , Total loss: 1.0209 , cos_loss :  0.9155 , classification_loss : 0.1054


 22%|██▏       | 43/200 [02:51<10:11,  3.89s/it]

Epoch: 43 , valid loss: 1.0278 , cos_loss :  0.9135 , classification_loss : 0.1143
Accuracy: 0.9552 , AUC:  0.8379 , F1:  0.5714 , Precision:  1.0 , recall:  0.4
Epoch: 44 , Total loss: 1.0197 , cos_loss :  0.9153 , classification_loss : 0.1044


 22%|██▏       | 44/200 [02:55<10:03,  3.87s/it]

Epoch: 44 , valid loss: 1.0249 , cos_loss :  0.9141 , classification_loss : 0.1108
Accuracy: 0.9577 , AUC:  0.8399 , F1:  0.6136 , Precision:  0.9643 , recall:  0.45
Epoch: 45 , Total loss: 1.0196 , cos_loss :  0.9158 , classification_loss : 0.1038


 22%|██▎       | 45/200 [02:59<10:12,  3.95s/it]

Epoch: 45 , valid loss: 1.0236 , cos_loss :  0.9139 , classification_loss : 0.1097
Accuracy: 0.959 , AUC:  0.8513 , F1:  0.6207 , Precision:  1.0 , recall:  0.45
Epoch: 46 , Total loss: 1.0145 , cos_loss :  0.9162 , classification_loss : 0.0983


 23%|██▎       | 46/200 [03:03<10:05,  3.93s/it]

Epoch: 46 , valid loss: 1.0248 , cos_loss :  0.9145 , classification_loss : 0.1103
Accuracy: 0.9577 , AUC:  0.8486 , F1:  0.6136 , Precision:  0.9643 , recall:  0.45
Epoch: 47 , Total loss: 1.0182 , cos_loss :  0.9162 , classification_loss : 0.102


 24%|██▎       | 47/200 [03:07<09:56,  3.90s/it]

Epoch: 47 , valid loss: 1.0299 , cos_loss :  0.9144 , classification_loss : 0.1155
Accuracy: 0.9577 , AUC:  0.8395 , F1:  0.6047 , Precision:  1.0 , recall:  0.4333
Epoch: 48 , Total loss: 1.0137 , cos_loss :  0.9161 , classification_loss : 0.0977


 24%|██▍       | 48/200 [03:11<10:02,  3.96s/it]

Epoch: 48 , valid loss: 1.0224 , cos_loss :  0.9145 , classification_loss : 0.1078
Accuracy: 0.959 , AUC:  0.8596 , F1:  0.6207 , Precision:  1.0 , recall:  0.45
Epoch: 49 , Total loss: 1.0137 , cos_loss :  0.9155 , classification_loss : 0.0982


 24%|██▍       | 49/200 [03:15<10:09,  4.04s/it]

Epoch: 49 , valid loss: 1.0209 , cos_loss :  0.914 , classification_loss : 0.107
Accuracy: 0.959 , AUC:  0.8605 , F1:  0.6292 , Precision:  0.9655 , recall:  0.4667
Epoch: 50 , Total loss: 1.0126 , cos_loss :  0.9159 , classification_loss : 0.0967


 25%|██▌       | 50/200 [03:19<10:07,  4.05s/it]

Epoch: 50 , valid loss: 1.0236 , cos_loss :  0.9143 , classification_loss : 0.1093
Accuracy: 0.959 , AUC:  0.847 , F1:  0.6207 , Precision:  1.0 , recall:  0.45
Epoch: 51 , Total loss: 1.0149 , cos_loss :  0.9159 , classification_loss : 0.099


 26%|██▌       | 51/200 [03:23<09:48,  3.95s/it]

Epoch: 51 , valid loss: 1.0216 , cos_loss :  0.9137 , classification_loss : 0.1079
Accuracy: 0.9614 , AUC:  0.8534 , F1:  0.6517 , Precision:  1.0 , recall:  0.4833
Epoch: 52 , Total loss: 1.0117 , cos_loss :  0.9155 , classification_loss : 0.0962


 26%|██▌       | 52/200 [03:26<09:32,  3.87s/it]

Epoch: 52 , valid loss: 1.0234 , cos_loss :  0.9139 , classification_loss : 0.1094
Accuracy: 0.959 , AUC:  0.851 , F1:  0.6292 , Precision:  0.9655 , recall:  0.4667
Epoch: 53 , Total loss: 1.0109 , cos_loss :  0.9158 , classification_loss : 0.0951


 26%|██▋       | 53/200 [03:30<09:29,  3.88s/it]

Epoch: 53 , valid loss: 1.0294 , cos_loss :  0.9138 , classification_loss : 0.1156
Accuracy: 0.9565 , AUC:  0.8508 , F1:  0.5977 , Precision:  0.963 , recall:  0.4333
Epoch: 54 , Total loss: 1.0085 , cos_loss :  0.9157 , classification_loss : 0.0928


 27%|██▋       | 54/200 [03:34<09:28,  3.89s/it]

Epoch: 54 , valid loss: 1.0262 , cos_loss :  0.9141 , classification_loss : 0.1121
Accuracy: 0.959 , AUC:  0.8562 , F1:  0.6207 , Precision:  1.0 , recall:  0.45
Epoch: 55 , Total loss: 1.0079 , cos_loss :  0.9155 , classification_loss : 0.0924


 28%|██▊       | 55/200 [03:38<09:28,  3.92s/it]

Epoch: 55 , valid loss: 1.024 , cos_loss :  0.9139 , classification_loss : 0.1101
Accuracy: 0.9627 , AUC:  0.8483 , F1:  0.6667 , Precision:  1.0 , recall:  0.5
Epoch: 56 , Total loss: 1.0075 , cos_loss :  0.9155 , classification_loss : 0.0919


 28%|██▊       | 56/200 [03:42<09:19,  3.89s/it]

Epoch: 56 , valid loss: 1.0242 , cos_loss :  0.9138 , classification_loss : 0.1104
Accuracy: 0.9614 , AUC:  0.8505 , F1:  0.6593 , Precision:  0.9677 , recall:  0.5
Epoch: 57 , Total loss: 1.0106 , cos_loss :  0.9156 , classification_loss : 0.095


 28%|██▊       | 57/200 [03:46<09:08,  3.83s/it]

Epoch: 57 , valid loss: 1.0294 , cos_loss :  0.9139 , classification_loss : 0.1156
Accuracy: 0.9614 , AUC:  0.847 , F1:  0.6517 , Precision:  1.0 , recall:  0.4833
Epoch: 58 , Total loss: 1.0078 , cos_loss :  0.9152 , classification_loss : 0.0925


 29%|██▉       | 58/200 [03:49<09:03,  3.83s/it]

Epoch: 58 , valid loss: 1.0231 , cos_loss :  0.9139 , classification_loss : 0.1092
Accuracy: 0.9614 , AUC:  0.8581 , F1:  0.6517 , Precision:  1.0 , recall:  0.4833
Epoch: 59 , Total loss: 1.0087 , cos_loss :  0.9154 , classification_loss : 0.0933


 30%|██▉       | 59/200 [03:54<09:15,  3.94s/it]

Epoch: 59 , valid loss: 1.0239 , cos_loss :  0.914 , classification_loss : 0.11
Accuracy: 0.959 , AUC:  0.8519 , F1:  0.6526 , Precision:  0.8857 , recall:  0.5167
Epoch: 60 , Total loss: 1.0079 , cos_loss :  0.9156 , classification_loss : 0.0923


 30%|███       | 60/200 [03:58<09:23,  4.02s/it]

Epoch: 60 , valid loss: 1.0261 , cos_loss :  0.9136 , classification_loss : 0.1125
Accuracy: 0.9602 , AUC:  0.8543 , F1:  0.6364 , Precision:  1.0 , recall:  0.4667
Epoch: 61 , Total loss: 1.0133 , cos_loss :  0.9155 , classification_loss : 0.0978


 30%|███       | 61/200 [04:02<09:32,  4.12s/it]

Epoch: 61 , valid loss: 1.0232 , cos_loss :  0.9142 , classification_loss : 0.109
Accuracy: 0.9627 , AUC:  0.8491 , F1:  0.6739 , Precision:  0.9688 , recall:  0.5167
Epoch: 62 , Total loss: 1.0149 , cos_loss :  0.9158 , classification_loss : 0.0991


 31%|███       | 62/200 [04:06<09:24,  4.09s/it]

Epoch: 62 , valid loss: 1.0262 , cos_loss :  0.9136 , classification_loss : 0.1126
Accuracy: 0.9614 , AUC:  0.8463 , F1:  0.6517 , Precision:  1.0 , recall:  0.4833
Epoch: 63 , Total loss: 1.0119 , cos_loss :  0.9155 , classification_loss : 0.0963


 32%|███▏      | 63/200 [04:10<09:03,  3.97s/it]

Epoch: 63 , valid loss: 1.0233 , cos_loss :  0.9142 , classification_loss : 0.1091
Accuracy: 0.9602 , AUC:  0.866 , F1:  0.6364 , Precision:  1.0 , recall:  0.4667
Epoch: 64 , Total loss: 1.0139 , cos_loss :  0.9158 , classification_loss : 0.0981


 32%|███▏      | 64/200 [04:14<08:48,  3.89s/it]

Epoch: 64 , valid loss: 1.022 , cos_loss :  0.9137 , classification_loss : 0.1083
Accuracy: 0.9577 , AUC:  0.8592 , F1:  0.6304 , Precision:  0.9062 , recall:  0.4833
Epoch: 65 , Total loss: 1.0076 , cos_loss :  0.9155 , classification_loss : 0.092


 32%|███▎      | 65/200 [04:17<08:32,  3.79s/it]

Epoch: 65 , valid loss: 1.0288 , cos_loss :  0.9137 , classification_loss : 0.1151
Accuracy: 0.959 , AUC:  0.8469 , F1:  0.6207 , Precision:  1.0 , recall:  0.45
Epoch: 66 , Total loss: 1.0063 , cos_loss :  0.9156 , classification_loss : 0.0907


 33%|███▎      | 66/200 [04:21<08:19,  3.73s/it]

Epoch: 66 , valid loss: 1.0234 , cos_loss :  0.9138 , classification_loss : 0.1096
Accuracy: 0.9627 , AUC:  0.8593 , F1:  0.6667 , Precision:  1.0 , recall:  0.5
Epoch: 67 , Total loss: 1.0065 , cos_loss :  0.9158 , classification_loss : 0.0907


 34%|███▎      | 67/200 [04:24<08:14,  3.72s/it]

Epoch: 67 , valid loss: 1.0261 , cos_loss :  0.9134 , classification_loss : 0.1127
Accuracy: 0.9614 , AUC:  0.8441 , F1:  0.6517 , Precision:  1.0 , recall:  0.4833
Epoch: 68 , Total loss: 1.0062 , cos_loss :  0.9155 , classification_loss : 0.0907


 34%|███▍      | 68/200 [04:28<08:07,  3.70s/it]

Epoch: 68 , valid loss: 1.0242 , cos_loss :  0.9141 , classification_loss : 0.1101
Accuracy: 0.9602 , AUC:  0.8623 , F1:  0.6364 , Precision:  1.0 , recall:  0.4667
Epoch: 69 , Total loss: 1.0089 , cos_loss :  0.9157 , classification_loss : 0.0932


 34%|███▍      | 69/200 [04:32<08:04,  3.70s/it]

Epoch: 69 , valid loss: 1.0265 , cos_loss :  0.9136 , classification_loss : 0.1129
Accuracy: 0.9602 , AUC:  0.8507 , F1:  0.6596 , Precision:  0.9118 , recall:  0.5167
Epoch: 70 , Total loss: 1.0051 , cos_loss :  0.9157 , classification_loss : 0.0893


 35%|███▌      | 70/200 [04:35<07:55,  3.66s/it]

Epoch: 70 , valid loss: 1.0214 , cos_loss :  0.9139 , classification_loss : 0.1075
Accuracy: 0.9627 , AUC:  0.8661 , F1:  0.6667 , Precision:  1.0 , recall:  0.5
Epoch: 71 , Total loss: 1.0066 , cos_loss :  0.916 , classification_loss : 0.0907


 36%|███▌      | 71/200 [04:39<07:49,  3.64s/it]

Epoch: 71 , valid loss: 1.0186 , cos_loss :  0.9142 , classification_loss : 0.1044
Accuracy: 0.9639 , AUC:  0.878 , F1:  0.6813 , Precision:  1.0 , recall:  0.5167
Epoch: 72 , Total loss: 1.0036 , cos_loss :  0.9159 , classification_loss : 0.0877


 36%|███▌      | 72/200 [04:43<07:48,  3.66s/it]

Epoch: 72 , valid loss: 1.0182 , cos_loss :  0.9141 , classification_loss : 0.1042
Accuracy: 0.9639 , AUC:  0.8796 , F1:  0.6813 , Precision:  1.0 , recall:  0.5167
Epoch: 73 , Total loss: 1.0037 , cos_loss :  0.9158 , classification_loss : 0.0879


 36%|███▋      | 73/200 [04:46<07:41,  3.64s/it]

Epoch: 73 , valid loss: 1.0243 , cos_loss :  0.9136 , classification_loss : 0.1107
Accuracy: 0.9565 , AUC:  0.8639 , F1:  0.6465 , Precision:  0.8205 , recall:  0.5333
Epoch: 74 , Total loss: 1.0068 , cos_loss :  0.9154 , classification_loss : 0.0914


 37%|███▋      | 74/200 [04:50<07:36,  3.62s/it]

Epoch: 74 , valid loss: 1.0284 , cos_loss :  0.9137 , classification_loss : 0.1147
Accuracy: 0.9602 , AUC:  0.8579 , F1:  0.6444 , Precision:  0.9667 , recall:  0.4833
Epoch: 75 , Total loss: 1.0059 , cos_loss :  0.9156 , classification_loss : 0.0903


 38%|███▊      | 75/200 [04:54<07:35,  3.64s/it]

Epoch: 75 , valid loss: 1.0255 , cos_loss :  0.9139 , classification_loss : 0.1115
Accuracy: 0.9602 , AUC:  0.8542 , F1:  0.6444 , Precision:  0.9667 , recall:  0.4833
Epoch: 76 , Total loss: 1.0044 , cos_loss :  0.9154 , classification_loss : 0.0891


 38%|███▊      | 76/200 [04:57<07:29,  3.62s/it]

Epoch: 76 , valid loss: 1.0223 , cos_loss :  0.9137 , classification_loss : 0.1086
Accuracy: 0.9527 , AUC:  0.8675 , F1:  0.6042 , Precision:  0.8056 , recall:  0.4833
Epoch: 77 , Total loss: 1.0025 , cos_loss :  0.9158 , classification_loss : 0.0867


 38%|███▊      | 77/200 [05:01<07:23,  3.61s/it]

Epoch: 77 , valid loss: 1.0292 , cos_loss :  0.9138 , classification_loss : 0.1154
Accuracy: 0.959 , AUC:  0.8552 , F1:  0.6374 , Precision:  0.9355 , recall:  0.4833
Epoch: 78 , Total loss: 1.0087 , cos_loss :  0.9154 , classification_loss : 0.0933


 39%|███▉      | 78/200 [05:04<07:22,  3.63s/it]

Epoch: 78 , valid loss: 1.0237 , cos_loss :  0.9133 , classification_loss : 0.1105
Accuracy: 0.9577 , AUC:  0.8573 , F1:  0.6304 , Precision:  0.9062 , recall:  0.4833
Epoch: 79 , Total loss: 1.0133 , cos_loss :  0.9157 , classification_loss : 0.0977


 40%|███▉      | 79/200 [05:08<07:18,  3.62s/it]

Epoch: 79 , valid loss: 1.027 , cos_loss :  0.9134 , classification_loss : 0.1136
Accuracy: 0.9552 , AUC:  0.8503 , F1:  0.625 , Precision:  0.8333 , recall:  0.5
Epoch: 80 , Total loss: 1.0196 , cos_loss :  0.9153 , classification_loss : 0.1043


 40%|████      | 80/200 [05:12<07:18,  3.65s/it]

Epoch: 80 , valid loss: 1.105 , cos_loss :  0.9126 , classification_loss : 0.1925
Accuracy: 0.9229 , AUC:  0.6864 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 81 , Total loss: 1.0494 , cos_loss :  0.9155 , classification_loss : 0.1339


 40%|████      | 81/200 [05:15<07:12,  3.63s/it]

Epoch: 81 , valid loss: 1.0415 , cos_loss :  0.9131 , classification_loss : 0.1284
Accuracy: 0.954 , AUC:  0.8322 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 82 , Total loss: 1.0348 , cos_loss :  0.9157 , classification_loss : 0.1191


 41%|████      | 82/200 [05:19<07:07,  3.62s/it]

Epoch: 82 , valid loss: 1.0319 , cos_loss :  0.9134 , classification_loss : 0.1184
Accuracy: 0.9502 , AUC:  0.847 , F1:  0.5349 , Precision:  0.8846 , recall:  0.3833
Epoch: 83 , Total loss: 1.0338 , cos_loss :  0.9156 , classification_loss : 0.1182


 42%|████▏     | 83/200 [05:23<07:07,  3.65s/it]

Epoch: 83 , valid loss: 1.0375 , cos_loss :  0.9128 , classification_loss : 0.1248
Accuracy: 0.954 , AUC:  0.8257 , F1:  0.5542 , Precision:  1.0 , recall:  0.3833
Epoch: 84 , Total loss: 1.0301 , cos_loss :  0.9151 , classification_loss : 0.115


 42%|████▏     | 84/200 [05:26<07:01,  3.64s/it]

Epoch: 84 , valid loss: 1.0344 , cos_loss :  0.9128 , classification_loss : 0.1216
Accuracy: 0.9515 , AUC:  0.8378 , F1:  0.5412 , Precision:  0.92 , recall:  0.3833
Epoch: 85 , Total loss: 1.0218 , cos_loss :  0.9151 , classification_loss : 0.1067


 42%|████▎     | 85/200 [05:30<06:56,  3.63s/it]

Epoch: 85 , valid loss: 1.0365 , cos_loss :  0.9127 , classification_loss : 0.1238
Accuracy: 0.9527 , AUC:  0.8357 , F1:  0.5476 , Precision:  0.9583 , recall:  0.3833
Epoch: 86 , Total loss: 1.0229 , cos_loss :  0.9153 , classification_loss : 0.1075


 43%|████▎     | 86/200 [05:34<06:57,  3.67s/it]

Epoch: 86 , valid loss: 1.035 , cos_loss :  0.9125 , classification_loss : 0.1225
Accuracy: 0.9478 , AUC:  0.839 , F1:  0.5227 , Precision:  0.8214 , recall:  0.3833
Epoch: 87 , Total loss: 1.0193 , cos_loss :  0.915 , classification_loss : 0.1044


 44%|████▎     | 87/200 [05:37<06:53,  3.66s/it]

Epoch: 87 , valid loss: 1.0316 , cos_loss :  0.9127 , classification_loss : 0.1189
Accuracy: 0.949 , AUC:  0.852 , F1:  0.5287 , Precision:  0.8519 , recall:  0.3833
Epoch: 88 , Total loss: 1.0159 , cos_loss :  0.9152 , classification_loss : 0.1007


 44%|████▍     | 88/200 [05:41<06:46,  3.63s/it]

Epoch: 88 , valid loss: 1.0343 , cos_loss :  0.9126 , classification_loss : 0.1217
Accuracy: 0.9515 , AUC:  0.8437 , F1:  0.5412 , Precision:  0.92 , recall:  0.3833
Epoch: 89 , Total loss: 1.0224 , cos_loss :  0.9151 , classification_loss : 0.1072


 44%|████▍     | 89/200 [05:44<06:46,  3.66s/it]

Epoch: 89 , valid loss: 1.0393 , cos_loss :  0.9127 , classification_loss : 0.1266
Accuracy: 0.9415 , AUC:  0.8474 , F1:  0.4946 , Precision:  0.697 , recall:  0.3833
Epoch: 90 , Total loss: 1.0208 , cos_loss :  0.9153 , classification_loss : 0.1055


 45%|████▌     | 90/200 [05:48<06:39,  3.64s/it]

Epoch: 90 , valid loss: 1.0341 , cos_loss :  0.9127 , classification_loss : 0.1214
Accuracy: 0.9527 , AUC:  0.8406 , F1:  0.5476 , Precision:  0.9583 , recall:  0.3833
Epoch: 91 , Total loss: 1.0217 , cos_loss :  0.9151 , classification_loss : 0.1066


 46%|████▌     | 91/200 [05:52<06:34,  3.62s/it]

Epoch: 91 , valid loss: 1.0342 , cos_loss :  0.9126 , classification_loss : 0.1216
Accuracy: 0.949 , AUC:  0.8466 , F1:  0.5393 , Precision:  0.8276 , recall:  0.4
Epoch: 92 , Total loss: 1.0189 , cos_loss :  0.9151 , classification_loss : 0.1039


 46%|████▌     | 92/200 [05:55<06:33,  3.64s/it]

Epoch: 92 , valid loss: 1.0322 , cos_loss :  0.9126 , classification_loss : 0.1196
Accuracy: 0.9502 , AUC:  0.8521 , F1:  0.5349 , Precision:  0.8846 , recall:  0.3833
Epoch: 93 , Total loss: 1.0132 , cos_loss :  0.9148 , classification_loss : 0.0983


 46%|████▋     | 93/200 [05:59<06:28,  3.63s/it]

Epoch: 93 , valid loss: 1.0393 , cos_loss :  0.9126 , classification_loss : 0.1267
Accuracy: 0.9502 , AUC:  0.8447 , F1:  0.5349 , Precision:  0.8846 , recall:  0.3833
Epoch: 94 , Total loss: 1.0227 , cos_loss :  0.9149 , classification_loss : 0.1077


 47%|████▋     | 94/200 [06:03<06:26,  3.65s/it]

Epoch: 94 , valid loss: 1.0399 , cos_loss :  0.9125 , classification_loss : 0.1274
Accuracy: 0.9527 , AUC:  0.8405 , F1:  0.5476 , Precision:  0.9583 , recall:  0.3833
Epoch: 95 , Total loss: 1.0203 , cos_loss :  0.9149 , classification_loss : 0.1054


 48%|████▊     | 95/200 [06:06<06:20,  3.63s/it]

Epoch: 95 , valid loss: 1.0347 , cos_loss :  0.9124 , classification_loss : 0.1222
Accuracy: 0.944 , AUC:  0.8576 , F1:  0.5055 , Precision:  0.7419 , recall:  0.3833
Epoch: 96 , Total loss: 1.0188 , cos_loss :  0.9149 , classification_loss : 0.1039


 48%|████▊     | 96/200 [06:10<06:16,  3.62s/it]

Epoch: 96 , valid loss: 1.038 , cos_loss :  0.9126 , classification_loss : 0.1254
Accuracy: 0.9478 , AUC:  0.8389 , F1:  0.5227 , Precision:  0.8214 , recall:  0.3833
Epoch: 97 , Total loss: 1.0185 , cos_loss :  0.9152 , classification_loss : 0.1033


 48%|████▊     | 97/200 [06:14<06:15,  3.64s/it]

Epoch: 97 , valid loss: 1.0363 , cos_loss :  0.9127 , classification_loss : 0.1236
Accuracy: 0.9478 , AUC:  0.8405 , F1:  0.5227 , Precision:  0.8214 , recall:  0.3833
Epoch: 98 , Total loss: 1.016 , cos_loss :  0.9151 , classification_loss : 0.1009


 49%|████▉     | 98/200 [06:17<06:10,  3.63s/it]

Epoch: 98 , valid loss: 1.0326 , cos_loss :  0.9127 , classification_loss : 0.12
Accuracy: 0.9515 , AUC:  0.8538 , F1:  0.5517 , Precision:  0.8889 , recall:  0.4
Epoch: 99 , Total loss: 1.0135 , cos_loss :  0.9151 , classification_loss : 0.0984


 50%|████▉     | 99/200 [06:21<06:05,  3.62s/it]

Epoch: 99 , valid loss: 1.031 , cos_loss :  0.913 , classification_loss : 0.1179
Accuracy: 0.9502 , AUC:  0.8576 , F1:  0.5455 , Precision:  0.8571 , recall:  0.4
Epoch: 100 , Total loss: 1.0111 , cos_loss :  0.9153 , classification_loss : 0.0959


 50%|█████     | 100/200 [06:24<06:04,  3.65s/it]

Epoch: 100 , valid loss: 1.0289 , cos_loss :  0.9129 , classification_loss : 0.116
Accuracy: 0.9527 , AUC:  0.845 , F1:  0.5778 , Precision:  0.8667 , recall:  0.4333
Epoch: 101 , Total loss: 1.0086 , cos_loss :  0.9153 , classification_loss : 0.0932


 50%|█████     | 100/200 [06:28<06:28,  3.89s/it]

Epoch: 101 , valid loss: 1.0236 , cos_loss :  0.9131 , classification_loss : 0.1105
Accuracy: 0.954 , AUC:  0.8698 , F1:  0.5747 , Precision:  0.9259 , recall:  0.4167
Early stopping triggered at epoch 70
Best Combined Score (AUC): 0.6813





In [43]:
print(model_1_path)

results/20240424/model_lr0.001_classi1_dim64_hid256_layer8_epoch78_no_center_total_label_time_feature_0.5_141837.pth


In [44]:
total_te_loss = 0
total_te_cos_loss = 0
total_te_classi_loss = 0
total_te_const_loss = 0

te_labels_list = []
te_predictions_list = []
te_probabilities_list = []

model_1_path = model_1_save_path

model_1.load_state_dict(torch.load(model_1_path))

model_1.eval()

with torch.no_grad():
    for batch_data in tqdm(test_loader):
        te_labels = batch_data['label'].to(device)
        te_output_1, te_next_visit_output_1, te_final_visit_classification_1, te_final_visit_1 = model_1(batch_data)
        y_te = torch.ones(te_output_1.size(0), dtype=torch.float, device=device)

        te_cosine_loss_mean_1 = cosine_embedding_loss(te_output_1, te_next_visit_output_1, y_te)
        classification_loss_te_1 = criterion(te_final_visit_classification_1.squeeze(), te_labels.long())
        te_loss = (cos_lambda * te_cosine_loss_mean_1) + (classi_lambda * classification_loss_te_1) 

        total_te_loss += te_loss.item()
        total_te_cos_loss += (te_cosine_loss_mean_1.item())
        total_te_classi_loss += (classification_loss_te_1.item())

        te_probs = F.softmax(te_final_visit_classification_1)
        te_predictions = torch.max(te_probs, 1)[1].view((len(te_labels),))
        
        te_labels_list.extend(te_labels.view(-1).cpu().numpy())
        te_predictions_list.extend(te_predictions.cpu().numpy())
        te_probabilities_list.extend(te_probs[:,1].cpu().numpy())

    avg_te_loss = total_te_loss / len(test_loader)
    avg_te_cos_loss = total_te_cos_loss / len(test_loader)
    avg_te_classi_loss = total_te_classi_loss / len(test_loader)

100%|██████████| 2/2 [00:00<00:00,  5.37it/s]


In [45]:
# 성능 지표 계산
accuracy = accuracy_score(te_labels_list, te_predictions_list)
auc = roc_auc_score(te_labels_list, te_probabilities_list)
f1 = f1_score(te_labels_list, te_predictions_list)
precision = precision_score(te_labels_list,te_predictions_list)
recall = recall_score(te_labels_list, te_predictions_list)

print("test loss:", round(avg_te_loss, 4), ", cos_loss : ", round(avg_te_cos_loss, 4),
      ", classification_loss :",round(avg_te_classi_loss, 4), 
      # ", contrastive_loss : ", round(avg_te_const_loss,4)
     )
print("Accuracy:", round(accuracy,4), ", AUC: ", round(auc,4), ", F1: ", round(f1,4), ", Precision: ", round(precision,4), ", recall: ", round(recall,4))

np.unique(te_labels_list,  return_counts=True)

test loss: 1.026 , cos_loss :  0.914 , classification_loss : 0.112
Accuracy: 0.9602 , AUC:  0.8951 , F1:  0.6923 , Precision:  0.9474 , recall:  0.5455


(array([0, 1]), array([738,  66]))

In [46]:
confusion_matrix(te_labels_list, te_predictions_list)

array([[736,   2],
       [ 30,  36]])

In [81]:
print(model_1_path)

results/20240424/model_lr0.001_classi1_dim64_hid256_layer5_epoch35_no_center_total_label_time_feature_0.5_140528.pth


In [85]:
total_te_loss = 0
total_te_cos_loss = 0
total_te_classi_loss = 0
total_te_const_loss = 0

te_labels_list = []
te_predictions_list = []
te_probabilities_list = []

model_1_path = model_1_save_path

model_1.load_state_dict(torch.load(model_1_path))

model_1.eval()

with torch.no_grad():
    for batch_data in tqdm(test_loader):
        te_labels = batch_data['label'].to(device)
        te_output_1, te_next_visit_output_1, te_final_visit_classification_1, te_final_visit_1 = model_1(batch_data)
        y_te = torch.ones(te_output_1.size(0), dtype=torch.float, device=device)

        te_cosine_loss_mean_1 = cosine_embedding_loss(te_output_1, te_next_visit_output_1, y_te)
        classification_loss_te_1 = criterion(te_final_visit_classification_1.squeeze(), te_labels.long())
        te_loss = (cos_lambda * te_cosine_loss_mean_1) + (classi_lambda * classification_loss_te_1) 

        total_te_loss += te_loss.item()
        total_te_cos_loss += (te_cosine_loss_mean_1.item())
        total_te_classi_loss += (classification_loss_te_1.item())

        te_probs = F.softmax(te_final_visit_classification_1)
        te_predictions = torch.max(te_probs, 1)[1].view((len(te_labels),))
        
        te_labels_list.extend(te_labels.view(-1).cpu().numpy())
        te_predictions_list.extend(te_predictions.cpu().numpy())
        te_probabilities_list.extend(te_probs[:,1].cpu().numpy())

    avg_te_loss = total_te_loss / len(test_loader)
    avg_te_cos_loss = total_te_cos_loss / len(test_loader)
    avg_te_classi_loss = total_te_classi_loss / len(test_loader)

100%|██████████| 2/2 [00:00<00:00,  5.32it/s]


In [86]:
# 성능 지표 계산
accuracy = accuracy_score(te_labels_list, te_predictions_list)
auc = roc_auc_score(te_labels_list, te_probabilities_list)
f1 = f1_score(te_labels_list, te_predictions_list)
precision = precision_score(te_labels_list,te_predictions_list)
recall = recall_score(te_labels_list, te_predictions_list)

print("test loss:", round(avg_te_loss, 4), ", cos_loss : ", round(avg_te_cos_loss, 4),
      ", classification_loss :",round(avg_te_classi_loss, 4), 
      # ", contrastive_loss : ", round(avg_te_const_loss,4)
     )
print("Accuracy:", round(accuracy,4), ", AUC: ", round(auc,4), ", F1: ", round(f1,4), ", Precision: ", round(precision,4), ", recall: ", round(recall,4))

np.unique(te_labels_list,  return_counts=True)

test loss: 1.0313 , cos_loss :  0.9141 , classification_loss : 0.1172
Accuracy: 0.9614 , AUC:  0.8807 , F1:  0.7156 , Precision:  1.0 , recall:  0.5571


(array([0, 1]), array([734,  70]))

In [87]:
log_path = './logs'
logs = f'model_lr{lr}_classi{classi_lambda}_dim{ninp}_hid{nhid}_layer{nlayer}_epoch{{epoch}}_{{model}}_{{pe}}_{model_time}.txt'
os.makedirs(os.path.join(log_path, date_dir), exist_ok=True)

results = [accuracy, auc, f1, precision, recall]
cm = confusion_matrix(te_labels_list, te_predictions_list)
with open(os.path.join(log_path, date_dir, logs.format(epoch=best_epoch, model=model_name, pe=pe)), 'w') as f:
    f.write(logs.format(epoch=best_epoch, model=model_name, pe=pe))
    f.write('\n')
    f.write(str(results))
    f.write('\n')
    f.write(str(cm))

In [88]:
confusion_matrix(te_labels_list, te_predictions_list)

array([[734,   0],
       [ 31,  39]])