In [None]:
import os
import pandas as pd
import gdown
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Check if running in Colab
try:
  from google.colab import drive
  IN_COLAB=True
  print("Running in Colab")
  # Mount Google Drive
  drive.mount('/content/drive', force_remount=True)
  # Change directory
  %cd "YOUR_PATH"
   #!!!
except:
  IN_COLAB=False
  print("Running locally")

In [None]:
from pairs import distinct_pairs_func, make_pairs, contrastive_loss
from algorithm import *
from synthetic_data import *

### 1. Load data

In [None]:
# Get the index of the dataset files
! wget https://raw.githubusercontent.com/tensorflow/datasets/master/tensorflow_datasets/datasets/celeb_a/checksums.tsv
urls = pd.read_csv('checksums.tsv', sep='\t', names=['url', 'size', 'checksum', 'filename'])

In [None]:
# download the files manually. If one of them says "access denied",
# you can download it from the link, and upload it to colab as you see fit - directly or to Google Drive
for _, row in urls.iterrows():
    if row.filename not in os.listdir():
        gdown.download(row.url, row.filename, quiet=False)

In [None]:
gdown.download('https://drive.google.com/uc?export=download&id=1roEIMXWh8rxneYlSGGSkit2-adkb0oxC')

In [None]:
! mkdir -p ~/tensorflow_datasets/downloads/manual
! mv list_eval_partition.txt ~/tensorflow_datasets/downloads/manual
! mv img_align_celeba.zip ~/tensorflow_datasets/downloads/manual
! mv list_attr_celeba.txt ~/tensorflow_datasets/downloads/manual
! mv identity_CelebA.txt ~/tensorflow_datasets/downloads/manual
! mv list_landmarks_align_celeba.txt ~/tensorflow_datasets/downloads/manual

In [None]:
celeb_a_builder = tfds.builder('celeb_a', version='2.1.0', try_gcs=False)
celeb_a_builder.download_and_prepare()
celeb_a_data = celeb_a_builder.as_dataset()

In [None]:
ATTR_KEY = "attributes"
IMAGE_KEY = "image"
LABEL_KEY = "identity"
GROUP_KEY = "Blond_Hair"
IMAGE_SIZE = 45

In [None]:
def preprocess_input_dict(feat_dict):
  # Separate out the image and target variable from the feature dictionary.
  image = feat_dict[IMAGE_KEY]
  label = feat_dict[LABEL_KEY]
  group = feat_dict[ATTR_KEY][GROUP_KEY]

  # Resize and normalize image.
  image = tf.cast(image, tf.float32)
  image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
  image /= 255.0

  # Cast label and group to float32.
  label = label
  group = tf.cast(group, tf.float32)

  feat_dict[IMAGE_KEY] = image
  feat_dict[LABEL_KEY] = label
  feat_dict[ATTR_KEY][GROUP_KEY] = group
  return feat_dict

In [None]:
get_image_label_and_group = lambda feat_dict: (feat_dict[IMAGE_KEY], feat_dict[LABEL_KEY], feat_dict[ATTR_KEY][GROUP_KEY])

In [None]:
train_data = celeb_a_builder.as_dataset(split='train').batch(1).map(preprocess_input_dict).map(get_image_label_and_group)
train_iterator = train_data.as_numpy_iterator()

In [None]:
test_data = celeb_a_builder.as_dataset(split='test').batch(1).map(preprocess_input_dict).map(get_image_label_and_group)
test_iterator = test_data.as_numpy_iterator()

In [None]:
def get_data(iterator, IMAGE_SIZE):
  Z, C, G  = [], [], []
  for v in iterator:
    Z.append(v[0].reshape(IMAGE_SIZE, IMAGE_SIZE, 3))
    C.append(v[1]['Identity_No'][0])
    G.append(v[2][0])
  return np.array(Z), np.array(C), np.array(G)

In [None]:
Z_train, C_train, G_train = get_data(train_iterator, IMAGE_SIZE)

In [None]:
Z_test, C_test, G_test = get_data(test_iterator, IMAGE_SIZE)

In [None]:
# remove classes with too little images from train

min_images = 3

def filter_rare(Z, C, G, min_images):
  unq_C = np.unique(C)
  img_cnts = np.array([np.sum(C==c) for c in unq_C])
  classes_to_keep = unq_C[img_cnts >= min_images]
  keep_idx = np.isin(C, classes_to_keep)
  return Z[keep_idx], C[keep_idx], G[keep_idx]

Z_train, C_train, G_train = filter_rare(Z_train, C_train, G_train, min_images)
Z_test, C_test, G_test = filter_rare(Z_test, C_test, G_test, min_images)

In [None]:
def divide_by_grp(C, G, min_images, minor_labl):
  unq_C = np.unique(C)
  w_grp, wo_grp = [], []

  for c in unq_C:
    c_idx = np.where(C==c)[0]
    if np.sum(G[c_idx]==minor_labl) >= min_images:
      w_grp.append(c)
    elif np.sum(G[c_idx]==1-minor_labl) >= min_images:
      wo_grp.append(c)
  return np.array(w_grp), np.array(wo_grp)

In [None]:
minor_labl = 1
people_w_grp_images_train, people_wo_grp_images_train = divide_by_grp(C_train, G_train, min_images, minor_labl)
people_w_grp_images_test, people_wo_grp_images_test = divide_by_grp(C_test, G_test, 2, minor_labl)

In [None]:
def select_instances(Z, C, G, unq_c_w_grp, unq_c_wo_grp, minor_labl):

  # for people with grp keep only those images, for the rest keep only images not in grp
  new_Z, new_C, new_G = [], [], []

  for c in unq_c_w_grp:
    c_idx = np.where(C==c)[0]
    for j in c_idx:
      if G[j]==minor_labl:
        new_Z.append(Z[j])
        new_C.append(c)
        new_G.append(minor_labl)

  for c in unq_c_wo_grp:
    c_idx = np.where(C==c)[0]
    for j in c_idx:
      if G[j]==1-minor_labl:
        new_Z.append(Z[j])
        new_C.append(c)
        new_G.append(1-minor_labl)

  return(np.array(new_Z), np.array(new_C), np.array(new_G))

In [None]:
def select_classes(Z, C, G, classes_in_major_grp, classes_in_minor_grp, p_minor_grp, minor_labl, max_Nc=500):

  # keep all classes in major grp
  N_major = len(classes_in_major_grp)
  N_minor = len(classes_in_minor_grp)

  N_major = min(N_major, int((1-p_minor_grp)*N_minor/p_minor_grp))
  N_minor = min(N_minor, int(p_minor_grp*N_major/(1-p_minor_grp)))

  N_major = min(N_major, int((1-p_minor_grp)*max_Nc))
  N_minor = min(N_minor, int(p_minor_grp*max_Nc))

  Nc = N_major + N_minor

  # select classes for minor grp for
  unq_c_minor_grp = np.random.choice(classes_in_minor_grp, N_minor, replace=False)
  # keep all those in major grp
  unq_c_major_grp = np.random.choice(classes_in_major_grp, N_major, replace=False)

  # for people for minor grp keep only those images, for the rest keep only images not in grp
  new_Z, new_C, new_G = select_instances(Z, C, G, unq_c_minor_grp, unq_c_major_grp, minor_labl)

  return new_Z, new_C, new_G

In [None]:
p_minor = 0.05

In [None]:
# in train mostly non-blonde people
z_train, c_train, a_train = select_classes(Z_train, C_train, G_train, people_wo_grp_images_train, people_w_grp_images_train, p_minor, minor_labl)

In [None]:
# in test mostly blonde people
z_test, c_test, a_test = select_classes(Z_test, C_test, G_test, people_w_grp_images_test, people_wo_grp_images_test, p_minor, 1-minor_labl)

In [None]:
# split train into train and validation

unq_c_train = np.unique(c_train)

unique_c_val = np.random.choice(unq_c_train, int(0.1*len(unq_c_train)), replace=False)
unq_c_train = np.array([c for c in unq_c_train if c not in unique_c_val])

val_bool = np.isin(c_train, unique_c_val)
z_val, c_val, a_val = z_train[val_bool], c_train[val_bool], a_train[val_bool]

train_bool = np.isin(c_train, unq_c_train)
z_train, c_train, a_train = z_train[train_bool], c_train[train_bool], a_train[train_bool]

In [None]:
pos_per_class_train, pos_per_class_test = 10, 10

# generate pairs
train_z1, train_z2, train_y, train_Cs = make_pairs(z_train, c_train, pos_per_class=pos_per_class_train)
val_z1, val_z2, val_y, val_Cs = make_pairs(z_val, c_val, pos_per_class=pos_per_class_train)
test_z1, test_z2, test_y, test_Cs = make_pairs(z_test, c_test, pos_per_class_test)

### Class sampling

In [None]:
classes_in_env = 2
classes_in_env_test = 2

n_sim_envs = 150

In [None]:
n_envs =  10**6

In [None]:
train_envs = []
for i in range(n_envs):
  e = np.random.choice(unq_c_train, classes_in_env, replace=False)
  train_envs.append(e)

train_envs = np.array(train_envs)

### Models

In [None]:
s1, s2, s3 = IMAGE_SIZE, IMAGE_SIZE, 3

In [None]:
def add_conv_block(model):
  model.add(tf.keras.layers.Conv2D(filters=16, kernel_size=3, strides=1, padding="same"))
  model.add(tf.keras.layers.BatchNormalization())
  model.add(tf.keras.layers.ReLU())
  model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))

In [None]:
def init_representation():
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Input(shape=(s1, s2, s3)))
    for i in range(2):
      add_conv_block(model)
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(32))
    return model

Parameters

In [None]:
lr = 1e-3
n_pairs = 8 * 10**6

ERM_factor = 0.0
CLoVE_factor = 0.085
VarAUC_factor = 0.2
VarREx_factor = 0.1

IRM_factor = 0.01
l2_regularizer_weight = tf.constant(0.01)

Initialization

In [None]:
init_g = init_representation()

ERM_g = init_representation()
IRM_g = init_representation()
CLoVE_g = init_representation()
VarREx_g = init_representation()
VarAUC_g = init_representation()

#### ERM

In [None]:
ERM_g.set_weights(init_g.get_weights())

In [None]:
ERM_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

In [None]:
ERM_g, ERM_losses, ERM_Ns, ERM_test_aucs, ERM_val_aucs = training(ERM_optimizer, ERM_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,
                                                                  test_z1, test_z2, test_y, n_pairs, pos_per_class_train, ERM_factor, n_sim_envs, penalty_type=None)

In [None]:
ERM_auc_train = evaluate(ERM_g, train_z1, train_z2, train_y)
ERM_auc_val = evaluate(ERM_g, val_z1, val_z2, val_y)
ERM_auc_test = evaluate(ERM_g, test_z1, test_z2, test_y)

ERM_auc_train, ERM_auc_val, ERM_auc_test
print("Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}".format(ERM_auc_train, ERM_auc_val, ERM_auc_test))

IRM

In [None]:
IRM_g.set_weights(init_g.get_weights())

In [None]:
IRM_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

In [None]:
IRM_g, IRM_losses, IRM_Ns, IRM_test_aucs, IRM_val_aucs = training(IRM_optimizer, IRM_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,
                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, IRM_factor, n_sim_envs,
                                                                                 penalty_type='IRM')

In [None]:
IRM_auc_train = evaluate(IRM_g, train_z1, train_z2, train_y)
IRM_auc_val = evaluate(IRM_g, val_z1, val_z2, val_y)
IRM_auc_test = evaluate(IRM_g, test_z1, test_z2, test_y)

print("Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}".format(IRM_auc_train, IRM_auc_val, IRM_auc_test))

CLoVE

In [None]:
CLoVE_g.set_weights(init_g.get_weights())

In [None]:
CLoVE_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

In [None]:
CLoVE_g, CLoVE_losses, CLoVE_Ns, CLoVE_test_aucs, CLoVE_val_aucs = training(CLoVE_optimizer, CLoVE_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,
                                                                            test_z1, test_z2, test_y, n_pairs, pos_per_class_train, CLoVE_factor, n_sim_envs,
                                                                            penalty_type='CLoVE')

In [None]:
CLoVE_auc_train = evaluate(CLoVE_g, train_z1, train_z2, train_y)
CLoVE_auc_val = evaluate(CLoVE_g, val_z1, val_z2, val_y)
CLoVE_auc_test = evaluate(CLoVE_g, test_z1, test_z2, test_y)

print("Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}".format(CLoVE_auc_train, CLoVE_auc_val, CLoVE_auc_test))

VarREx

In [None]:
VarREx_g.set_weights(init_g.get_weights())

In [None]:
VarREx_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

In [None]:
VarREx_g, VarREx_losses, VarREx_Ns, VarREx_test_aucs, VarREx_val_aucs = training(VarREx_optimizer, VarREx_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,
                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, VarREx_factor, n_sim_envs,
                                                                                 penalty_type='VarREx')

In [None]:
VarREx_auc_train = evaluate(VarREx_g, train_z1, train_z2, train_y)
VarREx_auc_val = evaluate(VarREx_g, val_z1, val_z2, val_y)
VarREx_auc_test = evaluate(VarREx_g, test_z1, test_z2, test_y)

print("Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}".format(VarREx_auc_train, VarREx_auc_val, VarREx_auc_test))

VarAUC

In [None]:
VarAUC_g.set_weights(init_g.get_weights())

In [None]:
VarAUC_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

In [None]:
VarAUC_g, VarAUC_losses, VarAUC_Ns, VarAUC_test_aucs, VarAUC_val_aucs = training(VarAUC_optimizer, VarAUC_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,
                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, VarAUC_factor, n_sim_envs,
                                                                                 penalty_type='VarAUC')

In [None]:
VarAUC_auc_train = evaluate(VarAUC_g, train_z1, train_z2, train_y)
VarAUC_auc_val = evaluate(VarAUC_g, val_z1, val_z2, val_y)
VarAUC_auc_test = evaluate(VarAUC_g, test_z1, test_z2, test_y)

print("Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}".format(VarAUC_auc_train, VarAUC_auc_val, VarAUC_auc_test))

### Comparing results

In [None]:
plt.plot(ERM_Ns, ERM_val_aucs, '--', color='C3')
plt.plot(ERM_Ns, ERM_test_aucs,  label='ERM', color='C3')

plt.plot(IRM_Ns, IRM_val_aucs, '--', color='C1', alpha=0.8)
plt.plot(IRM_Ns, IRM_test_aucs, label='IRM ', color='C1', alpha=0.8)

plt.plot(CLoVE_Ns, CLoVE_val_aucs, '--', color='C0', alpha=0.8)
plt.plot(CLoVE_Ns, CLoVE_test_aucs, label='CLOvE', color='C0', alpha=0.8)

plt.plot(VarREx_Ns, VarREx_val_aucs, '--', color='C4', alpha=0.8)
plt.plot(VarREx_Ns, VarREx_test_aucs, label='VarREx', color='C4', alpha=0.8)

plt.plot(VarAUC_Ns, VarAUC_val_aucs, '--',  color='C2')
plt.plot(VarAUC_Ns, VarAUC_test_aucs, label='VarAUC', color='C2')


plt.xlabel('Data Points (pairs)')
plt.ylabel('AUC')
plt.legend(loc=4);

In [None]:
# split trainin data by attribute
z_a1, c_a1 = z_train[a_train==1], c_train[a_train==1]
z_a0, c_a0 = z_train[a_train==0], c_train[a_train==0]

# make pairs
z1_a1, z2_a1, y_a1, c_a1 = make_pairs(z_a1, c_a1, pos_per_class_train)
z1_a0, z2_a0, y_a0, c_a0 = make_pairs(z_a0, c_a0, pos_per_class_train)

# ERM representations
z1_hat_a1_ERM, z2_hat_a1_ERM = ERM_g(z1_a1), ERM_g(z2_a1)
z1_hat_a0_ERM, z2_hat_a0_ERM = ERM_g(z1_a0), ERM_g(z2_a0)

# VarAUC representations
z1_hat_a1_VarAUC, z2_hat_a1_VarAUC = VarAUC_g(z1_a1), VarAUC_g(z2_a1)
z1_hat_a0_VarAUC, z2_hat_a0_VarAUC = VarAUC_g(z1_a0), VarAUC_g(z2_a0)

# unpenalized losses
def raw_loss(z1_hat, z2_hat, y_true, margin=0.5):
  dist = cosine_distance(z1_hat, z2_hat)
  l = contrastive_loss(y_true, dist, margin)
  return l, dist

# on a1
base_loss_a1_ERM, dist_a1_ERM = raw_loss(z1_hat_a1_ERM, z2_hat_a1_ERM, y_a1)
base_loss_a1_VarAUC, dist_a1_VarAUC = raw_loss(z1_hat_a1_VarAUC, z2_hat_a1_VarAUC, y_a1)
dist_a1_ERM, dist_a1_VarAUC = dist_a1_ERM.numpy(), dist_a1_VarAUC.numpy()

# on a0
base_loss_a0_ERM, dist_a0_ERM = raw_loss(z1_hat_a0_ERM, z2_hat_a0_ERM, y_a0)
base_loss_a0_VarAUC, dist_a0_VarAUC = raw_loss(z1_hat_a0_VarAUC, z2_hat_a0_VarAUC, y_a0)
dist_a0_ERM, dist_a0_VarAUC = dist_a0_ERM.numpy(), dist_a0_VarAUC.numpy()

In [None]:
def get_bins(diff, n_bins=19):
  w = np.ptp(diff)/n_bins
  u = np.ceil((max(diff)-w/2)/w)*w + w/2
  l = np.ceil((abs(min(diff))-w/2)/w)*w + w/2
  return np.linspace(-l, u, n_bins+2)

In [None]:
fig, axs = plt.subplots(2,2, figsize=(8,8))

t1 = len(base_loss_a1_ERM[y_a1==0])
t0 = len(base_loss_a0_ERM[y_a0==0])

n_bins=19

diff_00 = base_loss_a1_ERM[y_a1==0] - base_loss_a1_VarAUC[y_a1==0]
bins_00 = get_bins(diff_00, n_bins)
axs[0,0].hist(diff_00, bins=bins_00, alpha=0.5, weights=(np.ones(t1)/t1), color='C5', ec='C5')
axs[0,0].axvline(0, c='k', linestyle=':')
axs[0,0].set_title('Nymphalidae (minority in training)', fontsize=11)
axs[0,0].set_ylabel(r'$y=0$', fontsize=11)
xl = max(abs(bins_00))*1.1
axs[0,0].set_xlim(-xl, xl)

diff_01 = base_loss_a0_ERM[y_a0==0] - base_loss_a0_VarAUC[y_a0==0]
bins_01 = get_bins(diff_01, n_bins)
axs[0,1].hist(diff_01, bins=bins_01, alpha=0.5, weights=(np.ones(t0)/t0), color='C5', ec='C5')
axs[0,1].axvline(0, c='k', linestyle=':')
axs[0,1].set_title('Lycaenidae (majority in training)', fontsize=11)
xl = max(abs(bins_01))*1.1
axs[0,1].set_xlim(-xl, xl)

t1 = len(base_loss_a1_ERM[y_a1==1])
t0 = len(base_loss_a0_ERM[y_a0==1])

diff_10 = base_loss_a1_ERM[y_a1==1] - base_loss_a1_VarAUC[y_a1==1]
bins_10 = get_bins(diff_10, n_bins)
axs[1,0].hist(diff_10, bins=bins_10, alpha=0.5, weights=(np.ones(t1)/t1), color='C7', ec='C7')
axs[1,0].axvline(0, c='k', linestyle=':')
axs[1,0].set_ylabel(r'$y=1$', fontsize=11)
xl = max(abs(bins_10))*1.1
axs[1,0].set_xlim(-xl, xl)

diff_11 = base_loss_a0_ERM[y_a0==1] - base_loss_a0_VarAUC[y_a0==1]
bins_11 = get_bins(diff_11, n_bins)
axs[1,1].hist(diff_11, bins=bins_11, alpha=0.5, weights=(np.ones(t0)/t0), color='C7', ec='C7')
axs[1,1].axvline(0, c='k', linestyle=':')
xl = max(abs(bins_11))*1.1
axs[1,1].set_xlim(-xl, xl);