In [None]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from model import FairPFNModel
from data_generator import DataGenerator
from datasets import save_generated_data, SyntheticDataset
import torch
import random
import matplotlib.pyplot as plt
import seaborn as sns



In [None]:
DATASET_FILE = Path("data", "pre_generated_data.parquet")
if not os.path.exists(DATASET_FILE):
    raise FileNotFoundError(f"File {DATASET_FILE} does not exist.")
df = pd.read_parquet(DATASET_FILE)
print(f"Loaded {len(df)} samples from {DATASET_FILE}.")
print("Number of biased features:", df.shape[1] - 3)  # Exclude label column
print("Unique biased labels: ", df.iloc[:, -2].value_counts().to_dict())
print("Unique fair labels: ", df.iloc[:, -1].value_counts().to_dict())
print(f"Number of unique biased features: {len(set(df.iloc[:, 1:-2].values.flatten()))}, number of total biased features: {len(df.iloc[:, 1:-2].values.flatten())}")
print("Biggest biased feature value:", df.iloc[:, 1:-2].values.flatten().max())
print("Smallest biased feature value:", df.iloc[:, 1:-2].values.flatten().min())
print(df.head())

In [None]:
labels_different = df[(df.iloc[:, -2].astype(int) != df.iloc[:, -1])]


# sample five random rows from the filtered DataFrame
sampled_rows = labels_different.sample(n=5)
print("Sampled rows with different biased and fair labels:")
print(sampled_rows.to_string(index=False))
print(sampled_rows.to_latex(index=False)) 
# rows = df[(df.iloc[:, -2].astype(int) != df.iloc[:, -1]) & (df.iloc[:, 1].map(df.iloc[:, 1].value_counts()) == 2)]
# print((df.iloc[:, 0].map(df.iloc[:, 0].value_counts()) > 1).sum(), "rows with unique element count in f0 > 1")
# print("Rows with unique element count in f0 > 1 and f4 not equal to y_fair:")
# print(rows.head(5).to_string(index=False))
# print(df.iloc[15:20, :].to_string(index=False))
# print(df.describe().to_latex())

In [None]:
# Generate LaTeX table
latex_table = r"""\begin{table}[h]
\centering
\begin{tabular}{ll}
\hline
\textbf{Statistic} & \textbf{Value} \\
\hline
Total samples & """ + f"{df.shape[0]}" + r""" \\
Number of biased features & """ + f"{df.shape[1] - 3}" + r""" \\
Unique biased labels & """ + f"{df.iloc[:, -2].value_counts().to_dict()}" + r""" \\
Unique fair labels & """ + f"{df.iloc[:, -1].value_counts().to_dict()}" + r""" \\
Number of unique biased features & """ + f"{len(set(df.iloc[:, 1:-2].values.flatten()))}" + r""" \\
Total biased feature values & """ + f"{len(df.iloc[:, 1:-2].values.flatten())}" + r""" \\
Max biased feature value & """ + f"{df.iloc[:, 1:-2].values.flatten().max()}" + r""" \\
Min biased feature value & """ + f"{df.iloc[:, 1:-2].values.flatten().min()}" + r""" \\
\hline
\end{tabular}
\caption{Summary statistics of synthetic dataset}
\end{table}
"""

# Output the LaTeX table
print(latex_table)

In [None]:
# plotting the distribution of biased features
import matplotlib.pyplot as plt
df.iloc[:, 1:-2].hist(bins=30, figsize=(15, 10))
plt.suptitle("Distribution of Biased Features")
plt.xlabel("Feature Value")
plt.ylabel("Frequency")
plt.tight_layout(rect=[0, 0, 1, 0.96])
# show the plot
plt.show()

In [None]:
#plotting the distribution of fair labels
df.iloc[:, -1].value_counts().plot(kind='bar', figsize=(10, 5), title='Distribution of Fair Labels')
plt.xlabel('Fair Label')
plt.ylabel('Frequency')
plt.tight_layout()
plt.show()
# plotting the distribution of biased labels
df.iloc[:, -2].value_counts().plot(kind='bar', figsize=(10, 5), title='Distribution of Biased Labels')
plt.xlabel('Biased Label')
plt.ylabel('Frequency')
plt.tight_layout()
plt.show()

In [None]:
# Parameters
n = 1000
mu, sigma = 0, 1
w_A = 2.0
w_Xb = 1.5

# Generate noise
eps_Xb = np.random.normal(mu, sigma, n)
eps_Y = np.random.normal(mu, sigma, n)

# Generate protected attribute A
A = np.random.randint(0, 2, size=n)  # A ∈ {0,1}

# Generate features and outcome (observational)
Xb = w_A * A**2 + eps_Xb
Y_cont = w_Xb * Xb**2 + eps_Y
Y_threshold = np.median(Y_cont)
Y = (Y_cont >= Y_threshold).astype(int)

# Observational data
observational_df = pd.DataFrame({'A': A, 'Xb': Xb, 'Y': Y})

# --- Counterfactual Generation ---

def generate_counterfactual(A_new, eps_Xb, eps_Y):
    Xb_cf = w_A * A_new**2 + eps_Xb
    Y_cont_cf = w_Xb * Xb_cf**2 + eps_Y
    Y_cf = (Y_cont_cf >= Y_threshold).astype(int)
    return Xb_cf, Y_cf

# Counterfactuals for do(A=0) and do(A=1)
Xb_do0, Y_do0 = generate_counterfactual(np.zeros(n), eps_Xb, eps_Y)
Xb_do1, Y_do1 = generate_counterfactual(np.ones(n), eps_Xb, eps_Y)

counterfactual_df = pd.DataFrame({
    'A_obs': A,
    'Y_obs': Y,
    'Y_do0': Y_do0,
    'Y_do1': Y_do1
})

average_treatment_effect = np.mean(Y_do1) - np.mean(Y_do0)
print(f"Average Treatment Effect (ATE): {average_treatment_effect}")
print("Percentage of individuals with Y_do1 == Y_do0:", np.mean(Y_do1 == Y_do0) * 100, "%")


In [None]:
def prepare_data_for_model(Xb, A, Y, split_ratio=0.8):
    """
    Prepares data for the model by splitting into training and test sets.
    
    Args:
        Xb (np.ndarray): Feature array.
        A (np.ndarray): Protected attribute.
        Y (np.ndarray): Labels.
        split_ratio (float): Ratio to split the data into training and test sets.
        
    Returns:
        tuple: Training and test sets as DataFrames.
    """
    Xb_tensor = torch.tensor(Xb, dtype=torch.float32).unsqueeze(1)
    A_tensor = torch.tensor(A, dtype=torch.float32).unsqueeze(1)
    Y_tensor = torch.tensor(Y, dtype=torch.float32).unsqueeze(1)
    #concat A_tensor and Xb_tensor into X_bias
    X_biased = torch.cat((A_tensor, Xb_tensor), dim=1)
    split = int(0.75* len(Xb_tensor))
    forward_kwargs = dict(
                    train_x = Xb_tensor[:split, :].unsqueeze(1),
                    train_y = Y_tensor[:split, :].unsqueeze(1),
                    test_x = Xb_tensor[split:, :].unsqueeze(1),
                    categorical_inds=None,
                )
    return forward_kwargs, Y_tensor[split:, :].squeeze(), len(torch.unique(Y_tensor))
    

In [None]:
# load the FairPFN model
model = FairPFNModel(device='cpu')
model.load_model('models/fairpfn_model_epoch_100.pth')
print("Unique count values in Y: ", np.unique(Y, return_counts=True))
# Predict using the FairPFN model
split_ratio = 0.75
forward_kwargs, test_labels, num_classes = prepare_data_for_model(
    Xb=Xb,
    A=A,
    Y=Y,
    split_ratio=split_ratio
)

pred_fair_logits = model(**forward_kwargs)
pred_fair_logits = pred_fair_logits[:, :, :num_classes]
pred_fair_logits = pred_fair_logits.reshape(-1, pred_fair_logits.shape[-1])
print("Predicted labels:", pred_fair_logits.shape)
accuracy = (pred_fair_logits.argmax(dim=1) == test_labels).float().mean().item()
print(f"FairPFN Model Accuracy: {accuracy * 100:.2f}%")

In [None]:
print("Unique values in A and the counts: ", counterfactual_df['A_obs'].value_counts().to_dict())
print("Unique values in Y_obs and the counts: ", counterfactual_df['Y_obs'].value_counts().to_dict())
print("Unique values in Y_do0 and the counts: ", counterfactual_df['Y_do0'].value_counts().to_dict())
print("Unique values in Y_do1 and the counts: ", counterfactual_df['Y_do1'].value_counts().to_dict())
print(observational_df.head())
print(counterfactual_df.head())

In [None]:
data_generator = DataGenerator(
    U=16,  # Number of exogenous variables
    H=3,   # MLP depth
    M=16,  # Number of features
    N=10000,  # Number of samples
    device="cuda" if torch.cuda.is_available() else "cpu"
)

dataset_biased, y_fair = data_generator.generate_dataset()
dataset_biased_2, y_fair_2 = data_generator.generate_dataset()

# check if the two torch tensors are equal
if torch.equal(dataset_biased, dataset_biased_2) and torch.equal(y_fair, y_fair_2):
    print("The two datasets are equal.")
    print("The two y_fair tensors are equal.")


In [None]:
do_A0 = data_generator.do_A(A = 0)
do_A1 = data_generator.do_A(A = 1)

if torch.equal(do_A0, do_A1):
    print("The two do_A tensors are equal.")

In [None]:
print("Unique count values in do_A0:", torch.unique(do_A0[:, -1], return_counts=True))
base_causal_effect = do_A0[:, -1].float().mean() - do_A1[:, -1].float().mean()
print(f"Base causal effect: {base_causal_effect}")

In [None]:
split = int(0.75 * len(do_A0))

incontext_biased_features = do_A0[:split, :-1].unsqueeze(1)
incontext_biased_labels = do_A0[:split, -1].unsqueeze(1)
val_biased_features = do_A0[split:, :-1].unsqueeze(1)
num_classes = len(torch.unique(incontext_biased_labels))
forward_kwargs = dict(
    train_x=incontext_biased_features,
    train_y=incontext_biased_labels,
    test_x=val_biased_features,
    categorical_inds=None,
)
pred_fair_logits = model(**forward_kwargs)
pred_fair_logits = pred_fair_logits[:, :, :num_classes]
pred_fair_logits = pred_fair_logits.reshape(-1, pred_fair_logits.shape[-1])
predicted_labels_A0 = pred_fair_logits.argmax(dim=1)

incontext_biased_features = do_A1[:split, :-1].unsqueeze(1)
incontext_biased_labels = do_A1[:split, -1].unsqueeze(1)
val_biased_features = do_A1[split:, :-1].unsqueeze(1)
forward_kwargs = dict(
    train_x=incontext_biased_features,
    train_y=incontext_biased_labels,
    test_x=val_biased_features,
    categorical_inds=None,
)
pred_fair_logits = model(**forward_kwargs)
pred_fair_logits = pred_fair_logits[:, :, :num_classes]
pred_fair_logits = pred_fair_logits.reshape(-1, pred_fair_logits.shape[-1])
predicted_labels_A1 = pred_fair_logits.argmax(dim=1)
average_treatment_effect = predicted_labels_A1.float().mean() - predicted_labels_A0.float().mean()
print(f"Average Treatment Effect (ATE) from model predictions: {average_treatment_effect}")

## Generate datasets with certain base_causal_effect

In [None]:
datasets_counter = 0
while datasets_counter < 5:
    data_generator = DataGenerator(
        U=16,  # Number of exogenous variables
        H=3,   # MLP depth
        M=16,  # Number of features
        N=256,  # Number of samples
        device="cuda" if torch.cuda.is_available() else "cpu"
    )
    dataset_biased, y_fair = data_generator.generate_dataset()
    do_A0 = data_generator.do_A(A=0)
    do_A1 = data_generator.do_A(A=1)
    base_causal_effect = do_A0[:, -1].float().mean() - do_A1[:, -1].float().mean()
    if abs(base_causal_effect) > 0.3:
        print(f"Generated dataset with base causal effect: {base_causal_effect}")
        # Save the dataset to a file
        save_generated_data(
            Dbias=dataset_biased,
            y_fair=y_fair,
            filename=f"data/generated_data_observational.parquet"
        )
        save_generated_data(
            Dbias=do_A0,
            y_fair=y_fair,
            filename=f"data/generated_data_do_A0.parquet"
        )
        save_generated_data(
            Dbias=do_A1,
            y_fair=y_fair,
            filename=f"data/generated_data_do_A1.parquet"
        )
        datasets_counter += 1
print(f"Generated {datasets_counter} datasets with base causal effect greater than 0.3.")


### Load datasets

In [None]:
# DATA_DIR = Path("data")
# DATASET_OBSERVATIONAL = Path(DATA_DIR, "generated_data_observational.parquet")
# if not os.path.exists(DATASET_OBSERVATIONAL):
#     raise FileNotFoundError(f"File {DATASET_OBSERVATIONAL} does not exist.")
# df_observational = pd.read_parquet(DATASET_OBSERVATIONAL)
# print(f"Loaded {len(df_observational)} samples from {DATASET_OBSERVATIONAL}.")

# DATASET_DO_A0 = Path(DATA_DIR, "generated_data_do_A0.parquet")
# if not os.path.exists(DATASET_DO_A0):
#     raise FileNotFoundError(f"File {DATASET_DO_A0} does not exist.")
# df_do_A0 = pd.read_parquet(DATASET_DO_A0)
# print(f"Loaded {len(df_do_A0)} samples from {DATASET_DO_A0}.")

# DATASET_DO_A1 = Path(DATA_DIR, "generated_data_do_A1.parquet")
# if not os.path.exists(DATASET_DO_A1):
#     raise FileNotFoundError(f"File {DATASET_DO_A1} does not exist.")
# df_do_A1 = pd.read_parquet(DATASET_DO_A1)
# print(f"Loaded {len(df_do_A1)} samples from {DATASET_DO_A1}.")

In [None]:
DATA_DIR = Path("data")
DATASET_DO_A0 = Path(DATA_DIR, "generated_data_do_A0.parquet")
DATASET_DO_A1 = Path(DATA_DIR, "generated_data_do_A1.parquet")

datasets_do_A0 = SyntheticDataset(
    filename=DATASET_DO_A0,
)
datasets_do_A1 = SyntheticDataset(
    filename=DATASET_DO_A1,
)

expectations = {'A0': [], 'A1': []}
for dataset, _ in datasets_do_A0:
    expectation = dataset[:, -1].float().mean()
    expectations['A0'].append(expectation)

for dataset, _ in datasets_do_A1:
    expectation = dataset[:, -1].float().mean()
    expectations['A1'].append(expectation)

average_treatment_effect_base = np.array(expectations['A0']) - np.array(expectations['A1'])

In [None]:
print(f"Base Average Treatment Effect (ATE): {average_treatment_effect_base}")
# save the generated ate to a file
output_file = Path(DATA_DIR, "average_treatment_effect_base.npy")
if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)
np.save(output_file, average_treatment_effect_base)

#### FairPFN

In [None]:
model = FairPFNModel()
model.load_model('models/fairpfn_model_epoch_100.pth', eval_mode=True)
print("Model loaded successfully.")

In [None]:
datasets_do_A0 = SyntheticDataset(
    filename=DATASET_DO_A0,
)
datasets_do_A1 = SyntheticDataset(
    filename=DATASET_DO_A1,
)

expectations = {'A0': [], 'A1': []}
for dataset, _ in datasets_do_A0:
    split = int(0.75 * len(dataset))
    forward_kwargs = dict(
        train_x=dataset[:split, :-1].unsqueeze(1),
        train_y=dataset[:split, -1].unsqueeze(1),
        test_x=dataset[split:, :-1].unsqueeze(1),
        categorical_inds=None,
    )
    num_classes = len(torch.unique(dataset[:split, -1].unsqueeze(1)))
    pred_fair_logits = model(**forward_kwargs)
    pred_fair_logits = pred_fair_logits[:, :, :num_classes]
    pred_fair_logits = pred_fair_logits.reshape(-1, pred_fair_logits.shape[-1])
    predicted_labels_do_A0 = pred_fair_logits.argmax(dim=1)
    expectations['A0'].append(predicted_labels_do_A0.float().mean().item())

for dataset, _ in datasets_do_A1:
    split = int(0.75 * len(dataset))
    forward_kwargs = dict(
        train_x=dataset[:split, :-1].unsqueeze(1),
        train_y=dataset[:split, -1].unsqueeze(1),
        test_x=dataset[split:, :-1].unsqueeze(1),
        categorical_inds=None,
    )
    num_classes = len(torch.unique(dataset[:split, -1].unsqueeze(1)))
    pred_fair_logits = model(**forward_kwargs)
    pred_fair_logits = pred_fair_logits[:, :, :num_classes]
    pred_fair_logits = pred_fair_logits.reshape(-1, pred_fair_logits.shape[-1])
    predicted_labels_do_A1 = pred_fair_logits.argmax(dim=1)
    expectations['A1'].append(predicted_labels_do_A1.float().mean().item())

average_treatment_effect_fairpfn = np.array(expectations['A0']) - np.array(expectations['A1'])



In [None]:
print(f"Average Treatment Effect (ATE) of FairPFN: {average_treatment_effect_fairpfn}")
print("Number of ate < 0.3:", np.sum(np.abs(average_treatment_effect_fairpfn) < 0.3))

In [None]:
#save the generated average treatment effect to a file
output_file = Path(DATA_DIR, "average_treatment_effect_fairpfn.npy")
if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)
np.save(output_file, average_treatment_effect_fairpfn)

### TabPFN

In [None]:
model = FairPFNModel()

In [None]:
datasets_do_A0 = SyntheticDataset(
    filename=DATASET_DO_A0,
)
datasets_do_A1 = SyntheticDataset(
    filename=DATASET_DO_A1,
)

expectations = {'A0': [], 'A1': []}
for dataset, _ in datasets_do_A0:
    split = int(0.75 * len(dataset))
    forward_kwargs = dict(
        train_x=dataset[:split, :-1].unsqueeze(1),
        train_y=dataset[:split, -1].unsqueeze(1),
        test_x=dataset[split:, :-1].unsqueeze(1),
        categorical_inds=None,
    )
    num_classes = len(torch.unique(dataset[:split, -1].unsqueeze(1)))
    pred_fair_logits = model(**forward_kwargs)
    pred_fair_logits = pred_fair_logits[:, :, :num_classes]
    pred_fair_logits = pred_fair_logits.reshape(-1, pred_fair_logits.shape[-1])
    predicted_labels_do_A0 = pred_fair_logits.argmax(dim=1)
    expectations['A0'].append(predicted_labels_do_A0.float().mean().item())

for dataset, _ in datasets_do_A1:
    split = int(0.75 * len(dataset))
    forward_kwargs = dict(
        train_x=dataset[:split, :-1].unsqueeze(1),
        train_y=dataset[:split, -1].unsqueeze(1),
        test_x=dataset[split:, :-1].unsqueeze(1),
        categorical_inds=None,
    )
    num_classes = len(torch.unique(dataset[:split, -1].unsqueeze(1)))
    pred_fair_logits = model(**forward_kwargs)
    pred_fair_logits = pred_fair_logits[:, :, :num_classes]
    pred_fair_logits = pred_fair_logits.reshape(-1, pred_fair_logits.shape[-1])
    predicted_labels_do_A1 = pred_fair_logits.argmax(dim=1)
    expectations['A1'].append(predicted_labels_do_A1.float().mean().item())

average_treatment_effect_tabpfn = np.array(expectations['A0']) - np.array(expectations['A1'])

In [None]:
print(f"Average Treatment Effect (ATE) of TabPFN: {average_treatment_effect_tabpfn}")
print("Number of ate < 0.3:", np.sum(np.abs(average_treatment_effect_tabpfn) < 0.3))

In [None]:
#save the generated ate to a file
output_file = Path(DATA_DIR, "average_treatment_effect_tabpfn.npy")
if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)
np.save(output_file, average_treatment_effect_tabpfn)

# Load the generated ATEs

In [None]:
DATA_DIR = Path("data")
average_treatment_effect_base = np.load(Path(DATA_DIR, "average_treatment_effect_base.npy"))
average_treatment_effect_fairpfn = np.load(Path(DATA_DIR, "average_treatment_effect_fairpfn.npy"))
average_treatment_effect_tabpfn = np.load(Path(DATA_DIR, "average_treatment_effect_tabpfn.npy"))

### Box and whiskers plots

In [None]:
#box and whiskers plots
plt.figure(figsize=(10, 6))
plt.boxplot([average_treatment_effect_base, average_treatment_effect_fairpfn, average_treatment_effect_tabpfn],
            labels=['Base ATE', 'FairPFN ATE', 'TabPFN ATE'])
plt.title('Box and Whiskers Plot of Average Treatment Effects')
plt.xlabel('Model')
plt.ylabel('Average Treatment Effect (ATE)')
plt.tight_layout()
plt.show()

# KDE

In [None]:
# kernel density estimation (KDE) plots
plt.figure(figsize=(10, 6))

sns.kdeplot(average_treatment_effect_base, label='Base ATE', fill=True, alpha=0.5)
sns.kdeplot(average_treatment_effect_fairpfn, label='FairPFN ATE', fill=True, alpha=0.5)
sns.kdeplot(average_treatment_effect_tabpfn, label='TabPFN ATE', fill=True, alpha=0.5)
plt.title('Kernel Density Estimation of Average Treatment Effects')
plt.xlabel('Average Treatment Effect (ATE)')
plt.ylabel('Density')
plt.legend()
plt.tight_layout()
plt.show()

