In [1]:
"""
This script:
1. Loads metadata (clinical or sample labels).
2. Loads CDR3 data from hundreds of parquet files generated by your CDR3 extraction algorithm (aho_6RF or ROAR).
3. Builds a feature matrix where:
each row = a sample
each column = one CDR3 sequence
values = counts of that sequence
4. Keeps only the "top N" most informative CDR3 sequences.
5. Runs machine-learning models (LogReg, RF, SVM, KNN, XGBoost).
6. Performs SMOTE to balance the classes.
7. Evaluates models (Balanced accuracy, Precision, Recall, F1, ROC-AUC).
8. Runs feature selection methods (ANOVA, Mutual Info, RFE).
9.  Finds the best k features and best selector.
10. Saes results to cache folders, so you don’t repeat expensive computations.

11. Uses logging to write everything to a log file
"""

'\nThis script:\n1. Loads metadata (clinical or sample labels).\n2. Loads CDR3 data from hundreds of parquet files generated by your CDR3 extraction algorithm (aho_6RF or ROAR).\n3. Builds a feature matrix where:\neach row = a sample\neach column = one CDR3 sequence\nvalues = counts of that sequence\n4. Keeps only the "top N" most informative CDR3 sequences.\n5. Runs machine-learning models (LogReg, RF, SVM, KNN, XGBoost).\n6. Performs SMOTE to balance the classes.\n7. Evaluates models (Balanced accuracy, Precision, Recall, F1, ROC-AUC).\n8. Runs feature selection methods (ANOVA, Mutual Info, RFE).\n9.  Finds the best k features and best selector.\n10. Saes results to cache folders, so you don’t repeat expensive computations.\n\n11. Uses logging to write everything to a log file\n'

In [2]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
ML pipeline for AIRRPORT CDR3 features with enhanced models and plots
Author: Linoy Menda
Date: 2025-12-02
"""

import pandas as pd
from pathlib import Path
import pickle
from concurrent.futures import ProcessPoolExecutor
from collections import Counter
from itertools import chain
from imblearn.over_sampling import SMOTE
import numpy as np
import logging
from datetime import datetime
from hashlib import sha1
import json
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import (
    balanced_accuracy_score, roc_auc_score,
    precision_score, recall_score, f1_score,
    confusion_matrix
)
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier
from sklearn.feature_selection import SelectKBest, f_classif, mutual_info_classif, RFE

# --- Configuration ---
RESOURCE = "biopsy"  # biopsy or blood
ALGO_VERSION = 'aho_6RF'
RESOURCE = RESOURCE.lower()

BASE_PATH = Path("/dsi/efroni-lab/AIRRPORT/bulk_IBD")
RESOURCE_PATHS = {
    "biopsy": BASE_PATH / "biopsy/CDR3_match_results" / ALGO_VERSION,
    "blood": BASE_PATH / "blood/CDR3_match_results" / ALGO_VERSION,
}

if RESOURCE not in RESOURCE_PATHS:
    raise ValueError(f"Invalid RESOURCE: {RESOURCE}. Must be one of {list(RESOURCE_PATHS.keys())}.")
if ALGO_VERSION not in ['aho_6RF', 'roar']:
    raise ValueError(f"Invalid ALGO_VERSION: {ALGO_VERSION}. Must be 'aho_6RF' or 'roar'.")

ALGO_RESULTS_PATH = RESOURCE_PATHS[RESOURCE]
METADATA_FILE = ALGO_RESULTS_PATH / f"metadata_IBD{'_blood' if RESOURCE=='blood' else ''}.csv"

# --- Logging ---
logs_dir = Path.cwd() / "logs"
logs_dir.mkdir(parents=True, exist_ok=True)
log_file = logs_dir / f"ml_pipeline_{RESOURCE}_{ALGO_VERSION}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logger = logging.getLogger("airrport.ml_pipeline")
logger.setLevel(logging.INFO)
logger.propagate = False
if logger.handlers:
    for h in list(logger.handlers):
        logger.removeHandler(h)
file_handler = logging.FileHandler(log_file)
console_handler = logging.StreamHandler()
fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s","%Y-%m-%d %H:%M:%S")
file_handler.setFormatter(fmt)
console_handler.setFormatter(fmt)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
logger.info(f"Logging to: {log_file}")

# --- Cache ---
CACHE_ROOT = Path.cwd() / "cache"
CACHE_ROOT.mkdir(parents=True, exist_ok=True)

# --- Utility Functions ---
def _hash_bytes(b: bytes) -> str:
    return sha1(b).hexdigest()[:12]

def _dir_fingerprint(path: Path, pattern="*.parquet") -> str:
    files = sorted(path.glob(pattern))
    h = sha1()
    for f in files:
        try:
            stat = f.stat()
            h.update(f.name.encode())
            h.update(str(stat.st_size).encode())
            h.update(str(int(stat.st_mtime)).encode())
        except FileNotFoundError:
            continue
    return h.hexdigest()[:12]

def _file_fingerprint(f: Path) -> str:
    if not f.exists():
        return "no_meta"
    stat = f.stat()
    return _hash_bytes(f"{f.name}:{stat.st_size}:{int(stat.st_mtime)}".encode())

def data_fingerprint() -> str:
    p_fp = _dir_fingerprint(ALGO_RESULTS_PATH, "*.parquet")
    m_fp = _file_fingerprint(METADATA_FILE)
    return _hash_bytes(f"{RESOURCE}|{ALGO_VERSION}|{str(ALGO_RESULTS_PATH)}|{p_fp}|{m_fp}".encode())

def get_cache_dir(target_column: str) -> Path:
    fp = data_fingerprint()
    d = CACHE_ROOT / RESOURCE / ALGO_VERSION / target_column / fp
    d.mkdir(parents=True, exist_ok=True)
    return d

def save_pickle(p: Path, obj):
    with open(p, "wb") as f:
        pickle.dump(obj, f)

def load_pickle(p: Path):
    with open(p, "rb") as f:
        return pickle.load(f)

def save_json(p: Path, obj):
    with open(p, "w") as f:
        json.dump(obj, f, indent=2)

def try_load(p: Path):
    return p.exists(), (load_pickle(p) if p.suffix == ".pkl" else (pd.read_csv(p) if p.suffix == ".csv" else json.load(open(p)) if p.suffix == ".json" else None))

# --- Metadata / CDR3 Processing ---
def analysis_for_target(target_column):
    logger.info(f"Processing target column: {target_column}")
    metadata_df = pd.read_csv(METADATA_FILE)
    if target_column not in metadata_df.columns:
        logger.warning(f"Column '{target_column}' not found in metadata.")
        exit(1)
    if metadata_df[target_column].isna().sum() > len(metadata_df) * 0.5:
        logger.warning(f"Column '{target_column}' has >50% missing values.")
        exit(1)
    unique_vals = metadata_df[target_column].dropna().unique()
    if 1 < len(unique_vals) <= 4:
        logger.info(f"Target '{target_column}' has {len(unique_vals)} unique values: {unique_vals}")
    return metadata_df

def read_parquet_file(parquet_file_path):
    df = pd.read_parquet(parquet_file_path)
    sample_name = parquet_file_path.stem.replace('matched_', '').replace('.dedup', '')
    cdr3_counts = df.groupby('CDR3_match')['count'].sum().to_dict()
    return sample_name, cdr3_counts

def load_cdr3_data(target_column, use_cache=True, parallel=True):
    cache_dir = get_cache_dir(target_column)
    cache_file = cache_dir / "cdr3_data.pkl"
    if use_cache and cache_file.exists():
        logger.info(f"Loading cached CDR3 data from {cache_file}")
        return load_pickle(cache_file)

    parquet_files = list(ALGO_RESULTS_PATH.glob("*.parquet"))
    logger.info(f"Found {len(parquet_files)} parquet files")
    cdr3_data_from_all_parquets = {}

    if parallel:
        with ProcessPoolExecutor() as executor:
            for sample_name, cdr3_counts in executor.map(read_parquet_file, parquet_files):
                cdr3_data_from_all_parquets[sample_name] = cdr3_counts
    else:
        for i, parquet_file_path in enumerate(parquet_files, 1):
            sample_name, cdr3_counts = read_parquet_file(parquet_file_path)
            cdr3_data_from_all_parquets[sample_name] = cdr3_counts
            if i % 100 == 0:
                logger.info(f"Processed file {i}/{len(parquet_files)}")

    save_pickle(cache_file, cdr3_data_from_all_parquets)
    return cdr3_data_from_all_parquets

def create_feature_matrix_with_metdata(metadata, cdr3_data, top_n=600, target_column=None, use_cache=True):
    cache_dir = get_cache_dir(target_column)
    feat_cache = cache_dir / f"features_top{top_n}.pkl"
    if use_cache and feat_cache.exists():
        logger.info(f"Loading cached features from {feat_cache}")
        return load_pickle(feat_cache)

    rows = [(s, seq, c) for s, d in cdr3_data.items() for seq, c in d.items() if isinstance(seq,str) and seq.startswith('C')]
    df_long = pd.DataFrame(rows, columns=['sample','sequence','count'])
    seq_counts_across_samples = df_long.groupby('sequence')['sample'].nunique()
    top_sequences = seq_counts_across_samples.nlargest(top_n).index
    df_long = df_long[df_long['sequence'].isin(top_sequences)]
    x_matrix_feature = df_long.pivot_table(index='sample', columns='sequence', values='count', fill_value=0)

    metadata_filtered = metadata[metadata['Run'].isin(x_matrix_feature.index)].copy()
    metadata_filtered = metadata_filtered.set_index('Run').loc[x_matrix_feature.index]
    y_metadata_target = metadata_filtered[target_column].values

    save_pickle(feat_cache, (x_matrix_feature, y_metadata_target))
    logger.info(f"Saved feature cache to {feat_cache}")
    return x_matrix_feature, y_metadata_target

# --- Plotting helpers ---
def plot_model_bars(results_df: pd.DataFrame, cache_dir: Path, top_n: int):
    df = results_df.copy().sort_values(by="Balanced Accuracy", ascending=False)

    # --- Balanced Accuracy plot ---
    plt.figure(figsize=(8,4.5))
    palette = sns.color_palette("pastel", n_colors=len(df))
    bars = plt.bar(df['Model'], df['Balanced Accuracy'], color=palette)
    plt.xticks(rotation=45, ha='right')
    plt.ylim(0,1)
    plt.ylabel("Balanced Accuracy")
    plt.title(f"Balanced Accuracy by model (top{top_n} features)")

    # Add values on top of bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, height + 0.02, f"{height:.3f}",
                 ha='center', va='bottom', fontsize=10)

    plt.tight_layout()
    plt.savefig(cache_dir / f"balanced_accuracy_top{top_n}.png", dpi=150)
    plt.close()

    # --- ROC AUC plot ---
    if 'ROC AUC' in df.columns:
        plt.figure(figsize=(8,4.5))
        roc_values = df['ROC AUC'].fillna(0)
        bars = plt.bar(df['Model'], roc_values, color=palette)
        plt.xticks(rotation=45, ha='right')
        plt.ylim(0,1)
        plt.ylabel("ROC AUC (NaN=0)")
        plt.title(f"ROC AUC by model (top{top_n} features)")

        # Add ROC AUC values on top
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2, height + 0.02, f"{height:.3f}",
                     ha='center', va='bottom', fontsize=10)

        plt.tight_layout()
        plt.savefig(cache_dir / f"roc_auc_top{top_n}.png", dpi=150)
        plt.close()

def plot_confusion_matrix(y_true, y_pred, model_name, cache_dir: Path):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(5,4))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title(f"Confusion Matrix: {model_name}")
    plt.tight_layout()
    plt.savefig(cache_dir / f"confusion_matrix_{model_name.replace(' ','_')}.png", dpi=150)
    plt.close()

def plot_feature_importance(clf, feature_names, model_name, cache_dir: Path, top_n=30):
    if not hasattr(clf, 'feature_importances_'):
        return
    importances = clf.feature_importances_
    indices = np.argsort(importances)[::-1][:top_n]
    plt.figure(figsize=(8,5))
    plt.barh(range(len(indices)), importances[indices][::-1], color='skyblue')
    plt.yticks(range(len(indices)), [feature_names[i] for i in indices][::-1])
    plt.xlabel("Feature Importance")
    plt.title(f"Top {top_n} Feature Importances: {model_name}")
    plt.tight_layout()
    plt.savefig(cache_dir / f"feature_importance_{model_name.replace(' ','_')}.png", dpi=150)
    plt.close()

# --- Model Training & Evaluation ---
def model(x, y, target_column=None, top_n=None, save_results=True):
    if not np.issubdtype(np.array(y).dtype, np.number):
        le = LabelEncoder()
        y = le.fit_transform(y)

    smote = SMOTE(random_state=42)
    X_resampled, y_resampled = smote.fit_resample(x, y)

    X_train, X_test, y_train, y_test = train_test_split(
        X_resampled, y_resampled, test_size=0.2, random_state=42, stratify=y_resampled
    )

    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    models_dict = {
        "Logistic Regression": LogisticRegression(max_iter=1000, random_state=42),
        "Random Forest": RandomForestClassifier(n_estimators=200, random_state=42),
        "SVM": SVC(probability=True, random_state=42),
        "KNN": KNeighborsClassifier(n_neighbors=5),
        "XGBoost": XGBClassifier(n_estimators=300, learning_rate=0.05, max_depth=4,
                                 subsample=0.8, colsample_bytree=0.8,
                                 eval_metric="logloss", random_state=42, use_label_encoder=False),
        "LightGBM": LGBMClassifier(n_estimators=300, learning_rate=0.05, random_state=42),
        "CatBoost": CatBoostClassifier(iterations=300, learning_rate=0.05,
                                       depth=4, verbose=0, random_seed=42)
    }

    results = []
    cache_dir = get_cache_dir(target_column)

    for name, clf in models_dict.items():
        clf.fit(X_train_scaled, y_train)
        y_pred = clf.predict(X_test_scaled)
        y_proba = clf.predict_proba(X_test_scaled)[:,1] if hasattr(clf,'predict_proba') else None
        avg = 'binary' if len(np.unique(y))==2 else 'weighted'
        metrics = {
            "Model": name,
            "Balanced Accuracy": balanced_accuracy_score(y_test, y_pred),
            "Precision": precision_score(y_test, y_pred, average=avg),
            "Recall": recall_score(y_test, y_pred, average=avg),
            "F1": f1_score(y_test, y_pred, average=avg),
            "ROC AUC": roc_auc_score(y_test, y_proba) if (y_proba is not None and len(np.unique(y))==2) else np.nan
        }
        results.append(metrics)

        # Extra plots
        plot_confusion_matrix(y_test, y_pred, name, cache_dir)
        if name in ["Random Forest","XGBoost","LightGBM","CatBoost"]:
            plot_feature_importance(clf, x.columns if hasattr(x,'columns') else np.arange(x.shape[1]), name, cache_dir, top_n=30)

    results_df = pd.DataFrame(results).sort_values(by="Balanced Accuracy", ascending=False)
    if save_results and target_column is not None and top_n is not None:
        results_df.to_csv(cache_dir / f"model_results_top{top_n}.csv", index=False)
        plot_model_bars(results_df, cache_dir, top_n)

    return X_train, X_test, y_train, y_test, results_df


2025-12-03 09:37:38 | INFO | Logging to: /home/dsi/linoym/airrport/Thesis/bulk_analysis/logs/ml_pipeline_biopsy_aho_6RF_20251203_093738.log


In [3]:
# --- Main execution ---
if __name__ == "__main__":
    # Define your target column and top_n features
    target_column = "ibd_clinicianmeasure_inactive_active"   # <-- change to the actual column in your metadata
    #demographics_gender
    #ibd_disease
    #ibd_clinicianmeasure_inactive_active
    top_n = 600

    # Load metadata
    metadata_df = analysis_for_target(target_column)

    # Load or compute CDR3 data
    cdr3_data = load_cdr3_data(target_column, use_cache=True, parallel=True)

    # Create feature matrix
    X, y = create_feature_matrix_with_metdata(metadata_df, cdr3_data, top_n=top_n, target_column=target_column)

    # Train models and generate plots
    X_train, X_test, y_train, y_test, results_df = model(X, y, target_column=target_column, top_n=top_n)

    logger.info("Pipeline completed successfully.")


2025-12-03 09:38:31 | INFO | Processing target column: ibd_clinicianmeasure_inactive_active
2025-12-03 09:38:31 | INFO | Target 'ibd_clinicianmeasure_inactive_active' has 2 unique values: ['Active' 'Inactive']
2025-12-03 09:38:32 | INFO | Found 2488 parquet files
2025-12-03 09:39:27 | INFO | Saved feature cache to /home/dsi/linoym/airrport/Thesis/bulk_analysis/cache/biopsy/aho_6RF/ibd_clinicianmeasure_inactive_active/355076c4ca2e/features_top600.pkl
Parameters: { "use_label_encoder" } are not used.

  bst.update(dtrain, iteration=i, fobj=obj)


[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.064965 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 153000
[LightGBM] [Info] Number of data points in the train set: 3662, number of used features: 600
[LightGBM] [Info] Start training from score -1.098339
[LightGBM] [Info] Start training from score -1.098339
[LightGBM] [Info] Start training from score -1.099159


2025-12-03 09:52:01 | INFO | Pipeline completed successfully.


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Improved ML pipeline for AIRRPORT CDR3 features
Author: Generated for Linoy Menda
Date: 2025-12-03

Key improvements included:
- Normalize CDR3 counts by sample depth
- Scale features before SMOTE (SMOTE after scaling)
- Stratified cross-validation (with CV metrics)
- Feature selection (SelectKBest with mutual information)
- Hyperparameter tuning (RandomizedSearchCV) for heavy models
- Ensemble VotingClassifier of tuned estimators
- SHAP explainability (optional)
- Better caching, logging and artifact saving

Usage: run as script. Edit CONFIG section to suit paths/targets.
"""

import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
from pathlib import Path
import pickle
import json
import logging
from datetime import datetime
from hashlib import sha1
from concurrent.futures import ProcessPoolExecutor
from collections import Counter
from itertools import chain

from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import (
    train_test_split, StratifiedKFold, cross_val_score, RandomizedSearchCV
)
from sklearn.metrics import (
    balanced_accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
)
from sklearn.feature_selection import SelectKBest, mutual_info_classif
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier
from imblearn.over_sampling import SMOTE
from joblib import dump

import matplotlib.pyplot as plt
import seaborn as sns

# Optional: SHAP for explainability
try:
    import shap
    _HAS_SHAP = True
except Exception:
    _HAS_SHAP = False

# ---------------- CONFIG ----------------
RESOURCE = "biopsy"  # biopsy or blood
ALGO_VERSION = 'aho_6RF'
BASE_PATH = Path("/dsi/efroni-lab/AIRRPORT/bulk_IBD")
RESOURCE = RESOURCE.lower()
RESOURCE_PATHS = {
    "biopsy": BASE_PATH / "biopsy/CDR3_match_results" / ALGO_VERSION,
    "blood": BASE_PATH / "blood/CDR3_match_results" / ALGO_VERSION,
}
ALGO_RESULTS_PATH = RESOURCE_PATHS[RESOURCE]
METADATA_FILE = ALGO_RESULTS_PATH / f"metadata_IBD{'_blood' if RESOURCE=='blood' else ''}.csv"

# working dirs
WORKDIR = Path.cwd()
LOGS_DIR = WORKDIR / "logs"
CACHE_ROOT = WORKDIR / "cache"
ARTIFACTS_DIR = WORKDIR / "artifacts"
for d in (LOGS_DIR, CACHE_ROOT, ARTIFACTS_DIR):
    d.mkdir(parents=True, exist_ok=True)

# pipeline params (tweakable)
TOP_N = 600                 # number of top sequences to use as features
FEATURE_SELECT_K = 300      # number of features to keep with SelectKBest
CV_FOLDS = 5
N_JOBS = -1                 # for RandomizedSearchCV
RANDOM_STATE = 42

# ---------------- logging ----------------
log_file = LOGS_DIR / f"ml_pipeline_{RESOURCE}_{ALGO_VERSION}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logger = logging.getLogger("airrport.improved_pipeline")
logger.setLevel(logging.INFO)
if logger.handlers:
    for h in list(logger.handlers):
        logger.removeHandler(h)
fh = logging.FileHandler(log_file)
ch = logging.StreamHandler()
fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s","%Y-%m-%d %H:%M:%S")
fh.setFormatter(fmt); ch.setFormatter(fmt)
logger.addHandler(fh); logger.addHandler(ch)
logger.info(f"Logging to {log_file}")

# ---------------- utils ----------------

def _hash_bytes(b: bytes) -> str:
    return sha1(b).hexdigest()[:12]


def _dir_fingerprint(path: Path, pattern="*.parquet") -> str:
    files = sorted(path.glob(pattern))
    h = sha1()
    for f in files:
        try:
            stat = f.stat()
            h.update(f.name.encode())
            h.update(str(stat.st_size).encode())
            h.update(str(int(stat.st_mtime)).encode())
        except FileNotFoundError:
            continue
    return h.hexdigest()[:12]


def _file_fingerprint(f: Path) -> str:
    if not f.exists():
        return "no_meta"
    stat = f.stat()
    return _hash_bytes(f"{f.name}:{stat.st_size}:{int(stat.st_mtime)}".encode())


def data_fingerprint() -> str:
    p_fp = _dir_fingerprint(ALGO_RESULTS_PATH, "*.parquet")
    m_fp = _file_fingerprint(METADATA_FILE)
    return _hash_bytes(f"{RESOURCE}|{ALGO_VERSION}|{str(ALGO_RESULTS_PATH)}|{p_fp}|{m_fp}".encode())


def get_cache_dir(target_column: str) -> Path:
    fp = data_fingerprint()
    d = CACHE_ROOT / RESOURCE / ALGO_VERSION / target_column / fp
    d.mkdir(parents=True, exist_ok=True)
    return d


def save_pickle(p: Path, obj):
    with open(p, 'wb') as f:
        pickle.dump(obj, f)


def load_pickle(p: Path):
    with open(p, 'rb') as f:
        return pickle.load(f)

# ---------------- data loading ----------------

def analysis_for_target(target_column: str) -> pd.DataFrame:
    logger.info(f"Loading metadata from {METADATA_FILE}")
    metadata_df = pd.read_csv(METADATA_FILE)
    if target_column not in metadata_df.columns:
        logger.error(f"Column {target_column} not in metadata")
        raise KeyError(target_column)
    metadata_df = metadata_df.dropna(subset=[target_column])
    if metadata_df.shape[0] == 0:
        logger.error("No samples left after dropping missing target values")
        raise ValueError("No samples")
    logger.info(f"Metadata loaded: {metadata_df.shape[0]} samples with target {target_column}")
    return metadata_df


def read_parquet_file(parquet_file_path: Path):
    df = pd.read_parquet(parquet_file_path)
    sample_name = parquet_file_path.stem.replace('matched_', '').replace('.dedup', '')
    cdr3_counts = df.groupby('CDR3_match')['count'].sum().to_dict()
    return sample_name, cdr3_counts


def load_cdr3_data(target_column: str, use_cache=True, parallel=True) -> dict:
    cache_dir = get_cache_dir(target_column)
    cache_file = cache_dir / "cdr3_data.pkl"
    if use_cache and cache_file.exists():
        logger.info(f"Loading cached CDR3 data from {cache_file}")
        return load_pickle(cache_file)

    parquet_files = list(ALGO_RESULTS_PATH.glob("*.parquet"))
    logger.info(f"Found {len(parquet_files)} parquet files")
    cdr3_data = {}
    if parallel:
        with ProcessPoolExecutor() as ex:
            for sample_name, c in ex.map(read_parquet_file, parquet_files):
                cdr3_data[sample_name] = c
    else:
        for p in parquet_files:
            s, c = read_parquet_file(p)
            cdr3_data[s] = c
    save_pickle(cache_file, cdr3_data)
    logger.info(f"Saved cdr3 cache to {cache_file}")
    return cdr3_data

# ---------------- feature matrix ----------------

def create_feature_matrix_with_metadata(metadata: pd.DataFrame,
                                        cdr3_data: dict,
                                        top_n: int = TOP_N,
                                        target_column: str = None,
                                        use_cache=True):
    cache_dir = get_cache_dir(target_column)
    feat_cache = cache_dir / f"features_top{top_n}_norm.pkl"
    if use_cache and feat_cache.exists():
        logger.info(f"Loading features from {feat_cache}")
        return load_pickle(feat_cache)

    # long-form dataframe
    rows = [(s, seq, c) for s, d in cdr3_data.items() for seq, c in d.items() if isinstance(seq, str) and seq.startswith('C')]
    df_long = pd.DataFrame(rows, columns=['sample','sequence','count'])
    if df_long.empty:
        logger.error("No CDR3 rows found (maybe different column name)")
        raise ValueError("empty cdr3")

    # normalize counts by sample depth to remove sequencing depth bias
    total_per_sample = df_long.groupby('sample')['count'].sum().rename('total')
    df_long = df_long.join(total_per_sample, on='sample')
    df_long['count_norm'] = df_long['count'] / df_long['total']

    # pick top sequences by number of samples where they appear
    seq_counts_across_samples = df_long.groupby('sequence')['sample'].nunique()
    top_sequences = seq_counts_across_samples.nlargest(top_n).index
    df_long = df_long[df_long['sequence'].isin(top_sequences)]

    x_matrix = df_long.pivot_table(index='sample', columns='sequence', values='count_norm', fill_value=0)

    # align metadata
    metadata_filtered = metadata[metadata['Run'].isin(x_matrix.index)].copy()
    metadata_filtered = metadata_filtered.set_index('Run').loc[x_matrix.index]
    y = metadata_filtered[target_column].values

    save_pickle(feat_cache, (x_matrix, y))
    logger.info(f"Saved features to {feat_cache}")
    return x_matrix, y

# ---------------- plotting helpers ----------------

def plot_model_bars(results_df: pd.DataFrame, cache_dir: Path, top_n: int):
    df = results_df.copy().sort_values(by="Balanced Accuracy", ascending=False)
    plt.figure(figsize=(10,4.5))
    palette = sns.color_palette('pastel', n_colors=len(df))
    bars = plt.bar(df['Model'], df['Balanced Accuracy'], color=palette)
    plt.xticks(rotation=45, ha='right')
    plt.ylim(0,1)
    plt.ylabel('Balanced Accuracy')
    plt.title(f'Balanced Accuracy by model (top{top_n})')
    for bar in bars:
        h = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, h + 0.01, f"{h:.3f}", ha='center')
    plt.tight_layout(); plt.savefig(cache_dir / f"balanced_accuracy_top{top_n}.png", dpi=150); plt.close()


def plot_confusion_matrix(y_true, y_pred, model_name, cache_dir: Path):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(4,3))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted'); plt.ylabel('Actual'); plt.title(f'Confusion: {model_name}')
    plt.tight_layout(); plt.savefig(cache_dir / f"confusion_{model_name.replace(' ','_')}.png", dpi=150); plt.close()

# ---------------- modeling ----------------

def evaluate_model(clf, X, y, cv=CV_FOLDS, scoring='balanced_accuracy'):
    skf = StratifiedKFold(n_splits=cv, shuffle=True, random_state=RANDOM_STATE)
    scores = cross_val_score(clf, X, y, scoring=scoring, cv=skf, n_jobs=N_JOBS)
    return scores


def tune_model_randomized(clf, param_distributions, X, y, cv=CV_FOLDS, n_iter=25):
    rs = RandomizedSearchCV(
        clf,
        param_distributions=param_distributions,
        n_iter=n_iter,
        scoring='balanced_accuracy',
        cv=StratifiedKFold(n_splits=cv, shuffle=True, random_state=RANDOM_STATE),
        n_jobs=N_JOBS,
        random_state=RANDOM_STATE,
        verbose=0
    )
    rs.fit(X, y)
    return rs


def build_and_evaluate(X, y, target_column: str, feature_names=None, save_artifacts=True):
    cache_dir = get_cache_dir(target_column)

    # label encode target if necessary
    if not np.issubdtype(np.array(y).dtype, np.number):
        le = LabelEncoder(); y_enc = le.fit_transform(y)
    else:
        y_enc = y

    # split once for final hold-out evaluation (stratified)
    X_train_full, X_hold, y_train_full, y_hold = train_test_split(X, y_enc, test_size=0.2, stratify=y_enc, random_state=RANDOM_STATE)

    # scale BEFORE SMOTE
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train_full)
    X_hold_scaled = scaler.transform(X_hold)

    # SMOTE on scaled training data
    sm = SMOTE(random_state=RANDOM_STATE)
    X_res, y_res = sm.fit_resample(X_train_scaled, y_train_full)
    logger.info(f"After SMOTE: {Counter(y_res)}")

    # Feature selection on resampled data
    selector = SelectKBest(mutual_info_classif, k=min(FEATURE_SELECT_K, X_res.shape[1]))
    X_res_sel = selector.fit_transform(X_res, y_res)
    X_hold_sel = selector.transform(X_hold_scaled)
    selected_features = None
    if feature_names is not None:
        selected_features = np.array(feature_names)[selector.get_support()]

    # simple candidate models
    candidates = {
        'logreg_l1': LogisticRegression(penalty='l1', solver='liblinear', max_iter=2000, random_state=RANDOM_STATE),
        'rf': RandomForestClassifier(n_estimators=200, random_state=RANDOM_STATE),
        'xgb': XGBClassifier(n_estimators=200, learning_rate=0.05, use_label_encoder=False, eval_metric='logloss', random_state=RANDOM_STATE),
        'lgb': LGBMClassifier(n_estimators=200, learning_rate=0.05, random_state=RANDOM_STATE)
    }

    # quick CV baseline
    results = []
    for name, clf in candidates.items():
        try:
            scores = evaluate_model(clf, X_res_sel, y_res, cv=CV_FOLDS)
            results.append({'Model': name, 'CV_BalAcc_Mean': scores.mean(), 'CV_BalAcc_STD': scores.std()})
            logger.info(f"{name}: CV balanced acc {scores.mean():.3f} ± {scores.std():.3f}")
        except Exception as e:
            logger.warning(f"Failed CV for {name}: {e}")

    results_df = pd.DataFrame(results).sort_values('CV_BalAcc_Mean', ascending=False)

    # Hyperparameter tuning for top candidates (rf and xgb)
    tuned = {}
    # Random Forest tuning
    rf_params = {
        'n_estimators': [200, 400, 600],
        'max_depth': [None, 5, 10],
        'min_samples_split': [2, 4, 8]
    }
    logger.info('Tuning RandomForest...')
    rf_search = tune_model_randomized(RandomForestClassifier(random_state=RANDOM_STATE), rf_params, X_res_sel, y_res, n_iter=8)
    tuned['rf'] = rf_search.best_estimator_
    logger.info(f"RF best: {rf_search.best_score_:.3f}, params: {rf_search.best_params_}")

    # XGBoost tuning
    xgb_params = {
        'n_estimators': [200, 400],
        'max_depth': [3, 4, 6],
        'learning_rate': [0.01, 0.05, 0.1],
        'subsample': [0.6, 0.8, 1.0]
    }
    logger.info('Tuning XGBoost...')
    xgb_search = tune_model_randomized(XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=RANDOM_STATE), xgb_params, X_res_sel, y_res, n_iter=8)
    tuned['xgb'] = xgb_search.best_estimator_
    logger.info(f"XGB best: {xgb_search.best_score_:.3f}, params: {xgb_search.best_params_}")

    # LightGBM quick tuning
    lgb_params = {
        'n_estimators': [200, 400],
        'num_leaves': [31, 63],
        'learning_rate': [0.01, 0.05]
    }
    logger.info('Tuning LightGBM...')
    lgb_search = tune_model_randomized(LGBMClassifier(random_state=RANDOM_STATE), lgb_params, X_res_sel, y_res, n_iter=6)
    tuned['lgb'] = lgb_search.best_estimator_
    logger.info(f"LGB best: {lgb_search.best_score_:.3f}, params: {lgb_search.best_params_}")

    # Build ensemble of tuned models
    estimators = [('rf', tuned['rf']), ('xgb', tuned['xgb']), ('lgb', tuned['lgb'])]
    ensemble = VotingClassifier(estimators=estimators, voting='soft')
    logger.info('Fitting ensemble on resampled & selected features...')
    ensemble.fit(X_res_sel, y_res)

    # Evaluate on holdout
    y_hold_pred = ensemble.predict(X_hold_sel)
    y_hold_proba = ensemble.predict_proba(X_hold_sel) if hasattr(ensemble, 'predict_proba') else None
    avg = 'binary' if len(np.unique(y_enc)) == 2 else 'weighted'
    metrics = {
        'Balanced Accuracy': balanced_accuracy_score(y_hold, y_hold_pred),
        'Precision': precision_score(y_hold, y_hold_pred, average=avg, zero_division=0),
        'Recall': recall_score(y_hold, y_hold_pred, average=avg, zero_division=0),
        'F1': f1_score(y_hold, y_hold_pred, average=avg, zero_division=0),
        'ROC AUC': roc_auc_score(y_hold, y_hold_proba[:,1]) if (y_hold_proba is not None and len(np.unique(y_enc))==2) else np.nan
    }
    logger.info(f"Holdout metrics: {metrics}")

    # save artifacts
    if save_artifacts:
        model_path = cache_dir / 'ensemble.joblib'
        dump(ensemble, model_path)
        dump(scaler, cache_dir / 'scaler.joblib')
        save_pickle(cache_dir / 'selected_features.pkl', selected_features.tolist() if selected_features is not None else None)
        results_df.to_csv(cache_dir / 'cv_baseline_results.csv', index=False)
        with open(cache_dir / 'holdout_metrics.json', 'w') as fh:
            json.dump(metrics, fh, indent=2)
        logger.info(f"Artifacts saved to {cache_dir}")

    # confusion plot
    plot_confusion_matrix(y_hold, y_hold_pred, 'ensemble_holdout', cache_dir)

    # SHAP explanation (tree-based explainer recommended)
    if _HAS_SHAP:
        try:
            logger.info('Computing SHAP values (this can take time)')
            explainer = shap.Explainer(tuned['xgb'])
            shap_values = explainer(X_hold_sel)
            shap.summary_plot(shap_values, features=X_hold_sel, feature_names=selected_features, show=False)
            plt.tight_layout(); plt.savefig(cache_dir / 'shap_summary.png', dpi=150); plt.close()
        except Exception as e:
            logger.warning(f"SHAP failed: {e}")

    # return objects useful for further analysis
    return {
        'ensemble': ensemble,
        'scaler': scaler,
        'selector': selector,
        'selected_features': selected_features,
        'holdout_metrics': metrics,
        'cv_baseline': results_df
    }

# ---------------- main ----------------
if __name__ == '__main__':
    target_column = 'demographics_gender'  # change as needed

    # load metadata
    metadata = analysis_for_target(target_column)

    # load cdr3 data
    cdr3_data = load_cdr3_data(target_column, use_cache=True, parallel=True)

    # create feature matrix (normalized counts)
    X_df, y = create_feature_matrix_with_metadata(metadata, cdr3_data, top_n=TOP_N, target_column=target_column, use_cache=True)

    logger.info(f"Feature matrix shape: {X_df.shape}")

    # run modeling
    out = build_and_evaluate(X_df.values, y, target_column=target_column, feature_names=X_df.columns)

    logger.info('Pipeline finished')
    logger.info(f"Holdout metrics: {out['holdout_metrics']}")
