## Instructions
1. Run all the cells sequentially. 
2. The experiments below describe how we can adapt C-LRG in the paper https://arxiv.org/pdf/2010.15234.pdf for non-linear models. We run comparsions with F-IRM http://proceedings.mlr.press/v119/ahuja20a/ahuja20a.pdf
3. We have comments in front of each command to guide one through the details of the implementation.






## Internal libraries summary
1. data_constructor.py: in this file we define two classes <br>
    assemble_data_mnist( ): for creating colored environments for MNIST digits <br>
    assemble_data_mnist_fashion( ): for creating colored environments for MNIST fashion <br>


2. IRM_methods.py: from this file we will use two classes <br>
    a) fixed_irm_game_model for model F-IRM from http://proceedings.mlr.press/v119/ahuja20a/ahuja20a.pdf  <br>
    b) fixed_irm_game_model_cons for adaptation of C-LRG for non-linear settings from https://arxiv.org/pdf/2010.15234.pdf <br>

    
    Each class is initialized using hyperparameters for the corresponding model.
    Each class has a fit method, which takes as input the data from the different environments and trains the models. Finally, each class has an evaluation method, which takes the test data from test environment as input and outputs the accuracy on the test data and also on the train data that was used.  

### data_constructor.py explained

The datasets used in Ahuja et.al. and Arjovsky et.al. essentially modified standard datasets such as MNIST digits, fashion MNIST to create multiple environments with different degrees of spurious correlations and the labels. Here we describe the classes that allow to create these datasets.

    1. data_constructor.py consists of two classes: assemble_data_mnist() and assemble_data_mnist_fashion() 
        a) assemble_data_mnist()/assemble_data_mnist_fashion() has following functions 
            i) create_training_data(n_e, p_color_list, p_label_list):
                n_e: number of environments, p_color_list: list of probabilities of switching the final label to obtain the color index, p_label_list: list of probabilities of switching pre-label
            ii) create_testing_data(p_color_test, p_label_test, n_e): 
                n_e: number of environments, p_color_test: probability of switching the final label to obtain the color index in test environment, p_label_test: probability of switching pre-label in test environment
        b)  assemble_data_mnist()/assemble_data_mnist_fashion() following attributes:
            i) data_tuple_list: list of length n_e, each element of the list is a tuple with three elements (data, label, environment index)
            ii) data_tuple_test: tuple with three elements (data_test, label_test, test environment index)
    

### IRM_methods.py explained

    1. fixed_irm_game_model class. Implements F-IRM game from http://proceedings.mlr.press/v119/ahuja20a/ahuja20a.pdf.
    
        A) Initialization:
        fixed_irm_game_model(model_list, learning_rate, num_epochs, batch_size, termination_acc, warm_start)
           i) model_list: list of models for each environment; use keras to construct the architectures
           ii) learning_rate: learning rate for Adam optimizer for training the models for each environment
           iii) batch_size: size of the batch used for each gradient update
            iv) num_epochs: number of epochs is number of training steps = number of training samples//batch size (each epoch is one full pass of the training data)
            v) termination_acc: once the model accuracy falls below this threshold we terminate training
           vi) warm_start: minimum number of steps before we terminate due to accuracy falling below threshold

        B) Methods:
            i) fit(data_tuple_list): takes data_tuple_list and trains the models
                   data_tuple_list- list of length n_e, each element of the list is a tuple with three elements (data, label, environment index)     
            ii) evaluate(data_tuple_test): tuple with three elements (data_test, label_test, test environment index)

        C) Attributes:
            i) model_list: list of models for each environment
            ii) train_acc: training accuracy (use after running evaluate method)
            iii) test_acc: testing accuracy  (use after running evaluate method) 

    2. fixed_irm_game_model class_cons: Implements C-LRG https://arxiv.org/pdf/2010.15234.pdf  adapted for non-linear settings. This class has same structure as fixed_irm_game_model class
    
         
            



# Import external libraries

In [1]:
import tensorflow as tf
import numpy as np
import argparse
import IPython.display as display
import matplotlib.pyplot as plt
from tensorflow import keras
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import pandas as pd
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
# import cProfile
import copy as cp
from sklearn.model_selection import KFold

In [2]:
print(tf.__version__)
tf.executing_eagerly()

2.3.1


True

# Import libraries


In [4]:
from data_construct import * ## contains functions for constructing data 
from IRM_methods import *    ## contains IRM games methods 

# MNIST digits:  2 environments

Below we illustrate how to use our IRM methods. 
We first setup the data in the cell below.  We set p_color_list = [0.2, 0.1] (from experiments in Arjovsky et.al.); note that there is marginal difference between the probabilities of switching the labels in the two environments. This marginal difference is useful for IRM methods to learn invariant predictors across environments that exploit the shape of digits and not the color. 

## Compare F-IRM and C-LRG adapted for standard colored MNIST

In [None]:
# initialize F-IRM model (we pass the hyper-parameters that we chose above)

# Create data for each environment

av_jm  = 0
av_jm1 = 0

for i in range(10):
    print ("trial"+ str(i))
    n_e = 2  # number of environments

    p_color_list = [0.2, 0.1] # list of probabilities of switching the final label to obtain the color index
    p_label_list = [0.25]*n_e # list of probabilities of switching pre-label
    D = assemble_data_mnist() # initialize mnist digits data object

    D.create_training_data(n_e, p_color_list, p_label_list) # creates the training environments

    p_label_test = 0.25 # probability of switching pre-label in test environment
    p_color_test = 0.9  # probability of switching the final label to obtain the color index in test environment

    D.create_testing_data(p_color_test, p_label_test, n_e)  # sets up the testing environment
    (num_examples_environment,length, width, height) = D.data_tuple_list[0][0].shape # attributes of the data
    num_classes = len(np.unique(D.data_tuple_list[0][1])) # number of classes in the data

    model_list = [] 
    for e in range(n_e):
        model_list.append(keras.Sequential([
                keras.layers.Flatten(input_shape=(length, width,height)),
                keras.layers.Dense(390, activation = 'elu'),
                 keras.layers.Dropout(0.75),
                keras.layers.Dense(390, activation='elu'),
                 keras.layers.Dropout(0.75),
                keras.layers.Dense(num_classes)
        ]))

    num_epochs       = 25
    batch_size       = 256
    termination_acc  = 0.53
    warm_start       = 100
    learning_rate    = 5e-4
    
    
    F_game = fixed_irm_game_model_cons(model_list, learning_rate, num_epochs, batch_size, termination_acc, warm_start) 

    # fit function runs the training on the data that we created
    F_game.fit(D.data_tuple_list)

    # evaluate function runs and evaluates train and test accuracy of the final model
    F_game.evaluate(D.data_tuple_test) 

    # print train and test accuracy
#     print ("Training accuracy " + str(F_game.train_acc)) 
    print ("Testing accuracy " + str(F_game.test_acc))
    av_jm = av_jm+ F_game.test_acc
    
    model_list = [] 
    for e in range(n_e):
        model_list.append(keras.Sequential([
                keras.layers.Flatten(input_shape=(length, width,height)),
                keras.layers.Dense(390, activation = 'elu'),
                 keras.layers.Dropout(0.75),
                keras.layers.Dense(390, activation='elu'),
                 keras.layers.Dropout(0.75),
                keras.layers.Dense(num_classes)
        ]))
        
    F_game1 = fixed_irm_game_model(model_list, learning_rate, num_epochs, batch_size, termination_acc, warm_start) 

    # fit function runs the training on the data that we created
    F_game1.fit(D.data_tuple_list)

    # evaluate function runs and evaluates train and test accuracy of the final model
    F_game1.evaluate(D.data_tuple_test) 

    # print train and test accuracy
#     print ("Training accuracy " + str(F_game1.train_acc)) 
    print ("Testing accuracy " + str(F_game1.test_acc))
    av_jm1 = av_jm1+ F_game1.test_acc

print ("F-IRM average accuracy:" +  str(av_jm1/10) )# F_IRM
print ("C-LRG adapted for non-linear models average accuracy:" + str(av_jm/10))  # C-LRG adapted