In [2]:
import os
import numpy as np
import pandas as pd
import glob

In [5]:
def load_dataset_full(dataset_path):
    """
    Load and concatenate N, C, and y data from train, val, and test splits.

    :param dataset_path: Path to dataset directory containing *_train.npy, *_val.npy, *_test.npy
    :return: Tuple of (N, C, y) as full concatenated arrays or None
    """
    def load_and_concat(prefix):
        parts = []
        for split in ['train', 'val', 'test']:
            path = os.path.join(dataset_path, f"{prefix}_{split}.npy")
            if os.path.exists(path):
                parts.append(np.load(path, allow_pickle=True))
        if parts:
            return np.concatenate(parts, axis=0)
        return None

    N_val = load_and_concat("N")
    C_val = load_and_concat("C")
    y_val = load_and_concat("y")

    return N_val, C_val, y_val

In [None]:

model_root = "./results_model"
dataset_root = "./data"
output_dir = "./merged_predictions"

selected_columns = ["true_label", "predicted_label", "logit_0", "logit_1", "row_index"]
dataset_names = [
    "INNHotelsGroup", "dabetes_130-us_hospitals", "Cardiovascular-Disease-dataset",
    "FOREX_audsgd-hour-High", "taiwanese_bankruptcy_prediction", "philippine",
    "naticusdroid+android+permissions+dataset", "default_of_credit_card_clients"
]

csv_files = glob.glob(f"{model_root}/**/*.csv", recursive=True)
print(csv_files)

for result_file in csv_files:
    for dataset_name in dataset_names:
        if dataset_name in result_file:
            print(dataset_name)
            dataset_path = os.path.join(dataset_root, dataset_name)

            # Extract model_name
            for part in result_file.split('/'):
                if dataset_name in part:
                    model_name = part.removeprefix(dataset_name).split('-')[1]

            print(dataset_path)
            print(model_name)

            # Load full dataset
            N_full, C_full, y_full = load_dataset_full(dataset_path)

            # Load predictions
            pred_df = pd.read_csv(result_file)
            if selected_columns is not None:
                pred_df = pred_df[selected_columns]

            # Ensure row_index is present
            if "row_index" not in pred_df.columns:
                raise ValueError(f"'row_index' not found in prediction file: {result_file}")

            row_index = pred_df["row_index"].values.astype(int)

            # Subset N, C, y using row_index
            df_parts = []

            if C_full is not None:
                C_selected = C_full[row_index]
                df_cat = pd.DataFrame(C_selected, columns=[f"cat_{i}" for i in range(C_selected.shape[1])])
                df_parts.append(df_cat)

            if N_full is not None:
                N_selected = N_full[row_index]
                df_num = pd.DataFrame(N_selected, columns=[f"num_{i}" for i in range(N_selected.shape[1])])
                df_parts.append(df_num)

            y_selected = y_full[row_index]
            df_y = pd.DataFrame(y_selected, columns=["label"])
            df_parts.append(df_y)

            # Add predictions last
            df_parts.append(pred_df)

            # Concatenate all parts
            df_all = pd.concat(df_parts, axis=1)
            df_all.drop(columns=["label"], inplace=True)
            # Save merged file
            os.makedirs(output_dir, exist_ok=True)
            output_path = os.path.join(output_dir, f"{dataset_name}__{model_name}.csv")
            df_all.to_csv(output_path, index=False)
            print(f"Saved: {output_path}")


['./results_99/INNHotelsGroup-danets-100/Epoch50BZ1024-Norm-standard-Nan-mean-new-Cat-ordinal/predictions/best-val/predictions_seed0.csv', './results_99/taiwanese_bankruptcy_prediction-danets-100/Epoch50BZ512-Norm-standard-Nan-mean-new-Cat-ordinal/predictions/best-val/predictions_seed0.csv', './results_99/INNHotelsGroup-amformer-100/Epoch50BZ1024-Norm-standard-Nan-mean-new-Cat-indices/predictions/best-val/predictions_seed0.csv', './results_99/naticusdroid+android+permissions+dataset-modernNCA-100/Epoch50BZ512-Norm-standard-Nan-mean-new-Cat-tabr_ohe/predictions/best-val/predictions_seed0.csv', './results_99/philippine-amformer-100/Epoch50BZ128-Norm-standard-Nan-mean-new-Cat-indices/predictions/best-val/predictions_seed0.csv', './results_99/dabetes_130-us_hospitals-danets-100/Epoch50BZ1024-Norm-standard-Nan-mean-new-Cat-ordinal/predictions/best-val/predictions_seed0.csv', './results_99/FOREX_audsgd-hour-High-amformer-100/Epoch50BZ1024-Norm-standard-Nan-mean-new-Cat-indices/predictions/be