In [None]:
BASELINELD_PATH = os.path.join("/well", "palamara", "projects", "S-LDSC_reference_files", "GRCh38", "baselineLD_v2.2")
PLINK_PATH = os.path.join("/well", "palamara", "projects", "S-LDSC_reference_files", "GRCh38", "plink_files")
TRAITGYM_PATH = os.path.join("/well", "palamara", "users", "nrw600", "contribution_prediction", "TraitGym", "results", "dataset")

In [None]:
import logging

def setup_logger(seed):
    """set useful logger set-up"""
    logging.basicConfig(
        format="%(asctime)s %(message)s", encoding="utf-8", level=logging.INFO
    )
    # logging.debug(f"Pytorch version: {torch.__version__}")
    if seed is not None:
        logging.info(f"Seed: {seed}")

In [None]:
import logging
import os

import pandas as pd


def load_baselineLD_annotations(file_path):
    """
    Load baseline LD annotations from a predefined file path.

    Returns:
    """
    df = pd.read_csv(file_path, sep="\t", compression="gzip")
    logging.info(f"Loaded {len(df)} rows from {file_path}")
    logging.info(f"Columns in the dataframe: {df.columns.tolist()}")
    logging.info(f"Dataframe head:\n{df.head()}")
    logging.info(f"Data types:\n{df.dtypes}")
    return df


def load_bim_file(file_path):
    """ """

    bim_cols = ["chrom", "SNP", "genetic_dist", "pos", "ref", "alt"]
    bim_df = pd.read_csv(file_path, sep="\t", header=None, names=bim_cols)

    return bim_df


def merge_ld_bim(ld_df, bim_df):
    """
    Merge LD and BIM dataframes on SNP column.

    Returns:
    """
    merged_df = pd.merge(ld_df, bim_df, on="SNP", how="inner")
    return merged_df


def load_traitgym_data(file_path, split):
    """
    Load TraitGym dataset splits

    Returns:
    """
    df = pd.read_parquet(file_path)
    logging.info(f"Loaded {len(df)} rows from {file_path}")
    logging.info(f"Columns in the dataframe: {df.columns.tolist()}")
    logging.info(f"Dataframe head:\n{df.head()}")
    logging.info(f"Data types:\n{df.dtypes}")
    return df


def merge_varient_features(traitgym_df, annotation_df):
    """ """
    for df in (traitgym_df, annotation_df):
        df["chrom"] = df["chrom"].astype(str)
        df["pos"] = df["pos"].astype(int)
        df["ref"] = df["ref"].astype(str)
        df["alt"] = df["alt"].astype(str)
    merged_df = pd.merge(
        traitgym_df, annotation_df, on=["chrom", "pos", "ref", "alt"], how="inner"
    )
    return merged_df


def load_data(chromosome):
    """
    Load data for a specific chromosome.

    Returns:
    """
    ld_annotations = load_baselineLD_annotations(
        os.path.join(BASELINELD_PATH, f"baselineLD.{chromosome}.annot.gz")
    )
    bim_file = load_bim_file(
        os.path.join(PLINK_PATH, f"1000G.EUR.hg38.{chromosome}.bim")
    )
    merged_data = merge_ld_bim(ld_annotations, bim_file)

    traitgym_data = load_traitgym_data(
        os.path.join(TRAITGYM_PATH, "mendelian_traits_all", f"test.parquet"),
        split="test",
    )
    merged_traitgym_data = merge_varient_features(traitgym_data, merged_data)
    convert_columns = {
        "chrom": int,
        "ref": str,
        "alt": str,
        "OMIM": str,
        "consequence": str,
        "SNP": str,
    }
    for col, dtype in convert_columns.items():
        if col in merged_traitgym_data.columns:
            merged_traitgym_data[col] = merged_traitgym_data[col].astype(dtype)
        else:
            logging.warning(f"Column {col} not found in merged_traitgym_data")
    logging.info(
        f"Object columns:\n{merged_traitgym_data.select_dtypes(include='object').columns.tolist()}"
    )

    return merged_traitgym_data

In [None]:
import logging

import xgboost as xgb
from sklearn.metrics import average_precision_score
from sklearn.model_selection import train_test_split
from utils import setup_logger

setup_logger(seed=42)  # Initialize logger with a seed for reproducibility

data = data_loading.load_data(chromosome=11)  # Load data for chromosome 1

logging.info("Data loaded successfully.")
logging.info(f"Data shape: {data.shape}")
logging.info(f"Columns: {data.columns.tolist()}")
logging.info(f"Column types:\n{data.dtypes}")
logging.info(
    f"Object columns:\n{data.select_dtypes(include='object').columns.tolist()}"
)

X = data.drop(columns=["label"])  # Features
y = data["label"]  # Target variable

# Split data
X_train, X_val, y_train, y_val = train_test_split(
    X, y, stratify=y, test_size=0.2, random_state=42
)

# Train XGBoost model
model = xgb.XGBClassifier(use_label_encoder=False, eval_metric="logloss", enable_categorical=True)
model.fit(X_train, y_train)

# Predict
y_pred = model.predict_proba(X_val)[:, 1]

# Evaluate
ap = average_precision_score(y_val, y_pred)
print(f"AUPRC: {ap:.3f}")