In [None]:
import os
import numpy as np
import tensorflow as tf
from tqdm.notebook import tqdm
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
%matplotlib inline

The following assumes that this notebook is either used locally or in Colab within a folder named
Class_Distribution_Shifts_in_Zero_Shot_Learning_Learning_Robust_Representations




In [None]:
path = os.getcwd()

# Check if running in Colab
try:
  from google.colab import drive
  IN_COLAB=True
  print("Running in Colab")
  # Mount Google Drive
  drive.mount('/content/drive', force_remount=True)
  # Change directory
  %cd "YOUR_PATH"
except:
  IN_COLAB=False
  print("Running locally")

In [None]:
from pairs import distinct_pairs_func, make_pairs, contrastive_loss
from algorithm import *
from synthetic_data import *

### Generate Data

In [None]:
v0_dim = 5
vminus_dim = 10
vplus_dim = 10
noise_dim = 25

v0 = 1.0
vminus = 0.1
vplus = 2.0

p_minor = 0.1
Nc = 500
r = 30

vz = 1.0
vz_noise = 10.0

In [None]:
signal_dim = v0_dim + vminus_dim + vplus_dim
total_dim = signal_dim + noise_dim

In [None]:
z_train, c_train, z_val, c_val, z_test, c_test  = generate_synthetic_data(Nc, r, v0, vminus, vplus, vz, vz_noise, v0_dim, vminus_dim, vplus_dim, noise_dim, p_minor)

In [None]:
pos_per_class_train, pos_per_class_test = 5, 5

# generate pairs
train_z1, train_z2, train_y, train_Cs = make_pairs(z_train, c_train, pos_per_class_train)
val_z1, val_z2, val_y, val_Cs = make_pairs(z_val, c_val, pos_per_class_train)
test_z1, test_z2, test_y, test_Cs = make_pairs(z_test, c_test, pos_per_class_test)

### Class sampling

In [None]:
classes_in_env = 2
classes_in_env_test = 2

n_sim_envs = int(np.log(0.5)/np.log(1 - p_minor**2))

In [None]:
n_envs = 10**5

In [None]:
unq_c_train = np.unique(c_train)
train_envs = []

for i in range(n_envs):
  e = np.random.choice(unq_c_train, classes_in_env, replace=False)
  train_envs.append(e)
train_envs = np.array(train_envs)

### Models

In [None]:
def init_representation(add_dropout=False):
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Input(shape=(total_dim,)))
    model.add(tf.keras.layers.Dense(16))
    return model

Parameters

In [None]:
lr = 0.01
n_pairs = 7 * 10**5

ERM_factor = 0.0
CLoVE_factor = 0.085
VarAUC_factor = 1.3
IRM_factor = 0.01
VarREx_factor = 3.0

Initializtion

In [None]:
init_g = init_representation()

ERM_g = init_representation()
IRM_g = init_representation()
CLoVE_g = init_representation()
VarREx_g = init_representation()
VarAUC_g = init_representation()

ERM

In [None]:
ERM_g.set_weights(init_g.get_weights())

In [None]:
ERM_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

In [None]:
ERM_g, ERM_losses, ERM_Ns, ERM_test_aucs, ERM_val_aucs = training(ERM_optimizer, ERM_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,
                                                                  test_z1, test_z2, test_y, n_pairs, pos_per_class_train, ERM_factor, n_sim_envs, penalty_type=None)

In [None]:
ERM_w = ERM_g.get_weights()
ERM_imp = np.abs(ERM_w[0]).sum(axis=1)/np.abs(ERM_w[0]).sum()

ERM_auc_train = evaluate(ERM_g, train_z1, train_z2, train_y)
ERM_auc_val = evaluate(ERM_g, val_z1, val_z2, val_y)
ERM_auc_test = evaluate(ERM_g, test_z1, test_z2, test_y)

ERM_auc_train, ERM_auc_val, ERM_auc_test
print("Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}".format(ERM_auc_train, ERM_auc_val, ERM_auc_test))

IRM

In [None]:
IRM_g.set_weights(init_g.get_weights())

In [None]:
IRM_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

In [None]:
IRM_g, IRM_losses, IRM_Ns, IRM_test_aucs, IRM_val_aucs = training(IRM_optimizer, IRM_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,
                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, IRM_factor, n_sim_envs,
                                                                                 penalty_type='IRM')

In [None]:
IRM_w = IRM_g.get_weights()
IRM_imp = np.abs(IRM_w[0]).sum(axis=1)/np.abs(IRM_w[0]).sum()

IRM_auc_train = evaluate(IRM_g, train_z1, train_z2, train_y)
IRM_auc_val = evaluate(IRM_g, val_z1, val_z2, val_y)
IRM_auc_test = evaluate(IRM_g, test_z1, test_z2, test_y)

print("Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}".format(IRM_auc_train, IRM_auc_val, IRM_auc_test))

CLoVE

In [None]:
CLoVE_g.set_weights(init_g.get_weights())

In [None]:
CLoVE_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

In [None]:
CLoVE_g, CLoVE_losses, CLoVE_Ns, CLoVE_test_aucs, CLoVE_val_aucs = training(CLoVE_optimizer, CLoVE_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,
                                                                            test_z1, test_z2, test_y, n_pairs, pos_per_class_train, CLoVE_factor, n_sim_envs,
                                                                            penalty_type='CLoVE')

In [None]:
CLoVE_w = CLoVE_g.get_weights()
CLoVE_imp = np.abs(CLoVE_w[0]).sum(axis=1)/np.abs(CLoVE_w[0]).sum()

CLoVE_auc_train = evaluate(CLoVE_g, train_z1, train_z2, train_y)
CLoVE_auc_val = evaluate(CLoVE_g, val_z1, val_z2, val_y)
CLoVE_auc_test = evaluate(CLoVE_g, test_z1, test_z2, test_y)

print("Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}".format(CLoVE_auc_train, CLoVE_auc_val, CLoVE_auc_test))

VarREx

In [None]:
VarREx_g.set_weights(init_g.get_weights())

In [None]:
VarREx_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

In [None]:
VarREx_g, VarREx_losses, VarREx_Ns, VarREx_test_aucs, VarREx_val_aucs = training(VarREx_optimizer, VarREx_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,
                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, VarREx_factor, n_sim_envs,
                                                                                 penalty_type='VarREx')

In [None]:
VarREx_w = VarREx_g.get_weights()
VarREx_imp = np.abs(VarREx_w[0]).sum(axis=1)/np.abs(VarREx_w[0]).sum()

VarREx_auc_train = evaluate(VarREx_g, train_z1, train_z2, train_y)
VarREx_auc_val = evaluate(VarREx_g, val_z1, val_z2, val_y)
VarREx_auc_test = evaluate(VarREx_g, test_z1, test_z2, test_y)

print("Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}".format(VarREx_auc_train, VarREx_auc_val, VarREx_auc_test))

VarAUC

In [None]:
VarAUC_g.set_weights(init_g.get_weights())

In [None]:
VarAUC_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

In [None]:
VarAUC_g, VarAUC_losses, VarAUC_Ns, VarAUC_test_aucs, VarAUC_val_aucs = training(VarAUC_optimizer, VarAUC_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,
                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, VarAUC_factor, n_sim_envs,
                                                                                 penalty_type='VarAUC')

In [None]:
VarAUC_w = VarAUC_g.get_weights()
VarAUC_imp = np.abs(VarAUC_w[0]).sum(axis=1)/np.abs(VarAUC_w[0]).sum()

VarAUC_auc_train = evaluate(VarAUC_g, train_z1, train_z2, train_y)
VarAUC_auc_val = evaluate(VarAUC_g, val_z1, val_z2, val_y)
VarAUC_auc_test = evaluate(VarAUC_g, test_z1, test_z2, test_y)

print("Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}".format(VarAUC_auc_train, VarAUC_auc_val, VarAUC_auc_test))

### Comparing results

In [None]:
plt.plot(ERM_imp, '^', markersize=4, color='C3', label ='ERM', alpha=0.65)
plt.plot(IRM_imp, '*', color='C1', label='IRM', alpha=0.65)
plt.plot(CLoVE_imp, 'X', color='C0', label='CLOvE', alpha=0.65, markersize=5)
plt.plot(VarREx_imp, 'p', color='C4', label='VarREx', alpha=0.65, markersize=5)
plt.plot(VarAUC_imp, '.', markersize=8, color='C2', label='VarAUC', alpha=0.65)

plt.axvline(v0_dim, color='k', linewidth=0.7)
plt.axvline(v0_dim + vplus_dim, color='k', linewidth=0.7)
plt.axvline(v0_dim + vplus_dim + vminus_dim, color='k', linewidth=0.7)

plt.xticks([5, 15, 25])
plt.xlabel('Dimension')
plt.ylabel('Importance')
plt.legend();

In [None]:
plt.plot(ERM_Ns, ERM_val_aucs, '--', color='C3')
plt.plot(ERM_Ns, ERM_test_aucs,  label='ERM', color='C3')

plt.plot(IRM_Ns, IRM_val_aucs, '--', color='C1', alpha=0.8)
plt.plot(IRM_Ns, IRM_test_aucs, label='IRM ', color='C1', alpha=0.8)

plt.plot(CLoVE_Ns, CLoVE_val_aucs, '--', color='C0', alpha=0.8)
plt.plot(CLoVE_Ns, CLoVE_test_aucs, label='CLOvE', color='C0', alpha=0.8)

plt.plot(VarREx_Ns, VarREx_val_aucs, '--', color='C4', alpha=0.8)
plt.plot(VarREx_Ns, VarREx_test_aucs, label='VarREx', color='C4', alpha=0.8)

plt.plot(VarAUC_Ns, VarAUC_val_aucs, '--',  color='C2')
plt.plot(VarAUC_Ns, VarAUC_test_aucs, label='VarAUC', color='C2')


plt.xlabel('Data Points (pairs)')
plt.ylabel('AUC')
plt.legend(loc=4);