## Unifying Causal Representation Learning with the Invariance Principle [[arXiv](https://www.arxiv.org/abs/2409.02772)]

### Experiments

In [None]:
import matplotlib.pyplot as plt
import sys
import numpy as np
import pandas as pd
sys.path.append('../src')
import warnings
warnings.filterwarnings("ignore")

from data import PPCI

In [None]:
# experiment Universal CRL
exp = "pos14"
encoder = "dino"

ic_weights = [0] + list(np.logspace(-1, 4, num=16))  #[0, 0.1, 1, 10, 100, 1000, 10000]
seeds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
dataset = PPCI(encoder = encoder,
               token = "class",
               task = "or",
               split_criteria = "position",
               environment = "supervised",
               batch_size = 64,
               num_proc = 4,
               verbose = False,
               data_dir = 'data/istant_hq',
               results_dir = f'results/istant_hq/{encoder}')
all_metrics = pd.DataFrame(columns=['ic_weight', 'seed', "inv_loss_val", "loss_val", "acc_val", "bacc_val", "TEB_val", "acc", "bacc", "TEB", "TEB_bin", "EAD", "best_epoch"])
i = 0
for ic_weight in ic_weights:
    print(f"IC weight: {ic_weight}")
    for seed in seeds:
        print(f"Seed: {seed}")
        dataset.train(add_pred_env="supervised", 
                    hidden_layers = 1,
                    hidden_nodes = 256,
                    batch_size = 128,
                    lr = 0.0005,
                    num_epochs=15,
                    verbose=False,
                    multidomain=True,
                    ic_weight=ic_weight,
                    seed=seed)
        all_metrics_i = dataset.evaluate(color=None, verbose=False)
        all_metrics_i['ic_weight'] = ic_weight
        all_metrics_i['seed'] = seed
        all_metrics_i['best_epoch'] = dataset.model.best_epoch 
        all_metrics.loc[i] = all_metrics_i
        i += 1

all_metrics['TERB'] = abs(all_metrics['TEB'])/all_metrics['EAD']*100
all_metrics.to_csv(f'results/istant_hq/{encoder}/invariance_{exp}.csv', index=False)
all_metrics

In [None]:
# load the results
exp = "pos14"
encoder = "dino"

all_metrics = pd.read_csv(f"results/istant_hq/{encoder}/invariance_{exp}.csv")
all_metrics["univ_loss_val"] = all_metrics["loss_val"] + all_metrics["inv_loss_val"]
all_metrics["TEAB_val"] = abs(all_metrics["TEB_val"])

# plot the TERB vs ic_weight averaging over seeds
plt.figure()
plt.xlabel(r"$\lambda_{INV}$")
plt.ylabel("TERB (%)", color='tab:blue')
plt.xscale('log')
all_metrics['ic_weight'] = all_metrics['ic_weight'].replace(0, 0.01)
plt.errorbar(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['TERB'].mean(), yerr=all_metrics.groupby('ic_weight')['TERB'].std(), fmt='o', color='tab:blue')
plt.plot(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['TERB'].mean(), '--', color='tab:blue')
plt.fill_between(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['TERB'].mean()-all_metrics.groupby('ic_weight')['TERB'].std(), all_metrics.groupby('ic_weight')['TERB'].mean()+all_metrics.groupby('ic_weight')['TERB'].std(), alpha=0.2, color='skyblue')

idx = all_metrics.groupby('ic_weight')["loss_val"].mean().idxmin() == all_metrics["ic_weight"]
idx = all_metrics[idx]["loss_val"].idxmin()
plt.plot(all_metrics.loc[idx]['ic_weight'], all_metrics.loc[idx]['TERB'], 'y*', markersize=12, alpha=1, color="orange", label="Min ERM", zorder=10, clip_on=False, markeredgecolor='tab:blue', markeredgewidth=1)

idx = all_metrics.groupby('ic_weight')["inv_loss_val"].mean().idxmin() == all_metrics["ic_weight"]
idx = all_metrics[idx]["inv_loss_val"].idxmin()
plt.plot(all_metrics.loc[idx]['ic_weight'], all_metrics.loc[idx]['TERB'], 'y*', markersize=12, alpha=1, color="purple", label="Min Invariance", zorder=10, clip_on=False, markeredgecolor='tab:blue', markeredgewidth=1)

idx = all_metrics.groupby('ic_weight')["TEAB_val"].mean().idxmin() == all_metrics["ic_weight"]
idx = all_metrics[idx]["TEAB_val"].idxmin()
plt.plot(all_metrics.loc[idx]['ic_weight'], all_metrics.loc[idx]['TERB'], 'y*', markersize=12, alpha=1, color="green", label="Min TERB", zorder=10, clip_on=False, markeredgecolor='tab:blue', markeredgewidth=1)

plt.ylim(0, 140)
plt.legend(loc='upper left', framealpha=1, title=f"Model Selection Criteria",# ($on$ $validation$)", 
           title_fontsize="8.5", fontsize="8", alignment="left") 
plt.xticks([0.01, 0.1, 1, 10, 100, 1000, 10000], [f"0\n(ERM)", 0.1, 1, 10, 100, 1000, 10000]);

plt.twinx()
plt.ylabel("Balanced Accuracy", color='tab:red')
plt.errorbar(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['bacc'].mean(), yerr=all_metrics.groupby('ic_weight')['bacc'].std(), fmt='o', color='tab:red', label="Accuracy");
plt.plot(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['bacc'].mean(), '--', color='tab:red', label="Balanced Accuracy");
plt.fill_between(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['bacc'].mean()-all_metrics.groupby('ic_weight')['bacc'].std(), all_metrics.groupby('ic_weight')['bacc'].mean()+all_metrics.groupby('ic_weight')['bacc'].std(), alpha=0.2, color='pink')
plt.ylim(0.45, 1)
plt.axvline(x=0.031622776601683794, color='black', linestyle='--', label=r"separation", alpha=0.2)

# save the plot
plt.savefig(f"results/istant_hq/{encoder}/invariance_{exp}.pdf", bbox_inches='tight')

In [None]:
# convert train_metrics to numpy array
train_metrics = np.squeeze(np.array(train_metrics)) # ic x epochs x metrics
val_metrics = np.squeeze(np.array(val_metrics)) # ic x epochs x metrics

# 6 plots: accuracy, precision, recall vs epochs varying ic_weight (for train and val)
metrics = ['accuracy', 'precision', 'recall']
colors = ['red', 'green', 'blue', 'purple', 'orange']
plt.figure(figsize=(15, 5))
for i, metric in enumerate(metrics):
    plt.subplot(1, 3, i+1)
    for j, ic_weight in enumerate(ic_weights):
        plt.plot(train_metrics[j, :, i], color=colors[j], label=f'ic_weight={ic_weight}')
    plt.xlabel('epochs')
    plt.ylabel(metric)
    plt.legend()
plt.tight_layout()
plt.suptitle('Training', fontsize=16, y=1.05)
plt.show()

plt.figure(figsize=(15, 5))
for i, metric in enumerate(metrics):
    plt.subplot(1, 3, i+1)
    for j, ic_weight in enumerate(ic_weights):
        plt.plot(val_metrics[j, :, i], color=colors[j], label=f'ic_weight={ic_weight}')
    plt.xlabel('epochs')
    plt.ylabel(metric)
    plt.legend()
plt.tight_layout()
plt.suptitle('Validation', fontsize=16, y=1.05)
plt.show()