In [None]:
import os
import gc
import subprocess
import warnings
import numpy as np
import pandas as pd
import torch
import statsmodels.api as sm

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

from tabpfn import TabPFNRegressor
from tensorflow.keras import backend as K


warnings.filterwarnings('ignore')

In [None]:
def clean_up_cuda(model):
    K.clear_session()
    del model
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    print("CUDA memory cleared and model deleted.")


In [None]:
def feature_extraction_best_corr_with_target(X, X_val, X_control, y, threshold=0.6, df_columns=None, number_of_features=40):
    if isinstance(X, np.ndarray):
        X = pd.DataFrame(X)
        if df_columns is not None:
            X.columns = df_columns
    if isinstance(y, np.ndarray):
        y = pd.Series(y)
    if isinstance(X_val, np.ndarray):
        X_val = pd.DataFrame(X_val)
        if df_columns is not None:
            X_val.columns = df_columns
    if isinstance(X_control, np.ndarray):
        X_control = pd.DataFrame(X_control)
        if df_columns is not None:
            X_control.columns = df_columns
    correlation_matrix = X.corrwith(y).abs()
    to_keep = correlation_matrix.sort_values(ascending=False).head(number_of_features).index
    X = X[to_keep]
    X_val = X_val[to_keep]
    X_control = X_control[to_keep]
    return X.to_numpy().copy(), X_val.to_numpy().copy(), X_control.to_numpy().copy()

def feature_extraction_with_Pearson(X, X_val, X_control, y, threshold=0.6, df_columns=None):
    if isinstance(X, np.ndarray):
        X = pd.DataFrame(X)
        if df_columns is not None:
            X.columns = df_columns
    if isinstance(X_val, np.ndarray):
        X_val = pd.DataFrame(X_val)
        if df_columns is not None:
            X_val.columns = df_columns
    if isinstance(X_control, np.ndarray):
        X_control = pd.DataFrame(X_control)
        if df_columns is not None:
            X_control.columns = df_columns
    correlation_matrix = X.corr().abs()
    upper = correlation_matrix.where(np.triu(np.ones(correlation_matrix.shape), k=1).astype(bool))
    to_drop = [column for column in upper.columns if any(upper[column] > threshold)]
    X = X.drop(columns=to_drop)
    X_val = X_val.drop(columns=to_drop)
    X_control = X_control.drop(columns=to_drop)
    return X.to_numpy().copy(), X_val.to_numpy().copy(), X_control.to_numpy().copy()

def feature_extration_with_PCA(X, X_val, X_control, n_components):
    pca = PCA(n_components=n_components)
    return pca.fit_transform(X), pca.transform(X_val), pca.transform(X_control)

def feature_extration_with_BE(X, X_val, X_control, y, significance_level=0.05, df_columns=None):
    if isinstance(X, np.ndarray):
        X = pd.DataFrame(X)
        if df_columns is not None:
            X.columns = df_columns
    if isinstance(X_val, np.ndarray):
        X_val = pd.DataFrame(X_val)
        if df_columns is not None:
            X_val.columns = df_columns
    if isinstance(X_control, np.ndarray):
        X_control = pd.DataFrame(X_control)
        if df_columns is not None:
            X_control.columns = df_columns
    X = X.reset_index(drop=True)
    y = y.reset_index(drop=True)
    X = sm.add_constant(X)
    while True:
        model = sm.OLS(y, X).fit()
        p_values = model.pvalues
        max_p_value = p_values.max()
        if max_p_value > significance_level:
            feature_to_remove = p_values.idxmax()
            print(f"Removing {feature_to_remove} with p-value {max_p_value:.4f}")
            X = X.drop(columns=[feature_to_remove])
            X_val = X_val.drop(columns=[feature_to_remove])
            X_control = X_control.drop(columns=[feature_to_remove])
        else:
            break
        print("Final Feature length: ", len(X.columns))
    X_ret = X.drop(columns=['const']).to_numpy().copy()
    return X_ret, X_val.to_numpy().copy(), X_control.to_numpy().copy()


In [None]:
def pearson_correlation_coefficient(y_true, y_pred):
    if len(y_true) <= 1 or len(y_pred) <= 1:
        raise ValueError("Pearson correlation requires at least two points in each array.")
    if len(y_true) != len(y_pred):
        raise ValueError("y_true and y_pred must have the same length.")

    # Convert input to pandas series (if not already)
    y_true = pd.Series(y_true).astype(int)
    y_pred = pd.Series(y_pred).astype(int)

    # Check for NaNs or infinite values
    if y_true.isna().any() or y_pred.isna().any():
        raise ValueError("Input contains NaN values.")
    if not np.isfinite(y_true).all() or not np.isfinite(y_pred).all():
        raise ValueError("Input contains infinite values.")

    # Compute and return the correlation
    result = y_true.corr(y_pred)
    if np.isnan(result):
        return 0.0
    return result

def evaluate_regression_performance(y_true, y_pred, title="", round_predictions=True):
    if round_predictions:
        y_pred = np.round(y_pred).astype(int)
        y_true = np.round(np.array(y_true)).astype(int)
    else:
        y_pred = np.array(y_pred).astype(float)
        y_true = np.array(y_true).astype(float)
    mse = mean_squared_error(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    r2  = r2_score(y_true, y_pred)
    pearson = pearson_correlation_coefficient(y_true, y_pred)
    results = {
        'mse': mse,
        'mae': mae,
        'r2': r2,
        'pearson': pearson
    }
    print(f"\n {title} Regressor Performance:")
    print(f"MSE: {mse:.4f}, MAE: {mae:.4f}, R2: {r2:.4f}, Pearson: {pearson:.4f}")
    return results

def aggregate_cv_metrics_and_print(all_results, model_name, tag="Validation"):
    aggregated = {'mse': [], 'mae': [], 'r2': [], 'pearson': []}
    for result in all_results:
        aggregated['mse'].append(result['mse'])
        aggregated['mae'].append(result['mae'])
        aggregated['r2'].append(result['r2'])
        aggregated['pearson'].append(result['pearson'])
    summary = {
        'mean_mse': np.mean(aggregated['mse']),
        'std_mse': np.std(aggregated['mse']),
        'mean_mae': np.mean(aggregated['mae']),
        'std_mae': np.std(aggregated['mae']),
        'mean_r2': np.mean(aggregated['r2']),
        'std_r2': np.std(aggregated['r2']),
        'mean_pearson': np.mean(aggregated['pearson']),
    }
    print(f"\n {model_name} Regressor Performance {tag}:")
    for k, v in summary.items():
        print(f"{k}: {v}")
    return summary

In [None]:
label_col= "age_at_assessment"
straticify_col = "NONE"
straticify_col_test = "NONE"
filter_for_sex = 0

os.makedirs("/opt/notebooks/TABPFN/02_UKB/00_data/age_label", exist_ok=True)
os.makedirs("../00_data/confounded/", exist_ok=True)
mri_table = "aparc.thickness_aparc.volume_aseg.volume.csv"
""" # Load the age data
command = "dx download file-GyGfBQ8J34gPK8XXxbjYGbg4 --output /opt/notebooks/TABPFN/02_UKB/00_data/age_label/all_ages_all_ids_healthy.csv --overwrite"
subprocess.run(command, shell=True, check=True)
#load mri data
command = f"dx download file-GyGf9vjJ34g2g9QbJQ7P1qZG --output '/opt/notebooks/TABPFN/02_UKB/00_data/deconfounded_but_age/{mri_table}' --overwrite"
subprocess.run(command, shell=True, check=True) """

# Load the age data middle
#command = "dx download file-GyJp51jJ34g246Y7bZ6j7yK4 --output /opt/notebooks/TABPFN/02_UKB/00_data/age_label/all_ages_all_ids_healthy.csv --overwrite"
#subprocess.run(command, shell=True, check=True)
#load mri data cleand and renamed but age
#command = f"dx download file-GyJp6B0J34g8xpf6Q6jz12xJ --output '/opt/notebooks/TABPFN/02_UKB/00_data/deconfounded_but_age/{mri_table}' --overwrite"
#subprocess.run(command, shell=True, check=True)

# Load the mri data
command = f"dx download file-GyQXZf0J34g7bJK06XB2vZQx --output '../00_data/confounded/{mri_table}' --overwrite"
subprocess.run(command, shell=True, check=True)

# Load the age data that is just healthy subjects from UKB
command = "dx download file-GyQY91QJ34gFY8Fv28Pqzp0v --output /opt/notebooks/TABPFN/02_UKB/00_data/age_label/healthy_subjects_train.csv --overwrite"
subprocess.run(command, shell=True, check=True)
df = pd.read_csv(f"../00_data/confounded/{mri_table}")
df["ID"] = df["ID"].astype(str)
label_df = pd.read_csv("/opt/notebooks/TABPFN/02_UKB/00_data/age_label/healthy_subjects_train.csv")
label_df["ID"] = label_df["ID"].astype(str)
#just get the ids where the sex is 0
print(label_df["sex"].value_counts())
label_df = label_df[label_df["sex"] == filter_for_sex]
print(label_df["sex"].value_counts())
n_splits = 5


if straticify_col != "NONE":
    label_df = label_df[['ID', label_col, straticify_col]]
    merged_df = pd.merge(df, label_df, on='ID', how='inner')
    merged_df.dropna(inplace=True)
    label_counts = merged_df[straticify_col].value_counts()

    # Include all rows for groups with fewer samples than the target threshold
    threshold = 1000  # You can adjust this threshold as needed
    small_groups = label_counts[label_counts <= threshold].index
    small_groups_df = merged_df[merged_df[straticify_col].isin(small_groups)]

    # Calculate how many more samples are needed to reach 10,000
    remaining_needed = 10000 - len(small_groups_df)

    # Sample proportionally from the larger groups
    large_groups = label_counts[label_counts > threshold].index
    large_groups_df = merged_df[merged_df[straticify_col].isin(large_groups)]

    # Stratified sampling from the remaining data
    proportional_sampled_df, _ = train_test_split(
        large_groups_df, 
        train_size=remaining_needed, 
        stratify=large_groups_df[straticify_col], 
        random_state=42
    )

    # Combine the small groups and the proportional sample
    final_sampled_df = pd.concat([small_groups_df, proportional_sampled_df])

    # Verify the result
    print(final_sampled_df[straticify_col].value_counts())
    print(f"Total samples: {len(final_sampled_df)}")
else:
    label_df = label_df[['ID', label_col]]

    merged_df = pd.merge(df, label_df, on='ID', how='inner')
    merged_df.dropna(inplace=True)
    final_sampled_df = merged_df.sample(n=10000, random_state=42)
if label_col != straticify_col and straticify_col != "NONE":
    final_sampled_df.drop(columns=[straticify_col], inplace=True)
final_sampled_df[label_col].hist()
print(len(final_sampled_df))
df_sampled = final_sampled_df

In [None]:
#load the test data complete
command = "dx download file-GyQXkxQJ34gB2qXXQyPGZv7g --output /opt/notebooks/TABPFN/02_UKB/00_data/age_label/Matched_validation_sick_healthy.csv --overwrite"
subprocess.run(command, shell=True, check=True)
df_sick_healthy = pd.read_csv("/opt/notebooks/TABPFN/02_UKB/00_data/age_label/Matched_validation_sick_healthy.csv")
#rename the column sick to label_sick
df_sick_healthy.rename(columns={"sick": "label_sick"}, inplace=True)
df_sick_healthy_sex = df_sick_healthy[df_sick_healthy["sex"] == filter_for_sex]
label_df_control = df_sick_healthy_sex[df_sick_healthy["label_sick"] == 0]
label_test_sick = df_sick_healthy_sex[df_sick_healthy["label_sick"] == 1]
label_df_control["ID"] = label_df_control["ID"].astype(int).astype(str)
label_test_sick["ID"] = label_test_sick["ID"].astype(int).astype(str)
print(label_df_control["sex"].value_counts())
print(label_test_sick["sex"].value_counts())

In [None]:
os.makedirs("/opt/notebooks/TABPFN/02_UKB/00_data/validation_data/00_National_Cohort/", exist_ok=True)
#load middle age control data
command = "dx download file-GyK09JQJ34g95zyvV9vFxQFv --output /opt/notebooks/TABPFN/02_UKB/00_data/validation_data/00_National_Cohort/all_ages_all_ids_subset_middle_age.csv --overwrite"
#subprocess.run(command, shell=True)
#load mri data
command = "dx download file-GyK08xjJ34g95zyvV9vFxQFf --output /opt/notebooks/TABPFN/02_UKB/00_data/validation_data/00_National_Cohort/aparc.thickness_aseg.volume_aparc.volume_deconfounded_but_age.csv --overwrite"
#subprocess.run(command, shell=True)

df_control = pd.read_csv(f"../00_data/confounded/{mri_table}")
df_control["ID"] = df_control["ID"].astype(str)
df_control["ID"] = df_control["ID"].str.replace("sub-", "")
if straticify_col_test != "NONE":
    label_df_control = label_df_control[['ID', straticify_col_test, label_col]]
    merged_df_control = pd.merge(df_control, label_df_control, on='ID', how='inner')
    merged_df_control.dropna(inplace=True)
    #sample 400 so that from each group 25 samples if possible
    target_samples_per_group = 25
    grouped_df = merged_df_control.groupby(straticify_col_test)

    # Sample 25 from each group if possible, otherwise sample all available
    sampled_control = grouped_df.apply(lambda x: x.sample(n=min(target_samples_per_group, len(x)), random_state=42))

    # Reset the index after sampling
    sampled_control.reset_index(drop=True, inplace=True)

    # Check if we reached the desired total number of 400 samples
    if len(sampled_control) < 400:
        print(f"Only {len(sampled_control)} samples available after balanced sampling.")
    else:
        print(f"Sampled {len(sampled_control)} rows with balanced distribution across groups.")
else:
    label_df_control = label_df_control[['ID', label_col]]
    merged_df_control = pd.merge(df_control, label_df_control, on='ID', how='inner')
    merged_df_control.dropna(inplace=True)
    sampled_control = merged_df_control
    #sample randomly 400 samples
    sampled_control = sampled_control.sample(n=1000, random_state=42)
X_control_source = sampled_control.drop(["ID", label_col], axis=1)
y_control_source = sampled_control[label_col]
control_ids = sampled_control["ID"]
y_control_source.hist()



In [None]:
column_control = df_sampled.drop([label_col, "ID"], axis=1).columns
X_control = X_control_source[column_control]
y_control_max = y_control_source.max()
y_control_min = y_control_source.min()

In [None]:
print(y_control_max, y_control_min)

In [None]:
print(len(df_sampled))
len(X_control)

In [None]:
len(X_control.columns)

In [None]:
n_splits = 5
best_pearson =0.0
best_model_path = None
best_model_type = None
# You can adjust these percentages as needed
#percentage_of_the_data = [1.0, 0.8, 0.6, 0.5, 0.2, 0.05]
percentage_of_the_data = [0.05, 0.2, 0.5, 0.6, 0.8, 1.0]
percentage_of_the_data = [0.05]
FE_strategy = ["Nothing", "BE", "PCA", "Correlation_in_Feature", "Correlation_with_target"]
FE_strategy = ["Nothing"]

# Dictionary to store aggregated CV metrics for each percentage and deconfounding strategy
percentage_dict = {}

# Lists to record individual predictions (for test and control sets)
test_predictions_records = []    # will include real outcomes and predictions (test set from CV)
control_predictions_records = [] # will include real outcomes and predictions (control set)

# We will also record random-baseline metrics (using random predictions drawn from uniform [0,1])
random_results = []            # for test set performance
random_results_eval = []       # for control set evaluation (if desired)

# For each deconfounding strategy and percentage, we will also collect model metrics
# We will collect separate lists for test-set ("cv_results") and control-set ("cv_results_eval")
for percentage in percentage_of_the_data:
    percentage_dict[percentage] = {}
    
    # Subsample the training data as needed
    if percentage == 1:
        df_sampled_subset = df_sampled.copy()
    else:
        df_sampled_subset, _ = train_test_split(
            df_sampled,
            train_size=percentage,
            random_state=42
        )
    print(f"\n #### TRAINING WITH {percentage} OF THE DATA ####")
    # Separate IDs, outcomes, and features
    ids = df_sampled_subset["ID"]
    y = df_sampled_subset[label_col]
    X = df_sampled_subset.drop(["ID", label_col], axis=1)
    
    print(f"Training data shape: {X.shape}, number of samples: {len(y)}")
    
    for FE_ext in FE_strategy:
        print(f"\n=== FE-Ext: {FE_ext} ===")
        
        # Prepare lists for storing CV metrics (for test set and for control set)
        tabpfn_results = []
        random_results_model = []  # for test-set random baseline
        
        tabpfn_results_eval = []
        random_results_eval_model = []  # for control-set random baseline (if desired)
        
        # Create a KFold object (using KFold for regression)
        kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
        fold_counter = 0
        for train_index, val_index in kf.split(X):
            fold_counter += 1
            print(f"\nFold {fold_counter}")
            X_train, X_test = X.iloc[train_index].copy(), X.iloc[val_index].copy()
            y_train, y_test = y.iloc[train_index].copy(), y.iloc[val_index].copy()
            test_ids = ids.iloc[val_index].copy()
            
            # Use the entire control set (IDs are stored in control_ids)
            X_control = X_control_source.copy()
            y_control = y_control_source.copy()
            try:
                X_control = X_control[X_train.columns]
            except Exception as e:
                print("Columns mismatch between training and control:", e)
            
            # Scale the data
            df_columns = X_train.columns
            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_test_scaled = scaler.transform(X_test)
            X_control_scaled = scaler.transform(X_control)
            
            # Apply deconfounding / feature extraction strategy
            if FE_ext == "BE":
                X_train_proc, X_test_proc, X_control_proc = feature_extration_with_BE(
                    X_train_scaled, X_test_scaled, X_control_scaled, y_train, df_columns=df_columns)
            elif FE_ext == "PCA":
                X_train_proc, X_test_proc, X_control_proc = feature_extration_with_PCA(
                    X_train_scaled, X_test_scaled, X_control_scaled, n_components=50)
            elif FE_ext == "Correlation_in_Feature":
                X_train_proc, X_test_proc, X_control_proc = feature_extraction_with_Pearson(
                    X_train_scaled, X_test_scaled, X_control_scaled, y_train, threshold=0.6, df_columns=df_columns)
            elif FE_ext == "Correlation_with_target":
                X_train_proc, X_test_proc, X_control_proc = feature_extraction_best_corr_with_target(
                    X_train_scaled, X_test_scaled, X_control_scaled, y_train, threshold=0.6, df_columns=df_columns)
            elif FE_ext == "Nothing":
                X_train_proc, X_test_proc, X_control_proc = X_train_scaled, X_test_scaled, X_control_scaled
            
            #############################
            # RANDOM BASELINE
            #############################
            n_samples_test = len(y_test)
            # Random predcition between min and max
            random_pred = np.random.uniform(y_test.min(), y_test.max(), n_samples_test)
            random_pred = np.round(random_pred)
            random_metrics = evaluate_regression_performance(y_test, random_pred, title="Random - Test")
            random_results_model.append(random_metrics)
            
            # For control set, random predictions as well:
            n_samples_control = len(y_control)
            random_control_pred = np.random.uniform(y_test.min(), y_test.max(), n_samples_control)
            random_metrics_eval = evaluate_regression_performance(y_control, random_control_pred, title="Random - Control")
            random_results_eval_model.append(random_metrics_eval)
            
            # Record random predictions in the prediction logs (for test set)
            for idx, true_val, pred_val in zip(test_ids, y_test, random_pred):
                test_predictions_records.append({
                    'ID': idx,
                    'fold': fold_counter,
                    'model': 'Random',
                    'real_outcome': true_val,
                    'predicted_outcome': pred_val,
                    'percentage': percentage,
                    'deconfounding': FE_ext
                })
            # And for control set
            for idx, true_val, pred_val in zip(control_ids, y_control, random_control_pred):
                control_predictions_records.append({
                    'ID': idx,
                    'fold': fold_counter,
                    'model': 'Random',
                    'real_outcome': true_val,
                    'predicted_outcome': pred_val,
                    'percentage': percentage,
                    'deconfounding': FE_ext
                })
            
            #############################
            # TABPFN Regressor
            #############################
            tabpfn_model = TabPFNRegressor()
            tabpfn_model.fit(X_train_proc, y_train)
            tabpfn_pred = tabpfn_model.predict(X_test_proc)
            tabpfn_metrics = evaluate_regression_performance(y_test, tabpfn_pred, title="tabpfn - Test")
            tabpfn_results.append(tabpfn_metrics)
            
            # Record tabpfn predictions (test set)
            for idx, true_val, pred_val in zip(test_ids, y_test, tabpfn_pred):
                test_predictions_records.append({
                    'ID': idx,
                    'fold': fold_counter,
                    'model': 'tabpfn',
                    'real_outcome': true_val,
                    'predicted_outcome': pred_val,
                    'percentage': percentage,
                    'deconfounding': FE_ext
                })
            # Evaluate on control set
            tabpfn_control_pred = tabpfn_model.predict(X_control_proc).flatten()
            tabpfn_control_metrics = evaluate_regression_performance(y_control, tabpfn_control_pred, title="tabpfn - Control")
            tabpfn_results_eval.append(tabpfn_control_metrics)
            # Record control predictions for tabpfn
            for idx, true_val, pred_val in zip(control_ids, y_control, tabpfn_control_pred):
                control_predictions_records.append({
                    'ID': idx,
                    'fold': fold_counter,
                    'model': 'tabpfn',
                    'real_outcome': true_val,
                    'predicted_outcome': pred_val,
                    'percentage': percentage,
                    'deconfounding': FE_ext
                })
            clean_up_cuda(tabpfn_model)
            
        #Get the best model with the best metric in resutl for control set
        for model, metric in zip(["tabpfn"], [tabpfn_control_metrics["pearson"]]):
            if metric > best_pearson:
                best_pearson = metric
                if model == "tabpfn":
                    best_model_path = f'/opt/notebooks/{percentage}_{FE_ext}_{model}_{metric}'
                    best_model_type = "tabpfn"
                    tabpfn_model.save_model(best_model_path)
                print(f"Best model saved with pearson {best_pearson} to {best_model_path}")
        # Aggregate and print performance for each model (test set)
        random_summary = aggregate_cv_metrics_and_print(random_results_model, "Random")
        tabpfn_summary = aggregate_cv_metrics_and_print(tabpfn_results, "tabpfn")

        # Aggregate for control set evaluations
        random_eval_summary = aggregate_cv_metrics_and_print(random_results_eval_model, "Random", tag="Control")
        tabpfn_eval_summary = aggregate_cv_metrics_and_print(tabpfn_results_eval, "tabpfn", tag="Control")
        
        # Save results in the dictionary
        percentage_dict[percentage][FE_ext] = {
            "Random": {
                "results": random_summary,
                "results_eval": random_eval_summary,
                "cv_results": random_results_model,
                "cv_results_eval": random_results_eval_model
            },
            "tabpfn": {
                "results": tabpfn_summary,
                "results_eval": tabpfn_eval_summary,
                "cv_results": tabpfn_results,
                "cv_results_eval": tabpfn_results_eval
            },
        }


    # Set these flags as desired
    Feature_extraction_applied = False
    Pretraining_applied = False
    # You can set these flags based on the deconfounding strategy if needed.

    all_rows = []
    log_file = "/opt/notebooks/results_regression.csv"
    # Iterate over percentages and their associated models
    for percentage, models in percentage_dict.items():
        for feat_ext, feature_summary_dict in models.items():
            for model_name, summary_dict in feature_summary_dict.items():
                # Each summary_dict contains aggregated metrics as well as CV lists.
                # Iterate over the number of folds (using the cv_results list)
                for i, (cv_result, cv_result_eval) in enumerate(zip(summary_dict["cv_results"], summary_dict["cv_results_eval"])):
                    # Prepare training (test set) row
                    row_train = {
                        "label_col": label_col,
                        "mri_table": mri_table,
                        "test_set_size": f"{(1 - percentage):.2%} (approx. of data left for test)",
                        "Feature_extraction_applied": Feature_extraction_applied,
                        "Pretraining_applied": Pretraining_applied,
                        "model_type": model_name,
                        "mse": cv_result.get("mse", None),
                        "mae": cv_result.get("mae", None),
                        "r2": cv_result.get("r2", None),
                        "pearson": cv_result.get("pearson", None),
                        "number_of_cross_validations": n_splits,
                        "cross_validation_count": i,
                        "search_term": f"{percentage}_{feat_ext}_{model_name}_train",
                        "percentage_of_data": percentage,
                        "eval_or_train": "train"
                    }
                    # Prepare evaluation (control set) row
                    row_eval = {
                        "label_col": label_col,
                        "mri_table": mri_table,
                        "test_set_size": f"{(1 - percentage):.2%} (approx. of data left for test)",
                        "Feature_extraction_applied": Feature_extraction_applied,
                        "Pretraining_applied": Pretraining_applied,
                        "model_type": model_name,
                        "mse": cv_result_eval.get("mse", None),
                        "mae": cv_result_eval.get("mae", None),
                        "r2": cv_result_eval.get("r2", None),
                        "pearson": cv_result_eval.get("pearson", None),
                        "number_of_cross_validations": n_splits,
                        "cross_validation_count": i,
                        "search_term": f"{percentage}_{feat_ext}_{model_name}_eval",
                        "percentage_of_data": percentage,
                        "eval_or_train": "eval"
                    }
                    all_rows.append(row_train)
                    all_rows.append(row_eval)

    # Convert to DataFrame and save CSV
    df_results = pd.DataFrame(all_rows)
    df_results.to_csv(log_file, index=False)
    # Convert the list of dictionaries to a DataFrame
    df_results = pd.DataFrame(all_rows)

    # Save the DataFrame to a CSV file
    df_results.to_csv(log_file, index=False)
    logs_path = "project-GqzxkVQJ34g6ygFJ4ZbvqBYF:/Esra/00_CLIP/01_training_logs/"
    label = os.environ.get("DX_JOB_ID") 
    logs_path_label = os.path.join(logs_path, label)
    dx_mkdir_command = f"dx mkdir '{logs_path_label}'"
    subprocess.run(dx_mkdir_command, shell=True)
    time_tag = pd.Timestamp.now().strftime("%Y-%m-%d_%H-%M-%S")
    command_csv = f"dx upload '{log_file}' --path '{logs_path_label}/{time_tag}_result_baseline.csv'"
    subprocess.run(command_csv, shell=True)

    df_test_predictions = pd.DataFrame(test_predictions_records)
    df_control_predictions = pd.DataFrame(control_predictions_records)

    test_csv_path = "/opt/notebooks/regression_test_predictions.csv"
    control_csv_path = "/opt/notebooks/regression_control_predictions.csv"

    df_test_predictions.to_csv(test_csv_path)
    df_control_predictions.to_csv(control_csv_path)

    print(f"Saved test predictions to {test_csv_path}")
    print(f"Saved control predictions to {control_csv_path}")
    command_csv = f"dx upload '{test_csv_path}' --path '{logs_path_label}/{time_tag}_test_predictions.csv'"
    subprocess.run(command_csv, shell=True)
    command_csv = f"dx upload '{control_csv_path}' --path '{logs_path_label}/{time_tag}_control_predictions.csv'"
    subprocess.run(command_csv, shell=True)
#upload best model
name = best_model_path.split("/")[-1]
command_csv = f"dx upload '{best_model_path}' --path '{logs_path_label}/{time_tag}_{name}'"
subprocess.run(command_csv, shell=True)


In [None]:
def terminate_instance():
    job_id = os.environ.get("DX_JOB_ID")
    if job_id:
        print(f"Terminating job: {job_id}")
        # Terminate the job using dx terminate
        subprocess.run(["dx", "terminate", job_id], check=True)

In [None]:
def match_participant(target_row, df_candidates, relax_order, used_ids,thresholds):
    df_candidates = df_candidates[~df_candidates['ID'].isin(used_ids)].copy()
    for relax_level in relax_order:
        filtered = df_candidates.copy()
        for criterion, match_exact in relax_level.items():
            if criterion in ['assessment_centre']:
                if match_exact:
                    filtered = filtered[filtered[criterion] == target_row[criterion]]
            elif criterion in ['sex']:
                if match_exact:
                    #find opposite sex
                    filtered = filtered[filtered[criterion] != target_row[criterion]]
            else:
                if match_exact:
                    diff = abs(filtered[criterion] - target_row[criterion])
                    filtered = filtered[diff <= thresholds[criterion]]
            if filtered.empty:
                break
        if not filtered.empty:
            return filtered.iloc[0]
    print("NO candidate found")
    return None
relax_order = [
    {'assessment_centre': True, 'deprivation_index': True, 'bmi': True, 'age_at_assessment': True},
    {'assessment_centre': False, 'deprivation_index': True, 'bmi': True, 'age_at_assessment': True},
    {'assessment_centre': False, 'deprivation_index': False, 'bmi': True, 'age_at_assessment': True},
    {'assessment_centre': False, 'deprivation_index': False, 'bmi': False, 'age_at_assessment': True},
    {'assessment_centre': False, 'deprivation_index': False, 'bmi': False, 'age_at_assessment': False},
]
thresholds = {
    'age_at_assessment': 2,
    'bmi': 3,
    'deprivation_index': 1
}

In [None]:
command = "dx download file-GyQZgzQJ34g4xxYbB6BJJYPB --output /opt/notebooks/merged_multitarget_df_with_demographics_wm.csv --overwrite"
subprocess.run(command, shell=True, check=True)
all_df = pd.read_csv("/opt/notebooks/merged_multitarget_df_with_demographics_wm.csv")

In [None]:
#make all that are not label_bad_memory 1 to 0 in col label_bad_memory
all_df["label_Bad_WM_Memory"] = all_df["label_Bad_WM_Memory"].apply(lambda x: 1 if x == 1 else 0)
#nans too
all_df["label_Bad_WM_Memory"] = all_df["label_Bad_WM_Memory"].fillna(0)

In [None]:
all_df["label_Bad_WM_Memory"].value_counts()

In [None]:
all_df["ID"] = all_df["ID"].astype(str)

In [None]:
#drop all IDs from all_df that are in df_sampled
all_df = all_df[~all_df["ID"].isin(df_sampled["ID"])]

In [None]:
all_df["label_Bad_WM_Memory"].value_counts()

In [None]:
#Just get the ones where label_sick is 0 for the negative groupe
all_df_neg = all_df[all_df["label_sick"] == 0]
all_df_wm_pos = all_df[all_df["label_Bad_WM_Memory"] == 1]
full_df = pd.concat([all_df_neg, all_df_wm_pos])

In [None]:
full_df.value_counts("label_Bad_WM_Memory")
full_df = full_df[full_df["sex"] == filter_for_sex]

In [None]:
from joblib import Parallel, delayed
import os
def match_for_label(label, merged_df, relax_order, thresholds):
    positive_group = merged_df[merged_df[label] == 1]
    negative_group = merged_df[merged_df[label] == 0]

    matches = []
    used_ids = set() 
    for _, target_row in positive_group.iterrows():
        matched = match_participant(target_row, negative_group, relax_order, used_ids, thresholds)
        if matched is not None:
            matches.append(matched)
            used_ids.add(matched['ID'])
    matched_df = pd.DataFrame(matches)
    return label, matched_df, positive_group

#label_cols = [col for col in df.columns if "label" in col]
label_cols = ["label_Bad_WM_Memory"]

results = Parallel(n_jobs=-1)(delayed(match_for_label)(label, full_df, relax_order, thresholds) for label in label_cols)

In [None]:
for label, matched_df, positive_group in results:
    print(f"Matched {len(matched_df)} participants for label {label}")
    print(f"Positive group size: {len(positive_group)}")
    print(f"Negative group size: {len(matched_df)}")
    print(f"Matched group size: {len(matched_df)}")
    df = pd.concat([positive_group, matched_df])
df["ID"] = df["ID"].astype(int).astype(str)
print(df["sex"].value_counts())
print(df["label_Bad_WM_Memory"].value_counts())
#df = df[df["sex"] == filter_for_sex]

bad_wm = df[df["label_Bad_WM_Memory"] == 1]
good_wm = df[df["label_Bad_WM_Memory"] == 0]

bad_wm["age_at_assessment"].hist()
good_wm["age_at_assessment"].hist()

In [None]:
#predict the age for each group with the best model
mir_df = pd.read_csv(f"../00_data/confounded/{mri_table}")
mir_df["ID"] = mir_df["ID"].str.replace("sub-", "")
merged_bad_wm = pd.merge(mir_df, bad_wm, on='ID', how='inner')
merged_good_wm = pd.merge(mir_df, good_wm, on='ID', how='inner')
y_bad_wm = merged_bad_wm[label_col]
X_bad_wm = merged_bad_wm[mir_df.columns]
X_bad_wm.drop(columns=["ID"], inplace=True)

y_good_wm = merged_good_wm[label_col]
X_good_wm = merged_good_wm[mir_df.columns]
X_good_wm.drop(columns=["ID"], inplace=True)

#predict the age for each group with the best model
# Load the model
if best_model_type == "tabpfn":
    model = TabPFNRegressor()
    model.load_model(best_model_path)

# Scale the data
scaler = StandardScaler()
X_bad_wm_scaled = scaler.fit_transform(X_bad_wm)
X_good_wm_scaled = scaler.transform(X_good_wm)

# Apply deconfounding / feature extraction strategy
#skkip for now

# Predict the age
if best_model_type == "tabpfn":
    y_pred_bad_wm = model.predict(X_bad_wm_scaled)
    y_pred_good_wm = model.predict(X_good_wm_scaled)

# save to csv
bad_wm["predicted_age"] = y_pred_bad_wm
good_wm["predicted_age"] = y_pred_good_wm

bad_wm.to_csv(f"/opt/notebooks/bad_wm_predicted.csv", index=False)
good_wm.to_csv(f"/opt/notebooks/good_wm_predicted.csv", index=False)

#upload the files
command = f"dx upload '/opt/notebooks/bad_wm_predicted.csv' --path '{logs_path_label}/{time_tag}_bad_wm_predicted.csv'"
subprocess.run(command, shell=True)
command = f"dx upload '/opt/notebooks/good_wm_predicted.csv' --path '{logs_path_label}/{time_tag}_good_wm_predicted.csv'"
subprocess.run(command, shell=True)


In [None]:
print(bad_wm["predicted_age"].describe())
bad_wm["age_at_assessment"].describe()

In [None]:
#terminate_instance()