In [None]:
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms
from torch import nn, optim, autograd
import pdb
import copy
from tqdm import tqdm
import pickle

In [None]:
parser = argparse.ArgumentParser(description='CIFAR MNIST')
parser.add_argument('--hidden_dim', type=int, default=390)
parser.add_argument('--l2_regularizer_weight', type=float,default=0.00110794568)
parser.add_argument('--lr', type=float, default=0.0004898536566546834)
parser.add_argument('--n_restarts', type=int, default=3)
parser.add_argument('--penalty_anneal_iters', type=int, default=200)
parser.add_argument('--penalty_weight', type=float, default=91257.18613115903)
parser.add_argument('--steps', type=int, default=701)
parser.add_argument('--grayscale_model', type=int, default=0)
parser.add_argument('--eiil', type=int, default=1)
parser.add_argument('--reiil_iters', type=int, default=10)
flags = parser.parse_args(['--hidden_dim', '390'])
torch.cuda.set_device(0)

In [None]:
print('Flags:')
for k,v in sorted(vars(flags).items()):
  print("\t{}: {}".format(k, v))

In [None]:
def split_data_opt(envs, model, n_steps=5000, n_samples=-1):
  """Learn soft environment assignment."""
  images = torch.cat((envs[0]['images'],envs[1]['images']),0)
  labels = torch.cat((envs[0]['labels'],envs[1]['labels']),0)
  # orig_labels = torch.cat((envs[0]['orig_labels'],envs[1]['orig_labels']),0)
  print('size of pooled envs: '+str(len(images)))

  scale = torch.tensor(1.).cuda().requires_grad_()
  logits = model(images)
  loss = nll(logits * scale, labels, reduction='none')

  env_w = torch.randn(len(logits)).cuda().requires_grad_()
  optimizer = optim.Adam([env_w], lr=0.01)

  print('learning soft environment assignments')
  prev_penalty = 0
  ind = 0
  max_diff = -np.inf
  pbar = tqdm(range(n_steps))
  for i in pbar:
    # penalty for env a
    lossa = (loss.squeeze() * env_w.sigmoid()).mean()
    grada = autograd.grad(lossa, [scale], create_graph=True)[0]
    penaltya = torch.sum(grada**2)
    # penalty for env b
    lossb = (loss.squeeze() * (1-env_w.sigmoid())).mean()
    gradb = autograd.grad(lossb, [scale], create_graph=True)[0]
    penaltyb = torch.sum(gradb**2)
    # negate
    npenalty = - torch.stack([penaltya, penaltyb]).mean()

    if i > 0:
      diff = abs(npenalty.item() - prev_penalty)
    else:
      diff = -np.inf
    if diff > max_diff:
      max_diff = diff
      ind = i
    pbar.set_description_str(desc='Negative Penalty: '+str(npenalty.item())+', Diff: '+str(diff)+', Max Diff: '+str(max_diff)+'('+str(ind)+')')
    prev_penalty = npenalty.item()
    optimizer.zero_grad()
    npenalty.backward(retain_graph=True)
    optimizer.step()

  # split envs based on env_w threshold
  new_envs = []

  idx0 = (env_w.sigmoid()>.5)
  idx1 = (env_w.sigmoid()<=.5)

  mid = len(idx0) // 2
  c1 = torch.count_nonzero(idx0[:mid])
  c2 = torch.count_nonzero(idx1[mid:])
  acc = (c1 + c2) / len(idx0)

  if acc < 0.5:
    acc = 1 - acc

  print (f'Environment assignment accuracy: {acc}')

  # train envs
  for idx in (idx0, idx1):
    new_envs.append(dict(images=images[idx], labels=labels[idx]))#, orig_labels=orig_labels[idx]))
  # test env
  new_envs.append(dict(images=envs[-1]['images'],
                       labels=envs[-1]['labels']))
                      #  orig_labels=envs[-1]['orig_labels']))
  print('size of env0: '+str(len(new_envs[0]['images'])))
  print('size of env1: '+str(len(new_envs[1]['images'])))
  print('size of env2: '+str(len(new_envs[2]['images'])))
  return new_envs, env_w

In [None]:
final_train_accs_erm = []
final_train_accs_erm2 = []
final_train_accs_irm = []
final_train_accs_erm_rnd = []
final_train_accs_eiil_irm = []
final_train_accs_eiil_erm = []
final_train_accs_eiil_erm2 = []
final_train_accs_ei_werm = []
final_test_accs_erm = []
final_test_accs_erm2 = []
final_test_accs_irm = []
final_test_accs_erm_rnd = []
final_test_accs_eiil_irm = []
final_test_accs_eiil_erm = []
final_test_accs_eiil_erm2 = []
final_test_accs_ei_werm = []

In [None]:
cifar_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Lambda(lambd=lambda x: x * 0.8)])

mnist_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Lambda(lambd=lambda x: torch.cat((x, x, x), 0)),
     transforms.Pad(2)])

cifar_train = datasets.CIFAR10(root='/content/datasets/', train=True, download=True, transform=cifar_transform)
cifar_test = datasets.CIFAR10(root='/content/datasets/', train=False, download=True, transform=cifar_transform)
mnist_train = datasets.MNIST(root='/content/datasets/', train=True, download=True, transform=mnist_transform)
mnist_test = datasets.MNIST(root='/content/datasets/', train=False, download=True, transform=mnist_transform)

In [None]:
cifar_train_data = [cifar_train[i] for i in range(len(cifar_train))]
cifar_train_images = torch.stack([d[0] for d in cifar_train_data], dim=0)
cifar_train_labels = torch.tensor([d[1] for d in cifar_train_data], dtype=torch.int32)

cifar_test_data = [cifar_test[i] for i in range(len(cifar_test))]
cifar_test_images = torch.stack([d[0] for d in cifar_test_data], dim=0)
cifar_test_labels = torch.tensor([d[1] for d in cifar_test_data], dtype=torch.int32)

mnist_train_data = [mnist_train[i] for i in range(len(mnist_train))]
mnist_train_images = torch.stack([d[0] for d in mnist_train_data], dim=0)
mnist_train_labels = torch.tensor([d[1] for d in mnist_train_data], dtype=torch.int32)

mnist_test_data = [mnist_test[i] for i in range(len(mnist_test))]
mnist_test_images = torch.stack([d[0] for d in mnist_test_data], dim=0)
mnist_test_labels = torch.tensor([d[1] for d in mnist_test_data], dtype=torch.int32)

print (cifar_train_images.shape, mnist_train_images.shape)
print (cifar_train_labels[:10])

In [None]:
random_cifar_images_inds = []
all_inds = np.arange(50000)
for i in range(10):
  inds = cifar_train_labels.cpu().numpy() == i
  i_inds = all_inds[inds]
  choice = np.random.choice(i_inds)
  # choice = np.random.choice(i_inds, (10, ))
  random_cifar_images_inds.append(choice)

random_cifar_images_inds = np.array(random_cifar_images_inds)

fig, axarr = plt.subplots(1, 10, figsize=(21, 2))
columns = 10
rows = 1
for i in range(rows):
  for j in range(columns):
    img = torch.permute(cifar_train_images[random_cifar_images_inds[i * columns + j]], (1, 2, 0)).cpu().numpy()
    # img = torch.permute(cifar_train_images[random_cifar_images_inds[i * columns + j, np.random.choice(10)]], (1, 2, 0)).cpu().numpy()
    axarr[j].imshow(img)
    axarr[j].axis('off')

In [None]:
c = torch.permute(cifar_train_images[0], (1, 2, 0)).cpu().numpy()
m = torch.permute(mnist_train_images[0], (1, 2, 0)).cpu().numpy()

fig, axarr = plt.subplots(1, 3, figsize=(7, 2))
columns = 3
rows = 1

axarr[0].imshow(c)
axarr[0].axis('off')
axarr[1].imshow(m)
axarr[1].axis('off')
axarr[2].imshow(np.clip(c + m, 0., 1.))
axarr[2].axis('off')

In [None]:
for restart in range(flags.n_restarts):
  print("Restart", restart)

  def make_environment_fixed(inds, mapping, flip, test_mapping=None):

    flipped = []
    classes = np.unique(mnist_train_labels)
    if flip == -1:
      mnist_imgs = mnist_train_images[inds]
      mnist_labels = mnist_train_labels[inds].type(torch.LongTensor)
      for i in range(len(mnist_imgs)):
        mnist_imgs[i] = torch.clip(mnist_imgs[i] + cifar_train_images[random_cifar_images_inds[mapping[mnist_labels[i].item()]]], 0., 1.)
        # mnist_imgs[i] = torch.clip(mnist_imgs[i] + cifar_train_images[random_cifar_images_inds[mapping[mnist_labels[i].item()], 0]], 0., 1.)

    else:
      mnist_imgs = mnist_train_images[inds]
      mnist_labels = mnist_train_labels[inds].type(torch.LongTensor)
      for i in range(len(mnist_imgs)):
        ind = mapping[mnist_labels[i].item()]
        test_ind = test_mapping[mnist_labels[i].item()]
        if np.random.random() < flip:
          ind = np.random.choice([c for c in classes if c != ind and c != test_ind])
          flipped.append(i)
        mnist_imgs[i] = torch.clip(mnist_imgs[i] + cifar_train_images[random_cifar_images_inds[ind]], 0., 1.)
        # mnist_imgs[i] = torch.clip(mnist_imgs[i] + cifar_train_images[random_cifar_images_inds[ind, 0]], 0., 1.)
    
    classes = torch.unique(mnist_labels)
    all_inds = np.arange(len(inds))
    class_wise_inds = [all_inds[mnist_labels == c] for c in classes]

    return {
        'images': mnist_imgs.cuda(),
        'labels': mnist_labels.cuda(),
        'class_wise_inds': class_wise_inds,
        'flipped': torch.tensor(flipped)
    }


  # Random mapping of class for training environments.
  random_mapping_train = np.array([8, 7, 5, 4, 0, 9, 1, 6, 3, 2])

  # For CBMNIST [II]
  random_mapping_train_flip = np.array([5, 8, 3, 2, 4, 1, 7, 9, 6, 0])

  # Random mapping of class for test environment.
  random_mapping_test = np.array([4, 9, 6, 1, 2, 0, 8, 7, 5, 3])

  # Randomly split the training set into two environments.
  env_inds = np.random.choice(50000, 50000, replace=False)
  # env_inds = np.random.choice(50000, 500, replace=False)

  envs = [
    make_environment_fixed(env_inds[::2], random_mapping_train, 0.01, random_mapping_test),
    make_environment_fixed(env_inds[1::2], random_mapping_train, 0.02, random_mapping_test),
    make_environment_fixed(np.arange(50000, 60000), random_mapping_test, -1)
  ]

  classes = torch.unique(envs[2]['labels'])

  flipped = []
  for env in envs:
    flipped.append(env['flipped'])

  # flipped = torch.cat((flipped[0], flipped[1] + 49250)).type(torch.LongTensor)
  flipped = torch.cat((flipped[0], flipped[1] + 25000)).type(torch.LongTensor)
  # flipped = torch.cat((flipped[0], flipped[1] + 250))
  # print (flipped)

  # Define and instantiate the model

  class MLP(nn.Module):
    def __init__(self):
      super(MLP, self).__init__()
      if flags.grayscale_model:
        lin1 = nn.Linear(32 * 32, flags.hidden_dim)
      else:
        lin1 = nn.Linear(3 * 32 * 32, flags.hidden_dim)
      lin2 = nn.Linear(flags.hidden_dim, flags.hidden_dim)
      lin3 = nn.Linear(flags.hidden_dim, 10)
      for lin in [lin1, lin2, lin3]:
        nn.init.xavier_uniform_(lin.weight)
        nn.init.zeros_(lin.bias)
      self._main = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True), lin3)
    def forward(self, input):
      if flags.grayscale_model:
        out = torch.reshape(input, (input.shape[0], 3, 32 * 32)).sum(dim=1)
      else:
        out = torch.reshape(input, (input.shape[0], 3 * 32 * 32))
      out = self._main(out)
      return out

  # Define loss function helpers

  def nll(logits, y, reduction='mean'):
    return nn.functional.cross_entropy(logits, y, reduction=reduction)

  def mean_accuracy(logits, y, reduce='sum'):
    preds = torch.argmax(logits, dim=1)
    if reduce == 'mean':
      return torch.count_nonzero(preds == y) / len(preds)
    else:
      return torch.count_nonzero(preds == y)

  def class_wise_accuracy(logits, y):
    preds = torch.argmax(logits, dim=1)
    acc = [0 for c in classes]
    for c in classes:
      c_all = y == c
      all = preds == y
      intersect = torch.count_nonzero(c_all * all)
      union = torch.count_nonzero(c_all)
      acc[c] = (intersect / union).item()
    return acc

  def penalty(logits, y):
    scale = torch.tensor(1.).cuda().requires_grad_()
    loss = nll(logits * scale, y)
    grad = autograd.grad(loss, [scale], create_graph=True)[0]
    return torch.sum(grad**2)

  # Train loop

  def pretty_print(*values):
    col_width = 13
    def format_val(v):
      if not isinstance(v, str):
        v = np.array2string(v, precision=5, floatmode='fixed')
      return v.ljust(col_width)
    str_values = [format_val(v) for v in values]
    print("   ".join(str_values))

  mlp_erm = MLP().cuda()
  mlp_werm = MLP().cuda()
  mlp_irm = MLP().cuda()

  optimizer_erm = optim.Adam(mlp_erm.parameters(), lr=flags.lr)
  optimizer_werm = optim.Adam(mlp_werm.parameters(), lr=flags.lr)
  optimizer_irm = optim.Adam(mlp_irm.parameters(), lr=flags.lr)

  if flags.eiil:
    
    print ('IRM')
    pretty_print('step', 'train nll', 'train acc', 'train penalty', 'test acc')

    # IRM model
    for step in range(flags.steps):
      for env in envs:
        logits = mlp_irm(env['images'])
        env['nll'] = nll(logits, env['labels'])
        env['acc'] = mean_accuracy(logits, env['labels'])
        env['cw_acc'] = class_wise_accuracy(logits, env['labels'])
        env['penalty'] = penalty(logits, env['labels'])

      tot_samples = len(envs[0]['images']) + len(envs[1]['images'])
      train_nll = torch.stack([envs[0]['nll'], envs[1]['nll']]).mean()
      train_acc = torch.stack([envs[0]['acc'], envs[1]['acc']]).sum() / tot_samples
      train_penalty = torch.stack([envs[0]['penalty'], envs[1]['penalty']]).mean()

      weight_norm = torch.tensor(0.).cuda()
      for w in mlp_irm.parameters():
        weight_norm += w.norm().pow(2)

      loss = train_nll.clone()
      loss += flags.l2_regularizer_weight * weight_norm
      # NOTE: IRM penalties used for testing.
      penalty_weight = (flags.penalty_weight
       if step >= flags.penalty_anneal_iters else 1.0)
      loss += penalty_weight * train_penalty
      if penalty_weight > 1.0:
       # Rescale the entire loss to keep gradients in a reasonable range
       loss /= penalty_weight

      optimizer_irm.zero_grad()
      loss.backward()
      optimizer_irm.step()

      test_acc = envs[2]['acc'] / len(envs[2]['images'])
      if step % 100 == 0:
        pretty_print(
          np.int32(step),
          train_nll.detach().cpu().numpy(),
          train_acc.detach().cpu().numpy(),
          train_penalty.detach().cpu().numpy(),
          test_acc.detach().cpu().numpy()
        )

    final_train_accs_irm.append(train_acc.detach().cpu().numpy())
    final_test_accs_irm.append(test_acc.detach().cpu().numpy())

    for env in envs:
      print (env['cw_acc'])
        

    print ('Starting fresh train for REIIL..')
    mlp_erm = MLP().cuda()
    optimizer_erm = optim.Adam(mlp_erm.parameters(), lr=flags.lr)

    new_envs = envs.copy()
    
    min_env = []
    for eiil_ind in range(flags.reiil_iters):
      print_iters = 100
      steps = flags.steps
      mlp_erm = MLP().cuda()
      print ('ERM')
      optimizer_erm = optim.Adam(mlp_erm.parameters(), lr=flags.lr)

      # First ERM is on the training environment itself.
      if eiil_ind == 0:
        pretty_print('step', 'train nll', 'train acc', 'train penalty', 'test acc')
      else:
        pretty_print('step', 'train nll', 'train acc (0)', 'train acc (1)', 'train penalty', 'test acc')

      for step in range(steps):
        for env in new_envs:
          logits = mlp_erm(env['images'])
          env['nll'] = nll(logits, env['labels'], reduction='sum')
          env['acc'] = mean_accuracy(logits, env['labels'])
          env['cw_acc'] = class_wise_accuracy(logits, env['labels'])

        if min_env:
          mlp_erm.eval()
          logits = mlp_erm(min_env['images'])
          min_env['acc'] = mean_accuracy(logits, min_env['labels'])
          min_env['cw_acc'] = class_wise_accuracy(logits, min_env['labels'])
          mlp_erm.train()

        if eiil_ind == 0:
          tot_samples = len(new_envs[0]['images']) + len(new_envs[1]['images'])
          train_nll = torch.stack([new_envs[0]['nll'], new_envs[1]['nll']]).sum() / tot_samples
          train_acc = torch.stack([new_envs[0]['acc'], new_envs[1]['acc']]).sum() / tot_samples
        else:
          tot_samples = len(new_envs[0]['images'])
          train_nll = new_envs[0]['nll'].sum() / tot_samples
          train_acc = new_envs[0]['acc'].sum() / tot_samples
          min_env_acc = min_env['acc'].sum() / len(min_env['images'])

        weight_norm = torch.tensor(0.).cuda()
        for w in mlp_erm.parameters():
          weight_norm += w.norm().pow(2)

        loss = train_nll.clone()
        loss += flags.l2_regularizer_weight * weight_norm
      
        optimizer_erm.zero_grad()
        loss.backward()
        optimizer_erm.step()

        if eiil_ind == 0:
          test_acc = new_envs[2]['acc'] / len(new_envs[2]['images'])
          if step % print_iters == 0:
            pretty_print(np.int32(step), train_nll.detach().cpu().numpy(), train_acc.detach().cpu().numpy(), train_penalty.detach().cpu().numpy(), test_acc.detach().cpu().numpy())
        else:
          test_acc = new_envs[1]['acc'] / len(new_envs[1]['images'])
          if step % print_iters == 0:
            pretty_print(np.int32(step), train_nll.detach().cpu().numpy(), train_acc.detach().cpu().numpy(), min_env_acc.detach().cpu().numpy(), train_penalty.detach().cpu().numpy(), test_acc.detach().cpu().numpy())

      new_envs, env_w = split_data_opt(envs, mlp_erm)

      rest_envs = new_envs.copy()
      
      if eiil_ind < flags.reiil_iters - 1:
        maj_ind = 0
        if len(new_envs[1]['images']) > len(new_envs[0]['images']):
          maj_ind = 1
        
        min_env = new_envs[1 - maj_ind]
        if eiil_ind == 0:
          mlp_erm.eval()
          min_logits = mlp_erm(min_env['images'])
          min_env_acc = mean_accuracy(min_logits, min_env['labels'])
          min_acc = min_env_acc.sum() / len(min_env['images'])
          maj_logits = mlp_erm(new_envs[maj_ind]['images'])
          maj_env_acc = mean_accuracy(maj_logits, new_envs[maj_ind]['labels'])
          maj_acc = maj_env_acc.sum() / len(new_envs[maj_ind]['images'])
          print (f'EIIL refernce model maj acc: {maj_acc}, min acc: {min_acc}')
          mlp_erm.train()
        new_envs = [new_envs[maj_ind], new_envs[-1]]
      
      all_min = env_w.sigmoid() <= 0.5
      if maj_ind == 1:
        all_min = env_w.sigmoid() > 0.5

      count_min = torch.count_nonzero(all_min[flipped])
      print (f'Total flipped in minority environment: {count_min} / {len(flipped)} ({count_min / len(flipped)})%')
      mlp = MLP().cuda()
      optimizer = optim.Adam(mlp.parameters(), lr=flags.lr)
      print ('REI_WERM')
      pretty_print('step', 'train nll', 'train acc', 'train penalty', 'test acc')
      for step in range(flags.steps):
        for env in rest_envs:
          logits = mlp(env['images'])
          env['nll'] = nll(logits, env['labels'])
          env['acc'] = mean_accuracy(logits, env['labels'])
          env['cw_acc'] = class_wise_accuracy(logits, env['labels'])
          env['penalty'] = penalty(logits, env['labels'])

        tot_samples = len(rest_envs[0]['images']) + len(rest_envs[1]['images'])
        train_nll = torch.stack([rest_envs[0]['nll'], rest_envs[1]['nll']]).mean()
        train_acc = torch.stack([rest_envs[0]['acc'], rest_envs[1]['acc']]).sum() / tot_samples
        train_penalty = torch.stack([rest_envs[0]['penalty'], rest_envs[1]['penalty']]).mean()

        weight_norm = torch.tensor(0.).cuda()
        for w in mlp.parameters():
          weight_norm += w.norm().pow(2)

        loss = train_nll.clone()
        loss += flags.l2_regularizer_weight * weight_norm

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        test_acc = rest_envs[2]['acc'] / len(rest_envs[2]['images'])
        if step % 100 == 0:
          pretty_print(
            np.int32(step),
            train_nll.detach().cpu().numpy(),
            train_acc.detach().cpu().numpy(),
            train_penalty.detach().cpu().numpy(),
            test_acc.detach().cpu().numpy()
          )

      final_train_accs_ei_werm.append(train_acc.detach().cpu().numpy())
      final_test_accs_ei_werm.append(test_acc.detach().cpu().numpy())
      print('Final train acc (mean/std across restarts so far):')
      print(np.mean(final_train_accs_ei_werm), np.std(final_train_accs_ei_werm))
      print('Final test acc (mean/std across restarts so far):')
      print(np.mean(final_test_accs_ei_werm), np.std(final_test_accs_ei_werm))

      mlp = MLP().cuda()
      optimizer = optim.Adam(mlp.parameters(), lr=flags.lr)
      print ('EIIL')
      pretty_print('step', 'train nll', 'train acc', 'train penalty', 'test acc')
      for step in range(flags.steps):
        for env in rest_envs:
          logits = mlp(env['images'])
          env['nll'] = nll(logits, env['labels'])
          env['acc'] = mean_accuracy(logits, env['labels'])
          env['cw_acc'] = class_wise_accuracy(logits, env['labels'])
          env['penalty'] = penalty(logits, env['labels'])

        tot_samples = len(rest_envs[0]['images']) + len(rest_envs[1]['images'])
        train_nll = torch.stack([rest_envs[0]['nll'], rest_envs[1]['nll']]).mean()
        train_acc = torch.stack([rest_envs[0]['acc'], rest_envs[1]['acc']]).sum() / tot_samples
        train_penalty = torch.stack([rest_envs[0]['penalty'], rest_envs[1]['penalty']]).mean()

        weight_norm = torch.tensor(0.).cuda()
        for w in mlp.parameters():
          weight_norm += w.norm().pow(2)

        loss = train_nll.clone()
        loss += flags.l2_regularizer_weight * weight_norm
        penalty_weight = (flags.penalty_weight
            if step >= flags.penalty_anneal_iters else 1.0)
        loss += penalty_weight * train_penalty
        if penalty_weight > 1.0:
          # Rescale the entire loss to keep gradients in a reasonable range
          loss /= penalty_weight

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        test_acc = rest_envs[2]['acc'] / len(rest_envs[2]['images'])
        if step % 100 == 0:
          pretty_print(
            np.int32(step),
            train_nll.detach().cpu().numpy(),
            train_acc.detach().cpu().numpy(),
            train_penalty.detach().cpu().numpy(),
            test_acc.detach().cpu().numpy()
          )

      final_train_accs_eiil_erm2.append(train_acc.detach().cpu().numpy())
      final_test_accs_eiil_erm2.append(test_acc.detach().cpu().numpy())
      print('Final train acc (mean/std across restarts so far):')
      print(np.mean(final_train_accs_eiil_erm2), np.std(final_train_accs_eiil_erm2))
      print('Final test acc (mean/std across restarts so far):')
      print(np.mean(final_test_accs_eiil_erm2), np.std(final_test_accs_eiil_erm2))

      for env in rest_envs:
        print (env['cw_acc'])