In [3]:
import pandas as pd
import numpy as np
import re

import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_auc_score, average_precision_score, classification_report


import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError
from torchmetrics.classification import BinaryAUROC, BinaryAveragePrecision
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import Callback

In [4]:
DATA_MAKER = "yeonseo"
if DATA_MAKER == "jiseock":
    DATA_PATH = "../dataset/jiseock"
else:
    DATA_PATH = "../dataset/yeonseo"

X_train = pd.read_csv(f"{DATA_PATH}/X_train.csv")
y_train = pd.read_csv(f"{DATA_PATH}/y_train.csv")

In [5]:
X_train

Unnamed: 0,나이,"성별 (M:1,F:2)","Rt:1,Lt:2",Height,Weight,"Tearsize (AP,cm)",Tearsize (ML),Tearsize (retraction),"흡연여부 (비흡연:1,흡연:2)","흡연여부 (비흡연:1,흡연:2) Missing flag",...,6M Goutallier (ISP),6M Goutallier (TM),Pre Goutallier (SSP) Missing flag,Pre Goutallier (SSC) Missing flag,Pre Goutallier (ISP) Missing flag,Pre Goutallier (TM) Missing flag,6M Goutallier (SSP) Missing flag,6M Goutallier (SSC) Missing flag,6M Goutallier (ISP) Missing flag,6M Goutallier (TM) Missing flag
0,83,2,2,-0.695229,-1.339841,-0.666718,-0.098987,-0.102130,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
1,55,2,1,-0.478311,-0.456936,-0.666718,-0.568482,-0.554538,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
2,70,2,2,0.268850,1.325693,-0.666718,-0.098987,-0.102130,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
3,75,2,1,-0.132850,0.678229,-0.666718,-0.685856,-0.667640,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
4,57,2,1,-0.454209,-0.498979,-0.666718,-0.685856,-0.667640,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8441,72,1,1,0.551094,0.699066,0.760231,1.485044,1.424252,1.0,1.0,...,2.588878,3.359851,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
8442,59,2,1,0.354153,0.306110,1.107813,1.032095,0.987788,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
8443,59,1,1,0.456311,0.204000,-0.184000,-0.218529,-0.217322,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
8444,73,2,2,0.000185,0.144777,0.987286,1.074749,1.028890,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [6]:
y_train

Unnamed: 0,POD 6M retear
0,0
1,0
2,0
3,0
4,0
...,...
8441,1
8442,1
8443,1
8444,1


In [7]:
columns = list(X_train.columns)
columns

['나이',
 '성별 (M:1,F:2)',
 'Rt:1,Lt:2',
 'Height',
 'Weight',
 'Tearsize (AP,cm)',
 'Tearsize (ML)',
 'Tearsize (retraction)',
 '흡연여부 (비흡연:1,흡연:2)',
 '흡연여부 (비흡연:1,흡연:2) Missing flag',
 'Hospital 0',
 'Hospital 1',
 'Hospital 2',
 'Hospital 3',
 'Hospital 4',
 'Hospital 5',
 'Hospital 6',
 'Disease 0',
 'Disease 1',
 'Disease 2',
 'Disease 3',
 'Disease 4',
 'Disease 5',
 'Disease 6',
 'Disease 7',
 '0M ASES',
 '0M CSS',
 '0M ERabd',
 '0M ERside',
 '0M FF',
 '0M IR',
 '0M KSS',
 '0M MMTgrade',
 '0M MMTsec',
 '0M VAS(activity)',
 '0M VAS(resting)',
 '0M add',
 '2M ERabd',
 '2M ERside',
 '2M FF',
 '2M IR',
 '2M MMTgrade',
 '2M MMTsec',
 '2M add',
 '3M ASES',
 '3M CSS',
 '3M ERabd',
 '3M ERside',
 '3M FF',
 '3M IR',
 '3M KSS',
 '3M MMTgrade',
 '3M MMTsec',
 '3M VAS(activity)',
 '3M VAS(resting)',
 '3M add',
 '4M ASES',
 '4M CSS',
 '4M ERabd',
 '4M ERside',
 '4M FF',
 '4M IR',
 '4M KSS',
 '4M MMTgrade',
 '4M MMTsec',
 '4M VAS(activity)',
 '4M VAS(resting)',
 '4M add',
 '6M ASES',
 '6M CSS',
 

In [8]:
static_columns = columns[:25]

# static 데이터 칼럼
static_columns

['나이',
 '성별 (M:1,F:2)',
 'Rt:1,Lt:2',
 'Height',
 'Weight',
 'Tearsize (AP,cm)',
 'Tearsize (ML)',
 'Tearsize (retraction)',
 '흡연여부 (비흡연:1,흡연:2)',
 '흡연여부 (비흡연:1,흡연:2) Missing flag',
 'Hospital 0',
 'Hospital 1',
 'Hospital 2',
 'Hospital 3',
 'Hospital 4',
 'Hospital 5',
 'Hospital 6',
 'Disease 0',
 'Disease 1',
 'Disease 2',
 'Disease 3',
 'Disease 4',
 'Disease 5',
 'Disease 6',
 'Disease 7']

In [9]:
seq_columns = columns[25:-16]

# 시퀀셜 데이터 관련 칼럼들
seq_columns_0M = seq_columns[:12]
seq_columns_2M = seq_columns[12:19]
seq_columns_3M = seq_columns[19:31]
seq_columns_4M = seq_columns[31:43]
seq_columns_6M = seq_columns[43:]

seq_columns_all = [seq_columns_0M, seq_columns_2M, seq_columns_3M, seq_columns_4M, seq_columns_6M]

for seq_col in seq_columns_all:
    print(seq_col)

['0M ASES', '0M CSS', '0M ERabd', '0M ERside', '0M FF', '0M IR', '0M KSS', '0M MMTgrade', '0M MMTsec', '0M VAS(activity)', '0M VAS(resting)', '0M add']
['2M ERabd', '2M ERside', '2M FF', '2M IR', '2M MMTgrade', '2M MMTsec', '2M add']
['3M ASES', '3M CSS', '3M ERabd', '3M ERside', '3M FF', '3M IR', '3M KSS', '3M MMTgrade', '3M MMTsec', '3M VAS(activity)', '3M VAS(resting)', '3M add']
['4M ASES', '4M CSS', '4M ERabd', '4M ERside', '4M FF', '4M IR', '4M KSS', '4M MMTgrade', '4M MMTsec', '4M VAS(activity)', '4M VAS(resting)', '4M add']
['6M ASES', '6M CSS', '6M ERabd', '6M ERside', '6M FF', '6M IR', '6M KSS', '6M MMTgrade', '6M MMTsec', '6M VAS(activity)', '6M VAS(resting)', '6M add']


In [10]:
goutallier_columns = columns[-16:]

# goutaliar 관련 칼럼들
goutallier_columns_0M = goutallier_columns [:4]
goutallier_columns_6M = goutallier_columns [4:8]
goutallier_columns_0M_missing = goutallier_columns [8:12]
goutallier_columns_6M_missing = goutallier_columns [12:]

print(goutallier_columns_0M)
print(goutallier_columns_6M)
print(goutallier_columns_0M_missing)
print(goutallier_columns_6M_missing)

['Pre Goutallier (SSP)', 'Pre Goutallier (SSC)', 'Pre Goutallier (ISP)', 'Pre Goutallier (TM)']
['6M Goutallier (SSP)', '6M Goutallier (SSC)', '6M Goutallier (ISP)', '6M Goutallier (TM)']
['Pre Goutallier (SSP) Missing flag', 'Pre Goutallier (SSC) Missing flag', 'Pre Goutallier (ISP) Missing flag', 'Pre Goutallier (TM) Missing flag']
['6M Goutallier (SSP) Missing flag', '6M Goutallier (SSC) Missing flag', '6M Goutallier (ISP) Missing flag', '6M Goutallier (TM) Missing flag']


In [11]:
len(columns) == len(static_columns) + len(seq_columns) + len(goutallier_columns)

True

In [12]:
label_column = "POD 6M retear"
output_columns = ["6M ASES", "6M CSS", "6M KSS", "6M VAS(activity)", "6M VAS(resting)"]
input_columns = static_columns + [column for column in seq_columns if column not in output_columns] + goutallier_columns

In [13]:
output_columns

['6M ASES', '6M CSS', '6M KSS', '6M VAS(activity)', '6M VAS(resting)']

In [14]:
input_columns

['나이',
 '성별 (M:1,F:2)',
 'Rt:1,Lt:2',
 'Height',
 'Weight',
 'Tearsize (AP,cm)',
 'Tearsize (ML)',
 'Tearsize (retraction)',
 '흡연여부 (비흡연:1,흡연:2)',
 '흡연여부 (비흡연:1,흡연:2) Missing flag',
 'Hospital 0',
 'Hospital 1',
 'Hospital 2',
 'Hospital 3',
 'Hospital 4',
 'Hospital 5',
 'Hospital 6',
 'Disease 0',
 'Disease 1',
 'Disease 2',
 'Disease 3',
 'Disease 4',
 'Disease 5',
 'Disease 6',
 'Disease 7',
 '0M ASES',
 '0M CSS',
 '0M ERabd',
 '0M ERside',
 '0M FF',
 '0M IR',
 '0M KSS',
 '0M MMTgrade',
 '0M MMTsec',
 '0M VAS(activity)',
 '0M VAS(resting)',
 '0M add',
 '2M ERabd',
 '2M ERside',
 '2M FF',
 '2M IR',
 '2M MMTgrade',
 '2M MMTsec',
 '2M add',
 '3M ASES',
 '3M CSS',
 '3M ERabd',
 '3M ERside',
 '3M FF',
 '3M IR',
 '3M KSS',
 '3M MMTgrade',
 '3M MMTsec',
 '3M VAS(activity)',
 '3M VAS(resting)',
 '3M add',
 '4M ASES',
 '4M CSS',
 '4M ERabd',
 '4M ERside',
 '4M FF',
 '4M IR',
 '4M KSS',
 '4M MMTgrade',
 '4M MMTsec',
 '4M VAS(activity)',
 '4M VAS(resting)',
 '4M add',
 '6M ERabd',
 '6M ERside

In [15]:
X_train[input_columns]

Unnamed: 0,나이,"성별 (M:1,F:2)","Rt:1,Lt:2",Height,Weight,"Tearsize (AP,cm)",Tearsize (ML),Tearsize (retraction),"흡연여부 (비흡연:1,흡연:2)","흡연여부 (비흡연:1,흡연:2) Missing flag",...,6M Goutallier (ISP),6M Goutallier (TM),Pre Goutallier (SSP) Missing flag,Pre Goutallier (SSC) Missing flag,Pre Goutallier (ISP) Missing flag,Pre Goutallier (TM) Missing flag,6M Goutallier (SSP) Missing flag,6M Goutallier (SSC) Missing flag,6M Goutallier (ISP) Missing flag,6M Goutallier (TM) Missing flag
0,83,2,2,-0.695229,-1.339841,-0.666718,-0.098987,-0.102130,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
1,55,2,1,-0.478311,-0.456936,-0.666718,-0.568482,-0.554538,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
2,70,2,2,0.268850,1.325693,-0.666718,-0.098987,-0.102130,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
3,75,2,1,-0.132850,0.678229,-0.666718,-0.685856,-0.667640,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
4,57,2,1,-0.454209,-0.498979,-0.666718,-0.685856,-0.667640,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8441,72,1,1,0.551094,0.699066,0.760231,1.485044,1.424252,1.0,1.0,...,2.588878,3.359851,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
8442,59,2,1,0.354153,0.306110,1.107813,1.032095,0.987788,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
8443,59,1,1,0.456311,0.204000,-0.184000,-0.218529,-0.217322,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
8444,73,2,2,0.000185,0.144777,0.987286,1.074749,1.028890,1.0,1.0,...,-0.230515,-0.202217,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [16]:
pd.concat([y_train, X_train[output_columns]], axis=1)

Unnamed: 0,POD 6M retear,6M ASES,6M CSS,6M KSS,6M VAS(activity),6M VAS(resting)
0,0,-0.511552,-0.729534,-0.111492,0.184973,0.195548
1,0,0.358745,1.458271,-0.058845,0.184973,0.195548
2,0,-0.358204,-0.323929,-0.608714,0.184973,0.195548
3,0,-0.328331,-0.732607,-0.614563,0.184973,0.195548
4,0,-0.324348,-0.628133,-0.202746,0.184973,0.195548
...,...,...,...,...,...,...
8441,1,-1.860095,-2.156255,0.763181,1.654141,1.660600
8442,1,-0.555922,0.420936,-0.739752,2.078198,2.083469
8443,1,0.093389,0.519937,-0.697858,0.485109,0.494844
8444,1,-0.978093,-0.573777,-0.302896,1.882890,1.888709


In [19]:
def get_dataset(split):
  assert split in ["train", "test"]
  
  X_file_name = f"{DATA_PATH}/X_{split}.csv"
  y_file_name = f"{DATA_PATH}/y_{split}.csv"

  X = pd.read_csv(X_file_name)
  y = pd.read_csv(y_file_name)
  
  # static 데이터
  X_static_tensor = torch.tensor(X[static_columns].to_numpy(), dtype=torch.float32)

  # 시기별 sequential 데이터
  X_seq_tensor_0M = torch.tensor(X[seq_columns_0M].to_numpy(), dtype=torch.float32)
  X_seq_tensor_2M = torch.tensor(X[seq_columns_2M].to_numpy(), dtype=torch.float32)
  X_seq_tensor_3M = torch.tensor(X[seq_columns_3M].to_numpy(), dtype=torch.float32)
  X_seq_tensor_4M = torch.tensor(X[seq_columns_4M].to_numpy(), dtype=torch.float32)
  X_seq_tensor_6M = torch.tensor(X[seq_columns_6M].to_numpy(), dtype=torch.float32)
  
  #0M, 6M goutalier 데이터
  X_goutalier_tensor_0M = torch.tensor(X[goutallier_columns_0M + goutallier_columns_0M_missing].to_numpy(), dtype=torch.float32)
  X_goutalier_tensor_6M = torch.tensor(X[goutallier_columns_6M + goutallier_columns_6M_missing].to_numpy(), dtype=torch.float32)
  
  # 전체 인풋 데이터
  X_tensor = torch.tensor(X[input_columns].to_numpy(), dtype=torch.float32)
  
  # 6M 예측 데이터
  y_tensor = torch.tensor(pd.concat([y, X[output_columns]], axis=1).to_numpy(), dtype=torch.float32)

  return TensorDataset(X_tensor, X_static_tensor, X_seq_tensor_0M, X_seq_tensor_2M, X_seq_tensor_3M, X_seq_tensor_4M, X_seq_tensor_6M, X_goutalier_tensor_0M, X_goutalier_tensor_6M, y_tensor)

In [None]:
# 기존 데이터셋 (참고용)
trainset = get_dataset("train")
testset = get_dataset("test")

print("기존 데이터셋 구조 확인")
print(f"Trainset size: {len(trainset)}")
print(f"Testset size: {len(testset)}")

In [None]:
# 특징 크기 확인
static_features = len(static_columns)
seq_features_0M = len(seq_columns_0M)
seq_features_2M = len(seq_columns_2M)
seq_features_3M = len(seq_columns_3M)
seq_features_4M = len(seq_columns_4M)
seq_features_6M = len(seq_columns_6M)
goutallier_features_0M = len(goutallier_columns_0M) + len(goutallier_columns_0M_missing)
goutallier_features_6M = len(goutallier_columns_6M) + len(goutallier_columns_6M_missing)

print(f"Static features: {static_features}")
print(f"0M features: {seq_features_0M}, 2M features: {seq_features_2M}, 3M features: {seq_features_3M}")
print(f"4M features: {seq_features_4M}, 6M features: {seq_features_6M}")
print(f"0M Goutallier features: {goutallier_features_0M}, 6M Goutallier features: {goutallier_features_6M}")

In [None]:
# 각 모델별 데이터셋 생성 함수

def get_dataset_model1(split):
    """Model 1: static + 0M + 0M_goutallier → 2M"""
    assert split in ["train", "test"]
    
    X_file_name = f"{DATA_PATH}/X_{split}.csv"
    y_file_name = f"{DATA_PATH}/y_{split}.csv"
    
    X = pd.read_csv(X_file_name)
    y = pd.read_csv(y_file_name)
    
    # 입력: static + 0M + 0M_goutallier
    X_static = torch.tensor(X[static_columns].to_numpy(), dtype=torch.float32)
    X_0M = torch.tensor(X[seq_columns_0M].to_numpy(), dtype=torch.float32)
    X_0M_goutallier = torch.tensor(X[goutallier_columns_0M + goutallier_columns_0M_missing].to_numpy(), dtype=torch.float32)
    
    # 출력: 2M
    y_2M = torch.tensor(X[seq_columns_2M].to_numpy(), dtype=torch.float32)
    
    return TensorDataset(X_static, X_0M, X_0M_goutallier, y_2M)

def get_dataset_model2(split):
    """Model 2: static + 0M + 2M + 0M_goutallier → 3M"""
    assert split in ["train", "test"]
    
    X_file_name = f"{DATA_PATH}/X_{split}.csv"
    y_file_name = f"{DATA_PATH}/y_{split}.csv"
    
    X = pd.read_csv(X_file_name)
    y = pd.read_csv(y_file_name)
    
    # 입력: static + 0M + 2M + 0M_goutallier
    X_static = torch.tensor(X[static_columns].to_numpy(), dtype=torch.float32)
    X_0M = torch.tensor(X[seq_columns_0M].to_numpy(), dtype=torch.float32)
    X_2M = torch.tensor(X[seq_columns_2M].to_numpy(), dtype=torch.float32)
    X_0M_goutallier = torch.tensor(X[goutallier_columns_0M + goutallier_columns_0M_missing].to_numpy(), dtype=torch.float32)
    
    # 출력: 3M
    y_3M = torch.tensor(X[seq_columns_3M].to_numpy(), dtype=torch.float32)
    
    return TensorDataset(X_static, X_0M, X_2M, X_0M_goutallier, y_3M)

def get_dataset_model3(split):
    """Model 3: static + 0M + 2M + 3M + 0M_goutallier → 4M"""
    assert split in ["train", "test"]
    
    X_file_name = f"{DATA_PATH}/X_{split}.csv"
    y_file_name = f"{DATA_PATH}/y_{split}.csv"
    
    X = pd.read_csv(X_file_name)
    y = pd.read_csv(y_file_name)
    
    # 입력: static + 0M + 2M + 3M + 0M_goutallier
    X_static = torch.tensor(X[static_columns].to_numpy(), dtype=torch.float32)
    X_0M = torch.tensor(X[seq_columns_0M].to_numpy(), dtype=torch.float32)
    X_2M = torch.tensor(X[seq_columns_2M].to_numpy(), dtype=torch.float32)
    X_3M = torch.tensor(X[seq_columns_3M].to_numpy(), dtype=torch.float32)
    X_0M_goutallier = torch.tensor(X[goutallier_columns_0M + goutallier_columns_0M_missing].to_numpy(), dtype=torch.float32)
    
    # 출력: 4M
    y_4M = torch.tensor(X[seq_columns_4M].to_numpy(), dtype=torch.float32)
    
    return TensorDataset(X_static, X_0M, X_2M, X_3M, X_0M_goutallier, y_4M)

def get_dataset_model4(split):
    """Model 4: static + 0M + 2M + 3M + 4M + 0M_goutallier → 6M + y + 6M_goutallier"""
    assert split in ["train", "test"]
    
    X_file_name = f"{DATA_PATH}/X_{split}.csv"
    y_file_name = f"{DATA_PATH}/y_{split}.csv"
    
    X = pd.read_csv(X_file_name)
    y = pd.read_csv(y_file_name)
    
    # 입력: static + 0M + 2M + 3M + 4M + 0M_goutallier
    X_static = torch.tensor(X[static_columns].to_numpy(), dtype=torch.float32)
    X_0M = torch.tensor(X[seq_columns_0M].to_numpy(), dtype=torch.float32)
    X_2M = torch.tensor(X[seq_columns_2M].to_numpy(), dtype=torch.float32)
    X_3M = torch.tensor(X[seq_columns_3M].to_numpy(), dtype=torch.float32)
    X_4M = torch.tensor(X[seq_columns_4M].to_numpy(), dtype=torch.float32)
    X_0M_goutallier = torch.tensor(X[goutallier_columns_0M + goutallier_columns_0M_missing].to_numpy(), dtype=torch.float32)
    
    # 출력: 6M + y + 6M_goutallier
    y_6M = torch.tensor(X[seq_columns_6M].to_numpy(), dtype=torch.float32)
    y_label = torch.tensor(y[label_column].to_numpy(), dtype=torch.float32).unsqueeze(1)
    y_6M_goutallier = torch.tensor(X[goutallier_columns_6M + goutallier_columns_6M_missing].to_numpy(), dtype=torch.float32)
    
    # 결합: [6M features (12) + y (1) + 6M_goutallier (8)] = 21
    y_combined = torch.cat([y_6M, y_label, y_6M_goutallier], dim=1)
    
    return TensorDataset(X_static, X_0M, X_2M, X_3M, X_4M, X_0M_goutallier, y_combined)

# 데이터셋 생성
trainset_model1 = get_dataset_model1("train")
testset_model1 = get_dataset_model1("test")
trainset_model2 = get_dataset_model2("train")
testset_model2 = get_dataset_model2("test")
trainset_model3 = get_dataset_model3("train")
testset_model3 = get_dataset_model3("test")
trainset_model4 = get_dataset_model4("train")
testset_model4 = get_dataset_model4("test")

print(f"Model 1 - Train: {len(trainset_model1)}, Test: {len(testset_model1)}")
print(f"Model 2 - Train: {len(trainset_model2)}, Test: {len(testset_model2)}")
print(f"Model 3 - Train: {len(trainset_model3)}, Test: {len(testset_model3)}")
print(f"Model 4 - Train: {len(trainset_model4)}, Test: {len(testset_model4)}")


In [None]:
# Model 1: static + 0M + 0M_goutallier → 2M
class SequentialMLP1(L.LightningModule):
    def __init__(self, static_features, seq_0M_features, goutallier_0M_features, out_features_2M):
        super().__init__()
        
        # 입력 인코더
        self.static_encoder = nn.Sequential(
            nn.Linear(static_features, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        self.seq_0M_encoder = nn.Sequential(
            nn.Linear(seq_0M_features, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        self.goutallier_0M_encoder = nn.Sequential(
            nn.Linear(goutallier_0M_features, 32),
            nn.LayerNorm(32),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        # 특징 결합 후 출력
        feat_dim = 64 + 64 + 32
        self.output_head = nn.Sequential(
            nn.Linear(feat_dim, 128),
            nn.LayerNorm(128),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, out_features_2M)
        )
        
        self.train_mse = MeanSquaredError()
        self.val_mse = MeanSquaredError()
        
    def forward(self, x_static, x_0M, x_0M_goutallier):
        static_feat = self.static_encoder(x_static)
        seq_0M_feat = self.seq_0M_encoder(x_0M)
        goutallier_0M_feat = self.goutallier_0M_encoder(x_0M_goutallier)
        
        combined = torch.cat([static_feat, seq_0M_feat, goutallier_0M_feat], dim=1)
        output = self.output_head(combined)
        return output
    
    def training_step(self, batch, batch_idx):
        x_static, x_0M, x_0M_goutallier, y_2M = batch
        pred_2M = self.forward(x_static, x_0M, x_0M_goutallier)
        loss = F.mse_loss(pred_2M, y_2M)
        
        self.log("train/loss", loss, on_epoch=True, prog_bar=True)
        self.train_mse.update(pred_2M, y_2M)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x_static, x_0M, x_0M_goutallier, y_2M = batch
        pred_2M = self.forward(x_static, x_0M, x_0M_goutallier)
        loss = F.mse_loss(pred_2M, y_2M)
        
        self.log("val/loss", loss, on_epoch=True, prog_bar=True)
        self.val_mse.update(pred_2M, y_2M)
        return loss
    
    def on_train_epoch_end(self):
        self.log("train/mse", self.train_mse.compute())
        self.train_mse.reset()
    
    def on_validation_epoch_end(self):
        self.log("val/mse", self.val_mse.compute())
        self.val_mse.reset()
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=5e-6, weight_decay=1e-4)


In [None]:
# Model 2: static + 0M + 2M + 0M_goutallier → 3M
class SequentialMLP2(L.LightningModule):
    def __init__(self, static_features, seq_0M_features, seq_2M_features, goutallier_0M_features, out_features_3M):
        super().__init__()
        
        # 입력 인코더
        self.static_encoder = nn.Sequential(
            nn.Linear(static_features, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        self.seq_0M_encoder = nn.Sequential(
            nn.Linear(seq_0M_features, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        self.seq_2M_encoder = nn.Sequential(
            nn.Linear(seq_2M_features, 32),
            nn.LayerNorm(32),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        self.goutallier_0M_encoder = nn.Sequential(
            nn.Linear(goutallier_0M_features, 32),
            nn.LayerNorm(32),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        # 특징 결합 후 출력
        feat_dim = 64 + 64 + 32 + 32
        self.output_head = nn.Sequential(
            nn.Linear(feat_dim, 128),
            nn.LayerNorm(128),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, out_features_3M)
        )
        
        self.train_mse = MeanSquaredError()
        self.val_mse = MeanSquaredError()
        
    def forward(self, x_static, x_0M, x_2M, x_0M_goutallier):
        static_feat = self.static_encoder(x_static)
        seq_0M_feat = self.seq_0M_encoder(x_0M)
        seq_2M_feat = self.seq_2M_encoder(x_2M)
        goutallier_0M_feat = self.goutallier_0M_encoder(x_0M_goutallier)
        
        combined = torch.cat([static_feat, seq_0M_feat, seq_2M_feat, goutallier_0M_feat], dim=1)
        output = self.output_head(combined)
        return output
    
    def training_step(self, batch, batch_idx):
        x_static, x_0M, x_2M, x_0M_goutallier, y_3M = batch
        pred_3M = self.forward(x_static, x_0M, x_2M, x_0M_goutallier)
        loss = F.mse_loss(pred_3M, y_3M)
        
        self.log("train/loss", loss, on_epoch=True, prog_bar=True)
        self.train_mse.update(pred_3M, y_3M)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x_static, x_0M, x_2M, x_0M_goutallier, y_3M = batch
        pred_3M = self.forward(x_static, x_0M, x_2M, x_0M_goutallier)
        loss = F.mse_loss(pred_3M, y_3M)
        
        self.log("val/loss", loss, on_epoch=True, prog_bar=True)
        self.val_mse.update(pred_3M, y_3M)
        return loss
    
    def on_train_epoch_end(self):
        self.log("train/mse", self.train_mse.compute())
        self.train_mse.reset()
    
    def on_validation_epoch_end(self):
        self.log("val/mse", self.val_mse.compute())
        self.val_mse.reset()
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=5e-6, weight_decay=1e-4)


In [None]:
# Model 3: static + 0M + 2M + 3M + 0M_goutallier → 4M
class SequentialMLP3(L.LightningModule):
    def __init__(self, static_features, seq_0M_features, seq_2M_features, seq_3M_features, goutallier_0M_features, out_features_4M):
        super().__init__()
        
        # 입력 인코더
        self.static_encoder = nn.Sequential(
            nn.Linear(static_features, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        self.seq_0M_encoder = nn.Sequential(
            nn.Linear(seq_0M_features, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        self.seq_2M_encoder = nn.Sequential(
            nn.Linear(seq_2M_features, 32),
            nn.LayerNorm(32),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        self.seq_3M_encoder = nn.Sequential(
            nn.Linear(seq_3M_features, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        self.goutallier_0M_encoder = nn.Sequential(
            nn.Linear(goutallier_0M_features, 32),
            nn.LayerNorm(32),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        # 특징 결합 후 출력
        feat_dim = 64 + 64 + 32 + 64 + 32
        self.output_head = nn.Sequential(
            nn.Linear(feat_dim, 128),
            nn.LayerNorm(128),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, out_features_4M)
        )
        
        self.train_mse = MeanSquaredError()
        self.val_mse = MeanSquaredError()
        
    def forward(self, x_static, x_0M, x_2M, x_3M, x_0M_goutallier):
        static_feat = self.static_encoder(x_static)
        seq_0M_feat = self.seq_0M_encoder(x_0M)
        seq_2M_feat = self.seq_2M_encoder(x_2M)
        seq_3M_feat = self.seq_3M_encoder(x_3M)
        goutallier_0M_feat = self.goutallier_0M_encoder(x_0M_goutallier)
        
        combined = torch.cat([static_feat, seq_0M_feat, seq_2M_feat, seq_3M_feat, goutallier_0M_feat], dim=1)
        output = self.output_head(combined)
        return output
    
    def training_step(self, batch, batch_idx):
        x_static, x_0M, x_2M, x_3M, x_0M_goutallier, y_4M = batch
        pred_4M = self.forward(x_static, x_0M, x_2M, x_3M, x_0M_goutallier)
        loss = F.mse_loss(pred_4M, y_4M)
        
        self.log("train/loss", loss, on_epoch=True, prog_bar=True)
        self.train_mse.update(pred_4M, y_4M)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x_static, x_0M, x_2M, x_3M, x_0M_goutallier, y_4M = batch
        pred_4M = self.forward(x_static, x_0M, x_2M, x_3M, x_0M_goutallier)
        loss = F.mse_loss(pred_4M, y_4M)
        
        self.log("val/loss", loss, on_epoch=True, prog_bar=True)
        self.val_mse.update(pred_4M, y_4M)
        return loss
    
    def on_train_epoch_end(self):
        self.log("train/mse", self.train_mse.compute())
        self.train_mse.reset()
    
    def on_validation_epoch_end(self):
        self.log("val/mse", self.val_mse.compute())
        self.val_mse.reset()
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=5e-6, weight_decay=1e-4)


In [None]:
# Model 4: static + 0M + 2M + 3M + 4M + 0M_goutallier → 6M + y + 6M_goutallier
class SequentialMLP4(L.LightningModule):
    def __init__(self, static_features, seq_0M_features, seq_2M_features, seq_3M_features, seq_4M_features, goutallier_0M_features, out_features_total):
        super().__init__()
        self.register_buffer('pos_weight', torch.tensor([1.0]))
        
        # 입력 인코더
        self.static_encoder = nn.Sequential(
            nn.Linear(static_features, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        self.seq_0M_encoder = nn.Sequential(
            nn.Linear(seq_0M_features, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        self.seq_2M_encoder = nn.Sequential(
            nn.Linear(seq_2M_features, 32),
            nn.LayerNorm(32),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        self.seq_3M_encoder = nn.Sequential(
            nn.Linear(seq_3M_features, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        self.seq_4M_encoder = nn.Sequential(
            nn.Linear(seq_4M_features, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        self.goutallier_0M_encoder = nn.Sequential(
            nn.Linear(goutallier_0M_features, 32),
            nn.LayerNorm(32),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
        )
        
        # 특징 결합
        feat_dim = 64 + 64 + 32 + 64 + 64 + 32
        
        # 분류 헤드 (y 예측)
        self.clshead = nn.Sequential(
            nn.Linear(feat_dim, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )
        
        # 회귀 헤드 (6M + 6M_goutallier 예측)
        self.reghead = nn.Sequential(
            nn.Linear(feat_dim, 256),
            nn.LayerNorm(256),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, out_features_total - 1)  # y 제외한 나머지
        )
        
        self.train_roc = BinaryAUROC()
        self.val_roc = BinaryAUROC()
        self.val_ap = BinaryAveragePrecision()
        self.train_mse = MeanSquaredError()
        self.val_mse = MeanSquaredError()
        
    def forward(self, x_static, x_0M, x_2M, x_3M, x_4M, x_0M_goutallier):
        static_feat = self.static_encoder(x_static)
        seq_0M_feat = self.seq_0M_encoder(x_0M)
        seq_2M_feat = self.seq_2M_encoder(x_2M)
        seq_3M_feat = self.seq_3M_encoder(x_3M)
        seq_4M_feat = self.seq_4M_encoder(x_4M)
        goutallier_0M_feat = self.goutallier_0M_encoder(x_0M_goutallier)
        
        combined = torch.cat([static_feat, seq_0M_feat, seq_2M_feat, seq_3M_feat, seq_4M_feat, goutallier_0M_feat], dim=1)
        
        logits = self.clshead(combined)
        regs = self.reghead(combined)
        
        # 결합: [logits (1) + regs (20)] = 21
        output = torch.cat([logits, regs], dim=1)
        return logits, regs, output
    
    def training_step(self, batch, batch_idx):
        x_static, x_0M, x_2M, x_3M, x_4M, x_0M_goutallier, y_combined = batch
        y_label = y_combined[:, :1]
        y_reg = y_combined[:, 1:]
        
        logits, regs, _ = self.forward(x_static, x_0M, x_2M, x_3M, x_4M, x_0M_goutallier)
        
        clf_loss = F.binary_cross_entropy_with_logits(logits, y_label, pos_weight=self.pos_weight)
        reg_loss = F.smooth_l1_loss(regs, y_reg)
        loss = clf_loss + reg_loss
        
        self.log("train/loss", loss, on_epoch=True, prog_bar=True)
        self.log("train/clf_loss", clf_loss)
        self.log("train/reg_loss", reg_loss)
        
        probs = logits.sigmoid().flatten()
        targets = y_label.flatten().to(torch.int)
        self.train_roc.update(probs, targets)
        self.train_mse.update(regs, y_reg)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x_static, x_0M, x_2M, x_3M, x_4M, x_0M_goutallier, y_combined = batch
        y_label = y_combined[:, :1]
        y_reg = y_combined[:, 1:]
        
        logits, regs, _ = self.forward(x_static, x_0M, x_2M, x_3M, x_4M, x_0M_goutallier)
        
        clf_loss = F.binary_cross_entropy_with_logits(logits, y_label, pos_weight=self.pos_weight)
        reg_loss = F.smooth_l1_loss(regs, y_reg)
        loss = clf_loss + reg_loss
        
        self.log("val/loss", loss, on_epoch=True, prog_bar=True)
        self.log("val/clf_loss", clf_loss)
        self.log("val/reg_loss", reg_loss)
        
        probs = logits.sigmoid().flatten()
        targets = y_label.flatten().to(torch.int)
        self.val_roc.update(probs, targets)
        self.val_ap.update(probs, targets)
        self.val_mse.update(regs, y_reg)
        
        return loss
    
    def on_train_epoch_end(self):
        self.log("train/roc", self.train_roc.compute())
        self.log("train/mse", self.train_mse.compute())
        self.train_roc.reset()
        self.train_mse.reset()
    
    def on_validation_epoch_end(self):
        self.log("val/roc", self.val_roc.compute())
        self.log("val/ap", self.val_ap.compute())
        self.log("val/mse", self.val_mse.compute())
        self.val_roc.reset()
        self.val_ap.reset()
        self.val_mse.reset()
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=5e-6, weight_decay=1e-4)


In [None]:
# 각 모델 학습
batch_size = 64
max_epochs = 24

# Model 1 학습
print("=" * 80)
print("Model 1 학습 시작: static + 0M + 0M_goutallier → 2M")
print("=" * 80)

model1 = SequentialMLP1(
    static_features=static_features,
    seq_0M_features=seq_features_0M,
    goutallier_0M_features=goutallier_features_0M,
    out_features_2M=seq_features_2M
)

trainloader1 = DataLoader(trainset_model1, batch_size=batch_size, shuffle=True, pin_memory=True)
valloader1 = DataLoader(testset_model1, batch_size=batch_size)

trainer1 = L.Trainer(
    max_epochs=max_epochs,
    callbacks=[ModelCheckpoint(monitor='val/loss', mode='min', save_top_k=1, filename='model1-best')]
)

trainer1.fit(model1, trainloader1, valloader1)
print("Model 1 학습 완료\n")


In [None]:
# Model 2 학습
print("=" * 80)
print("Model 2 학습 시작: static + 0M + 2M + 0M_goutallier → 3M")
print("=" * 80)

model2 = SequentialMLP2(
    static_features=static_features,
    seq_0M_features=seq_features_0M,
    seq_2M_features=seq_features_2M,
    goutallier_0M_features=goutallier_features_0M,
    out_features_3M=seq_features_3M
)

trainloader2 = DataLoader(trainset_model2, batch_size=batch_size, shuffle=True, pin_memory=True)
valloader2 = DataLoader(testset_model2, batch_size=batch_size)

trainer2 = L.Trainer(
    max_epochs=max_epochs,
    callbacks=[ModelCheckpoint(monitor='val/loss', mode='min', save_top_k=1, filename='model2-best')]
)

trainer2.fit(model2, trainloader2, valloader2)
print("Model 2 학습 완료\n")


In [None]:
# Model 3 학습
print("=" * 80)
print("Model 3 학습 시작: static + 0M + 2M + 3M + 0M_goutallier → 4M")
print("=" * 80)

model3 = SequentialMLP3(
    static_features=static_features,
    seq_0M_features=seq_features_0M,
    seq_2M_features=seq_features_2M,
    seq_3M_features=seq_features_3M,
    goutallier_0M_features=goutallier_features_0M,
    out_features_4M=seq_features_4M
)

trainloader3 = DataLoader(trainset_model3, batch_size=batch_size, shuffle=True, pin_memory=True)
valloader3 = DataLoader(testset_model3, batch_size=batch_size)

trainer3 = L.Trainer(
    max_epochs=max_epochs,
    callbacks=[ModelCheckpoint(monitor='val/loss', mode='min', save_top_k=1, filename='model3-best')]
)

trainer3.fit(model3, trainloader3, valloader3)
print("Model 3 학습 완료\n")


In [None]:
# Model 4 학습
print("=" * 80)
print("Model 4 학습 시작: static + 0M + 2M + 3M + 4M + 0M_goutallier → 6M + y + 6M_goutallier")
print("=" * 80)

model4 = SequentialMLP4(
    static_features=static_features,
    seq_0M_features=seq_features_0M,
    seq_2M_features=seq_features_2M,
    seq_3M_features=seq_features_3M,
    seq_4M_features=seq_features_4M,
    goutallier_0M_features=goutallier_features_0M,
    out_features_total=seq_features_6M + 1 + goutallier_features_6M  # 12 + 1 + 8 = 21
)

trainloader4 = DataLoader(trainset_model4, batch_size=batch_size, shuffle=True, pin_memory=True)
valloader4 = DataLoader(testset_model4, batch_size=batch_size)

trainer4 = L.Trainer(
    max_epochs=max_epochs,
    callbacks=[ModelCheckpoint(monitor='val/roc', mode='max', save_top_k=1, filename='model4-best')]
)

trainer4.fit(model4, trainloader4, valloader4)
print("Model 4 학습 완료\n")


In [None]:
# 연결된 시계열 모델 클래스 (4개 모델을 순차적으로 실행)
class SequentialModel(nn.Module):
    """0M 입력만으로 6M 예측하는 연결된 시계열 모델"""
    def __init__(self, model1, model2, model3, model4):
        super().__init__()
        self.model1 = model1
        self.model2 = model2
        self.model3 = model3
        self.model4 = model4
        
        # 평가 모드로 설정
        self.model1.eval()
        self.model2.eval()
        self.model3.eval()
        self.model4.eval()
    
    @torch.no_grad()
    def forward(self, x_static, x_0M, x_0M_goutallier):
        """
        입력: static, 0M, 0M_goutallier
        출력: 6M features, y (logits), 6M_goutallier
        """
        # Model 1: static + 0M + 0M_goutallier → 2M
        pred_2M = self.model1(x_static, x_0M, x_0M_goutallier)
        
        # Model 2: static + 0M + 2M + 0M_goutallier → 3M
        pred_3M = self.model2(x_static, x_0M, pred_2M, x_0M_goutallier)
        
        # Model 3: static + 0M + 2M + 3M + 0M_goutallier → 4M
        pred_4M = self.model3(x_static, x_0M, pred_2M, pred_3M, x_0M_goutallier)
        
        # Model 4: static + 0M + 2M + 3M + 4M + 0M_goutallier → 6M + y + 6M_goutallier
        logits, regs, output = self.model4(x_static, x_0M, pred_2M, pred_3M, pred_4M, x_0M_goutallier)
        
        # 출력 분리: [6M (12) + y (1) + 6M_goutallier (8)]
        pred_6M = regs[:, :seq_features_6M]
        pred_y_logits = logits
        pred_6M_goutallier = regs[:, seq_features_6M:]
        
        return {
            'pred_2M': pred_2M,
            'pred_3M': pred_3M,
            'pred_4M': pred_4M,
            'pred_6M': pred_6M,
            'pred_y_logits': pred_y_logits,
            'pred_6M_goutallier': pred_6M_goutallier,
            'output': output
        }

# 연결된 모델 생성
sequential_model = SequentialModel(model1, model2, model3, model4)
print("연결된 시계열 모델 생성 완료")


In [None]:
# 연결된 모델 평가 함수
@torch.no_grad()
def evaluate_sequential_model(sequential_model, testset_model1):
    """0M 입력만으로 전체 시계열 예측 평가"""
    sequential_model.eval()
    
    # Model 1 데이터셋에서 static, 0M, 0M_goutallier 추출
    all_pred_2M = []
    all_pred_3M = []
    all_pred_4M = []
    all_pred_6M = []
    all_pred_y_logits = []
    all_pred_6M_goutallier = []
    
    # 실제 값들 (참고용)
    all_true_2M = []
    all_true_3M = []
    all_true_4M = []
    all_true_6M = []
    all_true_y = []
    all_true_6M_goutallier = []
    
    for i in range(len(testset_model1)):
        x_static, x_0M, x_0M_goutallier, y_2M = testset_model1[i]
        x_static = x_static.unsqueeze(0)
        x_0M = x_0M.unsqueeze(0)
        x_0M_goutallier = x_0M_goutallier.unsqueeze(0)
        
        # 예측
        predictions = sequential_model(x_static, x_0M, x_0M_goutallier)
        
        all_pred_2M.append(predictions['pred_2M'])
        all_pred_3M.append(predictions['pred_3M'])
        all_pred_4M.append(predictions['pred_4M'])
        all_pred_6M.append(predictions['pred_6M'])
        all_pred_y_logits.append(predictions['pred_y_logits'])
        all_pred_6M_goutallier.append(predictions['pred_6M_goutallier'])
        
        # 실제 값 (다른 데이터셋에서 가져오기)
        _, _, _, _, _, _, y_combined = testset_model4[i]
        all_true_2M.append(y_2M)
        # 실제 값은 testset_model4에서 가져옴
        y_6M = y_combined[:seq_features_6M]
        y_label = y_combined[seq_features_6M:seq_features_6M+1]
        y_6M_goutallier = y_combined[seq_features_6M+1:]
        
        # 3M, 4M 실제 값
        _, _, _, _, y_3M = testset_model2[i]
        _, _, _, _, _, y_4M = testset_model3[i]
        
        all_true_3M.append(y_3M)
        all_true_4M.append(y_4M)
        all_true_6M.append(y_6M)
        all_true_y.append(y_label)
        all_true_6M_goutallier.append(y_6M_goutallier)
    
    # 텐서로 변환
    pred_2M = torch.cat(all_pred_2M, dim=0)
    pred_3M = torch.cat(all_pred_3M, dim=0)
    pred_4M = torch.cat(all_pred_4M, dim=0)
    pred_6M = torch.cat(all_pred_6M, dim=0)
    pred_y_logits = torch.cat(all_pred_y_logits, dim=0)
    pred_6M_goutallier = torch.cat(all_pred_6M_goutallier, dim=0)
    
    true_2M = torch.stack(all_true_2M)
    true_3M = torch.stack(all_true_3M)
    true_4M = torch.stack(all_true_4M)
    true_6M = torch.stack(all_true_6M)
    true_y = torch.stack(all_true_y)
    true_6M_goutallier = torch.stack(all_true_6M_goutallier)
    
    # 성능 계산
    mse_2M = F.mse_loss(pred_2M, true_2M).item()
    mse_3M = F.mse_loss(pred_3M, true_3M).item()
    mse_4M = F.mse_loss(pred_4M, true_4M).item()
    mse_6M = F.mse_loss(pred_6M, true_6M).item()
    mse_6M_goutallier = F.mse_loss(pred_6M_goutallier, true_6M_goutallier).item()
    
    # 분류 성능
    pred_y_probs = pred_y_logits.sigmoid().flatten()
    true_y_int = true_y.flatten().to(torch.int)
    
    roc_auc = roc_auc_score(true_y_int.cpu().numpy(), pred_y_probs.cpu().numpy())
    ap = average_precision_score(true_y_int.cpu().numpy(), pred_y_probs.cpu().numpy())
    
    results = {
        'mse_2M': mse_2M,
        'mse_3M': mse_3M,
        'mse_4M': mse_4M,
        'mse_6M': mse_6M,
        'mse_6M_goutallier': mse_6M_goutallier,
        'roc_auc': roc_auc,
        'ap': ap,
        'predictions': {
            'pred_2M': pred_2M,
            'pred_3M': pred_3M,
            'pred_4M': pred_4M,
            'pred_6M': pred_6M,
            'pred_y_logits': pred_y_logits,
            'pred_y_probs': pred_y_probs,
            'pred_6M_goutallier': pred_6M_goutallier
        },
        'targets': {
            'true_2M': true_2M,
            'true_3M': true_3M,
            'true_4M': true_4M,
            'true_6M': true_6M,
            'true_y': true_y_int,
            'true_6M_goutallier': true_6M_goutallier
        }
    }
    
    return results

# 평가 실행
print("=" * 80)
print("연결된 시계열 모델 평가 시작")
print("=" * 80)

eval_results = evaluate_sequential_model(sequential_model, testset_model1)

print("\n=== 시계열 모델 평가 결과 ===")
print(f"2M 예측 MSE: {eval_results['mse_2M']:.4f}")
print(f"3M 예측 MSE: {eval_results['mse_3M']:.4f}")
print(f"4M 예측 MSE: {eval_results['mse_4M']:.4f}")
print(f"6M 예측 MSE: {eval_results['mse_6M']:.4f}")
print(f"6M Goutallier 예측 MSE: {eval_results['mse_6M_goutallier']:.4f}")
print(f"\n분류 성능:")
print(f"ROC AUC: {eval_results['roc_auc']:.4f}")
print(f"AP (Average Precision): {eval_results['ap']:.4f}")


In [None]:
class MLP(L.LightningModule):
  def __init__(self, in_features, static_features, seq_features, goutallier_features, out_features):
    super().__init__()
    self.register_buffer('pos_weight', torch.tensor([1.0]))

    dropout = 0.3
    self.static_encoder = nn.Sequential(
      nn.Linear(static_features, 64), 
      nn.LayerNorm(64), 
      nn.LeakyReLU(), 
      nn.Dropout(0.2),
    )

    self.seq_encoder = nn.Sequential(
      nn.Linear(seq_features, 128), 
      nn.LayerNorm(128), 
      nn.LeakyReLU(), 
      nn.Dropout(0.2),
    )

    self.goutallier = nn.Sequential(
      nn.Linear(goutallier_features, 64), 
      nn.LayerNorm(64), 
      nn.LeakyReLU(), 
      nn.Dropout(0.2),
    )

    feat_dim = 64 + 128 + 64
    
    self.clshead = nn.Sequential(
      nn.Linear(feat_dim, 128),
      nn.LayerNorm(128),
      nn.ReLU(),

      nn.Linear(128, 1)
    )
    
    self.reghead = nn.Sequential(
      nn.Linear(feat_dim, 256),
      nn.LayerNorm(256),
      nn.LeakyReLU(),

      nn.Linear(256, 5)
    )

    self.train_roc = BinaryAUROC()
    self.test_roc = BinaryAUROC()
    self.test_ap = BinaryAveragePrecision()
    self.val_roc = BinaryAUROC()
    self.val_ap = BinaryAveragePrecision()

    self.train_mse = MeanSquaredError()
    self.test_mse  = MeanSquaredError()
    self.train_mae = MeanAbsoluteError()
    self.test_mae  = MeanAbsoluteError()

  def forward(self, xb, xb_static, xb_seq, xb_goutallier):
    static_features = self.static_encoder(xb_static)
    seq_features = self.seq_encoder(xb_seq)
    goutallier_features = self.goutallier(xb_goutallier)

    combined_features = torch.cat([static_features, seq_features, goutallier_features], dim=1)
    
    logits = self.clshead(combined_features)
    regs = self.reghead(combined_features)

    return logits, regs

  def _shared_step(self, batch, metric=True):
    xb, xb_static, xb_seq, xb_goutallier, yb = batch
    clf_targets = yb[:, :1]
    reg_targets = yb[:, 1:]

    logits, regs = self.forward(xb, xb_static, xb_seq, xb_goutallier)

    clf_loss = F.binary_cross_entropy_with_logits(logits, clf_targets, pos_weight=self.pos_weight)
    reg_loss = F.smooth_l1_loss(regs, reg_targets)
    loss = clf_loss + reg_loss

    return {
      "loss": loss,
      "clf_loss": clf_loss,
      "reg_loss": reg_loss,
      "clf_logits": logits.detach(),
      "clf_targets": clf_targets.detach(),
    }
  
  def training_step(self, batch, batch_idx):
    out = self._shared_step(batch)

    self.log("train/loss", out["loss"], on_epoch=True, prog_bar=True)
    self.log("train/clf_loss", out["clf_loss"])
    self.log("train/reg_loss", out["reg_loss"])

    probs = out["clf_logits"].sigmoid().flatten()
    targets = out["clf_targets"].flatten().to(torch.int)
    self.train_roc.update(probs, targets)

    return out["loss"]

  def test_step(self, batch, batch_idx):
    out = self._shared_step(batch)

    self.log("test/loss", out["loss"], prog_bar=True)
    self.log("test/clf_loss", out["clf_loss"])
    self.log("test/reg_loss", out["reg_loss"])

    probs = out["clf_logits"].sigmoid().flatten()
    targets = out["clf_targets"].flatten().to(torch.int)
    self.test_roc.update(probs, targets)
    self.test_ap.update(probs, targets)
    
    return out["loss"]
  
  def validation_step(self, batch, batch_idx):
    out = self._shared_step(batch)

    self.log("val/loss", out["loss"], prog_bar=True, on_epoch=True)
    self.log("val/clf_loss", out["clf_loss"], on_epoch=True)
    self.log("val/reg_loss", out["reg_loss"], on_epoch=True)

    probs = out["clf_logits"].sigmoid().flatten()
    targets = out["clf_targets"].flatten().to(torch.int)
    self.val_roc.update(probs, targets)
    self.val_ap.update(probs, targets)
    
    return out["loss"]
  
  def on_train_epoch_end(self):
    self.log("train/roc", self.train_roc.compute())
    self.train_roc.reset()

  def on_test_epoch_end(self):
    self.log("test/roc", self.test_roc.compute())
    self.log("test/ap", self.test_ap.compute())
    self.test_roc.reset()

  def on_validation_epoch_end(self):
    self.log("val/roc", self.val_roc.compute())
    self.log("val/ap", self.val_ap.compute())
    self.val_roc.reset()
    self.val_ap.reset()

  def configure_optimizers(self):
      optimizer = torch.optim.AdamW(self.parameters(), lr=5*1e-6, weight_decay=1e-4)

      return {
          "optimizer": optimizer,
      }

In [None]:
class LossHistoryCallback(Callback):
    def __init__(self):
        super().__init__()
        self.train_losses = []
        self.test_losses = []
    
    def on_train_epoch_end(self, trainer, pl_module):
        if len(self.train_losses) == 0:
            print(f"[Train] Available metrics: {list(trainer.callback_metrics.keys())}")
        
        train_loss = trainer.callback_metrics.get('train/loss_epoch')
        if train_loss is not None:
            self.train_losses.append(train_loss.item())
        else:
            print(f"Warning: train/loss_epoch not found!")
    
    def on_validation_epoch_end(self, trainer, pl_module):
        if len(self.test_losses) == 0:
            print(f"[Val] Available metrics: {list(trainer.callback_metrics.keys())}")
        
        val_loss = trainer.callback_metrics.get('val/loss')
        if val_loss is not None:
            self.test_losses.append(val_loss.item())
        else:
            print(f"Warning: val/loss not found!")


In [None]:
test_logs = []
batch_size = 64
num_experiments = 1

test_logs = []
models = []
loss_histories = []

for i in range(num_experiments):
    mlp = MLP(in_features, static_features, seq_features, goutallier_features, out_features)
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=True)
    testloader  = DataLoader(testset,  batch_size=batch_size)

    loss_history_callback = LossHistoryCallback()
    trainer = L.Trainer(
        max_epochs=24,
        callbacks=[
            ModelCheckpoint(monitor='train/roc', mode='max', save_top_k=1),
            loss_history_callback
        ]
    )
    trainer.fit(mlp, trainloader, testloader)
    test_result = trainer.test(mlp, testloader)
    test_logs.append(test_result)
    models.append(mlp)
    loss_histories.append(loss_history_callback)

In [None]:
individual_rocs = []
individual_aps = []

print("===== 개별 모델 성능 확인 =====")
for i, test_log in enumerate(test_logs):
    roc = test_log[0]["test/roc"]
    ap = test_log[0]["test/ap"]
    individual_rocs.append(roc)
    individual_aps.append(ap)
    print(f"모델 {i+1}: ROC AUC = {roc:.4f}, AP = {ap:.4f}")

individual_rocs = np.array(individual_rocs)
individual_aps = np.array(individual_aps)

print(f"\n개별 모델 ROC AUC: {individual_rocs.mean():.4f} ± {individual_rocs.std():.4f}")
print(f"개별 모델 AP: {individual_aps.mean():.4f} ± {individual_aps.std():.4f}")

best_model_idx = np.argmax(individual_rocs)
best_model = models[best_model_idx]

print(f"\n===== 최고 성능 모델 선택 =====")
print(f"최고 성능 모델: 모델 {best_model_idx + 1}")
print(f"ROC AUC: {individual_rocs[best_model_idx]:.4f}")
print(f"AP: {individual_aps[best_model_idx]:.4f}")

@torch.no_grad()
def predict_with_best_model(model, dataloader):
    model.eval()
    all_logits = []
    all_regs = []
    clf_targets = []
    reg_targets = []
    
    for xb, x_static, x_seq, x_goutallier, yb in dataloader:
        logits, regs = model(xb, x_static, x_seq, x_goutallier)
        all_logits.append(logits)
        all_regs.append(regs)
        clf_targets.append(yb[:, :1])
        reg_targets.append(yb[:, 1:])
    
    logits = torch.cat(all_logits)
    regs = torch.cat(all_regs)
    clf_targets = torch.cat(clf_targets).to(torch.int).flatten()
    reg_targets = torch.cat(reg_targets)
    
    return logits, regs, clf_targets, reg_targets

best_logits, best_regs, clf_targets, reg_targets = predict_with_best_model(best_model, testloader)
probs = best_logits.sigmoid()

if probs.dim() > 1:
    probs_flat = probs.flatten()
else:
    probs_flat = probs

In [None]:
best_roc = roc_auc_score(clf_targets, probs_flat)
best_ap = average_precision_score(clf_targets, probs_flat)
    
print(f"\n=== 최고 성능 모델 최종 성능 ===")
print(f"ROC AUC: {best_roc:.4f}")
print(f"AP: {best_ap:.4f}")
    
threshold = 0.3
predicted_labels = (probs_flat > threshold).int()
print(f"\n=== 분류 성능 ===")
print(classification_report(clf_targets, predicted_labels, target_names=['Negative', 'Positive']))
    
mse = torch.nn.functional.mse_loss(best_regs, reg_targets).item()
mae = torch.nn.functional.l1_loss(best_regs, reg_targets).item()
    
print(f"\n=== 회귀 성능 ===")
print(f"MSE: {mse:.4f}")
print(f"MAE: {mae:.4f}")

In [None]:
test_aps = np.array([test_log[0]["test/ap"] for test_log in test_logs])
test_rocs = np.array([test_log[0]["test/roc"] for test_log in test_logs])
pd.DataFrame({"ROC AUC": test_rocs, "PR AUC": test_aps}).describe()

In [None]:
def show_set_stat(dataset):
  _, _, _, _, y = dataset[:]
  negative, positive = torch.bincount(y[:, 0].to(torch.int)).tolist()
  samples = len(dataset)

  print(f"tatal   : {samples}")
  print(f"negative: {negative:3} ({negative/samples*100:5.2f}%)")
  print(f"positive: {positive:3} ({positive/samples*100:5.2f}%)")

In [None]:
print("trainset (SMOTE)")
show_set_stat(trainset)

In [None]:
print("testset")
show_set_stat(testset)

In [None]:
@torch.no_grad()
def forward_loader(model, dataloader):
  all_logits = []
  all_regs = []
  all_clf_targets = []
  all_reg_targets = []
  
  model.eval()
  for xb, x_static, x_seq, x_goutallier, yb in dataloader:
    logits, regs = model(xb, x_static, x_seq, x_goutallier)
    all_logits.append(logits)
    all_regs.append(regs)
    all_clf_targets.append(yb[:, :1])
    all_reg_targets.append(yb[:, 1:])

  logits = torch.cat(all_logits).flatten()
  regs = torch.cat(all_regs)
  clf_targets = torch.cat(all_clf_targets).to(torch.int).flatten()
  reg_targets = torch.cat(all_reg_targets)

  return logits, regs, clf_targets, reg_targets

In [None]:
logits, regs, clf_targets, reg_targets = forward_loader(mlp, testloader)
probs = logits.sigmoid()

print(f"logits.shape:      {logits.shape}")
print(f"probs.shape:      {probs.shape}")
print(f"regs.shape:        {regs.shape}")
print()
print(f"clf_targets.shape: {clf_targets.shape}")
print(f"reg_targets.shape: {reg_targets.shape}")

In [None]:
precisions, recalls, thresholds = precision_recall_curve(clf_targets, probs)
thresholds = np.append(thresholds, 1.0)

plt.figure(figsize=(8, 6))
plt.plot(thresholds, precisions, label='Precision', marker='o', markersize=3)
plt.plot(thresholds, recalls, label='Recall', marker='x', markersize=3)

plt.title("Precision & Recall vs Threshold")
plt.xlabel("Threshold")
plt.ylabel("Score")
plt.legend()
plt.grid(True, linestyle="--", alpha=0.6)
plt.tight_layout()
plt.show()

In [None]:
def plot_score_distributions(
  y_score, y_true, *,
  bins=40,
  title=None,
  density=False,
  th_lines=(0.5,),
):
  y_true = np.asarray(y_true).astype(int)
  y_score = np.asarray(y_score)

  x_main = y_score
  x_label = "Predicted probability"

  pos = x_main[y_true == 1]
  neg = x_main[y_true == 0]

  xmin = np.min(x_main)
  xmax = np.max(x_main)
  bins_edges = np.linspace(xmin, xmax, bins+1)

  plt.figure(figsize=(9, 5.5))
  plt.hist(neg, bins=bins_edges, alpha=0.55, density=density,
           label=f"Negative (n={len(neg)})", edgecolor="white", linewidth=0.5)
  plt.hist(pos, bins=bins_edges, alpha=0.55, density=density,
           label=f"Positive (n={len(pos)})", edgecolor="white", linewidth=0.5)

  if th_lines:
    for th in th_lines:
      plt.axvline(th, linestyle="--", linewidth=1.5)

  plt.xlabel(x_label)
  plt.ylabel("Density" if density else "Count")
  plt.title(title or "Score distributions by class")
  plt.legend(loc="best")
  plt.grid(True, linestyle="--", alpha=0.4)

  plt.tight_layout()
  plt.show()

In [None]:
default_thresholds = np.linspace(0, 1, 11)[1:-1].tolist() # [0.1, 0.2, ... , 0.9]

def test_thresholds(y_score, y_true, thresholds=default_thresholds, verbose=True):
  accuracies = []
  precisions = []
  recalls = []
  f1s = []
  for threshold in thresholds:
    bin_acc = BinaryAccuracy(threshold)
    bin_precison = BinaryPrecision(threshold)
    bin_recall = BinaryRecall(threshold)
    bin_f1 = BinaryF1Score(threshold)

    bin_acc.update(y_score, y_true)
    bin_precison.update(y_score, y_true)
    bin_recall.update(y_score, y_true)
    bin_f1.update(y_score, y_true)

    accuracies.append(bin_acc.compute().item())
    precisions.append(bin_precison.compute().item())
    recalls.append(bin_recall.compute().item())
    f1s.append(bin_f1.compute().item())

  result = pd.DataFrame({
    "threshold": thresholds,
    "accuracy": accuracies,
    "precison": precisions,
    "recall": recalls,
    "f1": f1s
  }).set_index("threshold")

  if verbose:
    print(result)

  return result

In [None]:
thresholds_range = np.arange(0.0, 1.01, 0.01)

accuracies = []
precisions = []
recalls = []
f1_scores = []
specificities = []
youden_indices = []

y_true_np = clf_targets.cpu().numpy()
probs_np = probs.cpu().numpy()

for th in thresholds_range:
    y_pred = (probs_np >= th).astype(int)
    
    tn = np.sum((y_pred == 0) & (y_true_np == 0))
    fp = np.sum((y_pred == 1) & (y_true_np == 0))
    fn = np.sum((y_pred == 0) & (y_true_np == 1))
    tp = np.sum((y_pred == 1) & (y_true_np == 1))
    
    accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    youden = recall + specificity - 1
    
    accuracies.append(accuracy)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)
    specificities.append(specificity)
    youden_indices.append(youden)

best_accuracy_th = thresholds_range[np.argmax(accuracies)]
best_f1_th = thresholds_range[np.argmax(f1_scores)]
best_youden_th = thresholds_range[np.argmax(youden_indices)]
balanced_th = thresholds_range[np.argmin(np.abs(np.array(precisions) - np.array(recalls)))]

print(f"\n{'='*100}")
print(" 최적 Threshold 결과")
print(f"{'='*100}\n")

print(f"Accuracy 최대화:        Threshold = {best_accuracy_th:.3f}  (Accuracy = {max(accuracies):.4f})")
print(f"F1 Score 최대화:        Threshold = {best_f1_th:.3f}  (F1 = {max(f1_scores):.4f})")
print(f"Precision-Recall 균형:  Threshold = {balanced_th:.3f}")

In [None]:
thresholds = [best_accuracy_th]
test_thresholds(probs, clf_targets, thresholds)
plot_score_distributions(probs, clf_targets, bins=40, density=True, th_lines=thresholds)

In [None]:
def print_regression_summary(regs, reg_targets, output_columns):
    
    results = []
    for i, col_name in enumerate(output_columns):
        pred = regs[:, i]
        true = reg_targets[:, i]
        
        mse = torch.mean((pred - true) ** 2).item()
        mae = torch.mean(torch.abs(pred - true)).item()
        rmse = torch.sqrt(torch.mean((pred - true) ** 2)).item()
        
        ss_res = torch.sum((true - pred) ** 2)
        ss_tot = torch.sum((true - torch.mean(true)) ** 2)
        r2 = 1 - (ss_res / ss_tot).item()
        
        results.append({
            'Column': col_name,
            'MSE': mse,
            'MAE': mae,
            'RMSE': rmse,
            'R²': r2
        })
    
    df = pd.DataFrame(results)
    print(df.round(4))
    
    return df

logits, regs, clf_targets, reg_targets = forward_loader(models[0], testloader)
regression_summary = print_regression_summary(regs, reg_targets, output_columns)

In [None]:
# print(testset)
# show_set_stat(testset)

In [None]:
# testloader = DataLoader(testset, batch_size=batch_size)
# test_logs = trainer.test(mlp, testloader)

In [None]:
# test_logits, test_regs, test_clf_targets, test_reg_targets = forward_loader(mlp, testloader)
# test_probs = test_logits.sigmoid()
# test_thresholds(test_probs, test_clf_targets, thresholds)
# plot_score_distributions(test_probs, test_clf_targets, bins=40, density=True, th_lines=thresholds)

In [None]:
# pre_columns = seq_columns[:12] + goutallier_columns[:4] + goutallier_columns[8:12]
# pre_columns

In [None]:
# mean_columns = [column for column in columns if column not in static_columns + pre_columns + output_columns]
# mean_columns

In [None]:
# mean_table = pd.read_csv("X_train.csv")
# mean_table["age_group"] = mean_table["나이"] // 10 * 10

# group_columns = ["성별 (M:1,F:2)", "age_group"]
# mean_table = mean_table.groupby(group_columns)[mean_columns].mean().reset_index()
# mean_table

In [None]:
# def get_pre_with_mean_dataset(split):
#   assert split in ["val", "test"]
#   X = pd.read_csv(f"X_{split}.csv")
#   y = pd.read_csv(f"y_{split}.csv")

#   indices = pd.concat([X["성별 (M:1,F:2)"], X["나이"] // 10 * 10], axis=1)
#   indices.columns = group_columns
#   mean_values = indices.merge(mean_table, on=group_columns, how="left")

#   X[mean_columns] = mean_values[mean_columns]
#   X_np = X[input_columns].to_numpy()
#   y_np = pd.concat([y, X[output_columns]], axis=1).to_numpy()

#   X_tensor = torch.tensor(X_np, dtype=torch.float32)
#   y_tensor = torch.tensor(y_np, dtype=torch.float32)

#   return TensorDataset(X_tensor, y_tensor)

In [None]:
# val_pre_with_mean_set = get_pre_with_mean_dataset("val")
# val_pre_with_mean_loader = DataLoader(val_pre_with_mean_set, batch_size=batch_size)
# val_pre_with_mean_logs = trainer.test(mlp, val_pre_with_mean_loader)

In [None]:
# logits, regs, clf_targets, reg_targets = forward_loader(mlp, val_pre_with_mean_loader)
# probs = logits.sigmoid()
# test_thresholds(probs, clf_targets, thresholds)
# plot_score_distributions(probs, clf_targets, bins=40, density=False, th_lines=thresholds)

In [None]:
# def fgsm_attack(data, data_grad, epsilon):
#   sign_data_grad = data_grad.sign()
#   perturbed_data = data + epsilon*sign_data_grad
#   return perturbed_data

In [None]:
# num_static_columns = len(static_columns)
# num_goutallier_columns= len(goutallier_columns)
# num_static_columns, num_goutallier_columns

In [None]:
# fgsm_target_start = num_static_columns+12
# fgsm_target_end = -num_goutallier_columns
# fgsm_target_columns = input_columns[fgsm_target_start:fgsm_target_end]
# fgsm_target_columns, len(fgsm_target_columns)

In [None]:
# # 0: '6M ASES'          -> maximize
# # 1: '6M CSS'           -> maximize
# # 2: '6M KSS'           -> maximize
# # 3: '6M VAS(activity)' -> minimize
# # 4: '6M VAS(resting)'  -> minimize
# maximize_indices = [0, 1, 2]
# minimize_indices = [3, 4]

# lambda_logits = 1.0
# lambda_reg = 0.3
# epsilon = 0.5

# all_logits = []
# all_regs = []
# all_perturbed_xb = []
# all_perturbed_logits = []
# all_perturbed_regs = []

# mlp.eval()
# for xb, yb in val_pre_with_mean_loader:
#   clf_targets = yb[:, :1]
#   reg_targets = yb[:, 1:]

#   xb.requires_grad = True
#   logits, regs = mlp(xb)
#   all_logits.append(logits.detach())
#   all_regs.append(regs.detach())

#   clf_loss = F.binary_cross_entropy_with_logits(logits, clf_targets)
#   logits_dir_loss = -logits.mean()

#   reg_inc_term = -regs[:, maximize_indices].mean()
#   reg_dec_term = regs[:, minimize_indices].mean()
#   reg_dir_loss = reg_inc_term + reg_dec_term

#   loss = clf_loss + lambda_logits * logits_dir_loss + lambda_reg * reg_dir_loss

#   mlp.zero_grad()
#   loss.backward()

#   xb_grad = xb.grad.data
#   perturbed_xb = fgsm_attack(xb, xb_grad, epsilon)
#   all_perturbed_xb.append(perturbed_xb.detach())

#   perturbed_logits, perturbed_regs = mlp(perturbed_xb)
#   all_perturbed_logits.append(perturbed_logits.detach())
#   all_perturbed_regs.append(perturbed_regs.detach())

In [None]:
# logits = torch.cat(all_logits).flatten()
# regs = torch.cat(all_regs)
# logits.shape, regs.shape

In [None]:
# perturbed_logits = torch.cat(all_perturbed_logits).flatten()
# perturbed_regs = torch.cat(all_perturbed_regs)
# perturbed_logits.shape, perturbed_regs.shape

In [None]:
# probs = logits.sigmoid()
# perturbed_probs = perturbed_logits.sigmoid()
# probs.shape, perturbed_probs.shape

In [None]:
# print(label_column)
# prob_results = pd.DataFrame({
#   "probability": probs.flatten(),
#   "after probability": perturbed_probs.flatten()
# })
# prob_results["delta"] = prob_results["after probability"] - prob_results["probability"]
# prob_results

In [None]:
# prob_results.describe()

In [None]:
# all_feature_results = []
# for feature_idx in range(regs.size(1)):
#   feature_column = output_columns[feature_idx]
#   after_column = f"after {feature_column}"

#   feature_results = pd.DataFrame({
#     feature_column: regs[:, feature_idx],
#     after_column: perturbed_regs[:, feature_idx]
#   })
#   feature_results["delta"] = feature_results[after_column] - feature_results[feature_column]
#   all_feature_results.append(feature_results)

#   direction = "maximize" if feature_idx in maximize_indices else "minimize"
#   print(f"{feature_column} ({direction})")
#   print(feature_results)
#   print()

In [None]:
# for feature_column, feature_results in zip(output_columns, all_feature_results):
#   direction = "maximize" if feature_idx in maximize_indices else "minimize"
#   print(f"{feature_column} ({direction})")
#   print(feature_results.describe())
#   print()