<a href="https://colab.research.google.com/github/MichelleAppel/Importance_sampling/blob/master/toy_examples/n_centered_moment_MNIST_AE_bastian.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid

from itertools import chain

import numpy as np
import math
import itertools

from scipy import signal

import matplotlib.pyplot as plt

# Toy example MNIST with PCA embeddings (one-sided)

In [2]:
def MNIST_data(distribution=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], one_hot_labels=False):
    # distribution: distribution for each label
    # returns (data, labels) for MNIST with all the classes; zeroes and ones have the given distribution

    MNIST = torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
    
    distribution /= np.array(distribution).sum()
    
    # bool mask for each label
    idxm = [MNIST.targets==label for label in range(len(distribution))]
    min_idx = torch.Tensor([sum(i) for i in idxm]).min().int().item()
    tot_idx = torch.Tensor([sum(i) for i in idxm]).sum().int().item()

    # list of indices for each label
    idx = [np.where(idxm[label])[0] for label in range(len(distribution))]
    
    len_idx = torch.Tensor([len(i) for i in idx])
    wanted_len = torch.Tensor([d*tot_idx for d in distribution])
    
    min_class = ((len_idx/wanted_len).argmin())
    min = ((len_idx/wanted_len).min().item())

    if min < 1:
      distribution *= min

    valid_idx = []
    class_len = [math.floor(d*tot_idx) for d in distribution]
    for label, length in enumerate(class_len):
      valid_idx += idx[label][:length].tolist()

    valid_idx = np.array(valid_idx)
    np.random.shuffle(valid_idx)

    # assign the new data and labels to the dataset
    if one_hot_labels:
      MNIST.targets = torch.nn.functional.one_hot(MNIST.targets[valid_idx])
    else:
      MNIST.targets = MNIST.targets[valid_idx]
    MNIST.data = MNIST.data[valid_idx]

    return MNIST 

In [3]:
class MNISTDataset(Dataset):
    '''The dataset for the MNIST binary data
    '''
    def __init__(self, distribution=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], one_hot_labels=False):

        self.distribution = distribution        
        self.dataset = MNIST_data(distribution=self.distribution, one_hot_labels=one_hot_labels)
        self.example_imgs = self.example()
        
        # to take out in real applications
        self.unique_labels = torch.unique(self.dataset.targets)[:len(distribution)]

    def example(self):
        '''
        Returns an example from each digit in the domain
        
        '''
        labels = self.dataset.targets
        data = self.dataset.data
        '''img0 = data[labels==0][0].unsqueeze(0)
        img1 = data[labels==1][0].unsqueeze(0)
        img2 = data[labels==2][0].unsqueeze(0)
        img3 = data[labels==3][0].unsqueeze(0)
        img4 = data[labels==4][0].unsqueeze(0)
        img5 = data[labels==5][0].unsqueeze(0)
        img6 = data[labels==6][0].unsqueeze(0)
        img7 = data[labels==7][0].unsqueeze(0)
        img8 = data[labels==8][0].unsqueeze(0)
        img9 = data[labels==9][0].unsqueeze(0)
        ex = torch.cat((img0, img1, img2, img3, img4, img5, img6, img7, img8, img9), 0)'''
        img = []
        for label in torch.unique(labels)[:len(self.distribution)]:
          img = img + [data[labels==label][0].unsqueeze(0)]
        ex = torch.cat(img, 0)
              
        return ex

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):      
        return self.dataset[idx]

In this example we have 2 domains with 10 classes. Only the first two classes have a different probability.

In [4]:
batch_size = 256

# Settings for domain A
dataset_A = MNISTDataset(distribution=[0.05, 0.15, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
# dataset_A = MNISTDataset(distribution=[0.2, 0.8])
dataloader_A = DataLoader(dataset_A, batch_size, shuffle=True)

# Settings for domain B
dataset_B = MNISTDataset(distribution=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
# dataset_B = MNISTDataset(distribution=[0.5, 0.5])
dataloader_B = DataLoader(dataset_B, batch_size, shuffle=True)

In [5]:
class WeightNet(nn.Module):
    '''A simple network that predicts the importances of the samples'''

    def __init__(self):
        super(WeightNet, self).__init__()
        self.softmax = nn.Softmax(dim=0)

        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 40)
        self.fc2 = nn.Linear(40, 1)
        
    def forward(self, x):
        h1 = torch.sigmoid(F.max_pool2d(self.conv1(x), 2))
        h2 = torch.sigmoid(F.max_pool2d(self.conv2(h1), 2))
        h3 = h2.view(-1, 320)
        h4 = torch.sigmoid(self.fc1(h3))
        out = self.fc2(h4)
        return self.softmax(out), out

In [6]:
class Encoder_MNIST(nn.Module):
  def __init__(self):
    super(Encoder_MNIST, self).__init__()

    self.conv1 = nn.Conv2d(1, 8, 3, stride=2, padding=1)
    self.conv2 = nn.Conv2d(8, 16, 3, stride=2, padding=1)
    self.conv3 = nn.Conv2d(16, 32, 3, stride=2, padding=1)
    self.fc = nn.Linear(32*4*4, 16)

  def forward(self, x):
    x = nn.ReLU()(self.conv1(x))
    x = nn.ReLU()(self.conv2(x))
    x = nn.ReLU()(self.conv3(x))
    x = nn.Flatten()(x)
    x = self.fc(x)
    return x

class Decoder_MNIST(nn.Module):
  def __init__(self):
    super(Decoder_MNIST, self).__init__()

    self.fc = nn.Linear(16, 32*4*4)
    self.conv1 = nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1)
    self.conv2 = nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1)
    self.conv3 = nn.ConvTranspose2d(8, 1, 3, stride=2, padding=0)

  def forward(self, x):
    x = nn.ReLU()(self.fc(x))
    x = x.reshape(-1, 32, 4, 4)
    x = nn.ReLU()(self.conv1(x))
    x = nn.ReLU()(self.conv2(x))
    x = self.conv3(x, output_size=(28, 28))
    return x

In [7]:
def n_centered_moment(x, w, n):
  if n > 1:
    c = n_centered_moment(x, w, 1)
  else:
    c = 0
  return (((x - c)**n)*w).sum(0)

In [8]:
def raw_to_onehot(labels, n_classes):
  onehot = torch.zeros(len(labels), n_classes)
  for c in labels.unique():
    onehot[labels==c, c] = 1
  return onehot

In [10]:
# Initialize the networks
weight_network = WeightNet().cuda()
encoder = Encoder_MNIST().cuda()
decoder = Decoder_MNIST().cuda()

# Initialize the optimizers
lr = 0.01
optimizer_w = optim.Adam(weight_network.parameters(), lr=lr)
optimizer_ae = optim.Adam(chain(encoder.parameters(), decoder.parameters()), lr=lr)

criterion_w = nn.MSELoss()
criterion_ae = nn.BCELoss()

# For storing results
losses_w = []
losses_ae = []

means_A = []
means_B = []

vars_A = []
vars_B = []

moments_A = []
moments_B = []

example_importances_A = []

n = 2 # n-centered moment
n_classes = len(dataset_A.distribution)

for epoch in range(1):
    for i, (real_A, real_B) in enumerate(zip(dataloader_A, dataloader_B)):
        
        img_A = real_A[0].cuda()
        img_B = real_B[0].cuda()

        label_A = raw_to_onehot(real_A[1], n_classes).cuda()
        label_B = raw_to_onehot(real_B[1], n_classes).cuda()

        # The embeddings
        e_A = encoder(img_A)
        e_B = encoder(img_B)

        reconstructed_A = decoder(e_A)
        reconstructed_B = decoder(e_B)

        # The weighting process
        w_A = weight_network(img_A)[0]
        w_B = 1/len(img_B)

        # The loss function --------------------------------------------------------------------------------
        n_centered_moment_A = n_centered_moment(e_A.detach(), w_A, n)
        n_centered_moment_B = n_centered_moment(e_B.detach(), w_B, n)
        loss_w = criterion_w(n_centered_moment_A, n_centered_moment_B)

        loss_ae_A = criterion_ae(reconstructed_A, img_A)
        loss_ae_B = criterion_ae(reconstructed_B, img_B)
        loss_ae = loss_ae_A + loss_ae_B
        # ---------------------------------------------------------------------------------------------------

        # Backward
        optimizer_w.zero_grad()
        loss_w.backward()
        optimizer_w.step() 

        optimizer_ae.zero_grad()
        loss_ae.backward()
        optimizer_ae.step()     

        # Store values --------------------------------------------------------------------------------------
        moments_A += [n_centered_moment_A.cpu().detach().numpy()]
        moments_B += [n_centered_moment_B.cpu().detach().numpy()]

        means_A += [n_centered_moment(label_A, w_A, 1).detach().cpu().numpy()]
        means_B += [n_centered_moment(label_B, w_B, 1).detach().cpu().numpy()]

        vars_A += [n_centered_moment(label_A, w_A, 2).detach().cpu().numpy()]
        vars_B += [n_centered_moment(label_B, w_B, 2).detach().cpu().numpy()]    

        losses_w += [loss_w.item()]
        losses_ae += [loss_ae.item()]

        w_a = weight_network(dataset_A.example_imgs.cuda().unsqueeze(1).float())
        example_importances_A += [[importance.item() for importance in w_a[0]]] # Store examples in a list

        # ---------------------------------------------------------------------------------------------------

        # Print statistics
        if i % 50 == 0:
            print('epoch', epoch, 'step', i, 'loss_w: ', loss_w.item(), 'loss_ae', loss_ae.item())
            
        if i % 10000 == 0 and i != 0:
            break

In [None]:
moments_A = torch.Tensor(moments_A)
moments_B = torch.Tensor(moments_B)
means_A = torch.Tensor(means_A)
means_B = torch.Tensor(means_B)
vars_A = torch.Tensor(vars_A)
vars_B = torch.Tensor(vars_B)

## Results
In the plot below we see that the loss of W is going down.

In [None]:
plt.figure(figsize=(10,6))
plt.title('Losses over iterations')
plt.xlabel('Training iterations')
plt.ylabel('Loss')
# plt.yscale('symlog')
smoothed_losses_w = signal.savgol_filter(losses_w,101,3)
plt.plot(smoothed_losses_w)
plt.legend(['W'])
plt.show()

In [None]:
plt.figure(figsize=(10,6))
plt.title('Losses over iterations')
plt.xlabel('Training iterations')
plt.ylabel('Loss')
# plt.yscale('symlog')
smoothed_losses_ae = signal.savgol_filter(losses_ae,101,3)
plt.plot(smoothed_losses_ae)
plt.legend(['AE'])
plt.show()

The plot below shows that the classes 0 and 1 in domain A are weighted to match the uniform distribution in domain B.

In [None]:
plt.figure(figsize=(10,6))
plt.title('Moments A - measure of domain A after weighting (smoothed)')
plt.xlabel('Training iterations')
plt.ylabel('Loss')
plt.ylim(ymax = moments_A.max(), ymin = moments_A.min())
smoothed_Lmin = signal.savgol_filter(moments_A,101,3,axis=0)
plt.plot(smoothed_Lmin)
# plt.legend(np.arange(len(moments_A[0])))
plt.show()

plt.figure(figsize=(10,6))
plt.title('Moments B - Measure of domain B (smoothed)')
plt.xlabel('Training iterations')
plt.ylabel('Loss')
plt.ylim(ymax = moments_A.max(), ymin = moments_A.min())
smoothed_Lplus = signal.savgol_filter(moments_B,101,3,axis=0)
plt.plot(smoothed_Lplus)
# plt.legend(np.arange(len(moments_A[0])))
plt.show()

In [None]:
plt.figure(figsize=(10,6))
plt.title('Mean of domain A after weighting (smoothed)')
plt.xlabel('Training iterations')
plt.ylabel('Loss')
plt.ylim(ymax = means_A.max(), ymin = means_A.min())
smoothed_means_A = signal.savgol_filter(means_A,101,3,axis=0)
plt.plot(smoothed_means_A)
# plt.legend(np.arange(n_classes))
plt.show()

plt.figure(figsize=(10,6))
plt.title('Mean of domain B (smoothed)')
plt.xlabel('Training iterations')
plt.ylabel('Loss')
plt.ylim(ymax = means_A.max(), ymin = means_A.min())
smoothed_means_B = signal.savgol_filter(means_B,101,3,axis=0)
plt.plot(smoothed_means_B)
# plt.legend(np.arange(n_classes))
plt.show()

In [None]:
plt.figure(figsize=(10,6))
plt.title('Var of domain A after weighting (smoothed)')
plt.xlabel('Training iterations')
plt.ylabel('Loss')
plt.ylim(ymax = vars_A.max(), ymin = vars_A.min())
smoothed_vars_A = signal.savgol_filter(vars_A,101,3,axis=0)
plt.plot(smoothed_vars_A)
plt.legend(np.arange(n_classes))
plt.show()

plt.figure(figsize=(10,6))
plt.title('Var of domain B (smoothed)')
plt.xlabel('Training iterations')
plt.ylabel('Loss')
plt.ylim(ymax = vars_A.max(), ymin = vars_A.min())
smoothed_vars_B = signal.savgol_filter(vars_B,101,3,axis=0)
plt.plot(smoothed_vars_B)
plt.legend(np.arange(n_classes))
plt.show()

In [None]:
plt.figure(figsize=(10,6))
plt.title('Mean and var of domain A after weighting (smoothed)')
plt.xlabel('Training iterations')
plt.ylabel('Loss')
plt.ylim(ymax = (smoothed_means_A + smoothed_vars_A).max(), ymin = (smoothed_means_A - smoothed_vars_A).min())
plt.plot(smoothed_means_A)
for i in range(smoothed_means_A.shape[1]):
  plt.fill_between(np.arange(len(smoothed_means_A)), (smoothed_means_A - smoothed_vars_A)[:, i], (smoothed_means_A + smoothed_vars_A)[:, i], alpha=0.1)
plt.legend(np.arange(n_classes))
plt.show()

plt.figure(figsize=(10,6))
plt.title('Mean and var of domain A after weighting (smoothed)')
plt.xlabel('Training iterations')
plt.ylabel('Loss')
plt.ylim(ymax = (smoothed_means_A + smoothed_vars_A).max(), ymin = (smoothed_means_A - smoothed_vars_A).min())
plt.plot(smoothed_means_B)
for i in range(smoothed_means_B.shape[1]):
  plt.fill_between(np.arange(len(smoothed_means_B)), (smoothed_means_B - smoothed_vars_B)[:, i], (smoothed_means_B + smoothed_vars_B)[:, i], alpha=0.1)
plt.legend(np.arange(n_classes))
plt.show()

Here you see the weights assigned to the classes in domain A. As expected, 0 gets a large weight as it is underrepresented in domain A, and 1 gets a small weight as it is overrepresented in domain A.

In [None]:
plt.figure(figsize=(10,6))
plt.title('Assigned importances for the classes in domain A over the course of training')
smoothed_importances_A = signal.savgol_filter(example_importances_A,101,3,axis=0)
plt.plot(smoothed_importances_A)
plt.legend(np.arange(n_classes))
plt.ylabel('Assigned importance')
plt.xlabel('Training iterations')
plt.show()