In [1]:
import numpy as np
import pandas as pd
import os 
import pickle
import sys
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
from tqdm import tqdm
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix
import warnings

torch.set_printoptions(profile="full")
np.set_printoptions(threshold=sys.maxsize)
warnings.filterwarnings("ignore")

In [2]:
from src.attn import FixedPositionalEncoding, LearnablePositionalEncoding, TemporalEmbedding, MultiheadAttention, Decoder
from src.loss import ContrastiveLoss, FocalLoss

In [3]:
from datetime import datetime

In [4]:
seed = 777
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(seed)

In [5]:
path = '/data/notebook/shared/MIMIC-IV'
with open(os.path.join(path, 'dict_types_nomedi_mimic_240408_clinic_3_years.pkl'), 'rb') as f:
    dtype_dict = pickle.load(f)
f.close()

with open(os.path.join('./data', 'preprocessed_nomedi_240423_clinic_3_years.pkl'), 'rb') as f:
    data_dict_d = pickle.load(f)
f.close()

remain_year_clinical = pd.read_csv(os.path.join(path, "240408_remain_year_clinical_diag_seq_4.csv"))
del remain_year_clinical['clinical_label']

In [6]:
labels = []
code_labels = []
length_list = []
clinical_labels = []
code_length_list = []
for sample_id, visits in tqdm(data_dict_d.items()):
    # 레이블 추가
    label = visits['label']
    code_label = visits['code_label']
    clinical_label = visits['clinical_label']
    labels.append(label)
    code_labels.append(code_label)
    clinical_labels.append(clinical_labels)

100%|██████████| 8037/8037 [00:00<00:00, 603055.95it/s]


In [7]:
len(list(data_dict_d.keys()))

8037

In [8]:
len(labels)

8037

In [9]:
train_indices, test_indices, train_y, test_y = train_test_split(list(data_dict_d.keys()), labels, test_size=0.1, random_state=777, stratify=labels)
train_indices, valid_indices, valid_y, valid_y = train_test_split(train_indices, train_y, test_size=(len(test_indices)/len(train_indices)), random_state=777, stratify=train_y) 

In [10]:
train_data = {}
valid_data = {}
test_data = {}
for sample in tqdm(train_indices):
    train_data[sample] = data_dict_d[sample]

for sample in tqdm(valid_indices):
    valid_data[sample] = data_dict_d[sample]

for sample in tqdm(test_indices):
    test_data[sample] = data_dict_d[sample]

100%|██████████| 6429/6429 [00:00<00:00, 1562383.71it/s]
100%|██████████| 804/804 [00:00<00:00, 1098085.45it/s]
100%|██████████| 804/804 [00:00<00:00, 757461.91it/s]


In [11]:
train_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(train_indices)].reset_index(drop=True)
valid_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(valid_indices)].reset_index(drop=True)
test_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(test_indices)].reset_index(drop=True)

In [12]:
class CustomDataset(Dataset):
    def __init__(self, data, clinical_df, scaler, mode='train'):       
        self.keys = list(data.keys())  # 딕셔너리의 키 목록 저장
        self.data = data  # 딕셔너리에서 데이터만 추출하여 저장
        self.scaler = scaler
        if mode == 'train':
            scaled_data = self.scaler.fit_transform(clinical_df.iloc[:, 2:])
            scaled_clinical_df = pd.DataFrame(scaled_data, columns = clinical_df.iloc[:, 2:].columns)
            self.scaled_clinical_df = pd.concat([clinical_df.iloc[:, :2], scaled_clinical_df], axis=1)
        else:
            scaled_data = self.scaler.transform(clinical_df.iloc[:, 2:])
            scaled_clinical_df = pd.DataFrame(scaled_data, columns = clinical_df.iloc[:, 2:].columns)
            self.scaled_clinical_df = pd.concat([clinical_df.iloc[:, :2], scaled_clinical_df], axis=1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # x는 현재 방문까지의 모든 방문 데이터, y는 다음 방문 데이터
        padding_temp = torch.zeros((1, len(self.data[self.keys[idx]]['code_index'][0])), dtype=torch.long)
        
        origin_visit = torch.tensor(self.data[self.keys[idx]]['code_index'], dtype=torch.long)
        origin_mask = torch.tensor(self.data[self.keys[idx]]['seq_mask'], dtype=torch.long)
        origin_mask_final = torch.tensor(self.data[self.keys[idx]]['seq_mask_final'], dtype=torch.long)
        origin_mask_code = torch.tensor(self.data[self.keys[idx]]['seq_mask_code'], dtype=torch.long)
        
        next_visit = torch.cat((origin_visit[1:], padding_temp), dim=0)
        next_mask =torch.cat((origin_mask[1:], torch.tensor([0], dtype=torch.long)), dim=0)
        next_mask_code = torch.cat((origin_mask_code[1:], padding_temp), dim=0)
        
        clinical_data = self.scaled_clinical_df.loc[self.scaled_clinical_df['subject_id'] == self.keys[idx], self.scaled_clinical_df.columns[2:]].values
        clinical_data = torch.tensor(clinical_data, dtype=torch.float)
        visit_index = torch.tensor(self.data[self.keys[idx]]['year_onehot'], dtype=torch.long)
        last_visit_index = torch.tensor(self.data[self.keys[idx]]['last_year_onehot'], dtype=torch.float)
        time_feature = torch.tensor(self.data[self.keys[idx]]['time_feature'], dtype=torch.long)
        label_per_sample = self.data[self.keys[idx]]['label']
        key_per_sample = self.keys[idx]  # 해당 샘플의 키
        # 키 값도 함께 반환
        return {'sample_id': key_per_sample, 'origin_visit': origin_visit, 'next_visit': next_visit,\
                'origin_mask': origin_mask, 'origin_mask_final': origin_mask_final, 'next_mask': next_mask, \
                'origin_mask_code': origin_mask_code, 'next_mask_code': next_mask_code, 'clinical_data': clinical_data, \
                'visit_index': visit_index, 'last_visit_index': last_visit_index, 'time_feature': time_feature, \
                'label': label_per_sample}

In [13]:
scaler = StandardScaler()
train_dataset = CustomDataset(train_data, train_clinical, scaler, mode='train')
valid_dataset = CustomDataset(valid_data, valid_clinical, scaler, mode='valid')
test_dataset = CustomDataset(test_data, test_clinical, scaler, mode='test')

In [14]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=512, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

In [15]:
class CenterLoss(nn.Module):
    def __init__(self, num_classes=2, feat_dim=128, device='cuda:0'):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.device = device
        self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).to(self.device))
        
    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (batch_size).
        """
        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()

        distmat.addmm_(x, self.centers.t(), beta = 1, alpha = -2,)

        classes = torch.arange(self.num_classes).long().to(self.device)
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))

        dist = distmat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
        return loss

In [32]:
class CustomTransformerModel(nn.Module):
    def __init__(self, code_size, ninp, nhead, nhid, nlayers, dropout=0.5, device='cuda:0', pe='fixed'):
        super(CustomTransformerModel, self).__init__()
        self.ninp = ninp  # 축소된 차원 및 Transformer 입력 차원
        self.device = device
        # 차원 축소를 위한 선형 레이어
        self.pre_embedding = nn.Embedding(code_size, self.ninp)
        self.pe = pe
        if self.pe == 'fixed':
            self.pos_encoder = FixedPositionalEncoding(ninp, dropout)
        elif self.pe == 'time_feature':
            self.pos_encoder = TemporalEmbedding(ninp, embed_type='embed', dropout=dropout)
        else:
            self.pos_encoder = LearnablePositionalEncoding(ninp, dropout)
            
        self.transformer_decoder = Decoder(ninp, nhead, nhid, nlayers, dropout)
        self.decoder = nn.Linear(ninp, ninp, bias=False)  # 최종 출력 차원을 설정 (여기서는 ninp로 설정)
        self.classification_layer = nn.Linear(ninp * 2, 2)
        self.clinical_transform = nn.Linear(18, ninp) 
        self.cross_attn = MultiheadAttention(ninp, nhead, dropout=0.3)
        self.init_weights()
        
    def init_weights(self):
        nn.init.xavier_uniform_(self.pre_embedding.weight)
        nn.init.xavier_uniform_(self.decoder.weight)
        nn.init.xavier_uniform_(self.classification_layer.weight)
    
    def forward(self, batch_data):
        # 차원 축소
        origin_visit = batch_data['origin_visit'].to(self.device)
        next_visit = batch_data['next_visit'].to(self.device)
        clinical_tensor = batch_data['clinical_data'].to(self.device)
        mask_code = batch_data['origin_mask_code'].unsqueeze(3).to(self.device)
        next_mask_code = batch_data['next_mask_code'].unsqueeze(3).to(self.device)
        mask_final = batch_data['origin_mask_final'].unsqueeze(2).to(self.device)
        mask_final_year = batch_data['last_visit_index'].to(self.device)

                
        origin_visit_emb = (self.pre_embedding(origin_visit) * mask_code).sum(dim=2)
        next_visit_emb = (self.pre_embedding(next_visit) * next_mask_code).sum(dim=2)
        # 위치 인코딩 및 Transformer 디코더 적용
        if self.pe == 'time_feature':
            time_feature = batch_data['time_feature'].to(self.device)
            src = self.pos_encoder(origin_visit_emb, time_feature)
        else:
            src = self.pos_encoder(origin_visit_emb)
        
        seq_len = src.shape[1]
        src_mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=self.device), diagonal=1)
        output = self.transformer_decoder(src, attention_mask=src_mask)
        output = self.decoder(output)
        
        output_batch, output_visit_num, output_dim = output.size()
        next_visit_output_batch, next_visit_output_num, next_visit_output_dim = next_visit_emb.size()
        
        # print(output[:, :torch.,:])
        temp_output = output.reshape(-1, output_dim)
        temp_next_visit = next_visit_emb.reshape(-1, next_visit_output_dim)
        final_visit = (output * mask_final).sum(dim=1)
        year_emb = torch.bmm(output.transpose(1,2), mask_final_year).transpose(1,2)[:,:2,:]
        
        transformed_clinical = self.clinical_transform(clinical_tensor)
        mixed_output, mixed_cross_attn = self.cross_attn(year_emb, transformed_clinical, transformed_clinical)
        mixed_output = mixed_output[:,-1,:]

        mixed_final_emb = torch.cat((final_visit, mixed_output), dim=-1)
        classification_result = self.classification_layer(mixed_final_emb)    
        return temp_output, temp_next_visit, classification_result, final_visit

In [52]:
lr = 0.001
ninp = 64
nhid = 256
nlayer = 6
model_name = 'with_center_total_label'
pe = 'fixed'
gamma = 0.5

In [53]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_1 = CustomTransformerModel(len(dtype_dict), ninp=ninp, nhead=8, nhid=nhid, nlayers=nlayer, dropout=0.1, device=device, pe=pe).to(device)
# model_2 = CustomTransformerModel(len(dtype_dict), ninp=ninp, nhead=8, nhid=nhid, nlayers=nlayer, dropout=0.6, device=device, pe='fixed').to(device)

optimizer = torch.optim.Adam(model_1.parameters(), lr=lr)

# criterion = nn.BCEWithLogitsLoss()
criterion = FocalLoss(2, gamma=gamma)
# cosine_loss = CosineSimilarityLoss()
cosine_embedding_loss = nn.CosineEmbeddingLoss()
# const_loss = ContrastiveLoss(temperature=0.05)
center_loss = CenterLoss(num_classes=2, feat_dim=2, device=device)

In [54]:
# 얼리 스타핑 설정
num_epochs = 100
patience = 40
best_loss = float("inf")
counter = 0
epoch_temp = 0

best_f1 = 0.0
best_combined_score = 0.0
best_epoch = 0

view_total_loss = []
view_cos_loss = []
view_classi_loss = []

cos_lambda = 1
classi_lambda = 1
center_lambda = 0.3

for epoch in tqdm(range(num_epochs)):
    model_1.train()

    total_train_loss = 0
    total_cos_loss = 0
    total_classi_loss = 0
    total_center_loss = 0
    
    for batch_data in train_loader:
        tr_labels = batch_data['label'].to(device)
        optimizer.zero_grad()
        output_1, next_visit_output_1, final_visit_classification_1, final_visit_1 = model_1(batch_data)
        y = torch.ones(output_1.size(0), dtype=torch.float, device=device)
        cosine_loss_mean_1 = cosine_embedding_loss(output_1, next_visit_output_1, y.to(device))
        classification_loss_1 = criterion(final_visit_classification_1.squeeze(), tr_labels.long())
        center_loss_1 = center_loss(final_visit_classification_1, tr_labels.long())
        loss = (cos_lambda * cosine_loss_mean_1) + (classi_lambda * classification_loss_1) + (center_lambda * center_loss_1)

        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        total_cos_loss += (cosine_loss_mean_1.item())
        total_classi_loss += (classification_loss_1.item())
        total_center_loss += (center_loss_1.item())

    # 평균 손실 계산
    avg_train_loss = total_train_loss / len(train_loader)
    avg_cos_loss = total_cos_loss / len(train_loader)
    avg_classi_loss = total_classi_loss / len(train_loader)
    avg_center_loss = total_center_loss / len(train_loader)

    print("Epoch:", epoch+1, ", Total loss:", round(avg_train_loss, 4), ", cos_loss : ",\
          round(avg_cos_loss, 4), ", classification_loss :",round(avg_classi_loss, 4), 
           ", center_loss :",round(avg_center_loss, 4))

    model_1.eval()
    total_val_loss = 0
    total_val_cos_loss = 0
    total_val_classi_loss = 0
    total_val_center_loss = 0

    val_labels_list = []
    val_predictions_list = []
    val_probabilities_list = []


    with torch.no_grad():
        for batch_data in valid_loader:
            val_labels = batch_data['label'].to(device)
            val_output_1, val_next_visit_output_1, val_final_visit_classification_1, val_final_visit_1 = model_1(batch_data)
            y_val = torch.ones(val_output_1.size(0), dtype=torch.float, device=device) 
            val_cosine_loss_mean_1 = cosine_embedding_loss(val_output_1, val_next_visit_output_1, y_val.to(device))            
            classification_loss_val_1 = criterion(val_final_visit_classification_1.squeeze(), val_labels.long())
            center_loss_val_1 = center_loss(val_final_visit_classification_1, val_labels.long())
            val_loss = (cos_lambda * val_cosine_loss_mean_1) + (classi_lambda * classification_loss_val_1) + (center_lambda * center_loss_val_1)

            total_val_loss += val_loss.item()
            total_val_cos_loss += (val_cosine_loss_mean_1.item())
            total_val_classi_loss += (classification_loss_val_1.item())
            total_val_center_loss += (center_loss_val_1.item())

            val_probs = F.softmax(val_final_visit_classification_1)
            val_predictions = torch.max(val_probs, 1)[1].view((len(val_labels),))

            val_labels_list.extend(val_labels.view(-1).cpu().numpy())
            val_predictions_list.extend(val_predictions.cpu().numpy())
            
            ########
            val_probabilities_list.extend(val_probs[:,1].cpu().numpy())
            ######## 
            
        avg_val_loss = total_val_loss / len(valid_loader)
        avg_val_cos_loss = total_val_cos_loss / len(valid_loader)
        avg_val_classi_loss = total_val_classi_loss / len(valid_loader)
        avg_val_center_loss = total_val_center_loss / len(valid_loader)

    # 성능 지표 계산
    accuracy = accuracy_score(val_labels_list, val_predictions_list)
    auc = roc_auc_score(val_labels_list, val_probabilities_list)
    f1 = f1_score(val_labels_list, val_predictions_list)
    precision = precision_score(val_labels_list,val_predictions_list)
    recall = recall_score(val_labels_list, val_predictions_list)

    print("Epoch:", epoch+1, ", valid loss:", round(avg_val_loss, 4), ", cos_loss : ",\
          round(avg_val_cos_loss, 4), ", classification_loss :",round(avg_val_classi_loss, 4),\
          ", center_loss :",round(avg_val_center_loss, 4))
    print("Accuracy:", round(accuracy,4), ", AUC: ", round(auc,4), ", F1: ", round(f1,4), ", Precision: ", round(precision,4), ", recall: ", round(recall,4))

    current_auc = roc_auc_score(val_labels_list, val_probabilities_list)
    current_f1 = f1_score(val_labels_list, val_predictions_list)
    current_combined_score = current_f1

    # 모델 저장 경로 설정
    model_save_path = f'results'
    date_dir = datetime.today().strftime("%Y%m%d")
    model_time =  datetime.today().strftime("%H%M%S")
    # 학습률과 분류 가중치를 파일명에 포함시키기 위한 문자열 포맷
    model_filename_format = f'model_lr{lr}_classi{classi_lambda}_dim{ninp}_hid{nhid}_layer{nlayer}_epoch{{epoch}}_{{model}}_{{pe}}_{{gamma}}_{model_time}.pth'


    # 모델 저장 폴더가 없으면 생성
    os.makedirs(os.path.join(model_save_path, date_dir), exist_ok=True)

    if current_combined_score > best_combined_score:
        best_combined_score = current_combined_score
        best_epoch = epoch
        counter = 0
        # 모델 저장 경로와 파일명을 결합하여 전체 파일 경로 생성
        model_1_save_path = os.path.join(model_save_path, date_dir, model_filename_format.format(epoch=best_epoch, model=model_name, pe=pe, gamma=gamma))
        # 모델 저장
        torch.save(model_1.state_dict(), f'{model_1_save_path}')
    else:
        counter += 1

    if counter >= patience:
        print("Early stopping triggered at epoch {}".format(best_epoch))
        print(f"Best Combined Score (AUC): {best_combined_score:.4f}")
        break

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 1 , Total loss: 3.6848 , cos_loss :  0.9789 , classification_loss : 0.6514 , center_loss : 6.848


  1%|          | 1/100 [00:03<06:04,  3.68s/it]

Epoch: 1 , valid loss: 1.5886 , cos_loss :  0.9554 , classification_loss : 0.5614 , center_loss : 0.2391
Accuracy: 0.408 , AUC:  0.5097 , F1:  0.147 , Precision:  0.084 , recall:  0.5857
Epoch: 2 , Total loss: 1.496 , cos_loss :  0.9445 , classification_loss : 0.364 , center_loss : 0.6247


  2%|▏         | 2/100 [00:07<06:12,  3.80s/it]

Epoch: 2 , valid loss: 1.389 , cos_loss :  0.9324 , classification_loss : 0.3979 , center_loss : 0.1955
Accuracy: 0.8333 , AUC:  0.5434 , F1:  0.0822 , Precision:  0.0789 , recall:  0.0857
Epoch: 3 , Total loss: 1.405 , cos_loss :  0.9291 , classification_loss : 0.3312 , center_loss : 0.4823


  3%|▎         | 3/100 [00:11<06:03,  3.75s/it]

Epoch: 3 , valid loss: 1.3392 , cos_loss :  0.9234 , classification_loss : 0.3371 , center_loss : 0.2624
Accuracy: 0.9129 , AUC:  0.5738 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 4 , Total loss: 1.3748 , cos_loss :  0.9229 , classification_loss : 0.3241 , center_loss : 0.4262


  4%|▍         | 4/100 [00:14<05:53,  3.69s/it]

Epoch: 4 , valid loss: 1.3291 , cos_loss :  0.9195 , classification_loss : 0.3162 , center_loss : 0.3115
Accuracy: 0.9129 , AUC:  0.5174 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 5 , Total loss: 1.3628 , cos_loss :  0.9201 , classification_loss : 0.3194 , center_loss : 0.4109


  5%|▌         | 5/100 [00:18<05:40,  3.58s/it]

Epoch: 5 , valid loss: 1.3265 , cos_loss :  0.9176 , classification_loss : 0.3235 , center_loss : 0.2846
Accuracy: 0.9129 , AUC:  0.576 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 6 , Total loss: 1.3549 , cos_loss :  0.9187 , classification_loss : 0.3183 , center_loss : 0.3933


  6%|▌         | 6/100 [00:21<05:31,  3.53s/it]

Epoch: 6 , valid loss: 1.323 , cos_loss :  0.9165 , classification_loss : 0.3165 , center_loss : 0.3003
Accuracy: 0.9129 , AUC:  0.6008 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 7 , Total loss: 1.3519 , cos_loss :  0.9177 , classification_loss : 0.3201 , center_loss : 0.3802


  7%|▋         | 7/100 [00:25<05:24,  3.49s/it]

Epoch: 7 , valid loss: 1.3222 , cos_loss :  0.9157 , classification_loss : 0.3067 , center_loss : 0.3324
Accuracy: 0.9129 , AUC:  0.5866 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 8 , Total loss: 1.3474 , cos_loss :  0.9171 , classification_loss : 0.3156 , center_loss : 0.3822


  8%|▊         | 8/100 [00:28<05:18,  3.47s/it]

Epoch: 8 , valid loss: 1.3208 , cos_loss :  0.9152 , classification_loss : 0.3136 , center_loss : 0.3067
Accuracy: 0.9129 , AUC:  0.6249 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 9 , Total loss: 1.3433 , cos_loss :  0.9166 , classification_loss : 0.3164 , center_loss : 0.3677


  9%|▉         | 9/100 [00:32<05:19,  3.51s/it]

Epoch: 9 , valid loss: 1.3206 , cos_loss :  0.9148 , classification_loss : 0.3094 , center_loss : 0.3214
Accuracy: 0.9129 , AUC:  0.5914 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 10 , Total loss: 1.3411 , cos_loss :  0.9162 , classification_loss : 0.3154 , center_loss : 0.3649


 10%|█         | 10/100 [00:35<05:13,  3.48s/it]

Epoch: 10 , valid loss: 1.3197 , cos_loss :  0.9144 , classification_loss : 0.3062 , center_loss : 0.3301
Accuracy: 0.9129 , AUC:  0.6135 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 11 , Total loss: 1.3388 , cos_loss :  0.9161 , classification_loss : 0.3133 , center_loss : 0.3647


 11%|█         | 11/100 [00:38<05:07,  3.46s/it]

Epoch: 11 , valid loss: 1.3197 , cos_loss :  0.9142 , classification_loss : 0.3095 , center_loss : 0.3201
Accuracy: 0.9129 , AUC:  0.5912 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 12 , Total loss: 1.3368 , cos_loss :  0.9157 , classification_loss : 0.315 , center_loss : 0.3538


 12%|█▏        | 12/100 [00:42<05:02,  3.44s/it]

Epoch: 12 , valid loss: 1.3191 , cos_loss :  0.9139 , classification_loss : 0.3049 , center_loss : 0.3342
Accuracy: 0.9129 , AUC:  0.6137 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 13 , Total loss: 1.3356 , cos_loss :  0.9154 , classification_loss : 0.3135 , center_loss : 0.3559


 13%|█▎        | 13/100 [00:45<04:58,  3.43s/it]

Epoch: 13 , valid loss: 1.3184 , cos_loss :  0.9138 , classification_loss : 0.3073 , center_loss : 0.3246
Accuracy: 0.9129 , AUC:  0.6466 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 14 , Total loss: 1.3343 , cos_loss :  0.9155 , classification_loss : 0.3131 , center_loss : 0.3526


 14%|█▍        | 14/100 [00:49<04:59,  3.48s/it]

Epoch: 14 , valid loss: 1.3185 , cos_loss :  0.9136 , classification_loss : 0.3074 , center_loss : 0.3249
Accuracy: 0.9129 , AUC:  0.61 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 15 , Total loss: 1.3336 , cos_loss :  0.9154 , classification_loss : 0.3133 , center_loss : 0.3494


 15%|█▌        | 15/100 [00:52<04:53,  3.45s/it]

Epoch: 15 , valid loss: 1.3191 , cos_loss :  0.9135 , classification_loss : 0.3159 , center_loss : 0.299
Accuracy: 0.9129 , AUC:  0.6172 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 16 , Total loss: 1.3324 , cos_loss :  0.9152 , classification_loss : 0.3139 , center_loss : 0.3444


 16%|█▌        | 16/100 [00:56<04:49,  3.44s/it]

Epoch: 16 , valid loss: 1.3187 , cos_loss :  0.9133 , classification_loss : 0.297 , center_loss : 0.3611
Accuracy: 0.9129 , AUC:  0.6306 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 17 , Total loss: 1.3302 , cos_loss :  0.9153 , classification_loss : 0.3105 , center_loss : 0.3481


 17%|█▋        | 17/100 [00:59<04:44,  3.43s/it]

Epoch: 17 , valid loss: 1.3176 , cos_loss :  0.9132 , classification_loss : 0.3046 , center_loss : 0.3325
Accuracy: 0.9129 , AUC:  0.6662 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 18 , Total loss: 1.3306 , cos_loss :  0.9151 , classification_loss : 0.3113 , center_loss : 0.3473


 18%|█▊        | 18/100 [01:03<04:45,  3.49s/it]

Epoch: 18 , valid loss: 1.319 , cos_loss :  0.9131 , classification_loss : 0.3203 , center_loss : 0.2852
Accuracy: 0.9129 , AUC:  0.6321 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 19 , Total loss: 1.3304 , cos_loss :  0.9148 , classification_loss : 0.3119 , center_loss : 0.3457


 19%|█▉        | 19/100 [01:06<04:40,  3.46s/it]

Epoch: 19 , valid loss: 1.3178 , cos_loss :  0.9131 , classification_loss : 0.298 , center_loss : 0.356
Accuracy: 0.9129 , AUC:  0.6768 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 20 , Total loss: 1.3286 , cos_loss :  0.915 , classification_loss : 0.3125 , center_loss : 0.3371


 20%|██        | 20/100 [01:10<04:36,  3.46s/it]

Epoch: 20 , valid loss: 1.3169 , cos_loss :  0.913 , classification_loss : 0.3083 , center_loss : 0.3187
Accuracy: 0.9129 , AUC:  0.6715 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 21 , Total loss: 1.3286 , cos_loss :  0.915 , classification_loss : 0.31 , center_loss : 0.3452


 21%|██        | 21/100 [01:13<04:31,  3.44s/it]

Epoch: 21 , valid loss: 1.3181 , cos_loss :  0.9129 , classification_loss : 0.3195 , center_loss : 0.2857
Accuracy: 0.9129 , AUC:  0.6451 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 22 , Total loss: 1.3281 , cos_loss :  0.9149 , classification_loss : 0.3122 , center_loss : 0.3368


 22%|██▏       | 22/100 [01:16<04:27,  3.43s/it]

Epoch: 22 , valid loss: 1.317 , cos_loss :  0.9129 , classification_loss : 0.31 , center_loss : 0.3138
Accuracy: 0.9129 , AUC:  0.6794 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 23 , Total loss: 1.3267 , cos_loss :  0.9148 , classification_loss : 0.3112 , center_loss : 0.3355


 23%|██▎       | 23/100 [01:20<04:28,  3.48s/it]

Epoch: 23 , valid loss: 1.3162 , cos_loss :  0.9128 , classification_loss : 0.3052 , center_loss : 0.3272
Accuracy: 0.9129 , AUC:  0.6876 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 24 , Total loss: 1.3263 , cos_loss :  0.9149 , classification_loss : 0.3096 , center_loss : 0.3392


 24%|██▍       | 24/100 [01:23<04:22,  3.46s/it]

Epoch: 24 , valid loss: 1.3163 , cos_loss :  0.9128 , classification_loss : 0.3122 , center_loss : 0.3046
Accuracy: 0.9129 , AUC:  0.6962 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 25 , Total loss: 1.3261 , cos_loss :  0.9148 , classification_loss : 0.3108 , center_loss : 0.3347


 25%|██▌       | 25/100 [01:27<04:18,  3.44s/it]

Epoch: 25 , valid loss: 1.3157 , cos_loss :  0.9127 , classification_loss : 0.3036 , center_loss : 0.3314
Accuracy: 0.9129 , AUC:  0.7091 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 26 , Total loss: 1.3251 , cos_loss :  0.9147 , classification_loss : 0.3086 , center_loss : 0.3394


 26%|██▌       | 26/100 [01:30<04:13,  3.43s/it]

Epoch: 26 , valid loss: 1.3167 , cos_loss :  0.9127 , classification_loss : 0.2884 , center_loss : 0.3853
Accuracy: 0.9129 , AUC:  0.7201 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 27 , Total loss: 1.3262 , cos_loss :  0.9148 , classification_loss : 0.3087 , center_loss : 0.3422


 27%|██▋       | 27/100 [01:34<04:14,  3.48s/it]

Epoch: 27 , valid loss: 1.3183 , cos_loss :  0.9127 , classification_loss : 0.3291 , center_loss : 0.2552
Accuracy: 0.9129 , AUC:  0.7048 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 28 , Total loss: 1.3244 , cos_loss :  0.9149 , classification_loss : 0.3074 , center_loss : 0.3403


 28%|██▊       | 28/100 [01:37<04:08,  3.46s/it]

Epoch: 28 , valid loss: 1.3144 , cos_loss :  0.9126 , classification_loss : 0.3123 , center_loss : 0.2983
Accuracy: 0.9129 , AUC:  0.7155 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 29 , Total loss: 1.3239 , cos_loss :  0.9146 , classification_loss : 0.3081 , center_loss : 0.3371


 29%|██▉       | 29/100 [01:41<04:04,  3.44s/it]

Epoch: 29 , valid loss: 1.3124 , cos_loss :  0.9126 , classification_loss : 0.3073 , center_loss : 0.3082
Accuracy: 0.9129 , AUC:  0.6991 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 30 , Total loss: 1.3184 , cos_loss :  0.9148 , classification_loss : 0.299 , center_loss : 0.3488


 30%|███       | 30/100 [01:44<03:59,  3.43s/it]

Epoch: 30 , valid loss: 1.309 , cos_loss :  0.9127 , classification_loss : 0.2872 , center_loss : 0.3637
Accuracy: 0.9142 , AUC:  0.6879 , F1:  0.0282 , Precision:  1.0 , recall:  0.0143
Epoch: 31 , Total loss: 1.3147 , cos_loss :  0.9147 , classification_loss : 0.2958 , center_loss : 0.3472


 31%|███       | 31/100 [01:47<03:56,  3.42s/it]

Epoch: 31 , valid loss: 1.3055 , cos_loss :  0.9128 , classification_loss : 0.277 , center_loss : 0.3857
Accuracy: 0.9415 , AUC:  0.7112 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 32 , Total loss: 1.313 , cos_loss :  0.9147 , classification_loss : 0.2886 , center_loss : 0.3657


 32%|███▏      | 32/100 [01:51<03:56,  3.47s/it]

Epoch: 32 , valid loss: 1.3042 , cos_loss :  0.9129 , classification_loss : 0.2932 , center_loss : 0.327
Accuracy: 0.9415 , AUC:  0.7058 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 33 , Total loss: 1.3111 , cos_loss :  0.9148 , classification_loss : 0.2873 , center_loss : 0.3632


 33%|███▎      | 33/100 [01:54<03:51,  3.45s/it]

Epoch: 33 , valid loss: 1.3047 , cos_loss :  0.9128 , classification_loss : 0.3023 , center_loss : 0.2988
Accuracy: 0.9415 , AUC:  0.7297 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 34 , Total loss: 1.3114 , cos_loss :  0.915 , classification_loss : 0.2899 , center_loss : 0.3549


 34%|███▍      | 34/100 [01:58<03:46,  3.43s/it]

Epoch: 34 , valid loss: 1.3043 , cos_loss :  0.9128 , classification_loss : 0.2976 , center_loss : 0.3129
Accuracy: 0.9415 , AUC:  0.7077 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 35 , Total loss: 1.3105 , cos_loss :  0.9149 , classification_loss : 0.2881 , center_loss : 0.3582


 35%|███▌      | 35/100 [02:01<03:42,  3.42s/it]

Epoch: 35 , valid loss: 1.3029 , cos_loss :  0.9128 , classification_loss : 0.285 , center_loss : 0.3506
Accuracy: 0.9415 , AUC:  0.746 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 36 , Total loss: 1.3087 , cos_loss :  0.9148 , classification_loss : 0.2865 , center_loss : 0.3581


 36%|███▌      | 36/100 [02:05<03:38,  3.41s/it]

Epoch: 36 , valid loss: 1.3034 , cos_loss :  0.9128 , classification_loss : 0.2801 , center_loss : 0.3682
Accuracy: 0.9415 , AUC:  0.7217 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 37 , Total loss: 1.3088 , cos_loss :  0.9146 , classification_loss : 0.2866 , center_loss : 0.3585


 37%|███▋      | 37/100 [02:08<03:38,  3.47s/it]

Epoch: 37 , valid loss: 1.3029 , cos_loss :  0.9127 , classification_loss : 0.2854 , center_loss : 0.3491
Accuracy: 0.9415 , AUC:  0.7356 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 38 , Total loss: 1.3089 , cos_loss :  0.9147 , classification_loss : 0.2855 , center_loss : 0.3623


 38%|███▊      | 38/100 [02:12<03:33,  3.45s/it]

Epoch: 38 , valid loss: 1.3035 , cos_loss :  0.9127 , classification_loss : 0.2915 , center_loss : 0.3306
Accuracy: 0.9415 , AUC:  0.7257 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 39 , Total loss: 1.3083 , cos_loss :  0.9147 , classification_loss : 0.286 , center_loss : 0.3586


 39%|███▉      | 39/100 [02:15<03:29,  3.44s/it]

Epoch: 39 , valid loss: 1.3033 , cos_loss :  0.9127 , classification_loss : 0.2899 , center_loss : 0.3359
Accuracy: 0.9415 , AUC:  0.6921 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 40 , Total loss: 1.3083 , cos_loss :  0.9147 , classification_loss : 0.2866 , center_loss : 0.3567


 40%|████      | 40/100 [02:18<03:25,  3.43s/it]

Epoch: 40 , valid loss: 1.303 , cos_loss :  0.9127 , classification_loss : 0.2878 , center_loss : 0.3416
Accuracy: 0.9415 , AUC:  0.7469 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 41 , Total loss: 1.3087 , cos_loss :  0.9147 , classification_loss : 0.2872 , center_loss : 0.3561


 41%|████      | 41/100 [02:22<03:22,  3.42s/it]

Epoch: 41 , valid loss: 1.3031 , cos_loss :  0.9126 , classification_loss : 0.2879 , center_loss : 0.3417
Accuracy: 0.9415 , AUC:  0.7151 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 42 , Total loss: 1.308 , cos_loss :  0.9145 , classification_loss : 0.286 , center_loss : 0.358


 42%|████▏     | 42/100 [02:25<03:21,  3.48s/it]

Epoch: 42 , valid loss: 1.3031 , cos_loss :  0.9126 , classification_loss : 0.2839 , center_loss : 0.355
Accuracy: 0.9415 , AUC:  0.7298 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 43 , Total loss: 1.3082 , cos_loss :  0.9146 , classification_loss : 0.2858 , center_loss : 0.3592


 43%|████▎     | 43/100 [02:29<03:18,  3.48s/it]

Epoch: 43 , valid loss: 1.3032 , cos_loss :  0.9126 , classification_loss : 0.2828 , center_loss : 0.3593
Accuracy: 0.9415 , AUC:  0.7191 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 44 , Total loss: 1.308 , cos_loss :  0.9145 , classification_loss : 0.2858 , center_loss : 0.3588


 44%|████▍     | 44/100 [02:32<03:13,  3.46s/it]

Epoch: 44 , valid loss: 1.303 , cos_loss :  0.9126 , classification_loss : 0.2896 , center_loss : 0.336
Accuracy: 0.9415 , AUC:  0.7295 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 45 , Total loss: 1.3079 , cos_loss :  0.9146 , classification_loss : 0.286 , center_loss : 0.3574


 45%|████▌     | 45/100 [02:36<03:09,  3.44s/it]

Epoch: 45 , valid loss: 1.303 , cos_loss :  0.9126 , classification_loss : 0.2883 , center_loss : 0.3405
Accuracy: 0.9415 , AUC:  0.7195 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 46 , Total loss: 1.3078 , cos_loss :  0.9145 , classification_loss : 0.2868 , center_loss : 0.3551


 46%|████▌     | 46/100 [02:39<03:08,  3.49s/it]

Epoch: 46 , valid loss: 1.3029 , cos_loss :  0.9125 , classification_loss : 0.2868 , center_loss : 0.3451
Accuracy: 0.9415 , AUC:  0.7172 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 47 , Total loss: 1.3075 , cos_loss :  0.9147 , classification_loss : 0.2857 , center_loss : 0.3571


 47%|████▋     | 47/100 [02:43<03:03,  3.46s/it]

Epoch: 47 , valid loss: 1.3028 , cos_loss :  0.9125 , classification_loss : 0.2878 , center_loss : 0.3419
Accuracy: 0.9415 , AUC:  0.7294 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 48 , Total loss: 1.3071 , cos_loss :  0.9143 , classification_loss : 0.2867 , center_loss : 0.3537


 48%|████▊     | 48/100 [02:46<02:58,  3.44s/it]

Epoch: 48 , valid loss: 1.3028 , cos_loss :  0.9125 , classification_loss : 0.2816 , center_loss : 0.3624
Accuracy: 0.9415 , AUC:  0.739 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 49 , Total loss: 1.3078 , cos_loss :  0.9145 , classification_loss : 0.2855 , center_loss : 0.3594


 49%|████▉     | 49/100 [02:49<02:54,  3.42s/it]

Epoch: 49 , valid loss: 1.3031 , cos_loss :  0.9125 , classification_loss : 0.2838 , center_loss : 0.3563
Accuracy: 0.9415 , AUC:  0.6897 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 50 , Total loss: 1.3071 , cos_loss :  0.9145 , classification_loss : 0.2866 , center_loss : 0.3535


 50%|█████     | 50/100 [02:53<02:50,  3.41s/it]

Epoch: 50 , valid loss: 1.3025 , cos_loss :  0.9125 , classification_loss : 0.2836 , center_loss : 0.3549
Accuracy: 0.9415 , AUC:  0.7546 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 51 , Total loss: 1.3076 , cos_loss :  0.9143 , classification_loss : 0.2849 , center_loss : 0.3615


 51%|█████     | 51/100 [02:56<02:49,  3.46s/it]

Epoch: 51 , valid loss: 1.3038 , cos_loss :  0.9125 , classification_loss : 0.2739 , center_loss : 0.3914
Accuracy: 0.9415 , AUC:  0.7183 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 52 , Total loss: 1.3066 , cos_loss :  0.9145 , classification_loss : 0.2849 , center_loss : 0.3578


 52%|█████▏    | 52/100 [03:00<02:45,  3.45s/it]

Epoch: 52 , valid loss: 1.3027 , cos_loss :  0.9125 , classification_loss : 0.283 , center_loss : 0.3577
Accuracy: 0.9415 , AUC:  0.7278 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 53 , Total loss: 1.306 , cos_loss :  0.9144 , classification_loss : 0.2849 , center_loss : 0.3559


 53%|█████▎    | 53/100 [03:03<02:40,  3.42s/it]

Epoch: 53 , valid loss: 1.3028 , cos_loss :  0.9124 , classification_loss : 0.2806 , center_loss : 0.3656
Accuracy: 0.9415 , AUC:  0.7337 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 54 , Total loss: 1.3072 , cos_loss :  0.9145 , classification_loss : 0.2856 , center_loss : 0.3572


 54%|█████▍    | 54/100 [03:07<02:36,  3.41s/it]

Epoch: 54 , valid loss: 1.3032 , cos_loss :  0.9124 , classification_loss : 0.2958 , center_loss : 0.3168
Accuracy: 0.9415 , AUC:  0.7415 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 55 , Total loss: 1.3069 , cos_loss :  0.9144 , classification_loss : 0.2866 , center_loss : 0.353


 55%|█████▌    | 55/100 [03:10<02:35,  3.45s/it]

Epoch: 55 , valid loss: 1.3028 , cos_loss :  0.9124 , classification_loss : 0.2863 , center_loss : 0.3471
Accuracy: 0.9415 , AUC:  0.7186 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 56 , Total loss: 1.3069 , cos_loss :  0.9143 , classification_loss : 0.2852 , center_loss : 0.358


 56%|█████▌    | 56/100 [03:13<02:31,  3.44s/it]

Epoch: 56 , valid loss: 1.3059 , cos_loss :  0.9124 , classification_loss : 0.2641 , center_loss : 0.4315
Accuracy: 0.9415 , AUC:  0.7522 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 57 , Total loss: 1.3069 , cos_loss :  0.9145 , classification_loss : 0.2832 , center_loss : 0.3638


 57%|█████▋    | 57/100 [03:17<02:27,  3.42s/it]

Epoch: 57 , valid loss: 1.3031 , cos_loss :  0.9124 , classification_loss : 0.2826 , center_loss : 0.3602
Accuracy: 0.9415 , AUC:  0.6828 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 58 , Total loss: 1.3066 , cos_loss :  0.9144 , classification_loss : 0.2859 , center_loss : 0.354


 58%|█████▊    | 58/100 [03:20<02:23,  3.41s/it]

Epoch: 58 , valid loss: 1.3027 , cos_loss :  0.9124 , classification_loss : 0.2861 , center_loss : 0.3472
Accuracy: 0.9415 , AUC:  0.7149 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 59 , Total loss: 1.3055 , cos_loss :  0.9142 , classification_loss : 0.2845 , center_loss : 0.356


 59%|█████▉    | 59/100 [03:24<02:22,  3.47s/it]

Epoch: 59 , valid loss: 1.3028 , cos_loss :  0.9124 , classification_loss : 0.2789 , center_loss : 0.3717
Accuracy: 0.9415 , AUC:  0.7524 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286


 59%|█████▉    | 59/100 [03:27<02:24,  3.52s/it]


KeyboardInterrupt: 

In [None]:
total_te_loss = 0
total_te_cos_loss = 0
total_te_classi_loss = 0
total_te_center_loss = 0

te_labels_list = []
te_predictions_list = []
te_probabilities_list = []

model_1_path = model_1_save_path
model_1.load_state_dict(torch.load(model_1_path))
model_1.eval()

with torch.no_grad():
    for batch_data in tqdm(test_loader):
        te_labels = batch_data['label'].to(device)
        te_output_1, te_next_visit_output_1, te_final_visit_classification_1, te_final_visit_1 = model_1(batch_data)
        y_te = torch.ones(te_output_1.size(0), dtype=torch.float, device=device)
        te_cosine_loss_mean_1 = cosine_embedding_loss(te_output_1, te_next_visit_output_1, y_te)
        classification_loss_te_1 = criterion(te_final_visit_classification_1.squeeze(), te_labels.long())
        center_loss_te_1 = center_loss(te_final_visit_classification_1, te_labels.long())
        te_loss = (cos_lambda * te_cosine_loss_mean_1) + (classi_lambda * classification_loss_te_1) + (center_lambda * center_loss_te_1)

        total_te_loss += te_loss.item()
        total_te_cos_loss += (te_cosine_loss_mean_1.item())
        total_te_classi_loss += (classification_loss_te_1.item())
        total_te_center_loss += (center_loss_te_1.item())

        te_probs = F.softmax(te_final_visit_classification_1)
        te_predictions = torch.max(te_probs, 1)[1].view((len(te_labels),))
        
        te_labels_list.extend(te_labels.view(-1).cpu().numpy())
        te_predictions_list.extend(te_predictions.cpu().numpy())
        te_probabilities_list.extend(te_probs[:,1].cpu().numpy())

    avg_te_loss = total_te_loss / len(test_loader)
    avg_te_cos_loss = total_te_cos_loss / len(test_loader)
    avg_te_classi_loss = total_te_classi_loss / len(test_loader)
    avg_te_center_loss = total_te_center_loss / len(test_loader)

In [None]:
# 성능 지표 계산
accuracy = accuracy_score(te_labels_list, te_predictions_list)
auc = roc_auc_score(te_labels_list, te_probabilities_list)
f1 = f1_score(te_labels_list, te_predictions_list)
precision = precision_score(te_labels_list,te_predictions_list)
recall = recall_score(te_labels_list, te_predictions_list)

print("test loss:", round(avg_te_loss, 4), ", cos_loss : ", round(avg_te_cos_loss, 4),
      ", classification_loss :",round(avg_te_classi_loss, 4), 
      ", center_loss : ", round(avg_te_center_loss, 4)
     )
print("Accuracy:", round(accuracy,4), ", AUC: ", round(auc,4), ", F1: ", round(f1,4), ", Precision: ", round(precision,4), ", recall: ", round(recall,4))

np.unique(te_labels_list,  return_counts=True)

In [None]:
log_path = './logs'
logs = f'model_lr{lr}_classi{classi_lambda}_dim{ninp}_hid{nhid}_layer{nlayer}_epoch{{epoch}}_{{model}}_{{pe}}_{model_time}.txt'
os.makedirs(os.path.join(log_path, date_dir), exist_ok=True)

results = [accuracy, auc, f1, precision, recall]
cm = confusion_matrix(te_labels_list, te_predictions_list)
with open(os.path.join(log_path, date_dir, logs.format(epoch=best_epoch, model=model_name, pe=pe)), 'w') as f:
    f.write(logs)
    f.write('\n')
    f.write(str(results))
    f.write('\n')
    f.write(str(cm))

In [None]:
confusion_matrix(te_labels_list, te_predictions_list)