## Imports

In [1]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
import pandas as pd
import numpy as np
import lightgbm as lgb
import shap
from params import (DATASET_PATH, DATASET_NAME, TARGET_COL, TEST_SIZE, VAL_SIZE, METRIC)

# Load and split data
df = pd.read_csv(DATASET_PATH, encoding='utf-8')
y = df[TARGET_COL]
X = df.drop(columns=[TARGET_COL])

X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=VAL_SIZE, random_state=42)
print(f"Dataset: {len(X)} samples, {len(X.columns)} features")

Dataset: 12259 samples, 67 features


## Fit initial model

In [2]:
params = {
    'n_estimators': 2000,
    'max_depth': 3,
    'learning_rate': 0.1,
    'reg_lambda': 20,
    'random_seed': 42,
    'early_stopping_round': 10,
    'subsample': 0.2,
    'use_quantized_grad': True,
    'force_col_wise': True,
    'n_jobs': -1,
    'verbose': -1
}
model = lgb.LGBMClassifier(**params)
model.fit(X_train, y_train, eval_set=[(X_val, y_val)])
val_pred = model.predict(X_val) if METRIC == 'accuracy' else model.predict_proba(X_val)[:, 1]
val_score = accuracy_score(y_val, val_pred) if METRIC == 'accuracy' else roc_auc_score(y_val, val_pred)
print(f"Validation Score ({METRIC.upper()}): {val_score:.4f}")

Validation Score (AUC): 0.8645


## Calculate SHAP feature contributions and feature interactions

In [3]:
feature_names = list(X_train.columns)
explainer = shap.TreeExplainer(model=model)
individual_interaction_values = shap.TreeExplainer(model).shap_interaction_values(X_train)
n_features = len(feature_names)
interaction_values = np.abs(individual_interaction_values).mean(axis=0)
interaction_pairs = []
for i in range(n_features):
    for j in range(i+1, n_features):
        interaction_strength = interaction_values[i, j]
        interaction_pairs.append((feature_names[i], feature_names[j], interaction_strength))

interaction_pairs.sort(key=lambda x: x[2], reverse=True)
top_interactions_str = "\n".join([f"- {f1} <-> {f2} (strength: {strength:.4f})" for f1, f2, strength in interaction_pairs[:100]])
print(top_interactions_str)

- Aniongap <-> GCS (strength: 0.0205)
- Sodiumserum <-> FoleymL_sum (strength: 0.0179)
- LacticAcid <-> Aniongap (strength: 0.0176)
- MotorResponse <-> anchor_age (strength: 0.0150)
- EyeOpening <-> FoleymL_sum (strength: 0.0148)
- MotorResponse <-> Aniongap (strength: 0.0141)
- NBPsmmHg <-> EyeOpening (strength: 0.0128)
- FiO2 <-> Potassiumserum (strength: 0.0122)
- FoleymL_sum <-> GCS (strength: 0.0120)
- Glucoseserum <-> GCS (strength: 0.0111)
- Aniongap <-> FoleymL_sum (strength: 0.0110)
- LacticAcid <-> FoleymL_sum (strength: 0.0104)
- anchor_age <-> GCS (strength: 0.0102)
- EyeOpening <-> LacticAcid (strength: 0.0102)
- Glucoseserum <-> Sodiumserum (strength: 0.0101)
- INR <-> GCS (strength: 0.0100)
- MotorResponse <-> FoleymL_sum (strength: 0.0098)
- MotorResponse <-> TotalPEEPLevelcmH2O (strength: 0.0098)
- FiO2 <-> FoleymL_sum (strength: 0.0090)
- FoleymL_sum <-> BUNCreatinine (strength: 0.0088)
- INR <-> eGFR (strength: 0.0087)
- Hematocritserum <-> EyeOpening (strength: 0.00

## Ask LLM for interaction constraints and feature removal based on SHAP and domain knowledge

In [4]:
import csv
from agno.agent import Agent
from typing import List, Optional
from pydantic import BaseModel, Field
from params import DATASET_PATH, DATASET_NAME, TARGET_COL, LLM_MODEL

class DiseaseMechanism(BaseModel):
    name: str = Field(..., description="Name of the disease mechanism")
    features: List[str] = Field(..., description="List of features that interact and contribute to the disease mechanism")
    rationale: str = Field(..., description="Rationale for why these features interact, based on disease mechanisms.")

class DiseaseMechanisms(BaseModel):
    disease_mechanisms: List[DiseaseMechanism] = Field(..., description="List of disease mechanisms")


with open(DATASET_PATH, "r", encoding="utf-8") as f:
    reader = csv.reader(f)
    header = next(reader)
    feature_names = [col for col in header if col.lower() != TARGET_COL.lower()]
with open(f"dataset_info/{DATASET_NAME.lower()}_info.txt", 'r') as f:
    dataset_description = f.read().strip()

# Create agent
agent = Agent(
    model=LLM_MODEL, 
    response_model=DiseaseMechanisms
)

prompt = f"""You are optimizing an ML model for {TARGET_COL} prediction through suggesting shared disease mechanisms and underlying feature interactions.
Your task is to suggest a list of disease mechanisms and the features that interact to cause the disease. This should help in understanding underlying disease mechanisms and improve model performance.
Your decision should be based on the SHAP analysis results and domain knowledge.

Dataset Description: {dataset_description}
Available Features: {feature_names}
Top Feature Interactions (from SHAP interaction values): {top_interactions_str}

Instructions
- Create 1-10 disease mechanisms depending on the dataset, SHAP analysis results and domain knowledge.
- Assign 2-20 features to each disease mechanism based on the SHAP analysis results and domain knowledge.
- Make sure to use the exact feature names from the available features list.
- Provide a rationale for why you chose the features for each disease mechanism.
"""

response = agent.run(prompt)


In [5]:
response.content.disease_mechanisms

[DiseaseMechanism(name='Neurological Impairment and Brain Dysfunction', features=['GCS', 'EyeOpening', 'VerbalResponse', 'MotorResponse', 'anchor_age', 'LacticAcid', 'INR', 'BUNCreatinine', 'TotalBilirubin', 'Hematocritserum'], rationale='Neurological status indicators (GCS components, motor and eye responses) reflect brain function and neurological impairment. Age influences neurological resilience. Elevated LacticAcid indicates hypoxia or metabolic disturbances affecting brain function. INR, BUNCreatinine, TotalBilirubin, and Hematocritserum are markers of systemic organ function that can impact neurological outcomes.'),
 DiseaseMechanism(name='Hemodynamic Instability and Circulatory Failure', features=['HRbpm', 'ABPmmmHg', 'NBPsmmHg', 'Norepinephrinemg_avg', 'Epinephrinemg_avg', 'Dobutaminemg_avg', 'TotalPEEPLevelcmH2O', 'CompliancecmH2OLseconds', 'anchor_age', 'CVPmmHg'], rationale='Vital signs such as heart rate and blood pressure, along with vasopressor usage (Norepinephrinemg_av

## Graph optimization with LLM suggested disease mechanisms

In [6]:
def create_interaction_constraints(interaction_values, threshold):
    filtered_interactions = []
    for i in range(len(interaction_values)):
        for j in range(i+1, len(interaction_values)):
            if interaction_values[i, j] > threshold:
                filtered_interactions.append((i, j))
    return filtered_interactions


In [7]:
interactions = create_interaction_constraints(interaction_values, 0.001)

In [8]:
interactions

[(0, 1),
 (0, 4),
 (0, 7),
 (0, 10),
 (0, 13),
 (0, 15),
 (0, 18),
 (0, 21),
 (0, 23),
 (0, 24),
 (0, 26),
 (0, 29),
 (0, 30),
 (0, 33),
 (0, 39),
 (0, 40),
 (0, 43),
 (0, 48),
 (1, 3),
 (1, 4),
 (1, 6),
 (1, 13),
 (1, 17),
 (1, 18),
 (1, 19),
 (1, 21),
 (1, 30),
 (1, 31),
 (1, 33),
 (1, 34),
 (1, 39),
 (1, 43),
 (2, 4),
 (2, 6),
 (2, 16),
 (2, 18),
 (2, 20),
 (2, 21),
 (2, 23),
 (2, 24),
 (2, 30),
 (2, 33),
 (2, 34),
 (2, 39),
 (2, 61),
 (2, 62),
 (3, 4),
 (3, 10),
 (3, 11),
 (3, 12),
 (3, 14),
 (3, 15),
 (3, 17),
 (3, 18),
 (3, 19),
 (3, 23),
 (3, 24),
 (3, 27),
 (3, 28),
 (3, 34),
 (3, 39),
 (3, 40),
 (3, 41),
 (3, 43),
 (3, 44),
 (3, 61),
 (3, 62),
 (3, 63),
 (3, 66),
 (4, 6),
 (4, 7),
 (4, 9),
 (4, 10),
 (4, 11),
 (4, 13),
 (4, 15),
 (4, 16),
 (4, 17),
 (4, 18),
 (4, 20),
 (4, 23),
 (4, 24),
 (4, 26),
 (4, 35),
 (4, 40),
 (4, 41),
 (4, 43),
 (4, 44),
 (4, 48),
 (4, 61),
 (5, 9),
 (5, 11),
 (5, 16),
 (5, 18),
 (5, 20),
 (5, 32),
 (5, 42),
 (5, 48),
 (5, 54),
 (6, 7),
 (6, 10),
 (6,

In [9]:
features = set()
for interaction in interactions:
    features.add(interaction[0])
interaction_constraints = {feature: [] for feature in features}


In [13]:
interaction_constraints = {feature: [] for feature in features}
for interaction in interactions:
    if len(interaction_constraints[interaction[0]]) == 0:
        interaction_constraints[interaction[0]] = [interaction[1]]
    else:
        interaction_constraints[interaction[0]].append(interaction[1])

interaction_constraints = [[key] + value for key, value in interaction_constraints.items()]
print(interaction_constraints)

[[0, 1, 4, 7, 10, 13, 15, 18, 21, 23, 24, 26, 29, 30, 33, 39, 40, 43, 48], [1, 3, 4, 6, 13, 17, 18, 19, 21, 30, 31, 33, 34, 39, 43], [2, 4, 6, 16, 18, 20, 21, 23, 24, 30, 33, 34, 39, 61, 62], [3, 4, 10, 11, 12, 14, 15, 17, 18, 19, 23, 24, 27, 28, 34, 39, 40, 41, 43, 44, 61, 62, 63, 66], [4, 6, 7, 9, 10, 11, 13, 15, 16, 17, 18, 20, 23, 24, 26, 35, 40, 41, 43, 44, 48, 61], [5, 9, 11, 16, 18, 20, 32, 42, 48, 54], [6, 7, 10, 11, 18, 20, 21, 24, 26, 28, 29, 34, 39, 43, 49], [7, 8, 15, 17, 20, 30, 34, 43], [8, 10, 13, 14, 24, 27, 32, 34, 41, 43, 49, 60, 66], [9, 19, 39, 55], [10, 11, 15, 17, 18, 19, 20, 25, 28, 30, 31, 33, 39, 40, 43, 48, 51, 61, 62], [11, 14, 15, 16, 17, 18, 23, 24, 27, 29, 30, 33, 34, 39, 49, 53, 54, 62, 63, 66], [12, 16, 20, 21, 23, 24, 26, 33, 39, 40, 41, 43, 44, 53, 54, 62], [13, 14, 28, 35, 39, 49, 60], [14, 15, 17, 24, 35, 39, 40, 51, 63, 66], [15, 16, 17, 20, 23, 24, 32, 33, 40, 43, 48, 51, 62, 65], [16, 17, 18, 20, 23, 26, 32, 33, 34, 43, 48, 54, 61, 66], [17, 18, 1

In [15]:
num_edges = 0
for interaction in interactions:
    num_edges += 1

print(num_edges)

400


In [11]:
model.set_params(interaction_constraints=interaction_constraints)
model.fit(X_train, y_train, eval_set=[(X_val, y_val)])
val_pred = model.predict(X_val) if METRIC == 'accuracy' else model.predict_proba(X_val)[:, 1]
val_score = accuracy_score(y_val, val_pred) if METRIC == 'accuracy' else roc_auc_score(y_val, val_pred)
print(f"Validation Score ({METRIC.upper()}): {val_score:.4f}")

Validation Score (AUC): 0.8602


In [140]:
# Get SHAP feature importance for feature removal
shap_values = explainer.shap_values(X_train)
if isinstance(shap_values, list):
    shap_values = shap_values[1]
base_feature_importance = np.abs(shap_values).mean(axis=0)

# Apply disease mechanism weights to interaction values
disease_mechanism_weight = 0.001
adjusted_interaction_values = interaction_values.copy()

for mechanism in response.content.disease_mechanisms:
    mechanism_features = mechanism.features
    mechanism_indices = []
    
    # Get indices for features in this mechanism
    for feature_name in mechanism_features:
        if feature_name in feature_names:
            mechanism_indices.append(feature_names.index(feature_name))
    
    # Add weight to all pairwise interactions within this mechanism
    for i in range(len(mechanism_indices)):
        for j in range(i+1, len(mechanism_indices)):
            idx1, idx2 = mechanism_indices[i], mechanism_indices[j]
            adjusted_interaction_values[idx1, idx2] += disease_mechanism_weight
            adjusted_interaction_values[idx2, idx1] += disease_mechanism_weight

# Calculate adjusted feature importance by summing interaction strengths
adjusted_feature_importance = base_feature_importance.copy()
for i in range(len(feature_names)):
    interaction_boost = np.sum(adjusted_interaction_values[i, :]) - np.sum(interaction_values[i, :])
    adjusted_feature_importance[i] += interaction_boost

print(f"Base importance range: {base_feature_importance.min():.6f} - {base_feature_importance.max():.6f}")
print(f"Adjusted importance range: {adjusted_feature_importance.min():.6f} - {adjusted_feature_importance.max():.6f}")

# Identify low importance features to remove (threshold-based)
importance_threshold = 0.0001
features_to_remove = [i for i, imp in enumerate(adjusted_feature_importance) if imp < importance_threshold]
features_to_keep = [i for i in range(len(feature_names)) if i not in features_to_remove]

print(f"Removing {len(features_to_remove)} low-importance features")
print(f"Keeping {len(features_to_keep)} features")

# Create mapping from old indices to new indices
old_to_new_idx = {}
for new_idx, old_idx in enumerate(features_to_keep):
    old_to_new_idx[old_idx] = new_idx

# Create updated interaction constraints using adjusted values
adjusted_interactions = create_interaction_constraints(adjusted_interaction_values, 0.0)
adjusted_features = set()
for interaction in adjusted_interactions:
    adjusted_features.add(interaction[0])

adjusted_interaction_constraints = {feature: [] for feature in adjusted_features}
for interaction in adjusted_interactions:
    if len(adjusted_interaction_constraints[interaction[0]]) == 0:
        adjusted_interaction_constraints[interaction[0]] = [interaction[1]]
    else:
        adjusted_interaction_constraints[interaction[0]].append(interaction[1])

adjusted_interaction_constraints = [[key] + value for key, value in adjusted_interaction_constraints.items()]

# Update interaction constraints with new indices after feature removal
updated_interaction_constraints = []
for constraint_group in adjusted_interaction_constraints:
    # Filter out removed features and update indices
    updated_group = []
    for old_idx in constraint_group:
        if old_idx in old_to_new_idx:
            updated_group.append(old_to_new_idx[old_idx])
    
    # Only keep constraint groups with at least 2 features
    if len(updated_group) >= 2:
        updated_interaction_constraints.append(updated_group)

print(f"Updated interaction constraints: {updated_interaction_constraints}")

# Apply feature removal to datasets
X_train_filtered = X_train.iloc[:, features_to_keep]
X_val_filtered = X_val.iloc[:, features_to_keep]
X_test_filtered = X_test.iloc[:, features_to_keep]

print(f"Original features: {X_train.shape[1]}, Filtered features: {X_train_filtered.shape[1]}")


Base importance range: 0.000000 - 0.298347
Adjusted importance range: 0.000000 - 0.330347
Removing 6 low-importance features
Keeping 61 features
Updated interaction constraints: [[0, 1, 2, 3, 4, 7, 10, 12, 13, 14, 15, 16, 17, 18, 21, 23, 24, 26, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 41, 43, 44, 45, 46, 54, 58], [1, 2, 3, 4, 6, 10, 11, 12, 13, 15, 17, 18, 19, 21, 26, 28, 29, 30, 31, 32, 33, 34, 37, 38, 39, 41, 43, 44, 45, 46], [2, 3, 4, 6, 13, 14, 16, 17, 18, 20, 21, 23, 24, 26, 30, 33, 34, 36, 37, 43, 44, 45, 46, 48, 55, 56, 59, 60], [3, 4, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 23, 24, 26, 27, 28, 29, 30, 31, 32, 33, 34, 37, 38, 39, 41, 42, 43, 44, 45, 53, 55, 56, 57, 60], [4, 6, 7, 9, 10, 11, 13, 14, 15, 16, 17, 18, 20, 21, 23, 24, 26, 30, 33, 34, 35, 36, 37, 38, 39, 41, 42, 43, 44, 45, 46, 48, 55, 60], [5, 6, 7, 9, 11, 16, 17, 18, 19, 20, 21, 24, 30, 32, 34, 38, 40, 46, 52, 55, 57, 60], [6, 7, 8, 9, 10, 11, 15, 16, 18, 19, 20, 21, 24, 26, 28, 29, 30, 31, 34, 35, 37, 38, 39, 40, 41



In [141]:
# Train model with both feature removal and interaction constraints
model_filtered = lgb.LGBMClassifier(**params)
model_filtered.set_params(interaction_constraints=updated_interaction_constraints)
model_filtered.fit(X_train_filtered, y_train, eval_set=[(X_val_filtered, y_val)])

val_pred_filtered = model_filtered.predict(X_val_filtered) if METRIC == 'accuracy' else model_filtered.predict_proba(X_val_filtered)[:, 1]
val_score_filtered = accuracy_score(y_val, val_pred_filtered) if METRIC == 'accuracy' else roc_auc_score(y_val, val_pred_filtered)

print(f"Filtered Model Validation Score ({METRIC.upper()}): {val_score_filtered:.4f}")
print(f"Original Model Score: {val_score:.4f}")
print(f"Improvement: {val_score_filtered - val_score:.4f}")


Filtered Model Validation Score (AUC): 0.8645
Original Model Score: 0.8645
Improvement: 0.0000
