## Unifying Causal Representation Learning with the Invariance Principle [[arXiv](https://www.arxiv.org/abs/2409.02772)]

### Experiments

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import sys
sys.path.append('../src')
import warnings
warnings.filterwarnings("ignore")

from data import PPCI

import os
os.chdir('../')

In [None]:
# experiment Universal CRL
encoder = "dino"
split_criteria = "position"

ic_weights = [0] + list(np.logspace(-1, 16, num=16))  #[0, 0.1, 1, 10, 100, 1000, 10000]
seeds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
dataset = PPCI(encoder = encoder,
               token = "class",
               task = "or",
               split_criteria = split_criteria,
               environment = "supervised",
               batch_size = 64,
               num_proc = 4,
               verbose = True,
               data_dir = 'data/istant_hq',
               results_dir = f'results/istant_hq/{encoder}')
all_metrics = pd.DataFrame(columns=['ic_weight', 'seed', "inv_loss_val", "loss_val", "acc_val", "bacc_val", "TEB_val", "acc", "bacc", "TEB", "TEB_bin", "EAD", "best_epoch"])
i = 0
num_epochs = 15
train_metrics = np.zeros((len(ic_weights), len(seeds), num_epochs, 4))
val_metrics = np.zeros((len(ic_weights), len(seeds), num_epochs, 4))
for j, ic_weight in enumerate(ic_weights):
    print(f"IC weight: {ic_weight}")
    for k, seed in enumerate(seeds):
        print(f"Seed: {seed}")
        dataset.train(add_pred_env="supervised", 
                    hidden_layers = 1,
                    hidden_nodes = 256,
                    batch_size = 128,
                    lr = 0.0005,
                    num_epochs=num_epochs,
                    verbose=False,
                    multidomain=True,
                    ic_weight=ic_weight,
                    seed=seed)
        train_metrics[j,k] = np.array(dataset.model.train_metrics).squeeze()
        val_metrics[j,k] = np.array(dataset.model.val_metrics).squeeze()
        all_metrics_i = dataset.evaluate(color=None, verbose=False)
        all_metrics_i['ic_weight'] = ic_weight
        all_metrics_i['seed'] = seed
        all_metrics_i['best_epoch'] = dataset.model.best_epoch 
        all_metrics.loc[i] = all_metrics_i
        i += 1

all_metrics['TERB'] = abs(all_metrics['TEB'])/all_metrics['EAD']*100
results_dir = f'results/istant_hq/{encoder}/{split_criteria}'
if not os.path.exists(results_dir):
    os.makedirs(results_dir)
all_metrics.to_csv(f'{results_dir}/invariance.csv', index=False)
np.save(f'{results_dir}/train_metrics.npy', train_metrics)
np.save(f'{results_dir}/val_metrics.npy', val_metrics)
all_metrics

In [None]:
# load the results
encoder = "dino"
split_criteria = "position"
results_dir = f'results/istant_hq/{encoder}/{split_criteria}'

all_metrics = pd.read_csv(f"{results_dir}/invariance.csv")
all_metrics["univ_loss_val"] = all_metrics["loss_val"] + all_metrics["inv_loss_val"]
all_metrics["TEAB_val"] = abs(all_metrics["TEB_val"])

# plot the TERB vs ic_weight averaging over seeds
plt.figure()
plt.xlabel(r"$\lambda_{INV}$")
plt.ylabel("TERB (%)", color='tab:blue')
plt.xscale('log')
all_metrics['ic_weight'] = all_metrics['ic_weight'].replace(0, 0.01)
plt.errorbar(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['TERB'].mean(), yerr=all_metrics.groupby('ic_weight')['TERB'].std(), fmt='o', color='tab:blue')
plt.plot(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['TERB'].mean(), '--', color='tab:blue')
plt.fill_between(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['TERB'].mean()-all_metrics.groupby('ic_weight')['TERB'].std(), all_metrics.groupby('ic_weight')['TERB'].mean()+all_metrics.groupby('ic_weight')['TERB'].std(), alpha=0.2, color='skyblue')

idx = all_metrics.groupby('ic_weight')["loss_val"].mean().idxmin() == all_metrics["ic_weight"]
idx = all_metrics[idx]["loss_val"].idxmin()
plt.plot(all_metrics.loc[idx]['ic_weight'], all_metrics.loc[idx]['TERB'], 'y*', markersize=12, alpha=1, color="orange", label="Min ERM", zorder=10, clip_on=False, markeredgecolor='tab:blue', markeredgewidth=1)

idx = all_metrics.groupby('ic_weight')["inv_loss_val"].mean().idxmin() == all_metrics["ic_weight"]
idx = all_metrics[idx]["inv_loss_val"].idxmin()
plt.plot(all_metrics.loc[idx]['ic_weight'], all_metrics.loc[idx]['TERB'], 'y*', markersize=12, alpha=1, color="purple", label="Min Invariance", zorder=10, clip_on=False, markeredgecolor='tab:blue', markeredgewidth=1)

idx = all_metrics.groupby('ic_weight')["TEAB_val"].mean().idxmin() == all_metrics["ic_weight"]
idx = all_metrics[idx]["TEAB_val"].idxmin()
plt.plot(all_metrics.loc[idx]['ic_weight'], all_metrics.loc[idx]['TERB'], 'y*', markersize=12, alpha=1, color="green", label="Min TERB", zorder=10, clip_on=False, markeredgecolor='tab:blue', markeredgewidth=1)

plt.ylim(0, 140)
plt.legend(loc='upper left', framealpha=1, title=f"Model Selection Criteria",# ($on$ $validation$)", 
           title_fontsize="8.5", fontsize="8", alignment="left") 
plt.xticks([0.01, 0.1, 1, 10, 100, 1000, 10000], [f"0\n(ERM)", 0.1, 1, 10, 100, 1000, 10000]);

plt.twinx()
plt.ylabel("Balanced Accuracy", color='tab:red')
plt.errorbar(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['bacc'].mean(), yerr=all_metrics.groupby('ic_weight')['bacc'].std(), fmt='o', color='tab:red', label="Accuracy");
plt.plot(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['bacc'].mean(), '--', color='tab:red', label="Balanced Accuracy");
plt.fill_between(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['bacc'].mean()-all_metrics.groupby('ic_weight')['bacc'].std(), all_metrics.groupby('ic_weight')['bacc'].mean()+all_metrics.groupby('ic_weight')['bacc'].std(), alpha=0.2, color='pink')
plt.ylim(0.45, 1)
plt.axvline(x=0.031622776601683794, color='black', linestyle='--', label=r"separation", alpha=0.2)

# save the plot
plt.savefig(f'{results_dir}/invariance.pdf', bbox_inches='tight')

### Visualization of the training convergence

In [None]:
encoder = "dino"
split_criteria = "position"
ic_weights = [0] + list(np.logspace(-1, 16, num=16))

results_dir = f'results/istant_hq/{encoder}/{split_criteria}'
train_metrics = np.load(f'{results_dir}/train_metrics.npy') # ic x seed x epochs x metrics
val_metrics = np.load(f'{results_dir}/val_metrics.npy') # ic x seed x epochs x metrics

# 6 plots: accuracy, balanced accuracy, precision, recall vs epochs varying ic_weight (for train and val)
metrics = ['accuracy', 'balanced accuracy', 'precision', 'recall']
colors = ['red', 'green', 'blue', 'purple', 'orange']
plt.figure(figsize=(15, 5))
for i, metric in enumerate(metrics):
    plt.subplot(1, 4, i+1)
    for j, ic_weight in enumerate(ic_weights):
        plt.errorbar(np.arange(num_epochs), train_metrics[j, :, :, i].mean(axis=0), yerr=train_metrics[j, :, :, i].std(axis=0), color=colors[j], label=f'{ic_weight:.2f}')
        plt.fill_between(np.arange(num_epochs), train_metrics[j, :, :, i].mean(axis=0)-train_metrics[j, :, :, i].std(axis=0), train_metrics[j, :, :, i].mean(axis=0)+train_metrics[j, :, :, i].std(axis=0), color=colors[j], alpha=0.2)
    plt.xlabel('epochs')
    plt.ylabel(metric)
plt.tight_layout()
plt.suptitle('Training', fontsize=16, y=1.05)
plt.legend(loc='upper right', bbox_to_anchor=(0, -0.1), fancybox=True, ncol=len(ic_weights), title=r"$\lambda_{INV}$", title_fontsize="12", fontsize="12")
plt.show()

plt.figure(figsize=(15, 5))
for i, metric in enumerate(metrics):
    plt.subplot(1, 4, i+1)
    for j, ic_weight in enumerate(ic_weights):
        plt.errorbar(np.arange(num_epochs), val_metrics[j, :, :, i].mean(axis=0), yerr=val_metrics[j, :, :, i].std(axis=0), color=colors[j], label=f'{ic_weight:.2f}')
        plt.fill_between(np.arange(num_epochs), val_metrics[j, :, :, i].mean(axis=0)-val_metrics[j, :, :, i].std(axis=0), val_metrics[j, :, :, i].mean(axis=0)+val_metrics[j, :, :, i].std(axis=0), color=colors[j], alpha=0.2)
    plt.xlabel('epochs')
    plt.ylabel(metric)
plt.tight_layout()
plt.suptitle('Validation', fontsize=16, y=1.05)
plt.legend(loc='upper right', bbox_to_anchor=(0, -0.1), fancybox=True, ncol=len(ic_weights), title=r"$\lambda_{INV}$", title_fontsize="12", fontsize="12")
plt.show()