In [None]:
# Imports
import torch
import numpy as np
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.parallel
import torchvision.utils as vutils
import torch.optim as optim

import matplotlib.pyplot as plt

import glob

# custom
from VanillaNet import VanillaNet

In [None]:
# constants and declarations
image_size = 64 # 64x64px
dimension = image_size ** 2
channels = 1 # grey scale

device = "cuda:3" if torch.cuda.is_available() else "cpu"

path = "" # need to be specified

# these are 10% and 90% quantiles respectively of pixel value ranges of samples created by Langevin Sampling
# for output, these are the lower and upper limits because Langevin samples tend to contain outliers which
# distort screen output
lower, upper = -1.0038, 0.2826

noise_scale = 1.5 # scale factor for uniform [-1,1] noise

learning_rate = 5e-3 # learning rate for reconstruction optimizer
weight_decay = 0.0 # weight decay for optimizer

iteration_count = 2500 # total count of iterations for optimization

output_count = 40 # output iteration step size

rows, cols = 1, 1 # image grid for output
width, height = 6, 6 # figure size
padding = 2 # image frame padding

In [None]:
# model loading
file = glob.glob("model/mnist_vanilla_net.pth")[0]
mnistModel = (file, torch.load(file, map_location=device))
mnist = VanillaNet(1,image_size).to(device)
mnist.load_state_dict(mnistModel[1]['model'])
print(mnistModel[0])

In [None]:
# get original image
dataset = datasets.MNIST(
    root=path,
    train=False,
    download=False,
    transform=transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])
)
dataloaderMNIST = torch.utils.data.DataLoader(dataset, batch_size=rows*cols, shuffle=True, drop_last=True)

dataMNIST = next(iter(dataloaderMNIST))
original = dataMNIST[0]

# show image
showimg = vutils.make_grid(original, padding=padding, normalize=True, nrow=rows, value_range=(lower,upper), scale_each=False)
fig, ax = plt.subplots()
fig.set_size_inches(width, height)
plt.axis("off")
plt.imshow(np.transpose(showimg,(1,2,0)))

In [None]:
# get mask
mask = torch.ones_like(original, device=device)
mask[0,0,10:35,15:55] = 0

# show mask
showimg = vutils.make_grid(mask.cpu(), padding=padding, normalize=True, nrow=rows, value_range=(lower,upper), scale_each=False)
fig, _ = plt.subplots()
fig.set_size_inches(width, height)
plt.axis("off")
plt.imshow(np.transpose(showimg,(1,2,0)))

In [None]:
# get masked
masked_image = original * mask

# show masked
showimg = vutils.make_grid(masked_image.cpu(), padding=padding, normalize=True, nrow=rows, value_range=(lower,upper), scale_each=False)
fig, _ = plt.subplots()
fig.set_size_inches(width, height)
plt.axis("off")
plt.imshow(np.transpose(showimg,(1,2,0)))

In [None]:
# define restriced model: only input inside mask will be considered
class Surrogate(nn.Module):
    def __init__(self, net: VanillaNet, mask: torch.tensor, values: torch.tensor):
        super(Surrogate, self).__init__()

        self.net = net
        self.mask = mask
        self.values = values

    def forward(self, x):
        return self.net((1-self.mask) * x + self.mask * self.values)

In [None]:
# preparations for optimization
noise = torch.randn([cols*rows, 1, image_size, image_size], device=device)
start_image = original.clone() * mask.clone() + noise_scale * noise.uniform_(-1, 1) * (1-mask)
x = torch.nn.Parameter(start_image, requires_grad=True)
mask.requires_grad=False
original.requires_grad=False

surrogate = Surrogate(mnist, mask, original).to(device)

for p in surrogate.parameters():
    p.requires_grad = False

optimizer = torch.optim.Adam([x], lr=learning_rate, weight_decay = weight_decay)
for k in range(iteration_count):
    y = surrogate(x)

    # Backpropagation
    optimizer.zero_grad()
    y.backward()
    optimizer.step()

    if k % output_count == 0:
        print(y)