# Importance sampling: toy example with parameters

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid

import matplotlib.pyplot as plt

Functions for generation and visualization of the image batches

In [2]:
def visualize_img_batch(batch):
    '''Visualizes image batch'''
    grid = make_grid(batch, nrow=8, padding=1, normalize=False, range=None, scale_each=False, pad_value=0)
    plt.imshow(grid.permute(1,2,0))
    plt.show()

In [3]:
def random_image_data(size=(32768, 3, 1, 1), ratio=0.5):
    '''Makes a random image batch of size (batch_size, height, width, channels) 
    with black to white ratio of value ratio
    '''
    idx = torch.randperm(size[0])[:int(ratio*size[0])]
    image_batch = torch.zeros(size) + 0.2 # to make light gray
    image_batch[idx] = 1 - 0.2 # to make light gray 
    return image_batch

The dataset class

In [4]:
class BlackWhiteDataset(Dataset):
    '''The dataloader for the black and white images'''
    def __init__(self, weight_network):
        self.dataset = random_image_data()
        
        self.weight_network = weight_network

    def __len__(self):
        return len(self.dataset)
    
    def accept_sample(self, weight_network, img):
        # Returns True if the image is accepted, False if rejected
        weight = weight_network(img)
        return bool(list(torch.utils.data.sampler.WeightedRandomSampler([1-weight, weight], 1))[0])
    
    def __getitem__(self, idx):
        # Random permutation on the dataset order (is this equivalent to uniform sampling?)
        all_idx = torch.randperm(len(dataset))
        
        # Loop through the samples and return once accepted
        for i in all_idx:
            accept = self.accept_sample(self.weight_network, self.dataset[i])
            if accept:
                return self.dataset[i]

The weight network with parameters

In [5]:
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(3, 1)
        self.fc2 = nn.Linear(1, 1)
        
        self.fc1.weight.data.fill_(0.5) # This as initialization because when the weights are too small
        self.fc2.weight.data.fill_(0.5) # no images are sampled

    def forward(self, x):
        h1 = F.relu(self.fc1(x.view(-1, 3)))
        out = torch.sigmoid(self.fc2(h1))
        return out

weight_network = Net()

In [6]:
dataset = BlackWhiteDataset(weight_network)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)

In [7]:
img = dataset.dataset[0:2]
w = weight_network(img)
print(w)

tensor([[0.7085],
        [0.7921]], grad_fn=<SigmoidBackward>)


Training loop

In [8]:
import torch.optim as optim

criterion = nn.MSELoss()
optimizer = optim.Adam(weight_network.parameters(), lr=0.01)

In [10]:
for i, data in enumerate(dataloader):
    labels = data.mean(1).view(-1, 1)
    
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = weight_network(data)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    # print statistics
    if i % 50 == 0:
        print('step', i, 'loss: ', loss.item())
#         print('outputs:', outputs[0].item(), 'labels:', labels[0].item())
#         print()

step 0 loss:  6.744604874597826e-15
step 50 loss:  6.744604874597826e-15
step 100 loss:  4.496403249731884e-15
step 150 loss:  0.0
step 200 loss:  0.0
step 250 loss:  4.496403249731884e-15
step 300 loss:  0.0
step 350 loss:  2.248201624865942e-15
step 400 loss:  8.992806499463768e-15
step 450 loss:  4.496403249731884e-15
step 500 loss:  2.248201624865942e-15
step 550 loss:  2.248201624865942e-15
step 600 loss:  8.992806499463768e-15
step 650 loss:  6.744604874597826e-15
step 700 loss:  2.248201624865942e-15
step 750 loss:  4.496403249731884e-15
step 800 loss:  4.496403249731884e-15
step 850 loss:  6.744604874597826e-15
step 900 loss:  2.248201624865942e-15
step 950 loss:  0.0
step 1000 loss:  2.248201624865942e-15
step 1050 loss:  0.0
step 1100 loss:  4.496403249731884e-15
step 1150 loss:  0.0
step 1200 loss:  2.248201624865942e-15
step 1250 loss:  2.248201624865942e-15
step 1300 loss:  6.744604874597826e-15
step 1350 loss:  4.08006961549745e-15
step 1400 loss:  1.3600232051658168e-15


KeyboardInterrupt: 