In [None]:
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms
from torch import nn, optim, autograd
import pdb
import copy
from tqdm import tqdm
import pickle
import math
import os
import sys

from sklearn.linear_model import LinearRegression
from itertools import chain, combinations
from scipy.stats import f as fdist
from scipy.stats import ttest_ind

from torch.autograd import grad

import scipy.optimize

import matplotlib

In [None]:
parser = argparse.ArgumentParser(description='Invariant regression')
parser.add_argument('--dim', type=int, default=10)
parser.add_argument('--n_samples', type=int, default=1000)
parser.add_argument('--n_reps', type=int, default=3)
parser.add_argument('--skip_reps', type=int, default=0)
parser.add_argument('--seed', type=int, default=0)  # Negative is random
parser.add_argument('--print_vectors', type=int, default=0)
parser.add_argument('--n_iterations', type=int, default=10000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--verbose', type=int, default=1)
parser.add_argument('--methods', type=str, default="IRM,REIIL,IRM,EIIL")
parser.add_argument('--alpha', type=float, default=0.05)
parser.add_argument('--setup_sem', type=str, default="chain")
parser.add_argument('--setup_hidden', type=int, default=0)
parser.add_argument('--setup_hetero', type=int, default=2)
parser.add_argument('--setup_scramble', type=int, default=0)
parser.add_argument('--results_dir', type=str, default="./tmp/experiment_synthetic")
parser.add_argument('--eiil_ref_alpha', type=float, default=-1,
                    help=('Value between zero and one to hard code the reference '
                          'classifier propensity to use the spurious feature. Set '
                          'to value outside zero one interval to disable.'))
parser.add_argument('--reiil_iters', type=int, default=10)
flags = dict(vars(parser.parse_args(['--n_rep', '3'])))

In [None]:
print('Flags:')
for k,v in sorted(flags.items()):
  print("\t{}: {}".format(k, v))

In [None]:
class ChainEquationModel(object):
    def __init__(self, dim, scramble=False, hetero=True, hidden=False):
        self.hetero = hetero
        self.hidden = hidden
        self.dim = dim // 2
        ones = True

        if ones:
            self.wxy = torch.eye(self.dim)
            self.wyz = torch.eye(self.dim)
        else:
            self.wxy = torch.randn(self.dim, self.dim) / dim
            self.wyz = torch.randn(self.dim, self.dim) / dim

        if scramble:
            self.scramble, _ = torch.qr(torch.randn(dim, dim))
        else:
            self.scramble = torch.eye(dim)

        if hidden:
            self.whx = torch.randn(self.dim, self.dim) / dim
            self.why = torch.randn(self.dim, self.dim) / dim
            self.whz = torch.randn(self.dim, self.dim) / dim
        else:
            self.whx = torch.eye(self.dim, self.dim)
            self.why = torch.zeros(self.dim, self.dim)
            self.whz = torch.zeros(self.dim, self.dim)

    def solution(self):
        w = torch.cat((self.wxy.sum(1), torch.zeros(self.dim))).view(-1, 1)
        return self.scramble.t() @ w

    def __call__(self, n, env, split=None):
        h = torch.randn(n, self.dim) * env

        if self.hetero == 2:
            if split:
              num = int(n * split)
              x_low_noise = torch.randn(num, self.dim) * 5.
              x_rest = torch.randn(n - num, self.dim) * 5.
              x = torch.cat((x_low_noise, x_rest), 0)
              y_low_noise = x_low_noise @ self.wxy + torch.randn(num, self.dim) * 0.1
              y_rest = x_rest @ self.wxy + torch.randn(n - num, self.dim) * env
              y = torch.cat((y_low_noise, y_rest), 0)
            else:
              x = torch.randn(n, self.dim) * 5.
              y = x @ self.wxy + torch.randn(n, self.dim) * env
            z = y @ self.wyz + torch.randn(n, self.dim)
        elif self.hetero == 1:
            x = h @ self.whx + torch.randn(n, self.dim) * env
            y = x @ self.wxy + h @ self.why + torch.randn(n, self.dim) * env
            z = y @ self.wyz + h @ self.whz + torch.randn(n, self.dim)
        else:
            x = h @ self.whx + torch.randn(n, self.dim) * env
            y = x @ self.wxy + h @ self.why + torch.randn(n, self.dim)
            z = y @ self.wyz + h @ self.whz + torch.randn(n, self.dim) * env
        variances = dict(
          h=h.var().item(),
          x=x.var().item(),
          y=y.var().item(),
          z=z.var().item(),
          e=(torch.randn(n, self.dim) * env).var().item()  # any env dependent noise we might add
        )
        from pprint import pprint
        print('in setting %d data in env %d have following variances' % (self.hetero, env))
        pprint(variances)
        return torch.cat((x, z), 1) @ self.scramble, y.sum(1, keepdim=True)

In [None]:
def pretty(vector):
    vlist = vector.view(-1).tolist()
    return "[" + ", ".join("{:+.4f}".format(vi) for vi in vlist) + "]"

#Models
class InvariantRiskMinimization(object):
    def __init__(self, environments, args):
        best_reg = 0
        best_err = 1e6

        x_val = environments[2][0]
        y_val = environments[2][1]

        for reg in [0, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]:
            reg = 1. - reg  # change of variables for consistency with old codebase
            self.train(environments[:2], args, reg=reg)
            err = (x_val @ self.solution() - y_val).pow(2).mean().item()

            if args["verbose"]:
                print("IRM (reg={:.6f}) has {:.3f} validation error.".format(
                    reg, err))

            if err < best_err:
                best_err = err
                best_reg = reg
                best_phi = self.phi.clone()
        self.phi = best_phi
        print ('IRM best err and phi:', best_err, self.phi)
        test_err = (environments[3][0] @ self.solution() - environments[3][1]).pow(2).mean().item()
        print ('IRM test err [env=0]:', test_err)
        test_err = (environments[4][0] @ self.solution() - environments[4][1]).pow(2).mean().item()
        print ('IRM test err [env=5]:', test_err)
        test_err = (environments[5][0] @ self.solution() - environments[5][1]).pow(2).mean().item()
        print ('IRM test err [env=10]:', test_err)

    def train(self, environments, args, reg=0):
        print('learning representation with', self, 'and reg', reg)
        dim_x = environments[0][0].size(1)

        x_1 = torch.cat((environments[0][0][:100], environments[1][0][:200]), 0)
        y_1 = torch.cat((environments[0][1][:100], environments[1][1][:200]), 0)
        x_2 = torch.cat((environments[0][0][100:], environments[1][0][200:]), 0)
        y_2 = torch.cat((environments[0][1][100:], environments[1][1][200:]), 0)

        environments = [(x_1, y_1), (x_2, y_2)]

        self.phi = torch.nn.Parameter(torch.eye(dim_x, dim_x))
        self.w = torch.ones(dim_x, 1)
        self.w.requires_grad = True

        opt = torch.optim.Adam([self.phi], lr=args["lr"])
        loss = torch.nn.MSELoss()

        for iteration in range(args["n_iterations"]):
            penalty = 0
            error = 0
            for x_e, y_e in environments:
                error_e = loss(x_e @ self.phi @ self.w, y_e)
                penalty += grad(error_e, self.w,
                                create_graph=True)[0].pow(2).mean()
                error += error_e

            opt.zero_grad()
#             (reg * error + (1 - reg) * penalty).backward()  # dumb; zero reg means regularize 100%
            ((1 - reg) * error + reg * penalty).backward()  # good
            opt.step()

            if args["verbose"] and iteration % 1000 == 0:
                w_str = pretty(self.solution())
                print("{:05d} | {:.5f} | {:.5f} | {:.5f} | {}".format(iteration,
                                                                      reg,
                                                                      error,
                                                                      penalty,
                                                                      w_str))

    def solution(self):
        return self.phi @ self.w

      
class LearnedEnvInvariantRiskMinimization(InvariantRiskMinimization):
    def __init__(self, environments, args, pretrain=False):
        best_reg = 0
        best_err = 1e6

        x_val = environments[2][0]
        y_val = environments[2][1]

        if args['eiil_ref_alpha'] >= 0 and args['eiil_ref_alpha'] <= 1: 
            print('Using hard-coded reference classifier with alpha={:.2f}'.format(
              args['eiil_ref_alpha']
            ))
            alpha = args['eiil_ref_alpha']
            w_causal = (1. - alpha) * np.ones((1, 5))
            w_noncausal = alpha * np.ones((1, 5))  # spurious contribution to prediction
            w_ref = np.hstack((w_causal, w_noncausal))
            w_ref = torch.tensor(w_ref, dtype=torch.float32)
        else:
            print('Using ERM soln as reference classifier.')
            # w_ref = EmpiricalRiskMinimizer(environments[:-1], args).solution()
            w_ref = EmpiricalRiskMinimizer(environments, args).solution()

        self.phi = torch.nn.Parameter(torch.diag(w_ref.squeeze()))
        dim_x = environments[0][0].size(1)
        self.w = torch.ones(dim_x, 1)
        self.w.requires_grad = True
        err = (x_val @ self.solution() - y_val).pow(2).mean().item()

        if args["verbose"]:
            print("EIIL's reference classifier has {:.3f} validation error.".format(
                err))
            print("EIIL's reference classifier has the following solution:\n.",
                  pretty(self.solution()))

        self.phi = self.phi.clone()

        environments, env_w = self.split(environments, args)
        if args["verbose"]:
            print("EIIL+ERM ref clf still has the following solution after AED (sanity check):\n.", pretty(self.solution()))
        best_reg = 0
        best_err = 1e6

        # Finding num flipped in minority.
        idx = [(env_w.sigmoid()>.5), (env_w.sigmoid()<=.5)]
        maj = 0
        if torch.count_nonzero(idx[1]) > torch.count_nonzero(idx[0]):
          maj = 1

        total_flipped = 100 + 200
        total_flipped_min = torch.count_nonzero(idx[1 - maj][:100]) + torch.count_nonzero(idx[1 - maj][1000:1200])
        print (f'EIIL Total flipped in minority environment: {total_flipped_min} / {total_flipped} ({total_flipped_min / total_flipped * 100} %)')


        for reg in [0, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]:
            reg = 1. - reg  # change of variables for consistency with old codebase
            # self.train(environments[:-1], args, reg=reg)
            self.train(environments[:2], args, reg=reg)
            err = (x_val @ self.solution() - y_val).pow(2).mean().item()

            if args["verbose"]:
                print("EIIL+IRM (reg={:.6f}) has {:.3f} validation error.".format(
                    reg, err))

            if err < best_err:
                best_err = err
                best_reg = reg
                best_phi = self.phi.clone()
        self.phi = best_phi
        print ('EIIL best err and phi:', best_err, self.phi)
        test_err = (environments[3][0] @ self.solution() - environments[3][1]).pow(2).mean().item()
        print ('EIIL test err [env=0]:', test_err)
        test_err = (environments[4][0] @ self.solution() - environments[4][1]).pow(2).mean().item()
        print ('EIIL test err [env=5]:', test_err)
        test_err = (environments[5][0] @ self.solution() - environments[5][1]).pow(2).mean().item()
        print ('EIIL test err [env=10]:', test_err)


    def split(self, environments, args, n_samples=-1):
          """Learn soft environment assignment."""
          envs = environments
          # test_env = envs[-1]
          test_envs = envs[2:]
          x = torch.cat((envs[0][0][:n_samples],envs[1][0][:n_samples]),0)
          y = torch.cat((envs[0][1][:n_samples],envs[1][1][:n_samples]),0)
          print('size of pooled envs: '+str(len(x)))
     
          loss = torch.nn.MSELoss(reduction='none')
          error = loss(x @ self.phi @ self.w, y)

          env_w = torch.randn(len(error)).requires_grad_()
          optimizer = torch.optim.Adam([env_w], lr=0.001)

          print('learning soft environment assignments')
          prev_penalty = 0
          ind = 0
          max_diff = -np.inf
          pbar = tqdm(range(args['n_iterations']))
          for i in pbar:
            # penalty for env a
            error_a = (error.squeeze() * env_w.sigmoid()).mean()
            penalty_a = grad(error_a, self.w, create_graph=True)[0].pow(2).mean()
            # penalty for env b
            error_b = (error.squeeze() * (1-env_w.sigmoid())).mean()
            penalty_b = grad(error_b, self.w, create_graph=True)[0].pow(2).mean()
            # negate
            npenalty = - torch.stack([penalty_a, penalty_b]).mean()
            if i > 0:
              diff = abs(npenalty.item() - prev_penalty)
            else:
              diff = -np.inf
            if diff > max_diff:
              max_diff = diff
              ind = i
            pbar.set_description_str(desc='Negative Penalty: '+str(npenalty.item())+', Diff: '+str(diff)+', Max Diff: '+str(max_diff)+'('+str(ind)+')')
            prev_penalty = npenalty.item()

            optimizer.zero_grad()
            npenalty.backward(retain_graph=True)
            optimizer.step()

          idx0 = (env_w.sigmoid()>.5)
          idx1 = (env_w.sigmoid()<=.5)
          
          envs = []
          envs.append((x[idx0],y[idx0]))
          print('size of env 0: '+str(len(x[idx0])))
          envs.append((x[idx1],y[idx1]))
          print('size of env 1: '+str(len(x[idx1])))
          print('weights: '+str(env_w.sigmoid()))
          # envs.append(test_env)
          envs.extend(test_envs)
          return envs, env_w

class RepeatedEIIL(InvariantRiskMinimization):
    def __init__(self, environments, args, pretrain=False):
        best_reg = 0
        best_err = 1e6

        # x_val = environments[-1][0]
        x_val = environments[2][0]
        # y_val = environments[-1][1]
        y_val = environments[2][1]
        self.reiters = args['reiil_iters']

        min_env = []
        new_envs = environments.copy()
        best_reg = 0
        best_err = 1e6

        for i in range(self.reiters):

          if args['eiil_ref_alpha'] >= 0 and args['eiil_ref_alpha'] <= 1: 
              print('Using hard-coded reference classifier with alpha={:.2f}'.format(
                args['eiil_ref_alpha']
              ))
              alpha = args['eiil_ref_alpha']
              w_causal = (1. - alpha) * np.ones((1, 5))
              w_noncausal = alpha * np.ones((1, 5))  # spurious contribution to prediction
              w_ref = np.hstack((w_causal, w_noncausal))
              w_ref = torch.tensor(w_ref, dtype=torch.float32)
          else:
              print(str(i)+': Using ERM soln as reference classifier.')
              # w_ref = EmpiricalRiskMinimizer(new_envs[:-1], args).solution()
              w_ref = EmpiricalRiskMinimizer(new_envs, args).solution()

          self.phi = torch.nn.Parameter(torch.diag(w_ref.squeeze()))
          dim_x = new_envs[0][0].size(1)
          self.w = torch.ones(dim_x, 1)
          self.w.requires_grad = True
          err = (x_val @ self.solution() - y_val).pow(2).mean().item()

          self.phi = self.phi.clone()

          if args["verbose"]:
              print(str(i)+": REIIL's reference classifier has {:.3f} validation error.".format(
                  err))
              print(str(i)+": REIIL's reference classifier has the following solution:\n.",
                    pretty(self.solution()))

          new_envs, env_w = self.split(environments, args)

          # Finding num flipped in minority.
          idx = [(env_w.sigmoid()>.5), (env_w.sigmoid()<=.5)]
          maj = 0
          if torch.count_nonzero(idx[1]) > torch.count_nonzero(idx[0]):
            maj = 1

          total_flipped = 100 + 200
          total_flipped_min = torch.count_nonzero(idx[1 - maj][:100]) + torch.count_nonzero(idx[1 - maj][1000:1200])
          print (f'REIIL {i} Total flipped in minority environment: {total_flipped_min} / {total_flipped} ({total_flipped_min / total_flipped * 100} %)')

          rest_envs = new_envs.copy()
          if args["verbose"]:
              print("REIIL+ERM ref clf still has the following solution after AED (sanity check):\n.", pretty(self.solution()))

          if i < self.reiters - 1:
            maj_ind = 0
            if len(new_envs[1][0]) > len(new_envs[0][0]):
              maj_ind = 1

            min_env = new_envs[1 - maj_ind]
            # new_envs = [new_envs[maj_ind], new_envs[2]]
            new_envs = [new_envs[maj_ind], new_envs[2], new_envs[3], new_envs[4], new_envs[5]]

          for reg in [0, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]:

              reg = 1. - reg  # change of variables for consistency with old codebase
              # self.train(rest_envs[:-1], args, reg=reg)
              self.train(rest_envs[:2], args, reg=reg)
              err = (x_val @ self.solution() - y_val).pow(2).mean().item()

              if args["verbose"]:
                  print(str(i)+": REIIL+IRM (reg={:.6f}) has {:.3f} validation error.".format(
                      reg, err))

              if err < best_err:
                  best_err = err
                  best_reg = reg
                  best_phi = self.phi.clone()
          self.phi = best_phi
          print ('REIIL best err and phi:', best_err, self.phi)
          test_err = (environments[3][0] @ self.solution() - environments[3][1]).pow(2).mean().item()
          print ('REIIL test err [env=0]:', test_err)
          test_err = (environments[4][0] @ self.solution() - environments[4][1]).pow(2).mean().item()
          print ('REIIL test err [env=5]:', test_err)
          test_err = (environments[5][0] @ self.solution() - environments[5][1]).pow(2).mean().item()
          print ('REIIL test err [env=10]:', test_err)


    def split(self, environments, args, n_samples=-1):
          """Learn soft environment assignment."""
          envs = environments
          # test_env = envs[-1]
          test_envs = envs[2:]
          x = torch.cat((envs[0][0][:n_samples],envs[1][0][:n_samples]),0)
          y = torch.cat((envs[0][1][:n_samples],envs[1][1][:n_samples]),0)
          print('size of pooled envs: '+str(len(x)))
     
          loss = torch.nn.MSELoss(reduction='none')
          error = loss(x @ self.phi @ self.w, y)

          env_w = torch.randn(len(error)).requires_grad_()
          optimizer = torch.optim.Adam([env_w], lr=0.001)

          print('learning soft environment assignments')
          prev_penalty = 0
          ind = 0
          max_diff = -np.inf
          pbar = tqdm(range(args['n_iterations']))
          for i in pbar:
            # penalty for env a
            error_a = (error.squeeze() * env_w.sigmoid()).mean()
            penalty_a = grad(error_a, self.w, create_graph=True)[0].pow(2).mean()
            # penalty for env b
            error_b = (error.squeeze() * (1-env_w.sigmoid())).mean()
            penalty_b = grad(error_b, self.w, create_graph=True)[0].pow(2).mean()
            # negate
            npenalty = - torch.stack([penalty_a, penalty_b]).mean()
            if i > 0:
              diff = abs(npenalty.item() - prev_penalty)
            else:
              diff = -np.inf
            if diff > max_diff:
              max_diff = diff
              ind = i
            pbar.set_description_str(desc='Negative Penalty: '+str(npenalty.item())+', Diff: '+str(diff)+', Max Diff: '+str(max_diff)+'('+str(ind)+')')
            prev_penalty = npenalty.item()

            optimizer.zero_grad()
            npenalty.backward(retain_graph=True)
            optimizer.step()

          envs = []
          idx0 = (env_w.sigmoid()>.5)
          idx1 = (env_w.sigmoid()<=.5)
          envs.append((x[idx0],y[idx0]))
          print('size of env 0: '+str(len(x[idx0])))
          envs.append((x[idx1],y[idx1]))
          print('size of env 1: '+str(len(x[idx1])))
          print('weights: '+str(env_w.sigmoid()))
          # envs.append(test_env)
          envs.extend(test_envs)
          return envs, env_w

 
class InvariantCausalPrediction(object):
    def __init__(self, environments, args):
        self.coefficients = None
        self.alpha = args["alpha"]

        x_all = []
        y_all = []
        e_all = []

        for e, (x, y) in enumerate(environments):
            x_all.append(x.numpy())
            y_all.append(y.numpy())
            e_all.append(np.full(x.shape[0], e))

        x_all = np.vstack(x_all)
        y_all = np.vstack(y_all)
        e_all = np.hstack(e_all)

        dim = x_all.shape[1]

        accepted_subsets = []
        for subset in self.powerset(range(dim)):
            if len(subset) == 0:
                continue

            x_s = x_all[:, subset]
            reg = LinearRegression(fit_intercept=False).fit(x_s, y_all)

            p_values = []
            for e in range(len(environments)):
                e_in = np.where(e_all == e)[0]
                e_out = np.where(e_all != e)[0]

                res_in = (y_all[e_in] - reg.predict(x_s[e_in, :])).ravel()
                res_out = (y_all[e_out] - reg.predict(x_s[e_out, :])).ravel()

                p_values.append(self.mean_var_test(res_in, res_out))

            # TODO: Jonas uses "min(p_values) * len(environments) - 1"
            p_value = min(p_values) * len(environments)

            if p_value > self.alpha:
                accepted_subsets.append(set(subset))
                if args["verbose"]:
                    print("Accepted subset:", subset)

        if len(accepted_subsets):
            accepted_features = list(set.intersection(*accepted_subsets))
            if args["verbose"]:
                print("Intersection:", accepted_features)
            self.coefficients = np.zeros(dim)

            if len(accepted_features):
                x_s = x_all[:, list(accepted_features)]
                reg = LinearRegression(fit_intercept=False).fit(x_s, y_all)
                self.coefficients[list(accepted_features)] = reg.coef_

            self.coefficients = torch.Tensor(self.coefficients)
        else:
            self.coefficients = torch.zeros(dim)

    def mean_var_test(self, x, y):
        pvalue_mean = ttest_ind(x, y, equal_var=False).pvalue
        pvalue_var1 = 1 - fdist.cdf(np.var(x, ddof=1) / np.var(y, ddof=1),
                                    x.shape[0] - 1,
                                    y.shape[0] - 1)

        pvalue_var2 = 2 * min(pvalue_var1, 1 - pvalue_var1)

        return 2 * min(pvalue_mean, pvalue_var2)

    def powerset(self, s):
        return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))

    def solution(self):
        return self.coefficients


class EmpiricalRiskMinimizer(object):
    def __init__(self, environments, args):
        # x_all = torch.cat([x for (x, y) in environments[:-1]]).numpy()
        print ('ERM training environments: ', len(environments[:-4]))
        x_all = torch.cat([x for (x, y) in environments[:-4]]).numpy()
        # y_all = torch.cat([y for (x, y) in environments[:-1]]).numpy()
        y_all = torch.cat([y for (x, y) in environments[:-4]]).numpy()

        x_val = environments[-4][0].numpy()
        y_val = environments[-4][1].numpy()
        x_test = environments[-1][0].numpy()
        y_test = environments[-1][1].numpy()

        w = LinearRegression(fit_intercept=False).fit(x_all, y_all).coef_
        self.w = torch.Tensor(w)
        if args['verbose']:
          print('Done training ERM.')
          # print (x_all[:2], x_all.dot(self.solution().T)[:2], y_all[:2])
          # input()
          err = np.mean((x_all.dot(self.solution().T) - y_all) ** 2.).item()
          print("ERM has {:.3f} train error.".format(err))
          err = np.mean((environments[-4][0].numpy().dot(self.solution().T) - environments[-4][1].numpy()) ** 2.).item()
          print("ERM has {:.3f} val error.".format(err))
          err = np.mean((environments[-3][0].numpy().dot(self.solution().T) - environments[-3][1].numpy()) ** 2.).item()
          print("ERM has {:.3f} test error [env=0].".format(err))
          err = np.mean((environments[-2][0].numpy().dot(self.solution().T) - environments[-2][1].numpy()) ** 2.).item()
          print("ERM has {:.3f} test error [env=5].".format(err))
          err = np.mean((environments[-1][0].numpy().dot(self.solution().T) - environments[-1][1].numpy()) ** 2.).item()
          print("ERM has {:.3f} test error [env=10].".format(err))
          print("ERM has the following solution:\n ", pretty(self.solution()))

    def solution(self):
        return self.w

In [None]:
def pretty(vector):
    vlist = vector.view(-1).tolist()
    return "[" + ", ".join("{:+.3f}".format(vi) for vi in vlist) + "]"


def errors(w, w_hat):
    w = w.view(-1)
    w_hat = w_hat.view(-1)

    i_causal = (w != 0).nonzero().view(-1)
    i_noncausal = (w == 0).nonzero().view(-1)

    if len(i_causal):
        error_causal = (w[i_causal] - w_hat[i_causal]).pow(2).mean()
        error_causal = error_causal.item()
    else:
        error_causal = 0

    if len(i_noncausal):
        error_noncausal = (w[i_noncausal] - w_hat[i_noncausal]).pow(2).mean()
        error_noncausal = error_noncausal.item()
    else:
        error_noncausal = 0

    return error_causal, error_noncausal


def run_experiment(args):
    if args["seed"] >= 0:
        torch.manual_seed(args["seed"])
        np.random.seed(args["seed"])
        torch.set_num_threads(1)

    if args["setup_sem"] == "chain":
        setup_str = "chain_hidden={}_hetero={}_scramble={}".format(
            args["setup_hidden"],
            args["setup_hetero"],
            args["setup_scramble"])
    elif args["setup_sem"] == "icp":
        setup_str = "sem_icp"
    else:
        raise NotImplementedError

    args['results_dir'] = os.path.join(args['results_dir'], setup_str)
    if args['eiil_ref_alpha'] >= 0 and args['eiil_ref_alpha'] <= 1: 
        args['results_dir'] = '{results_dir}_alpha_{eiil_ref_alpha:.1f}'.format(**args)
 
    if not os.path.exists(args['results_dir']):
        os.makedirs(args['results_dir'])
    pickle.dump(args, open(os.path.join(args['results_dir'], 'flags.p'), 'wb'))
    for f in sys.stdout, open(os.path.join(args['results_dir'], 'flags.txt'), 'w'):
        print('Flags:', file=f)
        for k,v in sorted(args.items()):
            print("\t{}: {}".format(k, v), file=f)
    print('results will be found here:')
    print(args['results_dir'])
    
    all_methods = {
        "ERM": EmpiricalRiskMinimizer,
        "ICP": InvariantCausalPrediction,
        "IRM": InvariantRiskMinimization,
        "EIIL": LearnedEnvInvariantRiskMinimization,
        "REIIL": RepeatedEIIL
    }

    if args["methods"] == "all":
        methods = all_methods
    else:
        methods = {m: all_methods[m] for m in args["methods"].split(',')}

    all_sems = []
    all_solutions = []
    all_environments = []
    from collections import defaultdict
    all_err_causal = defaultdict(list)
    all_err_noncausal = defaultdict(list)

    for rep_i in range(args["n_reps"]):
        if args["setup_sem"] == "chain":
            sem = ChainEquationModel(args["dim"],
                                     hidden=args["setup_hidden"],
                                     scramble=args["setup_scramble"],
                                     hetero=args["setup_hetero"])
            # environments = [sem(args["n_samples"], .2),
            #                 sem(args["n_samples"], 2.),
            #                 sem(args["n_samples"], 3.5),
            #                 sem(args["n_samples"], 0.),
            #                 sem(args["n_samples"], 5.),
            #                 sem(args["n_samples"], 10.)]
            environments = [sem(args["n_samples"], 3., 0.1),
                            sem(args["n_samples"], 3., 0.2),
                            sem(args["n_samples"], 3.5),
                            sem(args["n_samples"], 0.),
                            sem(args["n_samples"], 5.),
                            sem(args["n_samples"], 10.)]
        else:
            raise NotImplementedError

        all_sems.append(sem)
        all_environments.append(environments)
      
    for sem, environments in zip(all_sems, all_environments):
        soln = sem.solution()
        solutions = [
            "{} {:<5} {} {:.5f} {:.5f}".format(setup_str,
                                             "SEM",
                                             pretty(sem.solution()), 0, 0)
        ]
        

        for method_name, method_constructor in methods.items():
            method = method_constructor(environments, args)
            msolution = method.solution()
            err_causal, err_noncausal = errors(sem.solution(), msolution)
            all_err_causal[method_name].append(err_causal)
            all_err_noncausal[method_name].append(err_noncausal)
            solutions.append("{} {:<5} {} {:.5f} {:.5f}".format(setup_str,
                                                             method_name,
                                                             pretty(msolution),
                                                             err_causal,
                                                             err_noncausal))

        all_solutions += solutions

    # save results
    results = dict()
    results.update(setup_str=setup_str)
    results.update(all_sems=all_sems)
    results.update(all_solutions=all_solutions)
    results.update(all_environments=all_environments)
    results.update(all_environments=all_environments)
    results.update(all_err_causal=all_err_causal)
    results.update(all_err_noncausal=all_err_noncausal)
    with open(os.path.join(args['results_dir'], 'results.p'), 'wb') as f:
        pickle.dump(results, f)
    
    return all_solutions

In [None]:
hidden = [0]
hetero = [1]
scramble = [1]
for i in range(1):
  # flags['setup_hidden'] = hidden[i]
  # flags['setup_hetero'] = hetero[i]
  # flags['setup_scramble'] = scramble[i]

  all_solutions = run_experiment(flags)
  print("\n".join(all_solutions))
  print("\n".join(all_solutions), file=open(
    os.path.join(flags['results_dir'], 'all_solutions.txt'), 'w')
        )