In [1]:
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import argparse
import numpy as np
import torch
from torchvision import datasets
from torch import nn, optim, autograd
import torch.nn.functional as f

parser = argparse.ArgumentParser(description='Colored MNIST')
parser.add_argument('--hidden_dim', type=int, default=256)
parser.add_argument('--l2_regularizer_weight', type=float,default=0.001)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--n_restarts', type=int, default=10)
parser.add_argument('--penalty_anneal_iters', type=int, default=100)
parser.add_argument('--penalty_weight', type=float, default=10000.0)
parser.add_argument('--steps', type=int, default=501)
parser.add_argument('--grayscale_model', action='store_true')
flags = parser.parse_args(args=[])


ERMhidden_dim= 256
  
mnist = datasets.MNIST('~/datasets/mnist', train=True, download=True)
mnist_train = (mnist.data[:50000], mnist.targets[:50000])
mnist_val = (mnist.data[50000:], mnist.targets[50000:])


rng_state = np.random.get_state()
np.random.shuffle(mnist_train[0].numpy())
np.random.set_state(rng_state)
np.random.shuffle(mnist_train[1].numpy())


def make_high_environment(images, labels,color): #0~4 ->0, 5-9 ->1, images[:][0]:red. images[:][1]:green, 0-->color
    images = torch.stack([images,images], dim= 1)
    # for i in range(len(labels)):
    #     if labels[i] < 5:
    #         labels[i]=0
    #     else:
    #         labels[i]=1
    labels = np.array ([0 if c<5 else 1 for c in labels])
    if color == 0:
        images[torch.tensor(range(len(images))), 1-labels, :, :] *= 0
    else:
        images[torch.tensor(range(len(images))), labels, :, :] *= 0
    return {
      'images': (images.float() / 255.),
      'labels':labels }



henvs =[make_high_environment(mnist_train[0][::3], mnist_train[1][::3],0),
       make_high_environment(mnist_train[0][1::3], mnist_train[1][1::3],1)]





def make_low_environment(images, labels,color): #0~9b labels, image[:][0]:red, images[:][1]:green, 1~5 ---> color,
    images = torch.stack([images,images], dim=1)
    blabels=labels.clone()
    print(labels)
    for i in range(len(blabels)):
        if blabels[i] < 5:
            blabels[i]=0
        else:
            blabels[i]=1
    if color == 0:
        images[torch.tensor(range(len(images))), 1-blabels, :, :] *= 0
    else:
        images[torch.tensor(range(len(images))), blabels, :, :] *= 0
    return {
      'images': (images.float() / 255.),
      'labels':labels }


lenvs =make_low_environment(mnist_train[0][2::3], mnist_train[1][2::3],0)


#Emprical risk minimization term

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        lin1 = nn.Linear(2 * 28 * 28, ERMhidden_dim)
        lin2 = nn.Linear(ERMhidden_dim, ERMhidden_dim)
        lin3 = nn.Linear(ERMhidden_dim, 9)
        for lin in [lin1, lin2, lin3]:
            nn.init.xavier_uniform_(lin.weight)
            nn.init.zeros_(lin.bias)
        self._main = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True), lin3)
    def forward(self, input):
        out = self._main(input)
        out=out.reshape(-1,9)
        m = nn.Softmax(dim=1)
        return m(out)
    

    
#IP-extracting term

class IPE(nn.Module):
    def __init__(self):
        super(IPE,self).__init__()
        
        

  # Build environment





Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /Users/Ryuta/datasets/mnist/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting /Users/Ryuta/datasets/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /Users/Ryuta/datasets/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /Users/Ryuta/datasets/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting /Users/Ryuta/datasets/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /Users/Ryuta/datasets/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /Users/Ryuta/datasets/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting /Users/Ryuta/datasets/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /Users/Ryuta/datasets/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /Users/Ryuta/datasets/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting /Users/Ryuta/datasets/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /Users/Ryuta/datasets/mnist/MNIST/raw
Processing...


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Done!
tensor([5, 0, 8,  ..., 6, 6, 1])
