In [1]:
import numpy as np
import pandas as pd
import os 
import pickle
import sys
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
from tqdm import tqdm
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix
import warnings

torch.set_printoptions(profile="full")
np.set_printoptions(threshold=sys.maxsize)
warnings.filterwarnings("ignore")

In [2]:
from src.attn import FixedPositionalEncoding, LearnablePositionalEncoding, TemporalEmbedding, MultiheadAttention, Decoder
from src.loss import ContrastiveLoss, FocalLoss

In [3]:
from datetime import datetime

In [4]:
seed = 777
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(seed)

In [5]:
path = '/data/notebook/shared/MIMIC-IV'
with open(os.path.join(path, 'dict_types_nomedi_mimic_240408_clinic_3_years.pkl'), 'rb') as f:
    dtype_dict = pickle.load(f)
f.close()

with open(os.path.join('./data', 'preprocessed_nomedi_240423_clinic_3_years.pkl'), 'rb') as f:
    data_dict_d = pickle.load(f)
f.close()

remain_year_clinical = pd.read_csv(os.path.join(path, "240408_remain_year_clinical_diag_seq_4.csv"))
del remain_year_clinical['clinical_label']

In [6]:
labels = []
code_labels = []
length_list = []
clinical_labels = []
code_length_list = []
for sample_id, visits in tqdm(data_dict_d.items()):
    # 레이블 추가
    label = visits['label']
    code_label = visits['code_label']
    clinical_label = visits['clinical_label']
    labels.append(label)
    code_labels.append(code_label)
    clinical_labels.append(clinical_labels)

100%|██████████| 8037/8037 [00:00<00:00, 524720.53it/s]


In [7]:
len(list(data_dict_d.keys()))

8037

In [8]:
len(labels)

8037

In [9]:
train_indices, test_indices, train_y, test_y = train_test_split(list(data_dict_d.keys()), labels, test_size=0.1, random_state=777, stratify=labels)
train_indices, valid_indices, valid_y, valid_y = train_test_split(train_indices, train_y, test_size=(len(test_indices)/len(train_indices)), random_state=777, stratify=train_y) 

In [10]:
train_data = {}
valid_data = {}
test_data = {}
for sample in tqdm(train_indices):
    train_data[sample] = data_dict_d[sample]

for sample in tqdm(valid_indices):
    valid_data[sample] = data_dict_d[sample]

for sample in tqdm(test_indices):
    test_data[sample] = data_dict_d[sample]

100%|██████████| 6429/6429 [00:00<00:00, 1365809.68it/s]
100%|██████████| 804/804 [00:00<00:00, 1129343.74it/s]
100%|██████████| 804/804 [00:00<00:00, 1112576.84it/s]


In [11]:
train_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(train_indices)].reset_index(drop=True)
valid_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(valid_indices)].reset_index(drop=True)
test_clinical = remain_year_clinical[remain_year_clinical['subject_id'].isin(test_indices)].reset_index(drop=True)

In [12]:
class CustomDataset(Dataset):
    def __init__(self, data, clinical_df, scaler, mode='train'):       
        self.keys = list(data.keys())  # 딕셔너리의 키 목록 저장
        self.data = data  # 딕셔너리에서 데이터만 추출하여 저장
        self.scaler = scaler
        if mode == 'train':
            scaled_data = self.scaler.fit_transform(clinical_df.iloc[:, 2:])
            scaled_clinical_df = pd.DataFrame(scaled_data, columns = clinical_df.iloc[:, 2:].columns)
            self.scaled_clinical_df = pd.concat([clinical_df.iloc[:, :2], scaled_clinical_df], axis=1)
        else:
            scaled_data = self.scaler.transform(clinical_df.iloc[:, 2:])
            scaled_clinical_df = pd.DataFrame(scaled_data, columns = clinical_df.iloc[:, 2:].columns)
            self.scaled_clinical_df = pd.concat([clinical_df.iloc[:, :2], scaled_clinical_df], axis=1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # x는 현재 방문까지의 모든 방문 데이터, y는 다음 방문 데이터
        padding_temp = torch.zeros((1, len(self.data[self.keys[idx]]['code_index'][0])), dtype=torch.long)
        
        origin_visit = torch.tensor(self.data[self.keys[idx]]['code_index'], dtype=torch.long)
        origin_mask = torch.tensor(self.data[self.keys[idx]]['seq_mask'], dtype=torch.long)
        origin_mask_final = torch.tensor(self.data[self.keys[idx]]['seq_mask_final'], dtype=torch.long)
        origin_mask_code = torch.tensor(self.data[self.keys[idx]]['seq_mask_code'], dtype=torch.long)
        
        next_visit = torch.cat((origin_visit[1:], padding_temp), dim=0)
        next_mask =torch.cat((origin_mask[1:], torch.tensor([0], dtype=torch.long)), dim=0)
        next_mask_code = torch.cat((origin_mask_code[1:], padding_temp), dim=0)
        
        clinical_data = self.scaled_clinical_df.loc[self.scaled_clinical_df['subject_id'] == self.keys[idx], self.scaled_clinical_df.columns[2:]].values
        clinical_data = torch.tensor(clinical_data, dtype=torch.float)
        visit_index = torch.tensor(self.data[self.keys[idx]]['year_onehot'], dtype=torch.long)
        last_visit_index = torch.tensor(self.data[self.keys[idx]]['last_year_onehot'], dtype=torch.float)
        time_feature = torch.tensor(self.data[self.keys[idx]]['time_feature'], dtype=torch.long)
        label_per_sample = self.data[self.keys[idx]]['label']
        key_per_sample = self.keys[idx]  # 해당 샘플의 키
        # 키 값도 함께 반환
        return {'sample_id': key_per_sample, 'origin_visit': origin_visit, 'next_visit': next_visit,\
                'origin_mask': origin_mask, 'origin_mask_final': origin_mask_final, 'next_mask': next_mask, \
                'origin_mask_code': origin_mask_code, 'next_mask_code': next_mask_code, 'clinical_data': clinical_data, \
                'visit_index': visit_index, 'last_visit_index': last_visit_index, 'time_feature': time_feature, \
                'label': label_per_sample}

In [13]:
scaler = StandardScaler()
train_dataset = CustomDataset(train_data, train_clinical, scaler, mode='train')
valid_dataset = CustomDataset(valid_data, valid_clinical, scaler, mode='valid')
test_dataset = CustomDataset(test_data, test_clinical, scaler, mode='test')

In [14]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=512, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

In [15]:
class CenterLoss(nn.Module):
    def __init__(self, num_classes=2, feat_dim=128, device='cuda:0'):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.device = device
        self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).to(self.device))
        
    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (batch_size).
        """
        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()

        distmat.addmm_(x, self.centers.t(), beta = 1, alpha = -2,)

        classes = torch.arange(self.num_classes).long().to(self.device)
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))

        dist = distmat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
        return loss

In [16]:
class CustomTransformerModel(nn.Module):
    def __init__(self, code_size, ninp, nhead, nhid, nlayers, dropout=0.5, device='cuda:0', pe='fixed'):
        super(CustomTransformerModel, self).__init__()
        self.ninp = ninp  # 축소된 차원 및 Transformer 입력 차원
        self.device = device
        # 차원 축소를 위한 선형 레이어
        self.pre_embedding = nn.Embedding(code_size, self.ninp)
        self.pe = pe
        if self.pe == 'fixed':
            self.pos_encoder = FixedPositionalEncoding(ninp, dropout)
        elif self.pe == 'time_feature':
            self.pos_encoder = TemporalEmbedding(ninp, embed_type='embed', dropout=dropout)
        else:
            self.pos_encoder = LearnablePositionalEncoding(ninp, dropout)
            
        self.transformer_decoder = Decoder(ninp, nhead, nhid, nlayers, dropout)
        self.decoder = nn.Linear(ninp, ninp, bias=False)  # 최종 출력 차원을 설정 (여기서는 ninp로 설정)
        self.classification_layer = nn.Linear(ninp * 2, 2)
        self.clinical_transform = nn.Linear(18, ninp) 
        self.cross_attn = MultiheadAttention(ninp, nhead, dropout=0.3)
        self.init_weights()
        
    def init_weights(self):
        nn.init.xavier_uniform_(self.pre_embedding.weight)
        nn.init.xavier_uniform_(self.decoder.weight)
        nn.init.xavier_uniform_(self.classification_layer.weight)
    
    def forward(self, batch_data):
        # 차원 축소
        origin_visit = batch_data['origin_visit'].to(self.device)
        next_visit = batch_data['next_visit'].to(self.device)
        clinical_tensor = batch_data['clinical_data'].to(self.device)
        mask_code = batch_data['origin_mask_code'].unsqueeze(3).to(self.device)
        next_mask_code = batch_data['next_mask_code'].unsqueeze(3).to(self.device)
        mask_final = batch_data['origin_mask_final'].unsqueeze(2).to(self.device)
        mask_final_year = batch_data['last_visit_index'].to(self.device)

                
        origin_visit_emb = (self.pre_embedding(origin_visit) * mask_code).sum(dim=2)
        next_visit_emb = (self.pre_embedding(next_visit) * next_mask_code).sum(dim=2)
        # 위치 인코딩 및 Transformer 디코더 적용
        if self.pe == 'time_feature':
            time_feature = batch_data['time_feature'].to(self.device)
            src = self.pos_encoder(origin_visit_emb, time_feature)
        else:
            src = self.pos_encoder(origin_visit_emb)
        
        seq_len = src.shape[1]
        src_mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=self.device), diagonal=1)
        output = self.transformer_decoder(src, attention_mask=src_mask)
        output = self.decoder(output)
        
        output_batch, output_visit_num, output_dim = output.size()
        next_visit_output_batch, next_visit_output_num, next_visit_output_dim = next_visit_emb.size()
        
        # print(output[:, :torch.,:])
        temp_output = output.reshape(-1, output_dim)
        temp_next_visit = next_visit_emb.reshape(-1, next_visit_output_dim)
        final_visit = (output * mask_final).sum(dim=1)
        year_emb = torch.bmm(output.transpose(1,2), mask_final_year).transpose(1,2)[:,:2,:]
        
        transformed_clinical = self.clinical_transform(clinical_tensor)
        mixed_output, mixed_cross_attn = self.cross_attn(year_emb, transformed_clinical, transformed_clinical)
        mixed_output = mixed_output[:,-1,:]

        mixed_final_emb = torch.cat((final_visit, mixed_output), dim=-1)
        classification_result = self.classification_layer(mixed_final_emb)    
        return temp_output, temp_next_visit, classification_result, final_visit

In [17]:
lr = 0.001
ninp = 64
nhid = 256
nlayer = 6
model_name = 'with_center_total_label'
pe = 'time_feature'
gamma = 0.5

In [18]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_1 = CustomTransformerModel(len(dtype_dict), ninp=ninp, nhead=8, nhid=nhid, nlayers=nlayer, dropout=0.1, device=device, pe=pe).to(device)
# model_2 = CustomTransformerModel(len(dtype_dict), ninp=ninp, nhead=8, nhid=nhid, nlayers=nlayer, dropout=0.6, device=device, pe='fixed').to(device)

optimizer = torch.optim.Adam(model_1.parameters(), lr=lr, weight_decay=1e-5)

# criterion = nn.BCEWithLogitsLoss()
criterion = FocalLoss(2, gamma=gamma)
# cosine_loss = CosineSimilarityLoss()
cosine_embedding_loss = nn.CosineEmbeddingLoss()
# const_loss = ContrastiveLoss(temperature=0.05)
center_loss = CenterLoss(num_classes=2, feat_dim=ninp, device=device)

In [None]:
# 얼리 스타핑 설정
num_epochs = 100
patience = 40
best_loss = float("inf")
counter = 0
epoch_temp = 0

best_f1 = 0.0
best_combined_score = 0.0
best_epoch = 0

view_total_loss = []
view_cos_loss = []
view_classi_loss = []

cos_lambda = 1
classi_lambda = 1
center_lambda = 5

for epoch in tqdm(range(num_epochs)):
    model_1.train()

    total_train_loss = 0
    total_cos_loss = 0
    total_classi_loss = 0
    total_center_loss = 0
    
    for batch_data in train_loader:
        tr_labels = batch_data['label'].to(device)
        optimizer.zero_grad()
        output_1, next_visit_output_1, final_visit_classification_1, final_visit_1 = model_1(batch_data)
        y = torch.ones(output_1.size(0), dtype=torch.float, device=device)
        cosine_loss_mean_1 = cosine_embedding_loss(output_1, next_visit_output_1, y.to(device))
        classification_loss_1 = criterion(final_visit_classification_1.squeeze(), tr_labels.long())
        center_loss_1 = center_loss(final_visit_1, tr_labels.long())
        loss = (cos_lambda * cosine_loss_mean_1) + (classi_lambda * classification_loss_1) + (center_lambda * center_loss_1)

        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        total_cos_loss += (cosine_loss_mean_1.item())
        total_classi_loss += (classification_loss_1.item())
        total_center_loss += (center_loss_1.item())

    # 평균 손실 계산
    avg_train_loss = total_train_loss / len(train_loader)
    avg_cos_loss = total_cos_loss / len(train_loader)
    avg_classi_loss = total_classi_loss / len(train_loader)
    avg_center_loss = total_center_loss / len(train_loader)

    print("Epoch:", epoch+1, ", Total loss:", round(avg_train_loss, 4), ", cos_loss : ",\
          round(avg_cos_loss, 4), ", classification_loss :",round(avg_classi_loss, 4), 
           ", center_loss :",round(avg_center_loss, 4))

    model_1.eval()
    total_val_loss = 0
    total_val_cos_loss = 0
    total_val_classi_loss = 0
    total_val_center_loss = 0

    val_labels_list = []
    val_predictions_list = []
    val_probabilities_list = []


    with torch.no_grad():
        for batch_data in valid_loader:
            val_labels = batch_data['label'].to(device)
            val_output_1, val_next_visit_output_1, val_final_visit_classification_1, val_final_visit_1 = model_1(batch_data)
            y_val = torch.ones(val_output_1.size(0), dtype=torch.float, device=device) 
            val_cosine_loss_mean_1 = cosine_embedding_loss(val_output_1, val_next_visit_output_1, y_val.to(device))            
            classification_loss_val_1 = criterion(val_final_visit_classification_1.squeeze(), val_labels.long())
            center_loss_val_1 = center_loss(val_final_visit_1, val_labels.long())
            val_loss = (cos_lambda * val_cosine_loss_mean_1) + (classi_lambda * classification_loss_val_1) + (center_lambda * center_loss_val_1)

            total_val_loss += val_loss.item()
            total_val_cos_loss += (val_cosine_loss_mean_1.item())
            total_val_classi_loss += (classification_loss_val_1.item())
            total_val_center_loss += (center_loss_val_1.item())

            val_probs = F.softmax(val_final_visit_classification_1)
            val_predictions = torch.max(val_probs, 1)[1].view((len(val_labels),))

            val_labels_list.extend(val_labels.view(-1).cpu().numpy())
            val_predictions_list.extend(val_predictions.cpu().numpy())
            
            ########
            val_probabilities_list.extend(val_probs[:,1].cpu().numpy())
            ######## 
            
        avg_val_loss = total_val_loss / len(valid_loader)
        avg_val_cos_loss = total_val_cos_loss / len(valid_loader)
        avg_val_classi_loss = total_val_classi_loss / len(valid_loader)
        avg_val_center_loss = total_val_center_loss / len(valid_loader)

    # 성능 지표 계산
    accuracy = accuracy_score(val_labels_list, val_predictions_list)
    auc = roc_auc_score(val_labels_list, val_probabilities_list)
    f1 = f1_score(val_labels_list, val_predictions_list)
    precision = precision_score(val_labels_list,val_predictions_list)
    recall = recall_score(val_labels_list, val_predictions_list)

    print("Epoch:", epoch+1, ", valid loss:", round(avg_val_loss, 4), ", cos_loss : ",\
          round(avg_val_cos_loss, 4), ", classification_loss :",round(avg_val_classi_loss, 4),\
          ", center_loss :",round(avg_val_center_loss, 4))
    print("Accuracy:", round(accuracy,4), ", AUC: ", round(auc,4), ", F1: ", round(f1,4), ", Precision: ", round(precision,4), ", recall: ", round(recall,4))

    current_auc = roc_auc_score(val_labels_list, val_probabilities_list)
    current_f1 = f1_score(val_labels_list, val_predictions_list)
    current_combined_score = current_f1

    # 모델 저장 경로 설정
    model_save_path = f'results'
    date_dir = datetime.today().strftime("%Y%m%d")
    model_time =  datetime.today().strftime("%H%M%S")
    # 학습률과 분류 가중치를 파일명에 포함시키기 위한 문자열 포맷
    model_filename_format = f'model_lr{lr}_classi{classi_lambda}_dim{ninp}_hid{nhid}_layer{nlayer}_epoch{{epoch}}_{{model}}_{{pe}}_{{gamma}}_{model_time}.pth'


    # 모델 저장 폴더가 없으면 생성
    os.makedirs(os.path.join(model_save_path, date_dir), exist_ok=True)

    if current_combined_score > best_combined_score:
        best_combined_score = current_combined_score
        best_epoch = epoch
        counter = 0
        # 모델 저장 경로와 파일명을 결합하여 전체 파일 경로 생성
        model_1_save_path = os.path.join(model_save_path, date_dir, model_filename_format.format(epoch=best_epoch, model=model_name, pe=pe, gamma=gamma))
        # 모델 저장
        torch.save(model_1.state_dict(), f'{model_1_save_path}')
    else:
        counter += 1

    if counter >= patience:
        print("Early stopping triggered at epoch {}".format(best_epoch))
        print(f"Best Combined Score (AUC): {best_combined_score:.4f}")
        break

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 1 , Total loss: 217.4407 , cos_loss :  0.999 , classification_loss : 0.3176 , center_loss : 43.2248


  1%|          | 1/100 [00:03<06:13,  3.78s/it]

Epoch: 1 , valid loss: 87.3244 , cos_loss :  0.9908 , classification_loss : 0.2171 , center_loss : 17.2233
Accuracy: 0.9129 , AUC:  0.5515 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 2 , Total loss: 79.9017 , cos_loss :  0.9831 , classification_loss : 0.2164 , center_loss : 15.7405


  2%|▏         | 2/100 [00:07<05:53,  3.61s/it]

Epoch: 2 , valid loss: 73.2042 , cos_loss :  0.9743 , classification_loss : 0.2054 , center_loss : 14.4049
Accuracy: 0.9129 , AUC:  0.5525 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 3 , Total loss: 76.5398 , cos_loss :  0.9671 , classification_loss : 0.2098 , center_loss : 15.0726


  3%|▎         | 3/100 [00:10<05:41,  3.52s/it]

Epoch: 3 , valid loss: 72.4654 , cos_loss :  0.9594 , classification_loss : 0.1965 , center_loss : 14.2619
Accuracy: 0.9129 , AUC:  0.6165 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 4 , Total loss: 75.0756 , cos_loss :  0.9545 , classification_loss : 0.2092 , center_loss : 14.7824


  4%|▍         | 4/100 [00:14<05:37,  3.51s/it]

Epoch: 4 , valid loss: 71.322 , cos_loss :  0.9493 , classification_loss : 0.2004 , center_loss : 14.0344
Accuracy: 0.9129 , AUC:  0.5925 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 5 , Total loss: 74.3799 , cos_loss :  0.9459 , classification_loss : 0.2053 , center_loss : 14.6457


  5%|▌         | 5/100 [00:17<05:29,  3.47s/it]

Epoch: 5 , valid loss: 71.3309 , cos_loss :  0.9422 , classification_loss : 0.1999 , center_loss : 14.0377
Accuracy: 0.9129 , AUC:  0.5855 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 6 , Total loss: 73.3742 , cos_loss :  0.9402 , classification_loss : 0.2023 , center_loss : 14.4463


  6%|▌         | 6/100 [00:20<05:23,  3.44s/it]

Epoch: 6 , valid loss: 71.1723 , cos_loss :  0.9374 , classification_loss : 0.199 , center_loss : 14.0072
Accuracy: 0.9129 , AUC:  0.6007 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 7 , Total loss: 75.156 , cos_loss :  0.936 , classification_loss : 0.2041 , center_loss : 14.8032


  7%|▋         | 7/100 [00:24<05:20,  3.45s/it]

Epoch: 7 , valid loss: 71.2328 , cos_loss :  0.9339 , classification_loss : 0.1962 , center_loss : 14.0205
Accuracy: 0.9129 , AUC:  0.6144 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 8 , Total loss: 74.6683 , cos_loss :  0.9331 , classification_loss : 0.2051 , center_loss : 14.706


  8%|▊         | 8/100 [00:27<05:16,  3.44s/it]

Epoch: 8 , valid loss: 71.2054 , cos_loss :  0.9312 , classification_loss : 0.197 , center_loss : 14.0154
Accuracy: 0.9129 , AUC:  0.6199 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 9 , Total loss: 73.506 , cos_loss :  0.9303 , classification_loss : 0.2017 , center_loss : 14.4748


  9%|▉         | 9/100 [00:31<05:13,  3.45s/it]

Epoch: 9 , valid loss: 71.0503 , cos_loss :  0.9291 , classification_loss : 0.1973 , center_loss : 13.9848
Accuracy: 0.9129 , AUC:  0.607 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 10 , Total loss: 73.4885 , cos_loss :  0.9288 , classification_loss : 0.1987 , center_loss : 14.4722


 10%|█         | 10/100 [00:34<05:08,  3.43s/it]

Epoch: 10 , valid loss: 70.9978 , cos_loss :  0.9274 , classification_loss : 0.197 , center_loss : 13.9747
Accuracy: 0.9129 , AUC:  0.63 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 11 , Total loss: 73.6971 , cos_loss :  0.9274 , classification_loss : 0.1988 , center_loss : 14.5142


 11%|█         | 11/100 [00:38<05:06,  3.44s/it]

Epoch: 11 , valid loss: 70.8228 , cos_loss :  0.9261 , classification_loss : 0.1937 , center_loss : 13.9406
Accuracy: 0.9129 , AUC:  0.6463 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 12 , Total loss: 74.31 , cos_loss :  0.9264 , classification_loss : 0.1974 , center_loss : 14.6373


 12%|█▏        | 12/100 [00:41<05:00,  3.42s/it]

Epoch: 12 , valid loss: 70.8109 , cos_loss :  0.9251 , classification_loss : 0.1951 , center_loss : 13.9381
Accuracy: 0.9129 , AUC:  0.6444 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 13 , Total loss: 73.3334 , cos_loss :  0.9253 , classification_loss : 0.1945 , center_loss : 14.4427


 13%|█▎        | 13/100 [00:45<05:07,  3.53s/it]

Epoch: 13 , valid loss: 70.3343 , cos_loss :  0.9242 , classification_loss : 0.1933 , center_loss : 13.8434
Accuracy: 0.9129 , AUC:  0.6471 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 14 , Total loss: 72.4574 , cos_loss :  0.9245 , classification_loss : 0.1881 , center_loss : 14.269


 14%|█▍        | 14/100 [00:48<05:06,  3.56s/it]

Epoch: 14 , valid loss: 67.4636 , cos_loss :  0.924 , classification_loss : 0.1798 , center_loss : 13.272
Accuracy: 0.9129 , AUC:  0.7131 , F1:  0.0 , Precision:  0.0 , recall:  0.0
Epoch: 15 , Total loss: 66.3395 , cos_loss :  0.9263 , classification_loss : 0.1649 , center_loss : 13.0496


 15%|█▌        | 15/100 [00:52<05:07,  3.62s/it]

Epoch: 15 , valid loss: 57.261 , cos_loss :  0.9332 , classification_loss : 0.1483 , center_loss : 11.2359
Accuracy: 0.9415 , AUC:  0.7415 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 16 , Total loss: 54.0156 , cos_loss :  0.9471 , classification_loss : 0.1426 , center_loss : 10.5852


 16%|█▌        | 16/100 [00:56<05:01,  3.59s/it]

Epoch: 16 , valid loss: 52.3937 , cos_loss :  0.9558 , classification_loss : 0.1477 , center_loss : 10.258
Accuracy: 0.9415 , AUC:  0.7634 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 17 , Total loss: 50.4412 , cos_loss :  0.9517 , classification_loss : 0.1419 , center_loss : 9.8695


 17%|█▋        | 17/100 [00:59<04:52,  3.52s/it]

Epoch: 17 , valid loss: 49.7297 , cos_loss :  0.9522 , classification_loss : 0.1457 , center_loss : 9.7264
Accuracy: 0.9415 , AUC:  0.7595 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 18 , Total loss: 48.5164 , cos_loss :  0.9503 , classification_loss : 0.1382 , center_loss : 9.4856


 18%|█▊        | 18/100 [01:03<04:47,  3.50s/it]

Epoch: 18 , valid loss: 49.2026 , cos_loss :  0.9505 , classification_loss : 0.1466 , center_loss : 9.6211
Accuracy: 0.9415 , AUC:  0.755 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 19 , Total loss: 49.0 , cos_loss :  0.9471 , classification_loss : 0.1392 , center_loss : 9.5827


 19%|█▉        | 19/100 [01:06<04:40,  3.47s/it]

Epoch: 19 , valid loss: 48.9401 , cos_loss :  0.9482 , classification_loss : 0.1456 , center_loss : 9.5693
Accuracy: 0.9415 , AUC:  0.7545 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 20 , Total loss: 47.9315 , cos_loss :  0.9449 , classification_loss : 0.1378 , center_loss : 9.3698


 20%|██        | 20/100 [01:09<04:37,  3.47s/it]

Epoch: 20 , valid loss: 48.9974 , cos_loss :  0.9472 , classification_loss : 0.1477 , center_loss : 9.5805
Accuracy: 0.9415 , AUC:  0.7474 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 21 , Total loss: 48.172 , cos_loss :  0.9432 , classification_loss : 0.1364 , center_loss : 9.4185


 21%|██        | 21/100 [01:13<04:31,  3.44s/it]

Epoch: 21 , valid loss: 48.7349 , cos_loss :  0.9453 , classification_loss : 0.1453 , center_loss : 9.5289
Accuracy: 0.9415 , AUC:  0.7511 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 22 , Total loss: 48.5327 , cos_loss :  0.9409 , classification_loss : 0.1367 , center_loss : 9.491


 22%|██▏       | 22/100 [01:16<04:27,  3.43s/it]

Epoch: 22 , valid loss: 48.7324 , cos_loss :  0.9434 , classification_loss : 0.1461 , center_loss : 9.5286
Accuracy: 0.9415 , AUC:  0.751 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 23 , Total loss: 47.9493 , cos_loss :  0.9394 , classification_loss : 0.1356 , center_loss : 9.3748


 23%|██▎       | 23/100 [01:20<04:26,  3.46s/it]

Epoch: 23 , valid loss: 48.8139 , cos_loss :  0.9422 , classification_loss : 0.1475 , center_loss : 9.5449
Accuracy: 0.9415 , AUC:  0.7542 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 24 , Total loss: 48.8917 , cos_loss :  0.938 , classification_loss : 0.1364 , center_loss : 9.5635


 24%|██▍       | 24/100 [01:23<04:21,  3.43s/it]

Epoch: 24 , valid loss: 48.7632 , cos_loss :  0.9404 , classification_loss : 0.1473 , center_loss : 9.5351
Accuracy: 0.9415 , AUC:  0.7507 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 25 , Total loss: 49.001 , cos_loss :  0.9364 , classification_loss : 0.1357 , center_loss : 9.5858


 25%|██▌       | 25/100 [01:27<04:18,  3.45s/it]

Epoch: 25 , valid loss: 48.7694 , cos_loss :  0.9393 , classification_loss : 0.149 , center_loss : 9.5362
Accuracy: 0.9415 , AUC:  0.7457 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 26 , Total loss: 47.7831 , cos_loss :  0.9354 , classification_loss : 0.1293 , center_loss : 9.3437


 26%|██▌       | 26/100 [01:30<04:13,  3.42s/it]

Epoch: 26 , valid loss: 48.6795 , cos_loss :  0.9388 , classification_loss : 0.1483 , center_loss : 9.5185
Accuracy: 0.9415 , AUC:  0.7506 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 27 , Total loss: 48.1959 , cos_loss :  0.9346 , classification_loss : 0.1333 , center_loss : 9.4256


 27%|██▋       | 27/100 [01:33<04:10,  3.44s/it]

Epoch: 27 , valid loss: 48.6957 , cos_loss :  0.9379 , classification_loss : 0.1494 , center_loss : 9.5217
Accuracy: 0.9415 , AUC:  0.7456 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 28 , Total loss: 47.9724 , cos_loss :  0.9338 , classification_loss : 0.132 , center_loss : 9.3813


 28%|██▊       | 28/100 [01:37<04:06,  3.42s/it]

Epoch: 28 , valid loss: 48.8422 , cos_loss :  0.9375 , classification_loss : 0.1501 , center_loss : 9.5509
Accuracy: 0.9415 , AUC:  0.7495 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 29 , Total loss: 47.9373 , cos_loss :  0.9332 , classification_loss : 0.131 , center_loss : 9.3746


 29%|██▉       | 29/100 [01:40<04:01,  3.40s/it]

Epoch: 29 , valid loss: 48.6811 , cos_loss :  0.9369 , classification_loss : 0.1507 , center_loss : 9.5187
Accuracy: 0.9415 , AUC:  0.74 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 30 , Total loss: 47.3734 , cos_loss :  0.9329 , classification_loss : 0.1284 , center_loss : 9.2624


 30%|███       | 30/100 [01:44<03:59,  3.42s/it]

Epoch: 30 , valid loss: 48.688 , cos_loss :  0.9363 , classification_loss : 0.1497 , center_loss : 9.5204
Accuracy: 0.9415 , AUC:  0.7521 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286
Epoch: 31 , Total loss: 48.737 , cos_loss :  0.9323 , classification_loss : 0.1294 , center_loss : 9.5351


 31%|███       | 31/100 [01:47<03:55,  3.41s/it]

Epoch: 31 , valid loss: 49.3047 , cos_loss :  0.9355 , classification_loss : 0.1508 , center_loss : 9.6437
Accuracy: 0.9415 , AUC:  0.7517 , F1:  0.4946 , Precision:  1.0 , recall:  0.3286


In [None]:
total_te_loss = 0
total_te_cos_loss = 0
total_te_classi_loss = 0
total_te_center_loss = 0

te_labels_list = []
te_predictions_list = []
te_probabilities_list = []

model_1_path = model_1_save_path
model_1.load_state_dict(torch.load(model_1_path))
model_1.eval()

with torch.no_grad():
    for batch_idx, batch_data in enumerate(tqdm(test_loader)):
        te_labels = batch_data['label'].to(device)
        te_output_1, te_next_visit_output_1, te_final_visit_classification_1, te_final_visit_1 = model_1(batch_data)
        y_te = torch.ones(te_output_1.size(0), dtype=torch.float, device=device)
        te_cosine_loss_mean_1 = cosine_embedding_loss(te_output_1, te_next_visit_output_1, y_te)
        classification_loss_te_1 = criterion(te_final_visit_classification_1.squeeze(), te_labels.long())
        center_loss_te_1 = center_loss(te_final_visit_1, te_labels.long())
        te_loss = (cos_lambda * te_cosine_loss_mean_1) + (classi_lambda * classification_loss_te_1) + (center_lambda * center_loss_te_1)

        total_te_loss += te_loss.item()
        total_te_cos_loss += (te_cosine_loss_mean_1.item())
        total_te_classi_loss += (classification_loss_te_1.item())
        total_te_center_loss += (center_loss_te_1.item())

        te_probs = F.softmax(te_final_visit_classification_1)
        te_predictions = torch.max(te_probs, 1)[1].view((len(te_labels),))
        
        te_labels_list.extend(te_labels.view(-1).cpu().numpy())
        te_predictions_list.extend(te_predictions.cpu().numpy())
        te_probabilities_list.extend(te_probs[:,1].cpu().numpy())
        
        if batch_idx == 0:
            test_visit_embedding = te_final_visit_1.detach().cpu().numpy()
            test_label_numpy = te_labels.detach().cpu().numpy()
        else:
            add_visit_embedding = te_final_visit_1.detach().cpu().numpy()
            add_label_numpy = te_labels.detach().cpu().numpy()
            test_visit_embedding = np.concatenate((test_visit_embedding, add_visit_embedding), axis=0)
            test_label_numpy = np.concatenate((test_label_numpy, add_label_numpy))

    avg_te_loss = total_te_loss / len(test_loader)
    avg_te_cos_loss = total_te_cos_loss / len(test_loader)
    avg_te_classi_loss = total_te_classi_loss / len(test_loader)
    avg_te_center_loss = total_te_center_loss / len(test_loader)

In [None]:
# 성능 지표 계산
accuracy = accuracy_score(te_labels_list, te_predictions_list)
auc = roc_auc_score(te_labels_list, te_probabilities_list)
f1 = f1_score(te_labels_list, te_predictions_list)
precision = precision_score(te_labels_list,te_predictions_list)
recall = recall_score(te_labels_list, te_predictions_list)

print("test loss:", round(avg_te_loss, 4), ", cos_loss : ", round(avg_te_cos_loss, 4),
      ", classification_loss :",round(avg_te_classi_loss, 4), 
      ", center_loss : ", round(avg_te_center_loss, 4)
     )
print("Accuracy:", round(accuracy,4), ", AUC: ", round(auc,4), ", F1: ", round(f1,4), ", Precision: ", round(precision,4), ", recall: ", round(recall,4))

np.unique(te_labels_list,  return_counts=True)

In [None]:
log_path = './logs'
logs = f'model_lr{lr}_classi{classi_lambda}_dim{ninp}_hid{nhid}_layer{nlayer}_epoch{{epoch}}_{{model}}_{{pe}}_{model_time}.txt'
os.makedirs(os.path.join(log_path, date_dir), exist_ok=True)

results = [accuracy, auc, f1, precision, recall]
cm = confusion_matrix(te_labels_list, te_predictions_list)
with open(os.path.join(log_path, date_dir, logs.format(epoch=best_epoch, model=model_name, pe=pe)), 'w') as f:
    f.write(logs)
    f.write('\n')
    f.write(str(results))
    f.write('\n')
    f.write(str(cm))

In [None]:
confusion_matrix(te_labels_list, te_predictions_list)

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

tsne_model = TSNE(n_components=2)
reduction_emb = tsne_model.fit_transform(test_visit_embedding)

In [None]:
fig, ax = plt.subplots()
for g in np.unique(test_label_numpy):
    ix = np.where(test_label_numpy == g)
    ax.scatter(reduction_emb[ix, 0], reduction_emb[ix, 1], label = g, s = 10)
ax.legend()
plt.show()