In [None]:
import pandas as pd
import numpy as np
import re

import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_auc_score, average_precision_score, classification_report


import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader

from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError
from torchmetrics.classification import BinaryAUROC, BinaryAveragePrecision
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import Callback

In [None]:
X_train = pd.read_csv("X_train.csv")
y_train = pd.read_csv("y_train.csv")

In [None]:
X_train

In [None]:
y_train

In [None]:
columns = list(X_train.columns)
columns

In [None]:
static_columns = columns[:25]
static_columns

In [None]:
seq_columns = columns[25:-16]
print(seq_columns[:12])
print(seq_columns[12:19])
print(seq_columns[19:31])
print(seq_columns[31:43])
print(seq_columns[43:])

In [None]:
goutallier_columns = columns[-16:]
goutallier_columns

In [None]:
len(columns) == len(static_columns) + len(seq_columns) + len(goutallier_columns)

In [None]:
label_column = "POD 6M retear"
output_columns = ["6M ASES", "6M CSS", "6M KSS", "6M VAS(activity)", "6M VAS(resting)"]
input_columns = static_columns + [column for column in seq_columns if column not in output_columns] + goutallier_columns

In [None]:
output_columns

In [None]:
input_columns

In [None]:
X_train[input_columns]

In [None]:
pd.concat([y_train, X_train[output_columns]], axis=1)

In [None]:
def get_dataset(split):
  assert split in ["train", "test"]

  X_file_name = f"X_{split}.csv"
  y_file_name = f"y_{split}.csv"

  X = pd.read_csv(X_file_name)
  y = pd.read_csv(y_file_name)

  X_np = X[input_columns].to_numpy()
  X_static_np = X[static_columns].to_numpy()
  X_seq_np = X[seq_columns].to_numpy()
  X_goutallier_np = X[goutallier_columns].to_numpy()
  y_np = pd.concat([y, X[output_columns]], axis=1).to_numpy()

  X_tensor = torch.tensor(X_np, dtype=torch.float32)
  X_static_tensor = torch.tensor(X_static_np, dtype=torch.float32)
  X_seq_tensor = torch.tensor(X_seq_np, dtype=torch.float32)
  X_goutallier_tensor = torch.tensor(X_goutallier_np, dtype=torch.float32)
  y_tensor = torch.tensor(y_np, dtype=torch.float32)

  return TensorDataset(X_tensor, X_static_tensor, X_seq_tensor, X_goutallier_tensor, y_tensor)

In [None]:
trainset = get_dataset("train")
testset = get_dataset("test")

In [None]:
in_features = len(input_columns)
static_features = len(static_columns)
seq_features = len(seq_columns)
goutallier_features = len(goutallier_columns)
out_features = len([label_column]) + len(output_columns)

in_features, static_features, seq_features, goutallier_features, out_features

In [None]:
class MLP(L.LightningModule):
  def __init__(self, in_features, static_features, seq_features, goutallier_features, out_features):
    super().__init__()
    self.register_buffer('pos_weight', torch.tensor([1.0]))

    dropout = 0.3
    self.static_encoder = nn.Sequential(
      nn.Linear(static_features, 64), 
      nn.LayerNorm(64), 
      nn.LeakyReLU(), 
      nn.Dropout(0.2),
    )

    self.seq_encoder = nn.Sequential(
      nn.Linear(seq_features, 128), 
      nn.LayerNorm(128), 
      nn.LeakyReLU(), 
      nn.Dropout(0.2),
    )

    self.goutallier = nn.Sequential(
      nn.Linear(goutallier_features, 64), 
      nn.LayerNorm(64), 
      nn.LeakyReLU(), 
      nn.Dropout(0.2),
    )

    feat_dim = 64 + 128 + 64
    
    self.clshead = nn.Sequential(
      nn.Linear(feat_dim, 128),
      nn.LayerNorm(128),
      nn.ReLU(),

      nn.Linear(128, 1)
    )
    
    self.reghead = nn.Sequential(
      nn.Linear(feat_dim, 256),
      nn.LayerNorm(256),
      nn.LeakyReLU(),

      nn.Linear(256, 5)
    )

    self.train_roc = BinaryAUROC()
    self.test_roc = BinaryAUROC()
    self.test_ap = BinaryAveragePrecision()
    self.val_roc = BinaryAUROC()
    self.val_ap = BinaryAveragePrecision()

    self.train_mse = MeanSquaredError()
    self.test_mse  = MeanSquaredError()
    self.train_mae = MeanAbsoluteError()
    self.test_mae  = MeanAbsoluteError()

  def forward(self, xb, xb_static, xb_seq, xb_goutallier):
    static_features = self.static_encoder(xb_static)
    seq_features = self.seq_encoder(xb_seq)
    goutallier_features = self.goutallier(xb_goutallier)

    combined_features = torch.cat([static_features, seq_features, goutallier_features], dim=1)
    
    logits = self.clshead(combined_features)
    regs = self.reghead(combined_features)

    return logits, regs

  def _shared_step(self, batch, metric=True):
    xb, xb_static, xb_seq, xb_goutallier, yb = batch
    clf_targets = yb[:, :1]
    reg_targets = yb[:, 1:]

    logits, regs = self.forward(xb, xb_static, xb_seq, xb_goutallier)

    clf_loss = F.binary_cross_entropy_with_logits(logits, clf_targets, pos_weight=self.pos_weight)
    reg_loss = F.smooth_l1_loss(regs, reg_targets)
    loss = clf_loss + reg_loss

    return {
      "loss": loss,
      "clf_loss": clf_loss,
      "reg_loss": reg_loss,
      "clf_logits": logits.detach(),
      "clf_targets": clf_targets.detach(),
    }
  
  def training_step(self, batch, batch_idx):
    out = self._shared_step(batch)

    self.log("train/loss", out["loss"], on_epoch=True, prog_bar=True)
    self.log("train/clf_loss", out["clf_loss"])
    self.log("train/reg_loss", out["reg_loss"])

    probs = out["clf_logits"].sigmoid().flatten()
    targets = out["clf_targets"].flatten().to(torch.int)
    self.train_roc.update(probs, targets)

    return out["loss"]

  def test_step(self, batch, batch_idx):
    out = self._shared_step(batch)

    self.log("test/loss", out["loss"], prog_bar=True)
    self.log("test/clf_loss", out["clf_loss"])
    self.log("test/reg_loss", out["reg_loss"])

    probs = out["clf_logits"].sigmoid().flatten()
    targets = out["clf_targets"].flatten().to(torch.int)
    self.test_roc.update(probs, targets)
    self.test_ap.update(probs, targets)
    
    return out["loss"]
  
  def validation_step(self, batch, batch_idx):
    out = self._shared_step(batch)

    self.log("val/loss", out["loss"], prog_bar=True, on_epoch=True)
    self.log("val/clf_loss", out["clf_loss"], on_epoch=True)
    self.log("val/reg_loss", out["reg_loss"], on_epoch=True)

    probs = out["clf_logits"].sigmoid().flatten()
    targets = out["clf_targets"].flatten().to(torch.int)
    self.val_roc.update(probs, targets)
    self.val_ap.update(probs, targets)
    
    return out["loss"]
  
  def on_train_epoch_end(self):
    self.log("train/roc", self.train_roc.compute())
    self.train_roc.reset()

  def on_test_epoch_end(self):
    self.log("test/roc", self.test_roc.compute())
    self.log("test/ap", self.test_ap.compute())
    self.test_roc.reset()

  def on_validation_epoch_end(self):
    self.log("val/roc", self.val_roc.compute())
    self.log("val/ap", self.val_ap.compute())
    self.val_roc.reset()
    self.val_ap.reset()

  def configure_optimizers(self):
      optimizer = torch.optim.AdamW(self.parameters(), lr=5*1e-6, weight_decay=1e-4)

      return {
          "optimizer": optimizer,
      }

In [None]:
class LossHistoryCallback(Callback):
    def __init__(self):
        super().__init__()
        self.train_losses = []
        self.test_losses = []
    
    def on_train_epoch_end(self, trainer, pl_module):
        if len(self.train_losses) == 0:
            print(f"[Train] Available metrics: {list(trainer.callback_metrics.keys())}")
        
        train_loss = trainer.callback_metrics.get('train/loss_epoch')
        if train_loss is not None:
            self.train_losses.append(train_loss.item())
        else:
            print(f"Warning: train/loss_epoch not found!")
    
    def on_validation_epoch_end(self, trainer, pl_module):
        if len(self.test_losses) == 0:
            print(f"[Val] Available metrics: {list(trainer.callback_metrics.keys())}")
        
        val_loss = trainer.callback_metrics.get('val/loss')
        if val_loss is not None:
            self.test_losses.append(val_loss.item())
        else:
            print(f"Warning: val/loss not found!")


In [None]:
test_logs = []
batch_size = 64
num_experiments = 1

test_logs = []
models = []
loss_histories = []

for i in range(num_experiments):
    mlp = MLP(in_features, static_features, seq_features, goutallier_features, out_features)
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=True)
    testloader  = DataLoader(testset,  batch_size=batch_size)

    loss_history_callback = LossHistoryCallback()
    trainer = L.Trainer(
        max_epochs=24,
        callbacks=[
            ModelCheckpoint(monitor='train/roc', mode='max', save_top_k=1),
            loss_history_callback
        ]
    )
    trainer.fit(mlp, trainloader, testloader)
    test_result = trainer.test(mlp, testloader)
    test_logs.append(test_result)
    models.append(mlp)
    loss_histories.append(loss_history_callback)

In [None]:
individual_rocs = []
individual_aps = []

print("===== 개별 모델 성능 확인 =====")
for i, test_log in enumerate(test_logs):
    roc = test_log[0]["test/roc"]
    ap = test_log[0]["test/ap"]
    individual_rocs.append(roc)
    individual_aps.append(ap)
    print(f"모델 {i+1}: ROC AUC = {roc:.4f}, AP = {ap:.4f}")

individual_rocs = np.array(individual_rocs)
individual_aps = np.array(individual_aps)

print(f"\n개별 모델 ROC AUC: {individual_rocs.mean():.4f} ± {individual_rocs.std():.4f}")
print(f"개별 모델 AP: {individual_aps.mean():.4f} ± {individual_aps.std():.4f}")

best_model_idx = np.argmax(individual_rocs)
best_model = models[best_model_idx]

print(f"\n===== 최고 성능 모델 선택 =====")
print(f"최고 성능 모델: 모델 {best_model_idx + 1}")
print(f"ROC AUC: {individual_rocs[best_model_idx]:.4f}")
print(f"AP: {individual_aps[best_model_idx]:.4f}")

@torch.no_grad()
def predict_with_best_model(model, dataloader):
    model.eval()
    all_logits = []
    all_regs = []
    clf_targets = []
    reg_targets = []
    
    for xb, x_static, x_seq, x_goutallier, yb in dataloader:
        logits, regs = model(xb, x_static, x_seq, x_goutallier)
        all_logits.append(logits)
        all_regs.append(regs)
        clf_targets.append(yb[:, :1])
        reg_targets.append(yb[:, 1:])
    
    logits = torch.cat(all_logits)
    regs = torch.cat(all_regs)
    clf_targets = torch.cat(clf_targets).to(torch.int).flatten()
    reg_targets = torch.cat(reg_targets)
    
    return logits, regs, clf_targets, reg_targets

best_logits, best_regs, clf_targets, reg_targets = predict_with_best_model(best_model, testloader)
probs = best_logits.sigmoid()

if probs.dim() > 1:
    probs_flat = probs.flatten()
else:
    probs_flat = probs

In [None]:
best_roc = roc_auc_score(clf_targets, probs_flat)
best_ap = average_precision_score(clf_targets, probs_flat)
    
print(f"\n=== 최고 성능 모델 최종 성능 ===")
print(f"ROC AUC: {best_roc:.4f}")
print(f"AP: {best_ap:.4f}")
    
threshold = 0.3
predicted_labels = (probs_flat > threshold).int()
print(f"\n=== 분류 성능 ===")
print(classification_report(clf_targets, predicted_labels, target_names=['Negative', 'Positive']))
    
mse = torch.nn.functional.mse_loss(best_regs, reg_targets).item()
mae = torch.nn.functional.l1_loss(best_regs, reg_targets).item()
    
print(f"\n=== 회귀 성능 ===")
print(f"MSE: {mse:.4f}")
print(f"MAE: {mae:.4f}")

In [None]:
test_aps = np.array([test_log[0]["test/ap"] for test_log in test_logs])
test_rocs = np.array([test_log[0]["test/roc"] for test_log in test_logs])
pd.DataFrame({"ROC AUC": test_rocs, "PR AUC": test_aps}).describe()

In [None]:
def show_set_stat(dataset):
  _, _, _, _, y = dataset[:]
  negative, positive = torch.bincount(y[:, 0].to(torch.int)).tolist()
  samples = len(dataset)

  print(f"tatal   : {samples}")
  print(f"negative: {negative:3} ({negative/samples*100:5.2f}%)")
  print(f"positive: {positive:3} ({positive/samples*100:5.2f}%)")

In [None]:
print("trainset (SMOTE)")
show_set_stat(trainset)

In [None]:
print("testset")
show_set_stat(testset)

In [None]:
@torch.no_grad()
def forward_loader(model, dataloader):
  all_logits = []
  all_regs = []
  all_clf_targets = []
  all_reg_targets = []
  
  model.eval()
  for xb, x_static, x_seq, x_goutallier, yb in dataloader:
    logits, regs = model(xb, x_static, x_seq, x_goutallier)
    all_logits.append(logits)
    all_regs.append(regs)
    all_clf_targets.append(yb[:, :1])
    all_reg_targets.append(yb[:, 1:])

  logits = torch.cat(all_logits).flatten()
  regs = torch.cat(all_regs)
  clf_targets = torch.cat(all_clf_targets).to(torch.int).flatten()
  reg_targets = torch.cat(all_reg_targets)

  return logits, regs, clf_targets, reg_targets

In [None]:
logits, regs, clf_targets, reg_targets = forward_loader(mlp, testloader)
probs = logits.sigmoid()

print(f"logits.shape:      {logits.shape}")
print(f"probs.shape:      {probs.shape}")
print(f"regs.shape:        {regs.shape}")
print()
print(f"clf_targets.shape: {clf_targets.shape}")
print(f"reg_targets.shape: {reg_targets.shape}")

In [None]:
precisions, recalls, thresholds = precision_recall_curve(clf_targets, probs)
thresholds = np.append(thresholds, 1.0)

plt.figure(figsize=(8, 6))
plt.plot(thresholds, precisions, label='Precision', marker='o', markersize=3)
plt.plot(thresholds, recalls, label='Recall', marker='x', markersize=3)

plt.title("Precision & Recall vs Threshold")
plt.xlabel("Threshold")
plt.ylabel("Score")
plt.legend()
plt.grid(True, linestyle="--", alpha=0.6)
plt.tight_layout()
plt.show()

In [None]:
def plot_score_distributions(
  y_score, y_true, *,
  bins=40,
  title=None,
  density=False,
  th_lines=(0.5,),
):
  y_true = np.asarray(y_true).astype(int)
  y_score = np.asarray(y_score)

  x_main = y_score
  x_label = "Predicted probability"

  pos = x_main[y_true == 1]
  neg = x_main[y_true == 0]

  xmin = np.min(x_main)
  xmax = np.max(x_main)
  bins_edges = np.linspace(xmin, xmax, bins+1)

  plt.figure(figsize=(9, 5.5))
  plt.hist(neg, bins=bins_edges, alpha=0.55, density=density,
           label=f"Negative (n={len(neg)})", edgecolor="white", linewidth=0.5)
  plt.hist(pos, bins=bins_edges, alpha=0.55, density=density,
           label=f"Positive (n={len(pos)})", edgecolor="white", linewidth=0.5)

  if th_lines:
    for th in th_lines:
      plt.axvline(th, linestyle="--", linewidth=1.5)

  plt.xlabel(x_label)
  plt.ylabel("Density" if density else "Count")
  plt.title(title or "Score distributions by class")
  plt.legend(loc="best")
  plt.grid(True, linestyle="--", alpha=0.4)

  plt.tight_layout()
  plt.show()

In [None]:
default_thresholds = np.linspace(0, 1, 11)[1:-1].tolist() # [0.1, 0.2, ... , 0.9]

def test_thresholds(y_score, y_true, thresholds=default_thresholds, verbose=True):
  accuracies = []
  precisions = []
  recalls = []
  f1s = []
  for threshold in thresholds:
    bin_acc = BinaryAccuracy(threshold)
    bin_precison = BinaryPrecision(threshold)
    bin_recall = BinaryRecall(threshold)
    bin_f1 = BinaryF1Score(threshold)

    bin_acc.update(y_score, y_true)
    bin_precison.update(y_score, y_true)
    bin_recall.update(y_score, y_true)
    bin_f1.update(y_score, y_true)

    accuracies.append(bin_acc.compute().item())
    precisions.append(bin_precison.compute().item())
    recalls.append(bin_recall.compute().item())
    f1s.append(bin_f1.compute().item())

  result = pd.DataFrame({
    "threshold": thresholds,
    "accuracy": accuracies,
    "precison": precisions,
    "recall": recalls,
    "f1": f1s
  }).set_index("threshold")

  if verbose:
    print(result)

  return result

In [None]:
thresholds_range = np.arange(0.0, 1.01, 0.01)

accuracies = []
precisions = []
recalls = []
f1_scores = []
specificities = []
youden_indices = []

y_true_np = clf_targets.cpu().numpy()
probs_np = probs.cpu().numpy()

for th in thresholds_range:
    y_pred = (probs_np >= th).astype(int)
    
    tn = np.sum((y_pred == 0) & (y_true_np == 0))
    fp = np.sum((y_pred == 1) & (y_true_np == 0))
    fn = np.sum((y_pred == 0) & (y_true_np == 1))
    tp = np.sum((y_pred == 1) & (y_true_np == 1))
    
    accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    youden = recall + specificity - 1
    
    accuracies.append(accuracy)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)
    specificities.append(specificity)
    youden_indices.append(youden)

best_accuracy_th = thresholds_range[np.argmax(accuracies)]
best_f1_th = thresholds_range[np.argmax(f1_scores)]
best_youden_th = thresholds_range[np.argmax(youden_indices)]
balanced_th = thresholds_range[np.argmin(np.abs(np.array(precisions) - np.array(recalls)))]

print(f"\n{'='*100}")
print(" 최적 Threshold 결과")
print(f"{'='*100}\n")

print(f"Accuracy 최대화:        Threshold = {best_accuracy_th:.3f}  (Accuracy = {max(accuracies):.4f})")
print(f"F1 Score 최대화:        Threshold = {best_f1_th:.3f}  (F1 = {max(f1_scores):.4f})")
print(f"Precision-Recall 균형:  Threshold = {balanced_th:.3f}")

In [None]:
thresholds = [best_accuracy_th]
test_thresholds(probs, clf_targets, thresholds)
plot_score_distributions(probs, clf_targets, bins=40, density=True, th_lines=thresholds)

In [None]:
def print_regression_summary(regs, reg_targets, output_columns):
    
    results = []
    for i, col_name in enumerate(output_columns):
        pred = regs[:, i]
        true = reg_targets[:, i]
        
        mse = torch.mean((pred - true) ** 2).item()
        mae = torch.mean(torch.abs(pred - true)).item()
        rmse = torch.sqrt(torch.mean((pred - true) ** 2)).item()
        
        ss_res = torch.sum((true - pred) ** 2)
        ss_tot = torch.sum((true - torch.mean(true)) ** 2)
        r2 = 1 - (ss_res / ss_tot).item()
        
        results.append({
            'Column': col_name,
            'MSE': mse,
            'MAE': mae,
            'RMSE': rmse,
            'R²': r2
        })
    
    df = pd.DataFrame(results)
    print(df.round(4))
    
    return df

logits, regs, clf_targets, reg_targets = forward_loader(models[0], testloader)
regression_summary = print_regression_summary(regs, reg_targets, output_columns)

In [None]:
# print(testset)
# show_set_stat(testset)

In [None]:
# testloader = DataLoader(testset, batch_size=batch_size)
# test_logs = trainer.test(mlp, testloader)

In [None]:
# test_logits, test_regs, test_clf_targets, test_reg_targets = forward_loader(mlp, testloader)
# test_probs = test_logits.sigmoid()
# test_thresholds(test_probs, test_clf_targets, thresholds)
# plot_score_distributions(test_probs, test_clf_targets, bins=40, density=True, th_lines=thresholds)

In [None]:
# pre_columns = seq_columns[:12] + goutallier_columns[:4] + goutallier_columns[8:12]
# pre_columns

In [None]:
# mean_columns = [column for column in columns if column not in static_columns + pre_columns + output_columns]
# mean_columns

In [None]:
# mean_table = pd.read_csv("X_train.csv")
# mean_table["age_group"] = mean_table["나이"] // 10 * 10

# group_columns = ["성별 (M:1,F:2)", "age_group"]
# mean_table = mean_table.groupby(group_columns)[mean_columns].mean().reset_index()
# mean_table

In [None]:
# def get_pre_with_mean_dataset(split):
#   assert split in ["val", "test"]
#   X = pd.read_csv(f"X_{split}.csv")
#   y = pd.read_csv(f"y_{split}.csv")

#   indices = pd.concat([X["성별 (M:1,F:2)"], X["나이"] // 10 * 10], axis=1)
#   indices.columns = group_columns
#   mean_values = indices.merge(mean_table, on=group_columns, how="left")

#   X[mean_columns] = mean_values[mean_columns]
#   X_np = X[input_columns].to_numpy()
#   y_np = pd.concat([y, X[output_columns]], axis=1).to_numpy()

#   X_tensor = torch.tensor(X_np, dtype=torch.float32)
#   y_tensor = torch.tensor(y_np, dtype=torch.float32)

#   return TensorDataset(X_tensor, y_tensor)

In [None]:
# val_pre_with_mean_set = get_pre_with_mean_dataset("val")
# val_pre_with_mean_loader = DataLoader(val_pre_with_mean_set, batch_size=batch_size)
# val_pre_with_mean_logs = trainer.test(mlp, val_pre_with_mean_loader)

In [None]:
# logits, regs, clf_targets, reg_targets = forward_loader(mlp, val_pre_with_mean_loader)
# probs = logits.sigmoid()
# test_thresholds(probs, clf_targets, thresholds)
# plot_score_distributions(probs, clf_targets, bins=40, density=False, th_lines=thresholds)

In [None]:
# def fgsm_attack(data, data_grad, epsilon):
#   sign_data_grad = data_grad.sign()
#   perturbed_data = data + epsilon*sign_data_grad
#   return perturbed_data

In [None]:
# num_static_columns = len(static_columns)
# num_goutallier_columns= len(goutallier_columns)
# num_static_columns, num_goutallier_columns

In [None]:
# fgsm_target_start = num_static_columns+12
# fgsm_target_end = -num_goutallier_columns
# fgsm_target_columns = input_columns[fgsm_target_start:fgsm_target_end]
# fgsm_target_columns, len(fgsm_target_columns)

In [None]:
# # 0: '6M ASES'          -> maximize
# # 1: '6M CSS'           -> maximize
# # 2: '6M KSS'           -> maximize
# # 3: '6M VAS(activity)' -> minimize
# # 4: '6M VAS(resting)'  -> minimize
# maximize_indices = [0, 1, 2]
# minimize_indices = [3, 4]

# lambda_logits = 1.0
# lambda_reg = 0.3
# epsilon = 0.5

# all_logits = []
# all_regs = []
# all_perturbed_xb = []
# all_perturbed_logits = []
# all_perturbed_regs = []

# mlp.eval()
# for xb, yb in val_pre_with_mean_loader:
#   clf_targets = yb[:, :1]
#   reg_targets = yb[:, 1:]

#   xb.requires_grad = True
#   logits, regs = mlp(xb)
#   all_logits.append(logits.detach())
#   all_regs.append(regs.detach())

#   clf_loss = F.binary_cross_entropy_with_logits(logits, clf_targets)
#   logits_dir_loss = -logits.mean()

#   reg_inc_term = -regs[:, maximize_indices].mean()
#   reg_dec_term = regs[:, minimize_indices].mean()
#   reg_dir_loss = reg_inc_term + reg_dec_term

#   loss = clf_loss + lambda_logits * logits_dir_loss + lambda_reg * reg_dir_loss

#   mlp.zero_grad()
#   loss.backward()

#   xb_grad = xb.grad.data
#   perturbed_xb = fgsm_attack(xb, xb_grad, epsilon)
#   all_perturbed_xb.append(perturbed_xb.detach())

#   perturbed_logits, perturbed_regs = mlp(perturbed_xb)
#   all_perturbed_logits.append(perturbed_logits.detach())
#   all_perturbed_regs.append(perturbed_regs.detach())

In [None]:
# logits = torch.cat(all_logits).flatten()
# regs = torch.cat(all_regs)
# logits.shape, regs.shape

In [None]:
# perturbed_logits = torch.cat(all_perturbed_logits).flatten()
# perturbed_regs = torch.cat(all_perturbed_regs)
# perturbed_logits.shape, perturbed_regs.shape

In [None]:
# probs = logits.sigmoid()
# perturbed_probs = perturbed_logits.sigmoid()
# probs.shape, perturbed_probs.shape

In [None]:
# print(label_column)
# prob_results = pd.DataFrame({
#   "probability": probs.flatten(),
#   "after probability": perturbed_probs.flatten()
# })
# prob_results["delta"] = prob_results["after probability"] - prob_results["probability"]
# prob_results

In [None]:
# prob_results.describe()

In [None]:
# all_feature_results = []
# for feature_idx in range(regs.size(1)):
#   feature_column = output_columns[feature_idx]
#   after_column = f"after {feature_column}"

#   feature_results = pd.DataFrame({
#     feature_column: regs[:, feature_idx],
#     after_column: perturbed_regs[:, feature_idx]
#   })
#   feature_results["delta"] = feature_results[after_column] - feature_results[feature_column]
#   all_feature_results.append(feature_results)

#   direction = "maximize" if feature_idx in maximize_indices else "minimize"
#   print(f"{feature_column} ({direction})")
#   print(feature_results)
#   print()

In [None]:
# for feature_column, feature_results in zip(output_columns, all_feature_results):
#   direction = "maximize" if feature_idx in maximize_indices else "minimize"
#   print(f"{feature_column} ({direction})")
#   print(feature_results.describe())
#   print()