In [1]:
import numpy as np
import pandas as pd
import os 
import pickle
import sys
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix
import warnings

torch.set_printoptions(profile="full")
np.set_printoptions(threshold=sys.maxsize)
warnings.filterwarnings("ignore")

In [2]:
from src.attn import FixedPositionalEncoding, LearnablePositionalEncoding, MultiheadAttention, Decoder
from src.loss import ContrastiveLoss, FocalLoss

In [3]:
from datetime import datetime

In [4]:
path = '/data/notebook/shared/MIMIC-IV'

In [5]:
with open(os.path.join(path, 'dict_types_nomedi_mimic_240408_clinic_3_years.pkl'), 'rb') as f:
    dtype_dict = pickle.load(f)
f.close()

with open(os.path.join(path, 'total_data_dict_with_timedelta_nomedi_240421_clinic_3_years.pkl'), 'rb') as f:
    data_dict_d = pickle.load(f)
f.close()

In [6]:
remain_year_clinical = pd.read_csv(os.path.join(path, "240408_remain_year_clinical_diag_seq_4.csv"))
del remain_year_clinical['clinical_label']

In [7]:
len(dtype_dict)

12702

In [8]:
train_indices, test_indices = train_test_split(list(data_dict_d.keys()), test_size=0.1, random_state=777)
train_indices, valid_indices = train_test_split(train_indices, test_size=(len(test_indices)/len(train_indices)), random_state=777) 

In [9]:
train_data = {}
valid_data = {}
test_data = {}
for sample in tqdm(train_indices):
    train_data[sample] = data_dict_d[sample]

for sample in tqdm(valid_indices):
    valid_data[sample] = data_dict_d[sample]

for sample in tqdm(test_indices):
    test_data[sample] = data_dict_d[sample]

100%|██████████| 6429/6429 [00:00<00:00, 1519165.09it/s]
100%|██████████| 804/804 [00:00<00:00, 1134663.67it/s]
100%|██████████| 804/804 [00:00<00:00, 1010554.51it/s]


In [10]:
train_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(train_indices)].reset_index(drop=True)
valid_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(valid_indices)].reset_index(drop=True)
test_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(test_indices)].reset_index(drop=True)

In [11]:
class CustomDataset(Dataset):
    def __init__(self, data, clinical_df, scaler, mode='train'):       
        self.keys = list(data.keys())  # 딕셔너리의 키 목록 저장
        self.data = data  # 딕셔너리에서 데이터만 추출하여 저장
        self.scaler = scaler
        if mode == 'train':
            scaled_data = self.scaler.fit_transform(clinical_df.iloc[:, 2:])
            scaled_clinical_df = pd.DataFrame(scaled_data, columns = clinical_df.iloc[:, 2:].columns)
            self.scaled_clinical_df = pd.concat([clinical_df.iloc[:, :2], scaled_clinical_df], axis=1)
        else:
            scaled_data = self.scaler.transform(clinical_df.iloc[:, 2:])
            scaled_clinical_df = pd.DataFrame(scaled_data, columns = clinical_df.iloc[:, 2:].columns)
            self.scaled_clinical_df = pd.concat([clinical_df.iloc[:, :2], scaled_clinical_df], axis=1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # x는 현재 방문까지의 모든 방문 데이터, y는 다음 방문 데이터
        padding_temp = torch.zeros((1, len(self.data[self.keys[idx]]['code_index'][0])), dtype=torch.long)
        
        origin_visit = torch.tensor(self.data[self.keys[idx]]['code_index'], dtype=torch.long)
        origin_mask = torch.tensor(self.data[self.keys[idx]]['seq_mask'], dtype=torch.long)
        origin_mask_final = torch.tensor(self.data[self.keys[idx]]['seq_mask_final'], dtype=torch.long)
        origin_mask_code = torch.tensor(self.data[self.keys[idx]]['seq_mask_code'], dtype=torch.long)
        
        next_visit = torch.cat((origin_visit[1:], padding_temp), dim=0)
        next_mask =torch.cat((origin_mask[1:], torch.tensor([0], dtype=torch.long)), dim=0)
        next_mask_code = torch.cat((origin_mask_code[1:], padding_temp), dim=0)
        
        clinical_data = self.scaled_clinical_df.loc[self.scaled_clinical_df['subject_id'] == self.keys[idx], self.scaled_clinical_df.columns[2:]].values
        clinical_data = torch.tensor(clinical_data, dtype=torch.float)
        visit_index = torch.tensor(self.data[self.keys[idx]]['year_onehot'], dtype=torch.long)
        last_visit_index = torch.tensor(self.data[self.keys[idx]]['last_year_onehot'], dtype=torch.float)
        label_per_sample = self.data[self.keys[idx]]['label']
        key_per_sample = self.keys[idx]  # 해당 샘플의 키
        # 키 값도 함께 반환
        return {'sample_id': key_per_sample, 'origin_visit': origin_visit, 'next_visit': next_visit,\
                'origin_mask': origin_mask, 'origin_mask_final': origin_mask_final, 'next_mask': next_mask, \
                'origin_mask_code': origin_mask_code, 'next_mask_code': next_mask_code, 'clinical_data': clinical_data, \
                'visit_index': visit_index, 'last_visit_index': last_visit_index,'label': label_per_sample}

In [12]:
scaler = StandardScaler()
train_dataset = CustomDataset(train_data, train_clinical, scaler, mode='train')
valid_dataset = CustomDataset(valid_data, valid_clinical, scaler, mode='valid')
test_dataset = CustomDataset(test_data, test_clinical, scaler, mode='test')

In [13]:
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [14]:
class CustomTransformerModel(nn.Module):
    def __init__(self, code_size, ninp, nhead, nhid, nlayers, dropout=0.5, device='cuda:0', pe='fixed'):
        super(CustomTransformerModel, self).__init__()
        self.ninp = ninp  # 축소된 차원 및 Transformer 입력 차원
        self.device = device
        # 차원 축소를 위한 선형 레이어
        self.pre_embedding = nn.Embedding(code_size, self.ninp)
        if pe == 'fixed':
            self.pos_encoder = FixedPositionalEncoding(ninp, dropout)
        else:
            self.pos_encoder = LearnablePositionalEncoding(ninp, dropout)
            
        self.transformer_decoder = Decoder(ninp, nhead, nhid, nlayers, dropout)
        self.decoder = nn.Linear(ninp, ninp, bias=False)  # 최종 출력 차원을 설정 (여기서는 ninp로 설정)
        self.classification_layer = nn.Linear(ninp*2, 2)
        self.clinical_transform = nn.Linear(18, ninp) 
        self.cross_attn = MultiheadAttention(ninp, nhead, dropout=0.3)
        self.init_weights()
        
    def init_weights(self):
        nn.init.xavier_uniform_(self.pre_embedding.weight)
        nn.init.xavier_uniform_(self.decoder.weight)
        nn.init.xavier_uniform_(self.classification_layer.weight)
    
    def forward(self, batch_data):
        # 차원 축소
        origin_visit = batch_data['origin_visit'].to(self.device)
        next_visit = batch_data['next_visit'].to(self.device)
        clinical_tensor = batch_data['clinical_data'].to(self.device)
        mask_code = batch_data['origin_mask_code'].unsqueeze(3).to(self.device)
        next_mask_code = batch_data['next_mask_code'].unsqueeze(3).to(self.device)
        mask_final = batch_data['origin_mask_final'].unsqueeze(2).to(self.device)
        mask_final_year = batch_data['last_visit_index'].to(self.device)
                
        origin_visit_emb = (self.pre_embedding(origin_visit) * mask_code).sum(dim=2)
        next_visit_emb = (self.pre_embedding(next_visit) * next_mask_code).sum(dim=2)
        # 위치 인코딩 및 Transformer 디코더 적용
        src = self.pos_encoder(origin_visit_emb)
        seq_len = src.shape[1]
        src_mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=self.device), diagonal=1)
        output = self.transformer_decoder(src, attention_mask=src_mask)
        output = self.decoder(output)
        
        output_batch, output_visit_num, output_dim = output.size()
        next_visit_output_batch, next_visit_output_num, next_visit_output_dim = next_visit_emb.size()
        
        # print(output[:, :torch.,:])
        temp_output = output.reshape(-1, output_dim)
        temp_next_visit = next_visit_emb.reshape(-1, next_visit_output_dim)
        final_visit = (output * mask_final).sum(dim=1)
        
        year_emb = torch.bmm(output.transpose(1,2), mask_final_year).transpose(1,2)[:,:2,:]
        transformed_clinical = self.clinical_transform(clinical_tensor)
        mixed_output, mixed_cross_attn = self.cross_attn(year_emb, transformed_clinical, transformed_clinical)
        mixed_output = mixed_output[:,-1,:]
        # print(final_visit.shape)
        # print(mixed_output.shape)
        mixed_final_emb = torch.cat((final_visit, mixed_output), dim=-1)
        classification_result = self.classification_layer(mixed_final_emb)    
        return temp_output, temp_next_visit, classification_result, final_visit

In [32]:
lr = 0.001
ninp = 256
nhid = 512
nlayer = 4

In [33]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_1 = CustomTransformerModel(len(dtype_dict), ninp=ninp, nhead=4, nhid=nhid, nlayers=nlayer, dropout=0, device=device, pe='learnable').to(device)
model_2 = CustomTransformerModel(len(dtype_dict), ninp=ninp, nhead=4, nhid=nhid, nlayers=nlayer, dropout=0.4, device=device, pe='learnable').to(device)

optimizer = torch.optim.Adam(list(model_1.parameters()) + list(model_2.parameters()), lr=lr, weight_decay=1e-8)

# criterion = nn.BCEWithLogitsLoss()
criterion = FocalLoss(2, gamma=1.5)
# cosine_loss = CosineSimilarityLoss()
cosine_embedding_loss = nn.CosineEmbeddingLoss()
const_loss = ContrastiveLoss(temperature=0.07)

In [34]:
# 얼리 스타핑 설정
num_epochs = 100
patience = 30
best_loss = float("inf")
counter = 0
epoch_temp = 0

best_f1 = 0.0
best_combined_score = 0.0
best_epoch = 0

view_total_loss = []
view_cos_loss = []
view_classi_loss = []

cos_lambda = 1
classi_lambda = 1
const_lambda = 1

for epoch in tqdm(range(num_epochs)):
    model_1.train()
    model_2.train()

    total_train_loss = 0
    total_cos_loss = 0
    total_classi_loss = 0
    total_const_loss = 0
    # key_per_sample, origin_visit, origin_mask, origin_mask_code, origin_mask_final, next_visit, next_mask, next_mask_code, next_mask_final, label_per_sample
    for batch_data in train_loader:
        tr_labels = batch_data['label'].to(device)
        optimizer.zero_grad()
        output_1, next_visit_output_1, final_visit_classification_1, final_visit_1 = model_1(batch_data)
        y = torch.ones(output_1.size(0), dtype=torch.float, device=device)
        y = y.to(device)   
        cosine_loss_mean_1 = cosine_embedding_loss(output_1, next_visit_output_1, y)
        classification_loss_1 = criterion(final_visit_classification_1.squeeze(), tr_labels.long())
        
        output_2, next_visit_output_2, final_visit_classification_2, final_visit_2 = model_2(batch_data)
        cosine_loss_mean_2 = cosine_embedding_loss(output_2, next_visit_output_2, y)
        classification_loss_2 = criterion(final_visit_classification_2.squeeze(), tr_labels.long())        

        train_const_loss = const_loss(torch.cat([final_visit_1, final_visit_2], dim=0))
        # train_const_loss = const_loss(final_visit_1, final_visit_2)
        
        loss = ((cos_lambda * cosine_loss_mean_1) + (classi_lambda * classification_loss_1)
               + (cos_lambda * cosine_loss_mean_2) + (classi_lambda * classification_loss_2)
               + (const_lambda * train_const_loss))

        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        total_cos_loss += (cosine_loss_mean_1.item())
        total_classi_loss += (classification_loss_1.item())
        total_const_loss += train_const_loss.item()

        view_total_loss.append(total_train_loss)
        view_cos_loss.append(total_cos_loss)
        view_classi_loss.append(total_classi_loss)

    # 평균 손실 계산
    avg_train_loss = total_train_loss / len(train_loader)
    avg_cos_loss = total_cos_loss / len(train_loader)
    avg_classi_loss = total_classi_loss / len(train_loader)
    avg_const_loss = total_const_loss / len(train_loader)


    print("Epoch:", epoch+1, ", Total loss:", round(avg_train_loss, 4), ", cos_loss : ", round(avg_cos_loss, 4), \
          ", classification_loss :",round(avg_classi_loss, 4), ", contrastive_loss : ", round(avg_const_loss,4))
    model_1.eval()
    model_2.eval()
    total_val_loss = 0
    total_val_cos_loss = 0
    total_val_classi_loss = 0
    total_val_const_loss = 0

    val_labels_list = []
    val_predictions_list = []
    val_probabilities_list = []


    with torch.no_grad():
        for batch_data in valid_loader:
            val_labels = batch_data['label'].to(device)
            val_output_1, val_next_visit_output_1, val_final_visit_classification_1, val_final_visit_1 = model_1(batch_data)
            
            y_val = torch.ones(val_output_1.size(0), dtype=torch.float, device=device) 
            val_cosine_loss_mean_1 = cosine_embedding_loss(val_output_1, val_next_visit_output_1, y_val)
            
            classification_loss_val_1 = criterion(val_final_visit_classification_1.squeeze(), val_labels.long())
            val_output_2, val_next_visit_output_2, val_final_visit_classification_2, val_final_visit_2 = model_2(batch_data)   
            val_cosine_loss_mean_2 = cosine_embedding_loss(val_output_2, val_next_visit_output_2, y_val)

            classification_loss_val_2 = criterion(val_final_visit_classification_2.squeeze(), val_labels.long())
            
            # val_const_loss = const_loss(val_final_visit_1, val_final_visit_2)
            val_const_loss = const_loss(torch.cat([val_final_visit_1, val_final_visit_2], dim=0))

            val_loss = ((cos_lambda * val_cosine_loss_mean_1) + (classi_lambda * classification_loss_val_1) 
                       +(cos_lambda * val_cosine_loss_mean_2) + (classi_lambda * classification_loss_val_2) 
                       + (const_lambda * val_const_loss))

            total_val_loss += val_loss.item()
            total_val_cos_loss += (val_cosine_loss_mean_1.item())
            total_val_classi_loss += (classification_loss_val_1.item())
            total_val_const_loss += val_const_loss.item()

            val_probs = F.softmax(val_final_visit_classification_1)
            val_predictions = torch.max(val_probs, 1)[1].view((len(val_labels),))

            val_labels_list.extend(val_labels.view(-1).cpu().numpy())
            val_predictions_list.extend(val_predictions.cpu().numpy())
            
            ########
            val_probabilities_list.extend(val_probs[:,1].cpu().numpy())
            ######## 
            
            
        avg_val_loss = total_val_loss / len(valid_loader)
        avg_val_cos_loss = total_val_cos_loss / len(valid_loader)
        avg_val_classi_loss = total_val_classi_loss / len(valid_loader)
        avg_val_const_loss = total_val_const_loss / len(valid_loader)

    # 성능 지표 계산
    accuracy = accuracy_score(val_labels_list, val_predictions_list)
    auc = roc_auc_score(val_labels_list, val_probabilities_list)
    f1 = f1_score(val_labels_list, val_predictions_list)
    precision = precision_score(val_labels_list,val_predictions_list)
    recall = recall_score(val_labels_list, val_predictions_list)

    print("Epoch:", epoch+1, ", valid loss:", round(avg_val_loss, 4), ", cos_loss : ", round(avg_val_cos_loss, 4), ", classification_loss :",round(avg_val_classi_loss, 4), ", contrastive_loss : ", round(avg_val_const_loss,4))
    print("Accuracy:", round(accuracy,4), ", AUC: ", round(auc,4), ", F1: ", round(f1,4), ", Precision: ", round(precision,4), ", recall: ", round(recall,4))

    current_auc = roc_auc_score(val_labels_list, val_probabilities_list)
    current_f1 = f1_score(val_labels_list, val_predictions_list)
    current_combined_score = current_auc

    # 모델 저장 경로 설정
    model_save_path = f'results'
    date_dir = datetime.today().strftime("%Y%m%d")
    # 학습률과 분류 가중치를 파일명에 포함시키기 위한 문자열 포맷
    model_filename_format = f'model_lr{lr}_classi{classi_lambda}_dim{ninp}_hid{nhid}_layer{nlayer}_epoch{{epoch}}_{{model}}.pth'


    # 모델 저장 폴더가 없으면 생성
    os.makedirs(os.path.join(model_save_path, date_dir), exist_ok=True)

    if current_combined_score > best_combined_score:
        best_combined_score = current_combined_score
        best_epoch = epoch
        counter = 0
        # 모델 저장 경로와 파일명을 결합하여 전체 파일 경로 생성
        model_1_save_path = os.path.join(model_save_path, date_dir, model_filename_format.format(epoch=best_epoch, model='model1'))
        model_2_save_path = os.path.join(model_save_path, date_dir, model_filename_format.format(epoch=best_epoch, model='model2'))
        # 모델 저장
        torch.save(model_1.state_dict(), f'{model_1_save_path}')
        torch.save(model_2.state_dict(), f'{model_2_save_path}')
    else:
        counter += 1

    if counter >= patience:
        print("Early stopping triggered at epoch {}".format(best_epoch))
        print(f"Best Combined Score (AUC): {best_combined_score:.4f}")
        break

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 1 , Total loss: 19.7822 , cos_loss :  0.9686 , classification_loss : 0.3584 , contrastive_loss :  17.0521


  1%|          | 1/100 [00:05<08:15,  5.00s/it]

Epoch: 1 , valid loss: 19.1199 , cos_loss :  0.9368 , classification_loss : 0.1248 , contrastive_loss :  16.985
Accuracy: 0.9129 , AUC:  0.6065 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 2 , Total loss: 19.2555 , cos_loss :  0.9311 , classification_loss : 0.1297 , contrastive_loss :  17.1338


  2%|▏         | 2/100 [00:09<08:00,  4.90s/it]

Epoch: 2 , valid loss: 19.0906 , cos_loss :  0.9229 , classification_loss : 0.1386 , contrastive_loss :  16.9849
Accuracy: 0.9092 , AUC:  0.5633 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 3 , Total loss: 17.7759 , cos_loss :  0.9232 , classification_loss : 0.1382 , contrastive_loss :  15.6195


  3%|▎         | 3/100 [00:14<07:41,  4.76s/it]

Epoch: 3 , valid loss: 19.1222 , cos_loss :  0.9187 , classification_loss : 0.1446 , contrastive_loss :  16.972
Accuracy: 0.9129 , AUC:  0.5663 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 4 , Total loss: 16.5499 , cos_loss :  0.9203 , classification_loss : 0.1375 , contrastive_loss :  14.3921


  4%|▍         | 4/100 [00:19<07:30,  4.69s/it]

Epoch: 4 , valid loss: 19.1184 , cos_loss :  0.9165 , classification_loss : 0.1309 , contrastive_loss :  16.968
Accuracy: 0.9129 , AUC:  0.5732 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 5 , Total loss: 16.5196 , cos_loss :  0.9186 , classification_loss : 0.1252 , contrastive_loss :  14.3911


  5%|▌         | 5/100 [00:23<07:23,  4.67s/it]

Epoch: 5 , valid loss: 19.2408 , cos_loss :  0.9153 , classification_loss : 0.1598 , contrastive_loss :  16.9676
Accuracy: 0.9129 , AUC:  0.5719 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 6 , Total loss: 16.5154 , cos_loss :  0.9179 , classification_loss : 0.1239 , contrastive_loss :  14.3902


  6%|▌         | 6/100 [00:28<07:18,  4.66s/it]

Epoch: 6 , valid loss: 19.1607 , cos_loss :  0.9144 , classification_loss : 0.1488 , contrastive_loss :  16.9628
Accuracy: 0.9129 , AUC:  0.583 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 7 , Total loss: 16.5024 , cos_loss :  0.9171 , classification_loss : 0.1193 , contrastive_loss :  14.3899


  7%|▋         | 7/100 [00:32<07:14,  4.67s/it]

Epoch: 7 , valid loss: 19.0969 , cos_loss :  0.9136 , classification_loss : 0.1192 , contrastive_loss :  16.9613
Accuracy: 0.9129 , AUC:  0.6158 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 8 , Total loss: 16.4925 , cos_loss :  0.9166 , classification_loss : 0.1157 , contrastive_loss :  14.3896


  8%|▊         | 8/100 [00:37<07:07,  4.64s/it]

Epoch: 8 , valid loss: 19.1598 , cos_loss :  0.9132 , classification_loss : 0.1512 , contrastive_loss :  16.9599
Accuracy: 0.9129 , AUC:  0.5885 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 9 , Total loss: 16.4947 , cos_loss :  0.9167 , classification_loss : 0.1174 , contrastive_loss :  14.3895


  9%|▉         | 9/100 [00:42<07:00,  4.62s/it]

Epoch: 9 , valid loss: 19.12 , cos_loss :  0.9127 , classification_loss : 0.1235 , contrastive_loss :  16.9595
Accuracy: 0.9129 , AUC:  0.5904 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 10 , Total loss: 16.4853 , cos_loss :  0.9164 , classification_loss : 0.1156 , contrastive_loss :  14.3872


 10%|█         | 10/100 [00:46<06:54,  4.60s/it]

Epoch: 10 , valid loss: 19.0982 , cos_loss :  0.9134 , classification_loss : 0.1392 , contrastive_loss :  16.9373
Accuracy: 0.9129 , AUC:  0.5958 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 11 , Total loss: 16.4677 , cos_loss :  0.9168 , classification_loss : 0.1158 , contrastive_loss :  14.3669


 11%|█         | 11/100 [00:51<06:49,  4.60s/it]

Epoch: 11 , valid loss: 19.0546 , cos_loss :  0.9126 , classification_loss : 0.1191 , contrastive_loss :  16.9337
Accuracy: 0.9129 , AUC:  0.6114 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 12 , Total loss: 16.4961 , cos_loss :  0.9233 , classification_loss : 0.1197 , contrastive_loss :  14.3836


 12%|█▏        | 12/100 [00:56<06:49,  4.65s/it]

Epoch: 12 , valid loss: 19.0622 , cos_loss :  0.9172 , classification_loss : 0.136 , contrastive_loss :  16.9294
Accuracy: 0.9129 , AUC:  0.6694 , F1:  0.1026 , Precision:  0.5 , recall:  0.0571
Epoch: 13 , Total loss: 16.439 , cos_loss :  0.9312 , classification_loss : 0.1244 , contrastive_loss :  14.3163


 13%|█▎        | 13/100 [01:00<06:42,  4.63s/it]

Epoch: 13 , valid loss: 19.1174 , cos_loss :  0.9244 , classification_loss : 0.2336 , contrastive_loss :  16.884
Accuracy: 0.9129 , AUC:  0.4581 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 14 , Total loss: 16.4832 , cos_loss :  0.9263 , classification_loss : 0.1191 , contrastive_loss :  14.3644


 14%|█▍        | 14/100 [01:05<06:37,  4.63s/it]

Epoch: 14 , valid loss: 19.0953 , cos_loss :  0.948 , classification_loss : 0.1391 , contrastive_loss :  16.8851
Accuracy: 0.9179 , AUC:  0.797 , F1:  0.1081 , Precision:  1.0 , recall:  0.0571
Epoch: 15 , Total loss: 16.3752 , cos_loss :  0.942 , classification_loss : 0.1091 , contrastive_loss :  14.2578


 15%|█▌        | 15/100 [01:09<06:32,  4.62s/it]

Epoch: 15 , valid loss: 18.9749 , cos_loss :  0.9406 , classification_loss : 0.082 , contrastive_loss :  16.8788
Accuracy: 0.9515 , AUC:  0.7946 , F1:  0.6139 , Precision:  1.0 , recall:  0.4429
Epoch: 16 , Total loss: 16.3442 , cos_loss :  0.9553 , classification_loss : 0.1258 , contrastive_loss :  14.1883


 16%|█▌        | 16/100 [01:14<06:27,  4.61s/it]

Epoch: 16 , valid loss: 18.7506 , cos_loss :  0.9668 , classification_loss : 0.1341 , contrastive_loss :  16.5316
Accuracy: 0.8657 , AUC:  0.7628 , F1:  0.3571 , Precision:  0.3061 , recall:  0.4286
Epoch: 17 , Total loss: 16.0747 , cos_loss :  0.9735 , classification_loss : 0.1257 , contrastive_loss :  13.8922


 17%|█▋        | 17/100 [01:19<06:21,  4.60s/it]

Epoch: 17 , valid loss: 18.7216 , cos_loss :  0.9579 , classification_loss : 0.1641 , contrastive_loss :  16.4888
Accuracy: 0.9129 , AUC:  0.6432 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 18 , Total loss: 15.8954 , cos_loss :  0.9736 , classification_loss : 0.126 , contrastive_loss :  13.726


 18%|█▊        | 18/100 [01:23<06:18,  4.61s/it]

Epoch: 18 , valid loss: 18.3549 , cos_loss :  0.9717 , classification_loss : 0.116 , contrastive_loss :  16.1728
Accuracy: 0.9117 , AUC:  0.6109 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 19 , Total loss: 16.0472 , cos_loss :  0.9669 , classification_loss : 0.1122 , contrastive_loss :  13.9007


 19%|█▉        | 19/100 [01:28<06:13,  4.61s/it]

Epoch: 19 , valid loss: 18.6737 , cos_loss :  0.9789 , classification_loss : 0.1206 , contrastive_loss :  16.4912
Accuracy: 0.9129 , AUC:  0.5924 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 20 , Total loss: 16.3389 , cos_loss :  0.9747 , classification_loss : 0.1517 , contrastive_loss :  14.1449


 20%|██        | 20/100 [01:32<06:11,  4.64s/it]

Epoch: 20 , valid loss: 19.0818 , cos_loss :  0.9477 , classification_loss : 0.1447 , contrastive_loss :  16.9077
Accuracy: 0.8731 , AUC:  0.5697 , F1:  0.1207 , Precision:  0.1522 , recall:  0.1
Epoch: 21 , Total loss: 16.4931 , cos_loss :  0.9485 , classification_loss : 0.1132 , contrastive_loss :  14.3662


 21%|██        | 21/100 [01:37<06:06,  4.64s/it]

Epoch: 21 , valid loss: 18.9057 , cos_loss :  0.943 , classification_loss : 0.1273 , contrastive_loss :  16.759
Accuracy: 0.9129 , AUC:  0.5896 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 22 , Total loss: 16.475 , cos_loss :  0.9623 , classification_loss : 0.1105 , contrastive_loss :  14.3331


 22%|██▏       | 22/100 [01:42<06:00,  4.62s/it]

Epoch: 22 , valid loss: 18.8496 , cos_loss :  0.963 , classification_loss : 0.1179 , contrastive_loss :  16.7019
Accuracy: 0.9129 , AUC:  0.6306 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 23 , Total loss: 16.49 , cos_loss :  0.992 , classification_loss : 0.11 , contrastive_loss :  14.3213


 23%|██▎       | 23/100 [01:46<05:55,  4.62s/it]

Epoch: 23 , valid loss: 19.0308 , cos_loss :  0.9742 , classification_loss : 0.1782 , contrastive_loss :  16.8018
Accuracy: 0.9154 , AUC:  0.7855 , F1:  0.1053 , Precision:  0.6667 , recall:  0.0571
Epoch: 24 , Total loss: 16.3013 , cos_loss :  0.9836 , classification_loss : 0.1354 , contrastive_loss :  14.118


 24%|██▍       | 24/100 [01:51<05:50,  4.61s/it]

Epoch: 24 , valid loss: 19.1855 , cos_loss :  0.992 , classification_loss : 0.2306 , contrastive_loss :  16.8908
Accuracy: 0.9117 , AUC:  0.556 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 25 , Total loss: 16.522 , cos_loss :  0.9735 , classification_loss : 0.1268 , contrastive_loss :  14.3576


 25%|██▌       | 25/100 [01:55<05:44,  4.60s/it]

Epoch: 25 , valid loss: 19.0396 , cos_loss :  0.9507 , classification_loss : 0.1357 , contrastive_loss :  16.8505
Accuracy: 0.9129 , AUC:  0.5803 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 26 , Total loss: 16.3419 , cos_loss :  0.9838 , classification_loss : 0.1053 , contrastive_loss :  14.1839


 26%|██▌       | 26/100 [02:00<05:39,  4.59s/it]

Epoch: 26 , valid loss: 18.876 , cos_loss :  0.9825 , classification_loss : 0.0883 , contrastive_loss :  16.7429
Accuracy: 0.9515 , AUC:  0.7622 , F1:  0.6214 , Precision:  0.9697 , recall:  0.4571
Epoch: 27 , Total loss: 16.429 , cos_loss :  0.9835 , classification_loss : 0.0927 , contrastive_loss :  14.2885


 27%|██▋       | 27/100 [02:05<05:35,  4.59s/it]

Epoch: 27 , valid loss: 18.7957 , cos_loss :  0.9792 , classification_loss : 0.0847 , contrastive_loss :  16.667
Accuracy: 0.9502 , AUC:  0.7898 , F1:  0.6154 , Precision:  0.9412 , recall:  0.4571
Epoch: 28 , Total loss: 16.4434 , cos_loss :  0.9777 , classification_loss : 0.0902 , contrastive_loss :  14.3041


 28%|██▊       | 28/100 [02:09<05:30,  4.59s/it]

Epoch: 28 , valid loss: 18.5538 , cos_loss :  0.9732 , classification_loss : 0.0924 , contrastive_loss :  16.3989
Accuracy: 0.9515 , AUC:  0.7935 , F1:  0.6214 , Precision:  0.9697 , recall:  0.4571
Epoch: 29 , Total loss: 16.4001 , cos_loss :  0.9728 , classification_loss : 0.0907 , contrastive_loss :  14.2687


 29%|██▉       | 29/100 [02:14<05:26,  4.61s/it]

Epoch: 29 , valid loss: 18.6362 , cos_loss :  0.9664 , classification_loss : 0.0928 , contrastive_loss :  16.4938
Accuracy: 0.9465 , AUC:  0.7948 , F1:  0.5657 , Precision:  0.9655 , recall:  0.4
Epoch: 30 , Total loss: 16.378 , cos_loss :  0.9668 , classification_loss : 0.0809 , contrastive_loss :  14.263


 30%|███       | 30/100 [02:19<05:23,  4.62s/it]

Epoch: 30 , valid loss: 18.7968 , cos_loss :  0.9605 , classification_loss : 0.0803 , contrastive_loss :  16.68
Accuracy: 0.9502 , AUC:  0.8101 , F1:  0.6154 , Precision:  0.9412 , recall:  0.4571
Epoch: 31 , Total loss: 16.3638 , cos_loss :  0.9732 , classification_loss : 0.0818 , contrastive_loss :  14.2452


 31%|███       | 31/100 [02:23<05:18,  4.62s/it]

Epoch: 31 , valid loss: 18.8903 , cos_loss :  0.9708 , classification_loss : 0.0817 , contrastive_loss :  16.7421
Accuracy: 0.9515 , AUC:  0.7915 , F1:  0.6214 , Precision:  0.9697 , recall:  0.4571
Epoch: 32 , Total loss: 16.3764 , cos_loss :  0.9729 , classification_loss : 0.0865 , contrastive_loss :  14.2559


 32%|███▏      | 32/100 [02:28<05:16,  4.65s/it]

Epoch: 32 , valid loss: 19.0081 , cos_loss :  0.9811 , classification_loss : 0.0821 , contrastive_loss :  16.8257
Accuracy: 0.9465 , AUC:  0.7959 , F1:  0.5981 , Precision:  0.8649 , recall:  0.4571
Epoch: 33 , Total loss: 16.3234 , cos_loss :  0.9843 , classification_loss : 0.0902 , contrastive_loss :  14.1855


 33%|███▎      | 33/100 [02:32<05:10,  4.64s/it]

Epoch: 33 , valid loss: 18.7107 , cos_loss :  0.9699 , classification_loss : 0.0928 , contrastive_loss :  16.5656
Accuracy: 0.944 , AUC:  0.7915 , F1:  0.5872 , Precision:  0.8205 , recall:  0.4571
Epoch: 34 , Total loss: 16.1281 , cos_loss :  0.982 , classification_loss : 0.1015 , contrastive_loss :  13.9885


 34%|███▍      | 34/100 [02:37<05:05,  4.63s/it]

Epoch: 34 , valid loss: 19.2354 , cos_loss :  0.984 , classification_loss : 0.3411 , contrastive_loss :  16.8247
Accuracy: 0.9129 , AUC:  0.5078 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 35 , Total loss: 16.3612 , cos_loss :  0.9848 , classification_loss : 0.1207 , contrastive_loss :  14.2066


 35%|███▌      | 35/100 [02:42<05:00,  4.62s/it]

Epoch: 35 , valid loss: 18.3356 , cos_loss :  0.9831 , classification_loss : 0.1014 , contrastive_loss :  16.2169
Accuracy: 0.9167 , AUC:  0.7562 , F1:  0.1928 , Precision:  0.6154 , recall:  0.1143
Epoch: 36 , Total loss: 15.7634 , cos_loss :  0.9787 , classification_loss : 0.1154 , contrastive_loss :  13.6211


 36%|███▌      | 36/100 [02:46<04:54,  4.61s/it]

Epoch: 36 , valid loss: 18.3126 , cos_loss :  0.9792 , classification_loss : 0.13 , contrastive_loss :  16.1598
Accuracy: 0.9129 , AUC:  0.5891 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 37 , Total loss: 15.7098 , cos_loss :  0.9873 , classification_loss : 0.1129 , contrastive_loss :  13.5584


 37%|███▋      | 37/100 [02:51<04:49,  4.60s/it]

Epoch: 37 , valid loss: 18.1507 , cos_loss :  0.9967 , classification_loss : 0.1188 , contrastive_loss :  15.9938
Accuracy: 0.9092 , AUC:  0.7637 , F1:  0.0759 , Precision:  0.3333 , recall:  0.0429
Epoch: 38 , Total loss: 15.3891 , cos_loss :  0.9887 , classification_loss : 0.1039 , contrastive_loss :  13.2405


 38%|███▊      | 38/100 [02:55<04:46,  4.62s/it]

Epoch: 38 , valid loss: 17.2549 , cos_loss :  0.9959 , classification_loss : 0.1061 , contrastive_loss :  15.0622
Accuracy: 0.9266 , AUC:  0.75 , F1:  0.4272 , Precision:  0.6667 , recall:  0.3143
Epoch: 39 , Total loss: 15.1158 , cos_loss :  0.9947 , classification_loss : 0.0977 , contrastive_loss :  12.9543


 39%|███▉      | 39/100 [03:00<04:41,  4.62s/it]

Epoch: 39 , valid loss: 17.5124 , cos_loss :  0.9971 , classification_loss : 0.1077 , contrastive_loss :  15.3086
Accuracy: 0.9129 , AUC:  0.7374 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 40 , Total loss: 15.2906 , cos_loss :  0.9937 , classification_loss : 0.1146 , contrastive_loss :  13.1175


 40%|████      | 40/100 [03:05<04:36,  4.61s/it]

Epoch: 40 , valid loss: 17.6756 , cos_loss :  0.9929 , classification_loss : 0.1273 , contrastive_loss :  15.4884
Accuracy: 0.9129 , AUC:  0.5927 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 41 , Total loss: 15.018 , cos_loss :  0.9922 , classification_loss : 0.1115 , contrastive_loss :  12.8526


 41%|████      | 41/100 [03:09<04:31,  4.60s/it]

Epoch: 41 , valid loss: 17.3807 , cos_loss :  0.993 , classification_loss : 0.132 , contrastive_loss :  15.1536
Accuracy: 0.9117 , AUC:  0.5885 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 42 , Total loss: 14.7982 , cos_loss :  0.9923 , classification_loss : 0.1118 , contrastive_loss :  12.6333


 42%|████▏     | 42/100 [03:14<04:26,  4.60s/it]

Epoch: 42 , valid loss: 17.0686 , cos_loss :  0.9887 , classification_loss : 0.1223 , contrastive_loss :  14.9008
Accuracy: 0.9129 , AUC:  0.5934 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 43 , Total loss: 14.7262 , cos_loss :  0.992 , classification_loss : 0.1108 , contrastive_loss :  12.5629


 43%|████▎     | 43/100 [03:18<04:21,  4.59s/it]

Epoch: 43 , valid loss: 17.1721 , cos_loss :  0.9891 , classification_loss : 0.1271 , contrastive_loss :  14.9851
Accuracy: 0.9129 , AUC:  0.5847 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 44 , Total loss: 14.7317 , cos_loss :  0.9916 , classification_loss : 0.1096 , contrastive_loss :  12.5732


 44%|████▍     | 44/100 [03:23<04:17,  4.59s/it]

Epoch: 44 , valid loss: 17.1737 , cos_loss :  0.9938 , classification_loss : 0.1272 , contrastive_loss :  15.0036
Accuracy: 0.9129 , AUC:  0.5803 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 45 , Total loss: 14.7059 , cos_loss :  0.9919 , classification_loss : 0.1103 , contrastive_loss :  12.5476


 45%|████▌     | 45/100 [03:28<04:14,  4.63s/it]

Epoch: 45 , valid loss: 17.2741 , cos_loss :  0.9985 , classification_loss : 0.1235 , contrastive_loss :  15.1044
Accuracy: 0.9129 , AUC:  0.5647 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 46 , Total loss: 14.6679 , cos_loss :  0.9932 , classification_loss : 0.1124 , contrastive_loss :  12.5121


 46%|████▌     | 46/100 [03:32<04:09,  4.62s/it]

Epoch: 46 , valid loss: 17.0536 , cos_loss :  0.9933 , classification_loss : 0.1263 , contrastive_loss :  14.8812
Accuracy: 0.9129 , AUC:  0.5433 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 47 , Total loss: 14.5639 , cos_loss :  0.9942 , classification_loss : 0.1136 , contrastive_loss :  12.4076


 47%|████▋     | 47/100 [03:37<04:04,  4.61s/it]

Epoch: 47 , valid loss: 17.0383 , cos_loss :  0.9948 , classification_loss : 0.1249 , contrastive_loss :  14.8649
Accuracy: 0.9129 , AUC:  0.558 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 48 , Total loss: 14.5829 , cos_loss :  0.9935 , classification_loss : 0.1104 , contrastive_loss :  12.4231


 48%|████▊     | 48/100 [03:42<03:59,  4.60s/it]

Epoch: 48 , valid loss: 17.3642 , cos_loss :  0.9879 , classification_loss : 0.134 , contrastive_loss :  15.1708
Accuracy: 0.9129 , AUC:  0.5619 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 49 , Total loss: 14.5671 , cos_loss :  0.9947 , classification_loss : 0.113 , contrastive_loss :  12.4081


 49%|████▉     | 49/100 [03:46<03:54,  4.60s/it]

Epoch: 49 , valid loss: 16.846 , cos_loss :  0.9941 , classification_loss : 0.1238 , contrastive_loss :  14.683
Accuracy: 0.9129 , AUC:  0.6048 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 50 , Total loss: 14.8163 , cos_loss :  0.9932 , classification_loss : 0.109 , contrastive_loss :  12.6654


 50%|█████     | 50/100 [03:51<03:49,  4.60s/it]

Epoch: 50 , valid loss: 17.66 , cos_loss :  0.9949 , classification_loss : 0.1235 , contrastive_loss :  15.4969
Accuracy: 0.9117 , AUC:  0.629 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 51 , Total loss: 15.1339 , cos_loss :  0.9969 , classification_loss : 0.1116 , contrastive_loss :  12.9803


 51%|█████     | 51/100 [03:55<03:44,  4.59s/it]

Epoch: 51 , valid loss: 17.334 , cos_loss :  0.9997 , classification_loss : 0.1131 , contrastive_loss :  15.1727
Accuracy: 0.908 , AUC:  0.7308 , F1:  0.1591 , Precision:  0.3889 , recall:  0.1
Epoch: 52 , Total loss: 14.6769 , cos_loss :  0.9968 , classification_loss : 0.1041 , contrastive_loss :  12.5321


 52%|█████▏    | 52/100 [04:00<03:40,  4.59s/it]

Epoch: 52 , valid loss: 16.5093 , cos_loss :  0.9963 , classification_loss : 0.1259 , contrastive_loss :  14.3362
Accuracy: 0.903 , AUC:  0.6969 , F1:  0.093 , Precision:  0.25 , recall:  0.0571
Epoch: 53 , Total loss: 14.5886 , cos_loss :  0.9959 , classification_loss : 0.1016 , contrastive_loss :  12.4333


 53%|█████▎    | 53/100 [04:04<03:35,  4.59s/it]

Epoch: 53 , valid loss: 17.0316 , cos_loss :  0.998 , classification_loss : 0.1175 , contrastive_loss :  14.8532
Accuracy: 0.9129 , AUC:  0.7301 , F1:  0.125 , Precision:  0.5 , recall:  0.0714
Epoch: 54 , Total loss: 14.5259 , cos_loss :  0.9967 , classification_loss : 0.0998 , contrastive_loss :  12.3742


 54%|█████▍    | 54/100 [04:09<03:31,  4.59s/it]

Epoch: 54 , valid loss: 17.8224 , cos_loss :  0.9971 , classification_loss : 0.1414 , contrastive_loss :  15.6234
Accuracy: 0.9129 , AUC:  0.593 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 55 , Total loss: 15.0606 , cos_loss :  0.9946 , classification_loss : 0.1155 , contrastive_loss :  12.8938


 55%|█████▌    | 55/100 [04:14<03:26,  4.59s/it]

Epoch: 55 , valid loss: 17.4352 , cos_loss :  0.9955 , classification_loss : 0.1179 , contrastive_loss :  15.2373
Accuracy: 0.9117 , AUC:  0.654 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 56 , Total loss: 14.8514 , cos_loss :  0.9954 , classification_loss : 0.1076 , contrastive_loss :  12.6846


 56%|█████▌    | 56/100 [04:19<03:25,  4.68s/it]

Epoch: 56 , valid loss: 17.5326 , cos_loss :  0.9948 , classification_loss : 0.1225 , contrastive_loss :  15.3053
Accuracy: 0.9092 , AUC:  0.6371 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 57 , Total loss: 14.8807 , cos_loss :  0.9955 , classification_loss : 0.1075 , contrastive_loss :  12.7093


 57%|█████▋    | 57/100 [04:23<03:20,  4.65s/it]

Epoch: 57 , valid loss: 17.273 , cos_loss :  0.9954 , classification_loss : 0.1247 , contrastive_loss :  15.0787
Accuracy: 0.9104 , AUC:  0.652 , F1:  0.2 , Precision:  0.45 , recall:  0.1286
Epoch: 58 , Total loss: 14.7104 , cos_loss :  0.9953 , classification_loss : 0.1071 , contrastive_loss :  12.5438


 58%|█████▊    | 58/100 [04:28<03:15,  4.65s/it]

Epoch: 58 , valid loss: 16.8572 , cos_loss :  0.9944 , classification_loss : 0.1244 , contrastive_loss :  14.6684
Accuracy: 0.9129 , AUC:  0.6671 , F1:  0.125 , Precision:  0.5 , recall:  0.0714
Epoch: 59 , Total loss: 14.5693 , cos_loss :  0.996 , classification_loss : 0.1102 , contrastive_loss :  12.3971


 59%|█████▉    | 59/100 [04:33<03:12,  4.69s/it]

Epoch: 59 , valid loss: 16.515 , cos_loss :  0.9956 , classification_loss : 0.1325 , contrastive_loss :  14.3161
Accuracy: 0.9129 , AUC:  0.6058 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 60 , Total loss: 14.933 , cos_loss :  0.997 , classification_loss : 0.1065 , contrastive_loss :  12.7708


 59%|█████▉    | 59/100 [04:37<03:12,  4.71s/it]

Epoch: 60 , valid loss: 16.2369 , cos_loss :  0.997 , classification_loss : 0.1163 , contrastive_loss :  14.0472
Accuracy: 0.9192 , AUC:  0.7064 , F1:  0.2857 , Precision:  0.619 , recall:  0.1857
Early stopping triggered at epoch 29
Best Combined Score (AUC): 0.8101





In [35]:
total_te_loss = 0
total_te_cos_loss = 0
total_te_classi_loss = 0
total_te_const_loss = 0

te_labels_list = []
te_predictions_list = []
te_probabilities_list = []

model_1_path = model_1_save_path
model_2_path = model_2_save_path

model_1.load_state_dict(torch.load(model_1_path))
model_2.load_state_dict(torch.load(model_2_path))

# model_1.eval()
# model_2.eval()

with torch.no_grad():
    for batch_data in tqdm(test_loader):
        te_labels = batch_data['label'].to(device)
        te_output_1, te_next_visit_output_1, te_final_visit_classification_1, te_final_visit_1 = model_1(batch_data)
        y_te = torch.ones(te_output_1.size(0), dtype=torch.float, device=device)

        te_cosine_loss_mean_1 = cosine_embedding_loss(te_output_1, te_next_visit_output_1, y_te)
        classification_loss_te_1 = criterion(te_final_visit_classification_1.squeeze(), te_labels.long())
        te_output_2, te_next_visit_output_2, te_final_visit_classification_2, te_final_visit_2 = model_2(batch_data)
  
        te_cosine_loss_mean_2 = cosine_embedding_loss(te_output_2, te_next_visit_output_2, y_te)

        classification_loss_te_2 = criterion(te_final_visit_classification_2.squeeze(), te_labels.long())

        # te_const_loss = const_loss(te_final_visit_1, te_final_visit_2)
        te_const_loss = const_loss(torch.cat([final_visit_1, final_visit_2],dim=0))

        te_loss = ( (cos_lambda * te_cosine_loss_mean_1) + (classi_lambda * classification_loss_te_1) 
                   +(cos_lambda * te_cosine_loss_mean_2) + (classi_lambda * classification_loss_te_2) 
                   + te_const_loss)

        total_te_loss += te_loss.item()
        total_te_cos_loss += (te_cosine_loss_mean_1.item())
        total_te_classi_loss += (classification_loss_te_1.item())
        total_te_const_loss += te_const_loss.item()

        te_probs = F.softmax(te_final_visit_classification_1)
        te_predictions = torch.max(te_probs, 1)[1].view((len(te_labels),))
        
        te_labels_list.extend(te_labels.view(-1).cpu().numpy())
        te_predictions_list.extend(te_predictions.cpu().numpy())
        te_probabilities_list.extend(te_probs[:,1].cpu().numpy())

    avg_te_loss = total_te_loss / len(test_loader)
    avg_te_cos_loss = total_te_cos_loss / len(test_loader)
    avg_te_classi_loss = total_te_classi_loss / len(test_loader)
    avg_te_const_loss = total_te_const_loss / len(test_loader)

100%|██████████| 7/7 [00:00<00:00, 19.22it/s]


In [36]:
# 성능 지표 계산
accuracy = accuracy_score(te_labels_list, te_predictions_list)
auc = roc_auc_score(te_labels_list, te_probabilities_list)
f1 = f1_score(te_labels_list, te_predictions_list)
precision = precision_score(te_labels_list,te_predictions_list)
recall = recall_score(te_labels_list, te_predictions_list)

print("test loss:", round(avg_te_loss, 4), ", cos_loss : ", round(avg_te_cos_loss, 4),
      ", classification_loss :",round(avg_te_classi_loss, 4), ", contrastive_loss : ", round(avg_te_const_loss,4))
print("Accuracy:", round(accuracy,4), ", AUC: ", round(auc,4), ", F1: ", round(f1,4), ", Precision: ", round(precision,4), ", recall: ", round(recall,4))

np.unique(te_labels_list,  return_counts=True)

test loss: 14.3071 , cos_loss :  0.9602 , classification_loss : 0.0869 , contrastive_loss :  12.1966
Accuracy: 0.9502 , AUC:  0.7544 , F1:  0.5 , Precision:  1.0 , recall:  0.3333


(array([0, 1]), array([744,  60]))

In [37]:
confusion_matrix(te_labels_list, te_predictions_list)

array([[744,   0],
       [ 40,  20]])