In [1]:
import numpy as np
import pandas as pd
import os 
import pickle
import sys
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
from tqdm import tqdm
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix
import warnings

torch.set_printoptions(profile="full")
np.set_printoptions(threshold=sys.maxsize)
warnings.filterwarnings("ignore")

In [2]:
from src.attn import FixedPositionalEncoding, LearnablePositionalEncoding, TemporalEmbedding, MultiheadAttention, Decoder
from src.loss import ContrastiveLoss, FocalLoss

In [3]:
from datetime import datetime

In [4]:
seed = 777
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(seed)

In [5]:
path = '/data/notebook/shared/MIMIC-IV'

In [6]:
with open(os.path.join(path, 'dict_types_nomedi_mimic_240408_clinic_3_years.pkl'), 'rb') as f:
    dtype_dict = pickle.load(f)
f.close()

with open(os.path.join('./data', 'preprocessed_nomedi_240423_clinic_3_years.pkl'), 'rb') as f:
    data_dict_d = pickle.load(f)
f.close()

In [7]:
remain_year_clinical = pd.read_csv(os.path.join(path, "240408_remain_year_clinical_diag_seq_4.csv"))
del remain_year_clinical['clinical_label']

In [8]:
len(dtype_dict)

12702

In [9]:
labels = []
code_labels = []
length_list = []
clinical_labels = []
code_length_list = []
for sample_id, visits in tqdm(data_dict_d.items()):
    # 레이블 추가
    label = visits['label']
    code_label = visits['code_label']
    clinical_label = visits['clinical_label']
    labels.append(label)
    code_labels.append(code_label)
    clinical_labels.append(clinical_labels)

100%|██████████| 8037/8037 [00:00<00:00, 581902.66it/s]


In [10]:
train_indices, test_indices, train_y, test_y = train_test_split(list(data_dict_d.keys()), labels, test_size=0.1, random_state=777, stratify=labels)
train_indices, valid_indices, valid_y, valid_y = train_test_split(train_indices, train_y, test_size=(len(test_indices)/len(train_indices)), random_state=777, stratify=train_y) 

In [11]:
train_data = {}
valid_data = {}
test_data = {}
for sample in tqdm(train_indices):
    train_data[sample] = data_dict_d[sample]

for sample in tqdm(valid_indices):
    valid_data[sample] = data_dict_d[sample]

for sample in tqdm(test_indices):
    test_data[sample] = data_dict_d[sample]

100%|██████████| 6429/6429 [00:00<00:00, 1554368.25it/s]
100%|██████████| 804/804 [00:00<00:00, 1025926.50it/s]
100%|██████████| 804/804 [00:00<00:00, 1079455.96it/s]


In [12]:
train_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(train_indices)].reset_index(drop=True)
valid_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(valid_indices)].reset_index(drop=True)
test_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(test_indices)].reset_index(drop=True)

In [13]:
class CustomDataset(Dataset):
    def __init__(self, data, clinical_df, scaler, mode='train'):       
        self.keys = list(data.keys())  # 딕셔너리의 키 목록 저장
        self.data = data  # 딕셔너리에서 데이터만 추출하여 저장
        self.scaler = scaler
        if mode == 'train':
            scaled_data = self.scaler.fit_transform(clinical_df.iloc[:, 2:])
            scaled_clinical_df = pd.DataFrame(scaled_data, columns = clinical_df.iloc[:, 2:].columns)
            self.scaled_clinical_df = pd.concat([clinical_df.iloc[:, :2], scaled_clinical_df], axis=1)
        else:
            scaled_data = self.scaler.transform(clinical_df.iloc[:, 2:])
            scaled_clinical_df = pd.DataFrame(scaled_data, columns = clinical_df.iloc[:, 2:].columns)
            self.scaled_clinical_df = pd.concat([clinical_df.iloc[:, :2], scaled_clinical_df], axis=1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # x는 현재 방문까지의 모든 방문 데이터, y는 다음 방문 데이터
        padding_temp = torch.zeros((1, len(self.data[self.keys[idx]]['code_index'][0])), dtype=torch.long)
        
        origin_visit = torch.tensor(self.data[self.keys[idx]]['code_index'], dtype=torch.long)
        origin_mask = torch.tensor(self.data[self.keys[idx]]['seq_mask'], dtype=torch.long)
        origin_mask_final = torch.tensor(self.data[self.keys[idx]]['seq_mask_final'], dtype=torch.long)
        origin_mask_code = torch.tensor(self.data[self.keys[idx]]['seq_mask_code'], dtype=torch.long)
        
        next_visit = torch.cat((origin_visit[1:], padding_temp), dim=0)
        next_mask =torch.cat((origin_mask[1:], torch.tensor([0], dtype=torch.long)), dim=0)
        next_mask_code = torch.cat((origin_mask_code[1:], padding_temp), dim=0)
        
        clinical_data = self.scaled_clinical_df.loc[self.scaled_clinical_df['subject_id'] == self.keys[idx], self.scaled_clinical_df.columns[2:]].values
        clinical_data = torch.tensor(clinical_data, dtype=torch.float)
        visit_index = torch.tensor(self.data[self.keys[idx]]['year_onehot'], dtype=torch.long)
        last_visit_index = torch.tensor(self.data[self.keys[idx]]['last_year_onehot'], dtype=torch.float)
        time_feature = torch.tensor(self.data[self.keys[idx]]['time_feature'], dtype=torch.long)
        label_per_sample = self.data[self.keys[idx]]['label']
        key_per_sample = self.keys[idx]  # 해당 샘플의 키
        # 키 값도 함께 반환
        return {'sample_id': key_per_sample, 'origin_visit': origin_visit, 'next_visit': next_visit,\
                'origin_mask': origin_mask, 'origin_mask_final': origin_mask_final, 'next_mask': next_mask, \
                'origin_mask_code': origin_mask_code, 'next_mask_code': next_mask_code, 'time_feature': time_feature, \
                'clinical_data': clinical_data, 'visit_index': visit_index, 'last_visit_index': last_visit_index, \
                'label': label_per_sample}

In [14]:
scaler = StandardScaler()
train_dataset = CustomDataset(train_data, train_clinical, scaler, mode='train')
valid_dataset = CustomDataset(valid_data, valid_clinical, scaler, mode='valid')
test_dataset = CustomDataset(test_data, test_clinical, scaler, mode='test')

In [15]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=512, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

In [16]:
class CustomTransformerModel(nn.Module):
    def __init__(self, code_size, ninp, nhead, nhid, nlayers, dropout=0.5, device='cuda:0', pe='fixed'):
        super(CustomTransformerModel, self).__init__()
        self.ninp = ninp  # 축소된 차원 및 Transformer 입력 차원
        self.device = device
        # 차원 축소를 위한 선형 레이어
        self.pre_embedding = nn.Embedding(code_size, self.ninp)
        self.pe = pe
        if self.pe == 'fixed':
            self.pos_encoder = FixedPositionalEncoding(ninp, dropout)
        elif self.pe == 'time_feature':
            self.pos_encoder = TemporalEmbedding(ninp, embed_type='embed', dropout=dropout)
        else:
            self.pos_encoder = LearnablePositionalEncoding(ninp, dropout)
            
        self.transformer_decoder = Decoder(ninp, nhead, nhid, nlayers, dropout)
        self.decoder = nn.Linear(ninp, ninp, bias=False)  # 최종 출력 차원을 설정 (여기서는 ninp로 설정)
        self.classification_layer = nn.Linear(ninp, 2)
        self.clinical_transform = nn.Linear(18, ninp) 
        self.cross_attn = MultiheadAttention(ninp, nhead, dropout=0.3)
        self.init_weights()
        
    def init_weights(self):
        nn.init.xavier_uniform_(self.pre_embedding.weight)
        nn.init.xavier_uniform_(self.decoder.weight)
        nn.init.xavier_uniform_(self.classification_layer.weight)
    
    def forward(self, batch_data):
        # 차원 축소
        origin_visit = batch_data['origin_visit'].to(self.device)
        next_visit = batch_data['next_visit'].to(self.device)
        # clinical_tensor = batch_data['clinical_data'].to(self.device)
        mask_code = batch_data['origin_mask_code'].unsqueeze(3).to(self.device)
        next_mask_code = batch_data['next_mask_code'].unsqueeze(3).to(self.device)
        mask_final = batch_data['origin_mask_final'].unsqueeze(2).to(self.device)
        mask_final_year = batch_data['last_visit_index'].to(self.device)
                
        origin_visit_emb = (self.pre_embedding(origin_visit) * mask_code).sum(dim=2)
        next_visit_emb = (self.pre_embedding(next_visit) * next_mask_code).sum(dim=2)
        # 위치 인코딩 및 Transformer 디코더 적용
        if self.pe == 'time_feature':
            time_feature = batch_data['time_feature'].to(self.device)
            src = self.pos_encoder(origin_visit_emb, time_feature)
        else:
            src = self.pos_encoder(origin_visit_emb)
        
        seq_len = src.shape[1]
        src_mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=self.device), diagonal=1)
        output = self.transformer_decoder(src, attention_mask=src_mask)
        output = self.decoder(output)
        
        output_batch, output_visit_num, output_dim = output.size()
        next_visit_output_batch, next_visit_output_num, next_visit_output_dim = next_visit_emb.size()
        
        # print(output[:, :torch.,:])
        temp_output = output.reshape(-1, output_dim)
        temp_next_visit = next_visit_emb.reshape(-1, next_visit_output_dim)
        final_visit = (output * mask_final).sum(dim=1)
        # year_emb = torch.bmm(output.transpose(1,2), mask_final_year).transpose(1,2)[:,:2,:]
        
        # transformed_clinical = self.clinical_transform(clinical_tensor)
        # mixed_output, mixed_cross_attn = self.cross_attn(year_emb, transformed_clinical, transformed_clinical)
        # mixed_output = mixed_output[:,-1,:]
        # mixed_final_emb = torch.cat((final_visit, mixed_output), dim=-1)
        classification_result = self.classification_layer(final_visit)    
        return temp_output, temp_next_visit, classification_result, final_visit

In [17]:
lr = 0.001
ninp = 128
nhid = 256
nlayer = 5
gamma = 0.5
model_name = 'no_center_total_label_no_clinical'
pe = 'fixed'

In [18]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_1 = CustomTransformerModel(len(dtype_dict), ninp=ninp, nhead=8, nhid=nhid, nlayers=nlayer, dropout=0.1, device=device, pe=pe).to(device)
# model_2 = CustomTransformerModel(len(dtype_dict), ninp=ninp, nhead=8, nhid=nhid, nlayers=nlayer, dropout=0.6, device=device, pe='fixed').to(device)

optimizer = torch.optim.Adam(model_1.parameters(), lr=lr, weight_decay=1e-5)

# criterion = nn.BCEWithLogitsLoss()
criterion = FocalLoss(2, gamma=gamma)
# cosine_loss = CosineSimilarityLoss()
cosine_embedding_loss = nn.CosineEmbeddingLoss()
# const_loss = ContrastiveLoss(temperature=0.05)

In [19]:
# 얼리 스타핑 설정
num_epochs = 100
patience = 30
best_loss = float("inf")
counter = 0
epoch_temp = 0

best_f1 = 0.0
best_combined_score = 0.0
best_epoch = 0

view_total_loss = []
view_cos_loss = []
view_classi_loss = []

cos_lambda = 1
classi_lambda = 1
const_lambda = 1


for epoch in tqdm(range(num_epochs)):
    model_1.train()

    total_train_loss = 0
    total_cos_loss = 0
    total_classi_loss = 0
    total_const_loss = 0
    # key_per_sample, origin_visit, origin_mask, origin_mask_code, origin_mask_final, next_visit, next_mask, next_mask_code, next_mask_final, label_per_sample
    for batch_data in train_loader:
        tr_labels = batch_data['label'].to(device)
        optimizer.zero_grad()
        output_1, next_visit_output_1, final_visit_classification_1, final_visit_1 = model_1(batch_data)
        y = torch.ones(output_1.size(0), dtype=torch.float, device=device)
        y = y.to(device)   
        cosine_loss_mean_1 = cosine_embedding_loss(output_1, next_visit_output_1, y)
        classification_loss_1 = criterion(final_visit_classification_1.squeeze(), tr_labels.long())
        loss = (cos_lambda * cosine_loss_mean_1) + (classi_lambda * classification_loss_1)

        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        total_cos_loss += (cosine_loss_mean_1.item())
        total_classi_loss += (classification_loss_1.item())
        # total_const_loss += train_const_loss.item()

        view_total_loss.append(total_train_loss)
        view_cos_loss.append(total_cos_loss)
        view_classi_loss.append(total_classi_loss)

    # 평균 손실 계산
    avg_train_loss = total_train_loss / len(train_loader)
    avg_cos_loss = total_cos_loss / len(train_loader)
    avg_classi_loss = total_classi_loss / len(train_loader)

    print("Epoch:", epoch+1, ", Total loss:", round(avg_train_loss, 4), ", cos_loss : ",\
          round(avg_cos_loss, 4), ", classification_loss :",round(avg_classi_loss, 4))

    model_1.eval()
    total_val_loss = 0
    total_val_cos_loss = 0
    total_val_classi_loss = 0
    total_val_const_loss = 0

    val_labels_list = []
    val_predictions_list = []
    val_probabilities_list = []


    with torch.no_grad():
        for batch_data in valid_loader:
            val_labels = batch_data['label'].to(device)
            val_output_1, val_next_visit_output_1, val_final_visit_classification_1, val_final_visit_1 = model_1(batch_data)
            
            y_val = torch.ones(val_output_1.size(0), dtype=torch.float, device=device) 
            val_cosine_loss_mean_1 = cosine_embedding_loss(val_output_1, val_next_visit_output_1, y_val)
            
            classification_loss_val_1 = criterion(val_final_visit_classification_1.squeeze(), val_labels.long())            

            val_loss = (cos_lambda * val_cosine_loss_mean_1) + (classi_lambda * classification_loss_val_1) 
            
            total_val_loss += val_loss.item()
            total_val_cos_loss += (val_cosine_loss_mean_1.item())
            total_val_classi_loss += (classification_loss_val_1.item())
            # total_val_const_loss += val_const_loss.item()

            val_probs = F.softmax(val_final_visit_classification_1)
            val_predictions = torch.max(val_probs, 1)[1].view((len(val_labels),))

            val_labels_list.extend(val_labels.view(-1).cpu().numpy())
            val_predictions_list.extend(val_predictions.cpu().numpy())
            
            ########
            val_probabilities_list.extend(val_probs[:,1].cpu().numpy())
            ######## 
            
            
        avg_val_loss = total_val_loss / len(valid_loader)
        avg_val_cos_loss = total_val_cos_loss / len(valid_loader)
        avg_val_classi_loss = total_val_classi_loss / len(valid_loader)

    # 성능 지표 계산
    accuracy = accuracy_score(val_labels_list, val_predictions_list)
    auc = roc_auc_score(val_labels_list, val_probabilities_list)
    f1 = f1_score(val_labels_list, val_predictions_list)
    precision = precision_score(val_labels_list,val_predictions_list)
    recall = recall_score(val_labels_list, val_predictions_list)

    print("Epoch:", epoch+1, ", valid loss:", round(avg_val_loss, 4), ", cos_loss : ",\
          round(avg_val_cos_loss, 4), ", classification_loss :",round(avg_val_classi_loss, 4))
    print("Accuracy:", round(accuracy,4), ", AUC: ", round(auc,4), ", F1: ", round(f1,4), ", Precision: ", round(precision,4), ", recall: ", round(recall,4))

    current_auc = roc_auc_score(val_labels_list, val_probabilities_list)
    current_f1 = f1_score(val_labels_list, val_predictions_list)
    current_combined_score = current_f1

    # 모델 저장 경로 설정
    model_save_path = f'results'
    date_dir = datetime.today().strftime("%Y%m%d")
    model_time =  datetime.today().strftime("%H%M%S")
    # 학습률과 분류 가중치를 파일명에 포함시키기 위한 문자열 포맷
    model_filename_format = f'model_lr{lr}_classi{classi_lambda}_dim{ninp}_hid{nhid}_layer{nlayer}_epoch{{epoch}}_{{model}}_{{pe}}_{{gamma}}_{model_time}.pth'

    # 모델 저장 폴더가 없으면 생성
    os.makedirs(os.path.join(model_save_path, date_dir), exist_ok=True)

    if current_combined_score > best_combined_score:
        best_combined_score = current_combined_score
        best_epoch = epoch
        counter = 0
        # 모델 저장 경로와 파일명을 결합하여 전체 파일 경로 생성
        model_1_save_path = os.path.join(model_save_path, date_dir, model_filename_format.format(epoch=best_epoch, model=model_name, pe=pe, gamma=gamma))
        # 모델 저장
        torch.save(model_1.state_dict(), f'{model_1_save_path}')
    else:
        counter += 1

    if counter >= patience:
        print("Early stopping triggered at epoch {}".format(best_epoch))
        print(f"Best Combined Score (AUC): {best_combined_score:.4f}")
        break

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 1 , Total loss: 1.7998 , cos_loss :  0.9805 , classification_loss : 0.8193


  1%|          | 1/100 [00:03<05:38,  3.42s/it]

Epoch: 1 , valid loss: 1.4749 , cos_loss :  0.954 , classification_loss : 0.5209
Accuracy: 0.9129 , AUC:  0.6456 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 2 , Total loss: 1.3037 , cos_loss :  0.9443 , classification_loss : 0.3594


  2%|▏         | 2/100 [00:06<05:40,  3.47s/it]

Epoch: 2 , valid loss: 1.1369 , cos_loss :  0.9322 , classification_loss : 0.2047
Accuracy: 0.9129 , AUC:  0.5772 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 3 , Total loss: 1.1573 , cos_loss :  0.9287 , classification_loss : 0.2286


  3%|▎         | 3/100 [00:10<05:33,  3.43s/it]

Epoch: 3 , valid loss: 1.1473 , cos_loss :  0.9222 , classification_loss : 0.2251
Accuracy: 0.9129 , AUC:  0.6424 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 4 , Total loss: 1.1396 , cos_loss :  0.9221 , classification_loss : 0.2176


  4%|▍         | 4/100 [00:13<05:27,  3.41s/it]

Epoch: 4 , valid loss: 1.1263 , cos_loss :  0.9181 , classification_loss : 0.2082
Accuracy: 0.9129 , AUC:  0.6396 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 5 , Total loss: 1.1289 , cos_loss :  0.9195 , classification_loss : 0.2094


  5%|▌         | 5/100 [00:17<05:26,  3.44s/it]

Epoch: 5 , valid loss: 1.1195 , cos_loss :  0.9162 , classification_loss : 0.2033
Accuracy: 0.9129 , AUC:  0.673 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 6 , Total loss: 1.1275 , cos_loss :  0.918 , classification_loss : 0.2095


  6%|▌         | 6/100 [00:20<05:21,  3.42s/it]

Epoch: 6 , valid loss: 1.1183 , cos_loss :  0.9152 , classification_loss : 0.2031
Accuracy: 0.9129 , AUC:  0.6475 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 7 , Total loss: 1.1251 , cos_loss :  0.917 , classification_loss : 0.2081


  7%|▋         | 7/100 [00:24<05:20,  3.44s/it]

Epoch: 7 , valid loss: 1.1194 , cos_loss :  0.9145 , classification_loss : 0.2049
Accuracy: 0.9129 , AUC:  0.6871 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 8 , Total loss: 1.128 , cos_loss :  0.9167 , classification_loss : 0.2113


  8%|▊         | 8/100 [00:27<05:16,  3.44s/it]

Epoch: 8 , valid loss: 1.1181 , cos_loss :  0.914 , classification_loss : 0.204
Accuracy: 0.9129 , AUC:  0.6773 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 9 , Total loss: 1.1269 , cos_loss :  0.9158 , classification_loss : 0.2111


  9%|▉         | 9/100 [00:30<05:13,  3.45s/it]

Epoch: 9 , valid loss: 1.1164 , cos_loss :  0.9137 , classification_loss : 0.2028
Accuracy: 0.9129 , AUC:  0.6801 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 10 , Total loss: 1.1277 , cos_loss :  0.9161 , classification_loss : 0.2116


 10%|█         | 10/100 [00:34<05:07,  3.42s/it]

Epoch: 10 , valid loss: 1.1159 , cos_loss :  0.9134 , classification_loss : 0.2025
Accuracy: 0.9129 , AUC:  0.6822 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 11 , Total loss: 1.1283 , cos_loss :  0.9155 , classification_loss : 0.2128


 11%|█         | 11/100 [00:37<05:02,  3.40s/it]

Epoch: 11 , valid loss: 1.1187 , cos_loss :  0.9132 , classification_loss : 0.2055
Accuracy: 0.9129 , AUC:  0.6717 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 12 , Total loss: 1.1244 , cos_loss :  0.9155 , classification_loss : 0.2089


 12%|█▏        | 12/100 [00:41<05:02,  3.43s/it]

Epoch: 12 , valid loss: 1.1154 , cos_loss :  0.913 , classification_loss : 0.2025
Accuracy: 0.9129 , AUC:  0.669 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 13 , Total loss: 1.1224 , cos_loss :  0.9152 , classification_loss : 0.2072


 13%|█▎        | 13/100 [00:44<04:56,  3.41s/it]

Epoch: 13 , valid loss: 1.1161 , cos_loss :  0.9128 , classification_loss : 0.2033
Accuracy: 0.9129 , AUC:  0.6761 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 14 , Total loss: 1.1244 , cos_loss :  0.9152 , classification_loss : 0.2092


 14%|█▍        | 14/100 [00:48<04:55,  3.43s/it]

Epoch: 14 , valid loss: 1.1167 , cos_loss :  0.9127 , classification_loss : 0.2041
Accuracy: 0.9129 , AUC:  0.6719 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 15 , Total loss: 1.1238 , cos_loss :  0.9151 , classification_loss : 0.2088


 15%|█▌        | 15/100 [00:51<04:51,  3.43s/it]

Epoch: 15 , valid loss: 1.1142 , cos_loss :  0.9125 , classification_loss : 0.2017
Accuracy: 0.9129 , AUC:  0.6697 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 16 , Total loss: 1.1234 , cos_loss :  0.9149 , classification_loss : 0.2085


 16%|█▌        | 16/100 [00:54<04:49,  3.45s/it]

Epoch: 16 , valid loss: 1.1136 , cos_loss :  0.9124 , classification_loss : 0.2011
Accuracy: 0.9129 , AUC:  0.6758 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 17 , Total loss: 1.1235 , cos_loss :  0.9151 , classification_loss : 0.2084


 17%|█▋        | 17/100 [00:58<04:44,  3.42s/it]

Epoch: 17 , valid loss: 1.113 , cos_loss :  0.9123 , classification_loss : 0.2006
Accuracy: 0.9129 , AUC:  0.6767 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 18 , Total loss: 1.1228 , cos_loss :  0.9149 , classification_loss : 0.2079


 18%|█▊        | 18/100 [01:01<04:41,  3.44s/it]

Epoch: 18 , valid loss: 1.111 , cos_loss :  0.9123 , classification_loss : 0.1987
Accuracy: 0.9129 , AUC:  0.6731 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 19 , Total loss: 1.1142 , cos_loss :  0.9149 , classification_loss : 0.1993


 19%|█▉        | 19/100 [01:05<04:36,  3.42s/it]

Epoch: 19 , valid loss: 1.0944 , cos_loss :  0.9123 , classification_loss : 0.1821
Accuracy: 0.9129 , AUC:  0.6678 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 20 , Total loss: 1.0917 , cos_loss :  0.9149 , classification_loss : 0.1768


 20%|██        | 20/100 [01:08<04:35,  3.44s/it]

Epoch: 20 , valid loss: 1.0681 , cos_loss :  0.9131 , classification_loss : 0.155
Accuracy: 0.9415 , AUC:  0.6623 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 21 , Total loss: 1.0772 , cos_loss :  0.9159 , classification_loss : 0.1612


 21%|██        | 21/100 [01:11<04:30,  3.42s/it]

Epoch: 21 , valid loss: 1.0673 , cos_loss :  0.9135 , classification_loss : 0.1539
Accuracy: 0.9415 , AUC:  0.6559 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 22 , Total loss: 1.0755 , cos_loss :  0.9156 , classification_loss : 0.1599


 22%|██▏       | 22/100 [01:15<04:29,  3.45s/it]

Epoch: 22 , valid loss: 1.0633 , cos_loss :  0.9134 , classification_loss : 0.1498
Accuracy: 0.9415 , AUC:  0.673 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 23 , Total loss: 1.0652 , cos_loss :  0.9157 , classification_loss : 0.1495


 23%|██▎       | 23/100 [01:18<04:24,  3.43s/it]

Epoch: 23 , valid loss: 1.0683 , cos_loss :  0.9135 , classification_loss : 0.1548
Accuracy: 0.9415 , AUC:  0.6709 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 24 , Total loss: 1.0665 , cos_loss :  0.9154 , classification_loss : 0.1511


 24%|██▍       | 24/100 [01:22<04:19,  3.41s/it]

Epoch: 24 , valid loss: 1.0632 , cos_loss :  0.9133 , classification_loss : 0.1499
Accuracy: 0.9415 , AUC:  0.6771 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 25 , Total loss: 1.0647 , cos_loss :  0.9154 , classification_loss : 0.1493


 25%|██▌       | 25/100 [01:25<04:17,  3.43s/it]

Epoch: 25 , valid loss: 1.0633 , cos_loss :  0.9132 , classification_loss : 0.1501
Accuracy: 0.9415 , AUC:  0.6762 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 26 , Total loss: 1.0634 , cos_loss :  0.9156 , classification_loss : 0.1478


 26%|██▌       | 26/100 [01:29<04:12,  3.42s/it]

Epoch: 26 , valid loss: 1.064 , cos_loss :  0.9132 , classification_loss : 0.1508
Accuracy: 0.9415 , AUC:  0.672 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 27 , Total loss: 1.0634 , cos_loss :  0.9154 , classification_loss : 0.148


 27%|██▋       | 27/100 [01:32<04:10,  3.43s/it]

Epoch: 27 , valid loss: 1.0634 , cos_loss :  0.9131 , classification_loss : 0.1503
Accuracy: 0.9415 , AUC:  0.6717 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 28 , Total loss: 1.0634 , cos_loss :  0.9154 , classification_loss : 0.148


 28%|██▊       | 28/100 [01:35<04:05,  3.41s/it]

Epoch: 28 , valid loss: 1.0635 , cos_loss :  0.913 , classification_loss : 0.1505
Accuracy: 0.9415 , AUC:  0.6702 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 29 , Total loss: 1.0632 , cos_loss :  0.9152 , classification_loss : 0.1479


 29%|██▉       | 29/100 [01:39<04:09,  3.51s/it]

Epoch: 29 , valid loss: 1.0627 , cos_loss :  0.913 , classification_loss : 0.1497
Accuracy: 0.9415 , AUC:  0.6724 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 30 , Total loss: 1.0646 , cos_loss :  0.9153 , classification_loss : 0.1492


 30%|███       | 30/100 [01:43<04:02,  3.47s/it]

Epoch: 30 , valid loss: 1.0629 , cos_loss :  0.9129 , classification_loss : 0.1501
Accuracy: 0.9415 , AUC:  0.673 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 31 , Total loss: 1.0634 , cos_loss :  0.9152 , classification_loss : 0.1482


 31%|███       | 31/100 [01:46<03:57,  3.44s/it]

Epoch: 31 , valid loss: 1.0645 , cos_loss :  0.9128 , classification_loss : 0.1517
Accuracy: 0.9415 , AUC:  0.6743 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 32 , Total loss: 1.0665 , cos_loss :  0.9151 , classification_loss : 0.1514


 32%|███▏      | 32/100 [01:49<03:55,  3.46s/it]

Epoch: 32 , valid loss: 1.0646 , cos_loss :  0.9127 , classification_loss : 0.1519
Accuracy: 0.9415 , AUC:  0.6724 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 33 , Total loss: 1.0652 , cos_loss :  0.9151 , classification_loss : 0.1501


 33%|███▎      | 33/100 [01:53<03:49,  3.43s/it]

Epoch: 33 , valid loss: 1.0648 , cos_loss :  0.9127 , classification_loss : 0.1522
Accuracy: 0.9415 , AUC:  0.6604 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 34 , Total loss: 1.0699 , cos_loss :  0.9152 , classification_loss : 0.1547


 34%|███▍      | 34/100 [01:56<03:47,  3.45s/it]

Epoch: 34 , valid loss: 1.0622 , cos_loss :  0.9125 , classification_loss : 0.1497
Accuracy: 0.9415 , AUC:  0.674 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 35 , Total loss: 1.0664 , cos_loss :  0.915 , classification_loss : 0.1514


 35%|███▌      | 35/100 [02:00<03:46,  3.48s/it]

Epoch: 35 , valid loss: 1.0623 , cos_loss :  0.9127 , classification_loss : 0.1495
Accuracy: 0.9415 , AUC:  0.6705 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 36 , Total loss: 1.0647 , cos_loss :  0.9152 , classification_loss : 0.1495


 36%|███▌      | 36/100 [02:04<03:46,  3.55s/it]

Epoch: 36 , valid loss: 1.0625 , cos_loss :  0.9126 , classification_loss : 0.1499
Accuracy: 0.9415 , AUC:  0.6777 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 37 , Total loss: 1.0624 , cos_loss :  0.9151 , classification_loss : 0.1474


 37%|███▋      | 37/100 [02:07<03:39,  3.48s/it]

Epoch: 37 , valid loss: 1.0661 , cos_loss :  0.9126 , classification_loss : 0.1534
Accuracy: 0.9415 , AUC:  0.6659 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 38 , Total loss: 1.0651 , cos_loss :  0.915 , classification_loss : 0.1501


 38%|███▊      | 38/100 [02:10<03:34,  3.46s/it]

Epoch: 38 , valid loss: 1.0622 , cos_loss :  0.9124 , classification_loss : 0.1498
Accuracy: 0.9415 , AUC:  0.6805 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 39 , Total loss: 1.0626 , cos_loss :  0.9147 , classification_loss : 0.1479


 39%|███▉      | 39/100 [02:14<03:28,  3.42s/it]

Epoch: 39 , valid loss: 1.0632 , cos_loss :  0.9125 , classification_loss : 0.1507
Accuracy: 0.9415 , AUC:  0.6737 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 40 , Total loss: 1.0653 , cos_loss :  0.9149 , classification_loss : 0.1505


 40%|████      | 40/100 [02:17<03:23,  3.39s/it]

Epoch: 40 , valid loss: 1.0801 , cos_loss :  0.9126 , classification_loss : 0.1675
Accuracy: 0.9415 , AUC:  0.6704 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 41 , Total loss: 1.0677 , cos_loss :  0.9151 , classification_loss : 0.1527


 41%|████      | 41/100 [02:20<03:20,  3.40s/it]

Epoch: 41 , valid loss: 1.0622 , cos_loss :  0.9125 , classification_loss : 0.1497
Accuracy: 0.9415 , AUC:  0.6722 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 42 , Total loss: 1.0617 , cos_loss :  0.9149 , classification_loss : 0.1468


 42%|████▏     | 42/100 [02:24<03:15,  3.38s/it]

Epoch: 42 , valid loss: 1.0628 , cos_loss :  0.9124 , classification_loss : 0.1504
Accuracy: 0.9415 , AUC:  0.6783 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 43 , Total loss: 1.0634 , cos_loss :  0.915 , classification_loss : 0.1484


 43%|████▎     | 43/100 [02:27<03:13,  3.39s/it]

Epoch: 43 , valid loss: 1.0631 , cos_loss :  0.9124 , classification_loss : 0.1507
Accuracy: 0.9415 , AUC:  0.6802 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 44 , Total loss: 1.0621 , cos_loss :  0.9148 , classification_loss : 0.1473


 44%|████▍     | 44/100 [02:30<03:09,  3.38s/it]

Epoch: 44 , valid loss: 1.0623 , cos_loss :  0.9123 , classification_loss : 0.15
Accuracy: 0.9415 , AUC:  0.6601 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 45 , Total loss: 1.0637 , cos_loss :  0.9148 , classification_loss : 0.1489


 45%|████▌     | 45/100 [02:34<03:06,  3.40s/it]

Epoch: 45 , valid loss: 1.0623 , cos_loss :  0.9123 , classification_loss : 0.1499
Accuracy: 0.9415 , AUC:  0.6672 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 46 , Total loss: 1.0695 , cos_loss :  0.9148 , classification_loss : 0.1547


 46%|████▌     | 46/100 [02:37<03:06,  3.45s/it]

Epoch: 46 , valid loss: 1.0624 , cos_loss :  0.9122 , classification_loss : 0.1501
Accuracy: 0.9415 , AUC:  0.6801 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 47 , Total loss: 1.0651 , cos_loss :  0.9147 , classification_loss : 0.1504


 47%|████▋     | 47/100 [02:41<03:03,  3.47s/it]

Epoch: 47 , valid loss: 1.0619 , cos_loss :  0.9123 , classification_loss : 0.1496
Accuracy: 0.9415 , AUC:  0.6694 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 48 , Total loss: 1.063 , cos_loss :  0.9147 , classification_loss : 0.1482


 48%|████▊     | 48/100 [02:45<03:03,  3.53s/it]

Epoch: 48 , valid loss: 1.0644 , cos_loss :  0.9123 , classification_loss : 0.1521
Accuracy: 0.9415 , AUC:  0.6735 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 49 , Total loss: 1.0636 , cos_loss :  0.9148 , classification_loss : 0.1488


 49%|████▉     | 49/100 [02:48<02:58,  3.50s/it]

Epoch: 49 , valid loss: 1.0621 , cos_loss :  0.9123 , classification_loss : 0.1499
Accuracy: 0.9415 , AUC:  0.6745 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 50 , Total loss: 1.0631 , cos_loss :  0.915 , classification_loss : 0.1481


 49%|████▉     | 49/100 [02:52<02:59,  3.51s/it]

Epoch: 50 , valid loss: 1.0632 , cos_loss :  0.9122 , classification_loss : 0.151
Accuracy: 0.9415 , AUC:  0.6772 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Early stopping triggered at epoch 19
Best Combined Score (AUC): 0.4946





In [20]:
total_te_loss = 0
total_te_cos_loss = 0
total_te_classi_loss = 0
total_te_const_loss = 0

te_labels_list = []
te_predictions_list = []
te_probabilities_list = []

model_1_path = model_1_save_path

model_1.load_state_dict(torch.load(model_1_path))

model_1.eval()

with torch.no_grad():
    for batch_data in tqdm(test_loader):
        te_labels = batch_data['label'].to(device)
        te_output_1, te_next_visit_output_1, te_final_visit_classification_1, te_final_visit_1 = model_1(batch_data)
        y_te = torch.ones(te_output_1.size(0), dtype=torch.float, device=device)

        te_cosine_loss_mean_1 = cosine_embedding_loss(te_output_1, te_next_visit_output_1, y_te)
        classification_loss_te_1 = criterion(te_final_visit_classification_1.squeeze(), te_labels.long())
        te_loss = (cos_lambda * te_cosine_loss_mean_1) + (classi_lambda * classification_loss_te_1) 

        total_te_loss += te_loss.item()
        total_te_cos_loss += (te_cosine_loss_mean_1.item())
        total_te_classi_loss += (classification_loss_te_1.item())

        te_probs = F.softmax(te_final_visit_classification_1)
        te_predictions = torch.max(te_probs, 1)[1].view((len(te_labels),))
        
        te_labels_list.extend(te_labels.view(-1).cpu().numpy())
        te_predictions_list.extend(te_predictions.cpu().numpy())
        te_probabilities_list.extend(te_probs[:,1].cpu().numpy())

    avg_te_loss = total_te_loss / len(test_loader)
    avg_te_cos_loss = total_te_cos_loss / len(test_loader)
    avg_te_classi_loss = total_te_classi_loss / len(test_loader)

100%|██████████| 2/2 [00:00<00:00,  5.92it/s]


In [21]:
## 성능 지표 계산
accuracy = accuracy_score(te_labels_list, te_predictions_list)
auc = roc_auc_score(te_labels_list, te_probabilities_list)
f1 = f1_score(te_labels_list, te_predictions_list)
precision = precision_score(te_labels_list,te_predictions_list)
recall = recall_score(te_labels_list, te_predictions_list)

print("test loss:", round(avg_te_loss, 4), ", cos_loss : ", round(avg_te_cos_loss, 4),
      ", classification_loss :",round(avg_te_classi_loss, 4), 
      # ", contrastive_loss : ", round(avg_te_const_loss,4)
     )
print("Accuracy:", round(accuracy,4), ", AUC: ", round(auc,4), ", F1: ", round(f1,4), ", Precision: ", round(precision,4), ", recall: ", round(recall,4))

np.unique(te_labels_list,  return_counts=True)

test loss: 1.0737 , cos_loss :  0.9129 , classification_loss : 0.1607
Accuracy: 0.9453 , AUC:  0.7131 , F1:  0.5417 , Precision:  1.0 , recall:  0.3714


(array([0, 1]), array([734,  70]))

In [22]:
model_time

'175256'

In [23]:
log_path = './logs'
logs = f'model_lr{lr}_classi{classi_lambda}_dim{ninp}_hid{nhid}_layer{nlayer}_epoch{{epoch}}_{{model}}_{{pe}}_{model_time}.txt'
os.makedirs(os.path.join(log_path, date_dir), exist_ok=True)

results = [accuracy, auc, f1, precision, recall]
cm = confusion_matrix(te_labels_list, te_predictions_list)
with open(os.path.join(log_path, date_dir, logs.format(epoch=best_epoch, model=model_name, pe=pe)), 'w') as f:
    f.write(logs.format(epoch=best_epoch, model=model_name, pe=pe))
    f.write('\n')
    f.write(str(results))
    f.write('\n')
    f.write(str(cm))

In [24]:
confusion_matrix(te_labels_list, te_predictions_list)

array([[734,   0],
       [ 44,  26]])