In [None]:
import pandas as pd
import numpy as np
from sdv.single_table import CTGANSynthesizer
from sdv.metadata import SingleTableMetadata
from sklearn.model_selection import train_test_split, StratifiedKFold, ParameterGrid
from sklearn.preprocessing import OneHotEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, precision_recall_curve, roc_curve, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
def fit_ctgan_synthesizer(data, epochs=300, verbose=False):
    metadata = SingleTableMetadata()
    metadata.detect_from_dataframe(data)
    synthesizer = CTGANSynthesizer(metadata=metadata, epochs=epochs, verbose=verbose)
    synthesizer.fit(data)
    return synthesizer

In [None]:
aniridia_df = pd.read_excel("2025-04-11 Aniridiia oftal'molog.xlsx")
albinism_df = pd.read_excel("2025-04-11 Al'binizm oftal'molog.xlsx")


In [None]:
def parse_age_to_years(age_str):
    if pd.isna(age_str):
        return np.nan
    s = str(age_str).lower().replace(',', '.').strip()
    if 'mes' in s or 'мес' in s:
        digits = ''.join(ch for ch in s if ch.isdigit())
        if digits:
            return int(digits) / 12.0
        else:
            return np.nan
    if 'год' in s or 'лет' in s or 'года' in s:
        nums = [int(x) for x in s.split() if x.isdigit()]
        if len(nums) == 0:
            return np.nan
        if len(nums) == 1:
            return float(nums[0])
        if len(nums) >= 2:
            years = float(nums[0]); months = float(nums[1])
            if years < 0: years = 0
            if months < 0: months = 0
            return years + months/12.0
    try:
        return float(s)
    except:
        return np.nan

In [None]:
def preprocess_dataset(df, disease_type):
    """
    Preprocess aniridia or albinism dataframe:
    """
    result = pd.DataFrame()
    # Age
    if 'Возраст пациента' in df.columns or 'Возраст' in ' '.join(df.columns):
        age_col = 'Возраст пациента' if 'Возраст пациента' in df.columns else 'Возраст'
        result['Age'] = df[age_col].apply(parse_age_to_years)
    else:
        result['Age'] = np.nan
    # Sex (Male = 1, Female = 0)
    if 'Пол' in df.columns:
        result['Sex_Male'] = df['Пол'].map({'М': 1, 'Ж': 0})
    else:
        result['Sex_Male'] = np.nan
    # Nystagmus (Yes=1, No=0)
    nyst_col = None
    for col in df.columns:
        if 'нистагм' in str(col).lower():
            nyst_col = col; break
    if nyst_col:
        result['Nystagmus'] = df[nyst_col].map({'Да': 1, 'Нет': 0})
    else:
        result['Nystagmus'] = np.nan
    # Photophobia (light sensitivity)
    photo_col = None
    for col in df.columns:
        if 'светобоязнь' in str(col).lower():
            photo_col = col; break
    if photo_col:
        result['Photophobia'] = df[photo_col].map({'Да': 1, 'Нет': 0})
    else:
        result['Photophobia'] = 0 if disease_type == 'albinism' else np.nan
    # Cataract (Yes=1, No=0)
    cat_col = None
    for col in df.columns:
        if 'катаракта' in str(col).lower() and 'пациента' in str(col).lower():
            cat_col = col; break
    if cat_col:
        result['Cataract'] = df[cat_col].map({'Да': 1, 'Нет': 0})
    else:
        result['Cataract'] = np.nan
    # Genetic test done (Yes=1, No=0)
    gen_col = None
    for col in df.columns:
        if 'молекулярно' in str(col) and 'Да' in str(col):
            gen_col = col; break
    if gen_col:
        result['GeneticTestDone'] = df[gen_col].fillna('Нет').map(lambda x: 1 if x == 'Да' else 0)
    else:
        result['GeneticTestDone'] = np.nan
    # Uses assistive device (Yes=1, No=0)
    rehab_col = None
    for col in df.columns:
        if 'реабилитац' in str(col).lower():
            rehab_col = col; break
    if rehab_col:
        result['UsesDevice'] = df[rehab_col].map({'Да': 1, 'Нет': 0})
    else:
        result['UsesDevice'] = np.nan
    # Glaucoma (Yes=1, No=0)
    gl_col = None
    for col in df.columns:
        if 'глаукома' in str(col).lower() and 'пациента' in str(col).lower():
            gl_col = col; break
    if gl_col:
        result['Glaucoma'] = df[gl_col].map({'Да': 1, 'Нет': 0})
    else:
        result['Glaucoma'] = np.nan
    syndrome_label = []
    for _, row in df.iterrows():
        label = 'None'
        notes_text = ''
        for col in df.columns:
            if ('екомендации' in str(col)) or ('заметки' in str(col)) or ('рекомендации' in str(col).lower()):
                notes_text = str(row[col]).lower(); break
        if disease_type == 'aniridia':
            if 'wagr' in notes_text or 'вагр' in notes_text:
                label = 'WAGR'
        elif disease_type == 'albinism':
            if 'hps' in notes_text or 'hermansky' in notes_text or 'гепат' in notes_text or 'пудлак' in notes_text:
                label = 'HPS'
        syndrome_label.append(label)
    result['SyndromeLabel'] = syndrome_label
    result['DiseaseType'] = disease_type
    return result

In [None]:
proc_aniridia = preprocess_dataset(aniridia_df, disease_type='aniridia')
proc_albinism = preprocess_dataset(albinism_df, disease_type='albinism')

In [None]:
combined_df = pd.concat([proc_aniridia, proc_albinism], ignore_index=True)

combined_df['Age'].fillna(combined_df['Age'].median(), inplace=True)
proc_aniridia['Age'].fillna(proc_aniridia['Age'].median(), inplace=True)
proc_albinism['Age'].fillna(proc_albinism['Age'].median(), inplace=True)

for feature in ['Sex_Male', 'Nystagmus', 'Photophobia', 'Cataract', 'GeneticTestDone']:
    combined_df[feature].fillna(0, inplace=True)
    proc_aniridia[feature].fillna(0, inplace=True)
    proc_albinism[feature].fillna(0, inplace=True)

In [None]:
glaucoma_features = ['Age', 'Sex_Male', 'Nystagmus', 'Cataract', 'GeneticTestDone']
device_features   = ['Age', 'Sex_Male', 'Nystagmus', 'Photophobia', 'GeneticTestDone']
combined_df['DiseaseType'] = combined_df['DiseaseType'].map({'aniridia': 0, 'albinism': 1})
syndrome_features = ['Age', 'Sex_Male', 'Nystagmus', 'Photophobia', 'Cataract', 'GeneticTestDone', 'DiseaseType']

gl_mask = ~pd.isna(proc_aniridia['Glaucoma'])
X_glaucoma = proc_aniridia.loc[gl_mask, glaucoma_features].values
y_glaucoma = proc_aniridia.loc[gl_mask, 'Glaucoma'].astype(int).values

dev_mask = ~pd.isna(proc_albinism['UsesDevice'])
X_device = proc_albinism.loc[dev_mask, device_features].values
y_device = proc_albinism.loc[dev_mask, 'UsesDevice'].astype(int).values

X_syndrome = combined_df[syndrome_features].values
y_syndrome = combined_df['SyndromeLabel'].map({'None':0, 'WAGR':1, 'HPS':2}).values

In [None]:
plt.figure(figsize=(8,5))
sns.histplot(proc_aniridia['Age'], label='Aniridia', color='blue', kde=True, stat="density")
sns.histplot(proc_albinism['Age'], label='Albinism', color='orange', kde=True, stat="density")
plt.legend()
plt.title("Age distribution by disease type")
plt.xlabel("Age (years)")
plt.ylabel("Density")
plt.show()

binary_features = ['Nystagmus', 'Photophobia', 'Cataract', 'GeneticTestDone', 'UsesDevice']
fig, axes = plt.subplots(2, 3, figsize=(12,6))
axes = axes.flatten()
for i, feat in enumerate(binary_features):
    if feat not in combined_df.columns:
        continue
    sns.countplot(x=combined_df[feat].fillna(0).astype(int), ax=axes[i])
    axes[i].set_title(f"{feat} (0=No, 1=Yes)")
    axes[i].set_xlabel("")
axes[-1].axis('off')
plt.tight_layout()
plt.show()

fig, axes = plt.subplots(1, 3, figsize=(15,4))

axes[0].bar(['No Glaucoma','Glaucoma'], [np.sum(y_glaucoma==0), np.sum(y_glaucoma==1)])
axes[0].set_title("Glaucoma (Aniridia)")

axes[1].bar(['No Device','Uses Device'], [np.sum(y_device==0), np.sum(y_device==1)])
axes[1].set_title("Device (Albinism)")

counts_synd = [np.sum(y_syndrome==0), np.sum(y_syndrome==1), np.sum(y_syndrome==2)]
axes[2].bar(['None','WAGR','HPS'], counts_synd)
axes[2].set_title("Syndrome classification")
plt.tight_layout()
plt.show()

logreg_params = {'C': [0.1, 1, 10], 'class_weight': [None, 'balanced'], 'max_iter': [1000]}
rf_params     = {'n_estimators': [100], 'max_depth': [None, 5, 10], 'min_samples_leaf': [1, 2, 5]}
xgb_params    = {'n_estimators': [100], 'max_depth': [3, 6], 'learning_rate': [0.1],
                 'use_label_encoder': [False], 'eval_metric': ['logloss']}


In [None]:
def train_and_evaluate(X, y, model_type='binary'):
    """
    Train Logistic Regression, Random Forest, XGBoost with CV grid search.
    """
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    best_models = {}
    results = {}
    models = [
        ('LogReg', LogisticRegression(solver='liblinear'), logreg_params),
        ('RandomForest', RandomForestClassifier(random_state=42), rf_params),
        ('XGBoost', XGBClassifier(objective='multi:softprob', num_class=3, random_state=42, use_label_encoder=False, eval_metric='mlogloss'), xgb_params)
    ]
    for name, model, param_grid in models:
        best_score = -np.inf
        best_params = None
        for params in ParameterGrid(param_grid):
            model.set_params(**params)
            scores = []
            for train_idx, val_idx in skf.split(X, y):
                X_train, X_val = X[train_idx], X[val_idx]
                y_train, y_val = y[train_idx], y[val_idx]

                if model_type == 'multi':
                    train_df = pd.DataFrame(X_train, columns=syndrome_features)
                    train_df['y'] = y_train
                    if any(y_train == 1):
                        wagr_data = train_df[train_df['y'] == 1].drop('y', axis=1)
                        if len(wagr_data) >= 2:
                            ctgan_wagr = fit_ctgan_synthesizer(wagr_data)
                            syn_wagr = ctgan_wagr.sample(max(0, 20 - len(wagr_data)))
                            if not syn_wagr.empty:
                                syn_wagr['y'] = 1
                                train_df = pd.concat([train_df, syn_wagr], ignore_index=True)
                    if any(y_train == 2):
                        hps_data = train_df[train_df['y'] == 2].drop('y', axis=1)
                        if len(hps_data) >= 2:
                            ctgan_hps = fit_ctgan_synthesizer(hps_data)
                            syn_hps = ctgan_hps.sample(max(0, 20 - len(hps_data)))
                            if not syn_hps.empty:
                                syn_hps['y'] = 2
                                train_df = pd.concat([train_df, syn_hps], ignore_index=True)
                    y_train = train_df['y'].astype(int).values
                    X_train = train_df.drop('y', axis=1).values

                elif model_type == 'binary_imbalanced':
                    train_df = pd.DataFrame(X_train, columns=device_features)
                    train_df['y'] = y_train
                    if any(y_train == 1):
                        pos_data = train_df[train_df['y'] == 1].drop('y', axis=1)
                        if len(pos_data) > 0:
                            ctgan_pos = fit_ctgan_synthesizer(pos_data)
                            syn_pos = ctgan_pos.sample(2 * len(pos_data))
                            if not syn_pos.empty:
                                syn_pos['y'] = 1
                                train_df = pd.concat([train_df, syn_pos], ignore_index=True)
                    y_train = train_df['y'].astype(int).values
                    X_train = train_df.drop('y', axis=1).values

                model.fit(X_train, y_train)
                if model_type == 'multi':
                    y_pred = model.predict(X_val)
                    score = f1_score(y_val, y_pred, average='macro')
                else:
                    if hasattr(model, "predict_proba"):
                        y_score = model.predict_proba(X_val)[:, 1]
                    else:
                        try:
                            y_score = model.decision_function(X_val)
                        except AttributeError:
                            y_score = model.predict(X_val)
                    if len(np.unique(y_val)) < 2:
                        continue
                    score = roc_auc_score(y_val, y_score)
                scores.append(score)
            avg_score = np.mean(scores) if scores else -np.inf
            if avg_score > best_score:
                best_score = avg_score
                best_params = params
        best_model = model.__class__(**best_params)
        best_model.fit(X, y)
        best_models[name] = best_model
        results[name] = {'best_params': best_params, 'cv_score': best_score}
        print(f"{name} best CV score: {best_score:.3f} with params {best_params}")
    return best_models, results

In [None]:
print("Training models for Glaucoma Prediction...")
best_models_glaucoma, cv_results_glaucoma = train_and_evaluate(X_glaucoma, y_glaucoma, model_type='binary')
print("\nTraining models for Device Need Prediction...")
best_models_device, cv_results_device = train_and_evaluate(X_device, y_device, model_type='binary_imbalanced')
print("\nTraining models for Syndromic Classification...")
best_models_syndrome, cv_results_syndrome = train_and_evaluate(X_syndrome, y_syndrome, model_type='multi')

In [None]:
Xg_train, Xg_test, yg_train, yg_test = train_test_split(X_glaucoma, y_glaucoma, test_size=0.2, stratify=y_glaucoma, random_state=1)
Xd_train, Xd_test, yd_train, yd_test = train_test_split(X_device, y_device, test_size=0.2, stratify=y_device, random_state=1)
Xs_train, Xs_test, ys_train, ys_test = train_test_split(X_syndrome, y_syndrome, test_size=0.2, stratify=y_syndrome, random_state=1)

model_g = best_models_glaucoma['XGBoost']
model_d = best_models_device['XGBoost']
model_s = best_models_syndrome['XGBoost']

y_proba_g = model_g.predict_proba(Xg_test)[:, 1]
y_pred_g = model_g.predict(Xg_test)
print("\nGlaucoma Test ROC-AUC:", roc_auc_score(yg_test, y_proba_g))
print("Glaucoma Test PR-AUC:", average_precision_score(yg_test, y_proba_g))
print("Glaucoma Test Macro-F1:", f1_score(yg_test, y_pred_g, average='macro'))

y_proba_d = model_d.predict_proba(Xd_test)[:, 1]
y_pred_d = model_d.predict(Xd_test)
print("\nDevice Test ROC-AUC:", roc_auc_score(yd_test, y_proba_d))
print("Device Test PR-AUC:", average_precision_score(yd_test, y_proba_d))
print("Device Test Macro-F1:", f1_score(yd_test, y_pred_d, average='macro'))

y_pred_s = model_s.predict(Xs_test)
print("\nSyndrome Test Macro-F1:", f1_score(ys_test, y_pred_s, average='macro'))
print("Test set classification report (Syndrome):")
print(classification_report(ys_test, y_pred_s, target_names=['None', 'WAGR', 'HPS']))

In [None]:
plt.figure()
for name, mdl in best_models_glaucoma.items():
    if hasattr(mdl, "predict_proba"):
        y_score = mdl.predict_proba(Xg_test)[:, 1]
    else:
        try:
            y_score = mdl.decision_function(Xg_test)
        except AttributeError:
            y_score = mdl.predict(Xg_test)
    fpr, tpr, _ = roc_curve(yg_test, y_score)
    plt.plot(fpr, tpr, label=name)
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
plt.title("ROC Curve - Glaucoma Prediction (Test Set)")
plt.legend()
plt.show()

precision, recall, _ = precision_recall_curve(yd_test, y_proba_d)
plt.figure()
plt.plot(recall, precision, label="XGBoost")
plt.xlabel("Recall"); plt.ylabel("Precision")
plt.title("Precision-Recall Curve - Device Need Prediction (Test Set)")
plt.legend()
plt.show()

cm_g = confusion_matrix(yg_test, y_pred_g)
cm_d = confusion_matrix(yd_test, y_pred_d)
cm_s = confusion_matrix(ys_test, y_pred_s)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

sns.heatmap(cm_g, annot=True, fmt="d", cmap="Blues", ax=axes[0], cbar=False,
            xticklabels=["No Glaucoma", "Glaucoma"], yticklabels=["No Glaucoma", "Glaucoma"])
axes[0].set_title("Glaucoma Confusion Matrix"); axes[0].set_xlabel("Predicted"); axes[0].set_ylabel("Actual")

sns.heatmap(cm_d, annot=True, fmt="d", cmap="Blues", ax=axes[1], cbar=False,
            xticklabels=["No Device", "Uses Device"], yticklabels=["No Device", "Uses Device"])
axes[1].set_title("Device Need Confusion Matrix"); axes[1].set_xlabel("Predicted"); axes[1].set_ylabel("Actual")

sns.heatmap(cm_s, annot=True, fmt="d", cmap="Blues", ax=axes[2], cbar=False,
            xticklabels=["None", "WAGR", "HPS"], yticklabels=["None", "WAGR", "HPS"])
axes[2].set_title("Syndrome Classification Confusion Matrix"); axes[2].set_xlabel("Predicted"); axes[2].set_ylabel("Actual")
plt.tight_layout()
plt.show()


importances_g = model_g.feature_importances_
plt.figure()
order = np.argsort(importances_g)[::-1]
plt.bar(range(len(importances_g)), importances_g[order])
plt.xticks(range(len(importances_g)), [glaucoma_features[i] for i in order], rotation=45)
plt.ylabel("Importance"); plt.title("Feature Importances - Glaucoma (XGBoost)")
plt.show()

importances_d = model_d.feature_importances_
plt.figure()
order = np.argsort(importances_d)[::-1]
plt.bar(range(len(importances_d)), importances_d[order], color='orange')
plt.xticks(range(len(importances_d)), [device_features[i] for i in order], rotation=45)
plt.ylabel("Importance"); plt.title("Feature Importances - Device Need (XGBoost)")
plt.show()

importances_s = model_s.feature_importances_
plt.figure()
order = np.argsort(importances_s)[::-1]
plt.bar(range(len(importances_s)), importances_s[order], color='green')
plt.xticks(range(len(importances_s)), [syndrome_features[i] for i in order], rotation=45)
plt.ylabel("Importance"); plt.title("Feature Importances - Syndrome Classification (XGBoost)")
plt.show()

import shap
explainer_g = shap.TreeExplainer(model_g)
shap_values_g = explainer_g.shap_values(X_glaucoma)


shap.summary_plot(shap_values_g, X_glaucoma, feature_names=glaucoma_features)


explainer_s = shap.TreeExplainer(model_s)
shap_values_s = explainer_s.shap_values(X_syndrome)
shap.summary_plot(shap_values_s[1], X_syndrome, feature_names=syndrome_features, plot_type="bar")
shap.summary_plot(shap_values_s[2], X_syndrome, feature_names=syndrome_features, plot_type="bar")