In [1]:
# ============================================================
# 0. IMPORTS
# ============================================================
from kan import KAN
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import pennylane as qml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler, RobustScaler,MinMaxScaler
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, balanced_accuracy_score, average_precision_score,
    matthews_corrcoef, cohen_kappa_score, brier_score_loss, roc_auc_score
)
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import KFold
import random
import copy


from sklearn.decomposition import PCA, FactorAnalysis
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.utils import resample
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer, SimpleImputer

## 1. SEEDING (REPRODUCIBILITY)

In [2]:
batch_size = 32
random_state = 42
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## 2.Data Loading & Preprocessing

In [3]:
def log_transform_skewed(df, except_cols):
    skew = df.drop(columns=except_cols).skew()
    skew_cols = skew[abs(skew) > 0.1].index.tolist()

    for col in skew_cols:
        df[col] = np.log1p(df[col])
    return df


def balance_data(df, label_col="Sickness"):
    majority = df[df[label_col] == 1]
    minority = df[df[label_col] == 0]

    minority_up = resample(minority,
                           replace=True,
                           n_samples=len(majority),
                           random_state=42)

    df_bal = pd.concat([majority, minority_up], axis=0)
    return df_bal.sample(frac=1, random_state=random_state).reset_index(drop=True)

In [4]:
df = pd.read_csv("IndianLiverPatientDataset(ILPD).csv")

df['Gender'] = df['Gender'].map({'Male': 0, 'Female': 1})
df['Sickness'] = df['Sickness'].replace(2, 0)

In [5]:
# 18 rows bạn vừa gửi
rows_18 = [
    [18,0,1.8,0.7,178,35,36,6.8,3.6,1.10,1],
    [17,0,0.9,0.2,224,36,45,6.9,4.2,1.55,1],
    [24,0,1.0,0.2,189,52,31,8.0,4.8,1.50,1],
    [60,0,2.2,1.0,271,45,52,6.1,2.9,0.90,0],
    [60,0,0.8,0.2,215,24,17,6.3,3.0,0.90,0],
    [38,1,2.6,1.2,410,59,57,5.6,3.0,0.80,0],
    [35,0,2.0,1.1,226,33,135,6.0,2.7,0.80,0],
    [11,0,0.7,0.1,592,26,29,7.1,4.2,1.40,0],
    [65,0,0.7,0.2,265,30,28,5.2,1.8,0.52,0],
    [36,0,5.3,2.3,145,32,92,5.1,2.6,1.00,0],
    [48,0,0.7,0.2,208,15,30,4.6,2.1,0.80,0],
    [65,0,1.4,0.6,260,28,24,5.2,2.2,0.70,0],
    [62,0,0.6,0.1,160,42,110,4.9,2.6,1.10,0],
    [65,0,0.8,0.2,201,18,22,5.4,2.9,1.10,0],
    [17,1,0.7,0.2,145,18,36,7.2,3.9,1.18,0],
    [62,0,0.7,0.2,162,12,17,8.2,3.2,0.60,0],
    [65,0,1.9,0.8,170,36,43,3.8,1.4,0.58,0],
    [23,1,2.3,0.8,509,28,44,6.9,2.9,0.7,0]  
]

rows_18 = pd.DataFrame(rows_18, columns=df.columns)
rows_18_df = rows_18.copy()

# Tìm index trong df gốc
matching_indices = []

for i in range(len(rows_18)):
    mask = (df == rows_18.iloc[i]).all(axis=1)
    idx = df.index[mask].tolist()
    matching_indices.append((i, idx))

matching_indices

drop_idx = [134, 102, 496, 367, 33, 34, 474, 417, 493, 105, 106, 413, 414, 145, 36, 532, 182, 411]

df= df.drop([i for i in drop_idx if i in df.index], errors="ignore")
df=df.drop_duplicates().reset_index(drop=True)

In [6]:
# df = log_transform_skewed(df, except_cols=["Sickness", "Gender"])

### Train/test split

In [7]:
train_df, test_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df["Sickness"],
    random_state=random_state
)
# # Ghép lại 18 dòng vào train
# train_df = pd.concat(
#     [train_df, rows_18_df],
#     axis=0,
#     ignore_index=True
# )

# Đoạn này Oversampling, thích thì giữ không thì chỉ cần comment lại là xong. 
# Tránh test được oversampling, val thì không sao.
train_df = balance_data(train_df)

train_split_df, val_split_df = train_test_split(
    train_df,
    test_size=0.25,                
    stratify=train_df["Sickness"],
    random_state=random_state 
)
X_train_df = train_split_df.drop("Sickness", axis=1)
y_train_df = train_split_df["Sickness"]

X_val_df = val_split_df.drop("Sickness", axis=1)
y_val_df = val_split_df["Sickness"]

X_test_df = test_df.drop("Sickness", axis=1)
y_test_df = test_df["Sickness"]

### IMPUTER, PCA, LDA 

In [None]:
# ==========================================================
# MISSING VALUE
# ==========================================================
# imputer = SimpleImputer(strategy="mean")
imputer = IterativeImputer(random_state=42)
imputer.fit(X_train_df)               

X_train_imp = imputer.transform(X_train_df)
X_val_imp   = imputer.transform(X_val_df)
X_test_imp  = imputer.transform(X_test_df)


# ==========================================================
#  FIT SCALER ONLY ON TRAIN_SPLIT
# ==========================================================
scaler = StandardScaler()
scaler.fit(X_train_imp)

X_train_scaled = scaler.transform(X_train_imp)
X_val_scaled   = scaler.transform(X_val_imp)
X_test_scaled  = scaler.transform(X_test_imp)

# Khi muốn bỏ PCA, LDA, FA
X_train_final = X_train_scaled 
X_val_final = X_val_scaled   
X_test_final = X_test_scaled  

# ==========================================================
#  FIT PCA, FA, LDA ONLY ON TRAIN_SPLIT
# ==========================================================
pca = PCA(n_components=7, random_state=42)
pca.fit(X_train_scaled)

fa = FactorAnalysis(n_components=7, random_state=42)
fa.fit(X_train_scaled)

lda = LinearDiscriminantAnalysis(n_components=1)
lda.fit(X_train_scaled, y_train_df)


# ==========================================================
#  Transform all sets
# ==========================================================
X_train_final = np.concatenate([
    pca.transform(X_train_scaled),
    fa.transform(X_train_scaled),
    lda.transform(X_train_scaled)
], axis=1)

X_val_final = np.concatenate([
    pca.transform(X_val_scaled),
    fa.transform(X_val_scaled),
    lda.transform(X_val_scaled)
], axis=1)

X_test_final = np.concatenate([
    pca.transform(X_test_scaled),
    fa.transform(X_test_scaled),
    lda.transform(X_test_scaled)
], axis=1)

###  Convert to tensors + π/2 scaling

In [9]:
X_train = torch.tensor(X_train_final, dtype=torch.float32)
X_val   = torch.tensor(X_val_final,   dtype=torch.float32)
X_test  = torch.tensor(X_test_final,  dtype=torch.float32)

# X_train *= torch.tanh(X_train)
# X_val   *= torch.tanh(X_val)
# X_test  *= torch.tanh(X_test)

X_train *= (np.pi/2)
X_val   *= (np.pi/2)
X_test  *= (np.pi/2)


y_train = torch.tensor(y_train_df.values, dtype=torch.float32).unsqueeze(1)
y_val   = torch.tensor(y_val_df.values,   dtype=torch.float32).unsqueeze(1)
y_test  = torch.tensor(y_test_df.values,  dtype=torch.float32).unsqueeze(1)


###  DataLoaders

In [10]:
train_loader = DataLoader(
    TensorDataset(X_train, y_train),
    batch_size=batch_size,
    shuffle=True
)

val_loader = DataLoader(
    TensorDataset(X_val, y_val),
    batch_size=batch_size,
    shuffle=False
)

test_loader = DataLoader(
    TensorDataset(X_test, y_test),
    batch_size=batch_size,
    shuffle=False
)

print(X_train.size())
print(y_train.size())
print(X_val.size())
print(y_val.size())
print(X_test.size())
print(y_test.size())

torch.Size([483, 15])
torch.Size([483, 1])
torch.Size([161, 15])
torch.Size([161, 1])
torch.Size([111, 15])
torch.Size([111, 1])


## 3. Quantum Layer

In [11]:
# ============================================================
# 3. QUANTUM LAYER
# ============================================================

n_qubits = 1
n_layers = 1

dev = qml.device("default.qubit", wires=n_qubits, seed=seed)

@qml.qnode(dev, interface="torch")
def qnode(inputs, weights):
    qml.AngleEmbedding(inputs, wires=range(n_qubits))
    qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

weight_shapes = {"weights": (n_layers, n_qubits)}

## 4. BINARY FOCAL LOSS

In [12]:
import torch
import torch.nn as nn

class BinaryFocalLoss(nn.Module):
    def __init__(self, alpha=0.3, gamma=1.0, reduction="mean", eps=1e-7):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.eps = eps

    def forward(self, y_pred, y_true):
        y_true = y_true.float()
        y_pred = torch.clamp(y_pred, self.eps, 1.0 - self.eps)

        pos_loss = - self.alpha * (1 - y_pred) ** self.gamma * y_true * torch.log(y_pred)
        neg_loss = - (1 - self.alpha) * y_pred ** self.gamma * (1 - y_true) * torch.log(1 - y_pred)

        loss = pos_loss + neg_loss

        if self.reduction == "mean":
            return loss.mean()
        return loss.sum()


## 5. MODEL ARCHITECTURE

In [13]:
class HybridKANQuantumModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()

        # ======================
        # KAN NETWORK
        # ======================
        self.kan = KAN(
            width=[input_dim,16,8, n_qubits],  # bạn có thể chỉnh
            grid=5,       # số điểm spline
            k=3,          # spline bậc 3 (cubic)
            seed=seed
        )

        # ======================
        # QUANTUM LAYER
        # ======================
        self.q_layer = qml.qnn.TorchLayer(qnode, weight_shapes)

        # ======================
        # OUTPUT LAYER
        # ======================
        self.fc_out = nn.Linear(n_qubits, 1)
        nn.init.xavier_uniform_(self.fc_out.weight)
        nn.init.zeros_(self.fc_out.bias)

    def forward(self, x):
        x = self.kan(x)        # ⬅️ KAN thay MLP
        x = self.q_layer(x)    # ⬅️ Quantum layer
        x = self.fc_out(x)
        return torch.sigmoid(x)


## 6. CLASS WEIGHTS Smoothing

In [14]:
# # ============================================================

# # ============================================================

# # Chuyển y_train về numpy
# y_train_np = y_train.numpy().flatten()

# # Balanced class weights
# class_weights = compute_class_weight(
#     class_weight="balanced",
#     classes=np.unique(y_train_np),
#     y=y_train_np
# )

# # Keras-style smoothing
# alpha_smooth = 0.7   # 0.7–0.9 đều hợp lý
# class_weight_dict = {
#     i: float(1 + alpha_smooth * (w - 1))
#     for i, w in enumerate(class_weights)
# }

# print("Class Weights Dict:", class_weight_dict)


In [15]:

# # Chuyển y_train về numpy
# y_train_np = y_train.numpy().flatten()
# def effective_num_weight(n, beta=0.99):
#     return (1 - beta) / (1 - beta ** n)

# n0 = (y_train_np == 0).sum()
# n1 = (y_train_np == 1).sum()

# w0 = effective_num_weight(n0, beta=0.99)
# w1 = effective_num_weight(n1, beta=0.99)

# gamma = 1.3   # ⭐ 0.3 – 0.7
# w0 = w0 ** gamma
# w1 = w1 ** (gamma)

# # Normalize
# mean_w = (w0 + w1) / 2
# class_weight_dict = {
#     0: float(w0 / mean_w),
#     1: float(w1 / mean_w)
# }


In [16]:
def effective_num_weight(n, beta=0.99):
    return (1 - beta) / (1 - beta ** n)

# Counts
y_train_np = y_train.numpy().flatten()
n0 = (y_train_np == 0).sum()
n1 = (y_train_np == 1).sum()

# ENS gốc
w0 = effective_num_weight(n0, beta=0.99)
w1 = effective_num_weight(n1, beta=0.99)

# Power scaling (làm mềm)
gamma = 2.5
w0 = w0 ** gamma
w1 = w1 ** gamma

# Keras-style smoothing (kéo về 1)
alpha_smooth = 1   # 0.5 – 0.8
w0 = 1 + alpha_smooth * (w0 - 1)
w1 = 1 + alpha_smooth * (w1 - 1)

# Normalize (mean = 1)
mean_w = (w0 + w1) / 2
class_weight_dict = {
    0: float(w0 / mean_w),
    1: float(w1 / mean_w)
}

print("ENS + smooth weights:", class_weight_dict)


ENS + smooth weights: {0: 1.0012165623709262, 1: 0.998783437629074}


In [17]:
# def effective_num_weight(n, beta=0.99):
#     return (1 - beta) / (1 - beta ** n)

# # Counts
# y_train_np = y_train.numpy().flatten()
# n0 = (y_train_np == 0).sum()
# n1 = (y_train_np == 1).sum()

# # ENS gốc
# w0 = effective_num_weight(n0, beta=0.99)
# w1 = effective_num_weight(n1, beta=0.99)

# w = np.array([w0, w1])
# T = 0.8  # 1.5 – 3.0

# w = np.exp(w / T) / np.sum(np.exp(w / T))
# w = w * 2   # giữ mean = 1

# class_weight_dict = {0: w[0], 1: w[1]}
# print("ENS + smooth weights:", class_weight_dict)


In [18]:


# def effective_num_weight(n, beta=0.99):
#     return (1 - beta) / (1 - beta ** n)

# # Counts
# y_train_np = y_train.numpy().flatten()
# n0 = (y_train_np == 0).sum()
# n1 = (y_train_np == 1).sum()

# # ENS gốc
# w0 = effective_num_weight(n0, beta=0.99)
# w1 = effective_num_weight(n1, beta=0.99)

# max_ratio = 1.5  # hoặc 2.0
# w0 = np.clip(w0, 1 / max_ratio, max_ratio)
# w1 = np.clip(w1, 1 / max_ratio, max_ratio)

# w = np.array([w0, w1])
# # T = 0.8  # 1.5 – 3.0

# # w = np.exp(w / T) / np.sum(np.exp(w / T))
# # w = w * 2   # giữ mean = 1

# class_weight_dict = {0: w[0], 1: w[1]}
# print("ENS + smooth weights:", class_weight_dict)

## 7. TRAINING SETUP

In [19]:
model = HybridKANQuantumModel(X_train.shape[1])

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = BinaryFocalLoss(alpha=0.3, gamma= 1.0)

# Early stopping
patience = 1000
best_loss = np.inf
counter = 0
best_state = None


checkpoint directory created: ./model
saving model version 0.0


## 8. TRAINING LOOP

In [20]:

# ============================================================
# 8. TRAINING LOOP
# ============================================================

epochs = 223
for epoch in range(epochs):
    # --- TRAINING PHASE ---
    model.train()
    train_loss = 0

    for x, y in train_loader:
        optimizer.zero_grad()
        out = model(x)
        
        # 1. Tính raw loss (dạng vector, vì reduction='none' trong criterion)
        raw_loss = criterion(out, y) 
        
        # 2. Lấy trọng số cho từng mẫu trong batch từ class_weight_dict
        batch_weights = torch.tensor([class_weight_dict[int(yi.item())] for yi in y], dtype=torch.float32)
        
        # 3. Nhân trọng số và tính trung bình để ra một số thực (Scalar)
        # unsqueeze(1) giúp batch_weights có kích thước [32, 1] khớp với raw_loss
        loss = (raw_loss * batch_weights.unsqueeze(1)).mean()
        
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)

    train_loss /= len(train_loader.dataset)

    # --- VALIDATION PHASE (ĐÃ SỬA LỖI) ---
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x, y in val_loader:
            out = model(x)
            
            # Tính raw loss (vector)
            raw_loss = criterion(out, y)
            
            # Lấy trọng số tương tự như training để tính loss công bằng
            batch_weights = torch.tensor([class_weight_dict[int(yi.item())] for yi in y], dtype=torch.float32)
            
            # Nhân trọng số và tính trung bình (Scalar)
            loss = (raw_loss * batch_weights.unsqueeze(1)).mean()
            
            # Cộng dồn loss (lúc này loss.item() hoạt động bình thường vì loss là scalar)
            val_loss += loss.item() * x.size(0)

    val_loss /= len(val_loader.dataset)
    print(f"Epoch {epoch+1}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")



    # # --- EARLY STOPPING ---
    # if val_loss < best_loss:
    #     best_loss = val_loss
    #     best_state = model.state_dict()
    #     counter = 0
    # else:
    #     counter += 1
    #     if counter >= patience:
    #         print("Early stopping")
    #         break

# Load lại trọng số tốt nhất sau khi training xong
# model.load_state_dict(best_state)

  self.subnode_actscale.append(torch.std(x, dim=0).detach())
  input_range = torch.std(preacts, dim=0) + 0.1
  output_range_spline = torch.std(postacts_numerical, dim=0) # for training, only penalize the spline part
  output_range = torch.std(postacts, dim=0) # for visualization, include the contribution from both spline + symbolic


Epoch 1: train_loss=0.1647, val_loss=0.1609
Epoch 2: train_loss=0.1565, val_loss=0.1530
Epoch 3: train_loss=0.1493, val_loss=0.1452
Epoch 4: train_loss=0.1433, val_loss=0.1387
Epoch 5: train_loss=0.1394, val_loss=0.1345
Epoch 6: train_loss=0.1357, val_loss=0.1315
Epoch 7: train_loss=0.1331, val_loss=0.1283
Epoch 8: train_loss=0.1304, val_loss=0.1259
Epoch 9: train_loss=0.1283, val_loss=0.1247
Epoch 10: train_loss=0.1261, val_loss=0.1224
Epoch 11: train_loss=0.1229, val_loss=0.1195
Epoch 12: train_loss=0.1200, val_loss=0.1177
Epoch 13: train_loss=0.1169, val_loss=0.1158
Epoch 14: train_loss=0.1140, val_loss=0.1141
Epoch 15: train_loss=0.1106, val_loss=0.1122
Epoch 16: train_loss=0.1069, val_loss=0.1094
Epoch 17: train_loss=0.1030, val_loss=0.1079
Epoch 18: train_loss=0.0993, val_loss=0.1067
Epoch 19: train_loss=0.0953, val_loss=0.1047
Epoch 20: train_loss=0.0917, val_loss=0.1034
Epoch 21: train_loss=0.0878, val_loss=0.1027
Epoch 22: train_loss=0.0841, val_loss=0.1016
Epoch 23: train_los

## 9. EVALUATION

In [21]:
model.eval()
y_probs, y_true = [], []

with torch.no_grad():
    for x, y in test_loader:
        out = model(x)
        y_probs.extend(out.numpy())
        y_true.extend(y.numpy())

y_probs = np.array(y_probs)
y_true = np.array(y_true)
y_pred = (y_probs >= 0.5).astype(int)

tn, fp, fn, tp = confusion_matrix(y_true, y_pred) .ravel()

print("\n=== METRICS ===")
print(f"Accuracy: {accuracy_score(y_true, y_pred):.4f}")
print(f"Precision: {precision_score(y_true, y_pred):.4f}")
print(f"Recall: {recall_score(y_true, y_pred):.4f}")
print(f"Specificity: {tn / (tn + fp):.4f}")
print(f"F1-score: {f1_score(y_true, y_pred):.4f}")
print(f"Balanced Accuracy: {balanced_accuracy_score(y_true, y_pred):.4f}")
print(f"ROC AUC: {roc_auc_score(y_true, y_probs):.4f}")
print(f"PR AUC: {average_precision_score(y_true, y_probs):.4f}")
print(f"MCC: {matthews_corrcoef(y_true, y_pred):.4f}")
print(f"Cohen Kappa: {cohen_kappa_score(y_true, y_pred):.4f}")
print(f"Brier Score: {brier_score_loss(y_true, y_probs):.4f}")


=== METRICS ===
Accuracy: 0.7297
Precision: 0.7931
Recall: 0.8519
Specificity: 0.4000
F1-score: 0.8214
Balanced Accuracy: 0.6259
ROC AUC: 0.6543
PR AUC: 0.8352
MCC: 0.2717
Cohen Kappa: 0.2688
Brier Score: 0.2116


## 10. KAN eploration

### KAN VISUALIZATION

In [22]:
# model.kan.plot()


### Feature visualization

In [23]:

# def plot_learned_function(model, X, feature_idx, n_points=200):
#     model.eval()

#     x_min = X[:, feature_idx].min()
#     x_max = X[:, feature_idx].max()
#     x_vals = torch.linspace(x_min, x_max, n_points)

#     X_base = X.mean(dim=0).repeat(n_points, 1)
#     X_base[:, feature_idx] = x_vals

#     with torch.no_grad():
#         y_vals = model(X_base).squeeze()

#     return x_vals.cpu().numpy(), y_vals.cpu().numpy()


In [24]:
# x, y = plot_learned_function(model, X_train, feature_idx=2)

# plt.plot(x, y)
# plt.xlabel("Feature 2 (scaled)")
# plt.ylabel("Model output")
# plt.title("Learned function for feature 2")
# plt.show()


### Importance sensitivity features

In [25]:
# def kan_feature_importance_sensitivity(model, X, eps=1e-2):
#     model.eval()
#     base_output = model(X).detach()

#     input_dim = X.shape[1]
#     importance = torch.zeros(input_dim)

#     for i in range(input_dim):
#         X_perturbed = X.clone()
#         X_perturbed[:, i] += eps

#         pert_output = model(X_perturbed).detach()
#         importance[i] = torch.mean(torch.abs(pert_output - base_output))

#     importance = importance / importance.sum()
#     return importance
# importance = kan_feature_importance_sensitivity(
#     model,
#     X_train
# )

# for name, score in zip(X_train_df.columns, importance):
#     print(f"{name}: {score:.3f}")


In [26]:
# # ============================
# # SIMPLIFY KAN (PRUNING)
# # ============================
# pruned_model = HybridKANQuantumModel(X_train.shape[1])
# pruned_model.load_state_dict(best_state)
# pruned_model.eval()

# with torch.no_grad():
#     _ = pruned_model(X_train)
# pruned_model.kan.prune()
# pruned_model.kan.plot()

# plt.savefig(
#     "kan_structure.png",
#     dpi=1000,                 # chuẩn thesis / paper
#     bbox_inches="tight"
# )
# plt.close()

In [27]:
# x, y = plot_learned_function(model, X_train, feature_idx=2)

# plt.plot(x, y)
# plt.xlabel("Feature 2 (scaled)")
# plt.ylabel("Model output")
# plt.title("Learned function for feature 2")
# plt.show()
