In [None]:
import matplotlib.pyplot as plt
import sys
import numpy as np
import pandas as pd
sys.path.append('./src')
import warnings
warnings.filterwarnings("ignore")

from data import PPCI

In [None]:
# load dataset
encoder = "dino"
dataset = PPCI(encoder = encoder,
               token = "class",
               task = "or",
               split_criteria = "position_easy",
               environment = "supervised",
               batch_size = 64,
               num_proc = 4,
               verbose = True,
               data_dir = 'data/istant_hq',
               results_dir = f'results/istant_hq/{encoder}')

In [None]:
# example train
dataset.train(add_pred_env="supervised", 
            hidden_layers = 1,
            hidden_nodes = 256,
            batch_size = 128,
            lr = 0.0005,
            num_epochs=15,
            save = False,
            verbose=True,
            multidomain=True,
            ic_weight=0,
            seed=1)
dataset.evaluate(color=None, verbose=False)

In [None]:
# experiment Universal CRL
exp = "pos14"
encoder = "dino"

ic_weights = [0] + list(np.logspace(-1, 4, num=16))  #[0, 0.1, 1, 10, 100, 1000, 10000]
seeds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
dataset = PPCI(encoder = encoder,
               token = "class",
               task = "or",
               split_criteria = "position",
               environment = "supervised",
               batch_size = 64,
               num_proc = 4,
               verbose = False,
               data_dir = 'data/istant_hq',
               results_dir = f'results/istant_hq/{encoder}')
all_metrics = pd.DataFrame(columns=['ic_weight', 'seed', "inv_loss_val", "loss_val", "acc_val", "bacc_val", "TEB_val", "acc", "bacc", "TEB", "TEB_bin", "EAD", "best_epoch"])
i = 0
for ic_weight in ic_weights:
    print(f"IC weight: {ic_weight}")
    for seed in seeds:
        print(f"Seed: {seed}")
        dataset.train(add_pred_env="supervised", 
                    hidden_layers = 1,
                    hidden_nodes = 256,
                    batch_size = 128,
                    lr = 0.0005,
                    num_epochs=15,
                    verbose=False,
                    multidomain=True,
                    ic_weight=ic_weight,
                    seed=seed)
        all_metrics_i = dataset.evaluate(color=None, verbose=False)
        all_metrics_i['ic_weight'] = ic_weight
        all_metrics_i['seed'] = seed
        all_metrics_i['best_epoch'] = dataset.model.best_epoch 
        all_metrics.loc[i] = all_metrics_i
        i += 1

all_metrics['TERB'] = abs(all_metrics['TEB'])/all_metrics['EAD']*100
all_metrics.to_csv(f'results/istant_hq/{encoder}/invariance_{exp}.csv', index=False)
all_metrics

In [None]:
# load the results
exp = "pos14"
encoder = "dino"

all_metrics = pd.read_csv(f"results/istant_hq/{encoder}/invariance_{exp}.csv")
all_metrics["univ_loss_val"] = all_metrics["loss_val"] + all_metrics["inv_loss_val"]
all_metrics["TEAB_val"] = abs(all_metrics["TEB_val"])

# plot the TERB vs ic_weight averaging over seeds
plt.figure()
plt.xlabel(r"$\lambda_{INV}$")
plt.ylabel("TERB (%)", color='tab:blue')
plt.xscale('log')
all_metrics['ic_weight'] = all_metrics['ic_weight'].replace(0, 0.01)
plt.errorbar(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['TERB'].mean(), yerr=all_metrics.groupby('ic_weight')['TERB'].std(), fmt='o', color='tab:blue')
plt.plot(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['TERB'].mean(), '--', color='tab:blue')
plt.fill_between(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['TERB'].mean()-all_metrics.groupby('ic_weight')['TERB'].std(), all_metrics.groupby('ic_weight')['TERB'].mean()+all_metrics.groupby('ic_weight')['TERB'].std(), alpha=0.2, color='skyblue')

idx = all_metrics.groupby('ic_weight')["loss_val"].mean().idxmin() == all_metrics["ic_weight"]
idx = all_metrics[idx]["loss_val"].idxmin()
plt.plot(all_metrics.loc[idx]['ic_weight'], all_metrics.loc[idx]['TERB'], 'y*', markersize=12, alpha=1, color="orange", label="Min ERM", zorder=10, clip_on=False, markeredgecolor='tab:blue', markeredgewidth=1)

idx = all_metrics.groupby('ic_weight')["inv_loss_val"].mean().idxmin() == all_metrics["ic_weight"]
idx = all_metrics[idx]["inv_loss_val"].idxmin()
plt.plot(all_metrics.loc[idx]['ic_weight'], all_metrics.loc[idx]['TERB'], 'y*', markersize=12, alpha=1, color="purple", label="Min Invariance", zorder=10, clip_on=False, markeredgecolor='tab:blue', markeredgewidth=1)

idx = all_metrics.groupby('ic_weight')["TEAB_val"].mean().idxmin() == all_metrics["ic_weight"]
idx = all_metrics[idx]["TEAB_val"].idxmin()
plt.plot(all_metrics.loc[idx]['ic_weight'], all_metrics.loc[idx]['TERB'], 'y*', markersize=12, alpha=1, color="green", label="Min TERB", zorder=10, clip_on=False, markeredgecolor='tab:blue', markeredgewidth=1)

plt.ylim(0, 140)
plt.legend(loc='upper left', framealpha=1, title=f"Model Selection Criteria",# ($on$ $validation$)", 
           title_fontsize="8.5", fontsize="8", alignment="left") 
plt.xticks([0.01, 0.1, 1, 10, 100, 1000, 10000], [f"0\n(ERM)", 0.1, 1, 10, 100, 1000, 10000]);

plt.twinx()
plt.ylabel("Balanced Accuracy", color='tab:red')
plt.errorbar(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['bacc'].mean(), yerr=all_metrics.groupby('ic_weight')['bacc'].std(), fmt='o', color='tab:red', label="Accuracy");
plt.plot(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['bacc'].mean(), '--', color='tab:red', label="Balanced Accuracy");
plt.fill_between(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['bacc'].mean()-all_metrics.groupby('ic_weight')['bacc'].std(), all_metrics.groupby('ic_weight')['bacc'].mean()+all_metrics.groupby('ic_weight')['bacc'].std(), alpha=0.2, color='pink')
plt.ylim(0.45, 1)
plt.axvline(x=0.031622776601683794, color='black', linestyle='--', label=r"separation", alpha=0.2)

# save the plot
plt.savefig(f"results/istant_hq/{encoder}/invariance_{exp}.pdf", bbox_inches='tight')

In [None]:
# convert train_metrics to numpy array
train_metrics = np.squeeze(np.array(train_metrics)) # ic x epochs x metrics
val_metrics = np.squeeze(np.array(val_metrics)) # ic x epochs x metrics

# 6 plots: accuracy, precision, recall vs epochs varying ic_weight (for train and val)
metrics = ['accuracy', 'precision', 'recall']
colors = ['red', 'green', 'blue', 'purple', 'orange']
plt.figure(figsize=(15, 5))
for i, metric in enumerate(metrics):
    plt.subplot(1, 3, i+1)
    for j, ic_weight in enumerate(ic_weights):
        plt.plot(train_metrics[j, :, i], color=colors[j], label=f'ic_weight={ic_weight}')
    plt.xlabel('epochs')
    plt.ylabel(metric)
    plt.legend()
plt.tight_layout()
plt.suptitle('Training', fontsize=16, y=1.05)
plt.show()

plt.figure(figsize=(15, 5))
for i, metric in enumerate(metrics):
    plt.subplot(1, 3, i+1)
    for j, ic_weight in enumerate(ic_weights):
        plt.plot(val_metrics[j, :, i], color=colors[j], label=f'ic_weight={ic_weight}')
    plt.xlabel('epochs')
    plt.ylabel(metric)
    plt.legend()
plt.tight_layout()
plt.suptitle('Validation', fontsize=16, y=1.05)
plt.show()

In [None]:
df = pd.DataFrame(all_metrics, index=ic_weights)
df

In [None]:

from torch.utils.data import DataLoader

n = dataset.supervised["X"].shape[0]
pairs = []
labels = []
for i in range(n):  
    if env_id[i] == env_id[i+1]:
        pairs.append((dataset.supervised["X"][i], dataset.supervised["X"][i+1]))
        labels.append(1)
    env_avg_len = n // len(np.unique(env_id))
    k = round(i+env_avg_len*1.5)%n
    pairs.append((dataset.supervised["X"][i], dataset.supervised["X"][k]))
    labels.append(1)
DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)

In [None]:
len(dataset.supervised["T"][dataset.supervised["split"]]==2)

In [None]:
exp = 0
pos = 6
frame = 720

frame_id = ((dataset.supervised["source_data"]["experiment"] == exp) & (dataset.supervised["source_data"]["position"] == pos) & (dataset.supervised["source_data"]["frame"] == frame)).nonzero(as_tuple=True)[0][0].item()
img = dataset.supervised["source_data"][frame_id]["image"] # shape 3, 770, 770
print(dataset.supervised["W"][frame])
print(dataset.supervised["Y"][frame])
# remove ticks
plt.axis('off')
plt.imshow(img.permute(1, 2, 0));

In [None]:
exp = 0
pos = 6
frame = 450

frame_id = ((dataset.supervised["source_data"]["experiment"] == exp) & (dataset.supervised["source_data"]["position"] == pos) & (dataset.supervised["source_data"]["frame"] == frame)).nonzero(as_tuple=True)[0][0].item()
img1 = dataset.supervised["source_data"][frame_id]["image"] # shape 3, 770, 770
print(dataset.supervised["Y"][frame_id])

exp = 0
pos = 6
frame = 720

frame_id = ((dataset.supervised["source_data"]["experiment"] == exp) & (dataset.supervised["source_data"]["position"] == pos) & (dataset.supervised["source_data"]["frame"] == frame)).nonzero(as_tuple=True)[0][0].item()
img2 = dataset.supervised["source_data"][frame_id]["image"] # shape 3, 770, 770
print(dataset.supervised["Y"][frame_id])

# plot the 2 images
fig, ax = plt.subplots(1, 2)
ax[0].imshow(img1.permute(1, 2, 0))
ax[1].imshow(img2.permute(1, 2, 0))
# remove ticks
for a in ax:
    a.set_xticks([])
    a.set_yticks([])
# attach the plot closer
plt.show()

In [None]:
dataset.evaluate(color="blue", verbose=False)

In [None]:
from causal import compute_ate
compute_ate(dataset.supervised["Y_hat"], 
            dataset.supervised["T"], 
            dataset.supervised["W"], 
            method="ead", 
            color="blue")

In [None]:
# dataset = PPCI()
# dataset.plot_out_distribution()
# dataset.train()
# dataset.visualize()
# dataset.evaluate()

## Post-Processing

In [None]:
exp = (dataset.supervised["source_data"]["experiment"]==4)
pos = (dataset.supervised["source_data"]["position"]==1)
filter = (exp & pos).nonzero().squeeze()
y = dataset.supervised["Y"][filter][:,0].detach()
y_hat = dataset.supervised["Y_hat"][filter][:,0].detach()
y_pred = y_hat.round()

plt.scatter(range(len(filter)), y_hat, s=1, c="blue", alpha=0.5, label="y_probs")
plt.scatter(range(len(filter)),y_pred-(-1)**y_pred.detach()*0.04, s=1, c="red", alpha=0.5, label="y_pred")
plt.scatter(range(len(filter)), y-(-1)**y.detach()*0.02, s=1, c="green", alpha=0.5, label="y")
plt.legend()
plt.show()

In [None]:
frame = (dataset.supervised["source_data"]["frame"]==2220)
idx = (exp & pos & frame).nonzero().item()
img = dataset.supervised["source_data"][idx]["image"]
outcome = dataset.supervised["source_data"][idx]["outcome"]

img = img.permute(1, 2, 0)
plt.title(f"Y2F: {int(outcome[0])}, B2F: {int(outcome[1])}")
plt.imshow(img);