In [3]:
# Imports
import torch
import numpy as np
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.parallel
import torchvision.utils as vutils
import torch.optim as optim

import matplotlib.pyplot as plt

import glob

# custom
from VanillaNet import VanillaNet

In [4]:
# constants and declarations
image_size = 64 # 64x64px
dimension = image_size ** 2
channels = 1 # grey scale

device = "cuda:3" if torch.cuda.is_available() else "cpu"

path = "" # need to be specified

# these are 10% and 90% quantiles respectively of pixel value ranges of samples created by Langevin Sampling
# for output, these are the lower and upper limits because Langevin samples tend to contain outliers which
# distort screen output
lower, upper = -1.0038, 0.2826

noise_scale = 1.5 # scale factor for uniform [-1,1] noise

learning_rate = 6e-4 # learning rate for reconstruction optimizer
regularizer = 0.005 # scaling factor in front of regularization term
# this is an additional scaling factor in front of the regularization term and represents the inverse of the typical energy
# of samples from MNIST and Langevin Sampling respectively
# this is meant to kind of renormalize the regularization term to 1 such that the "regularizer" constant gets comparable to
# the data discrepancy
scale = 1e-8
iteration_count = 4000 # total count of iterations for optimization

output_count = 40 # output iteration step size

rows, cols = 1, 1 # image grid for output
width, height = 6, 6 # figure size
padding = 2 # image frame padding

In [8]:
# model loading
file = glob.glob("model/mnist_vanilla_net.pth")[0]
mnistModel = (file, torch.load(file, map_location=device))
mnist = VanillaNet(1,image_size).to(device)
mnist.load_state_dict(mnistModel[1]['model'])
print(mnistModel[0])

model/mnist_vanilla_net.pth


In [None]:
# get original image
dataset = datasets.MNIST(
    root=path,
    train=False,
    download=False,
    transform=transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])
)
dataloaderMNIST = torch.utils.data.DataLoader(dataset, batch_size=rows*cols, shuffle=True, drop_last=True)

dataMNIST = next(iter(dataloaderMNIST))
original = dataMNIST[0]

# show image
showimg = vutils.make_grid(original, padding=padding, normalize=True, nrow=rows, value_range=(lower,upper), scale_each=False)
fig, ax = plt.subplots()
fig.set_size_inches(width, height)
plt.axis("off")
plt.imshow(np.transpose(showimg,(1,2,0)))

In [None]:
# get noise
noise = torch.randn([cols*rows, 1, image_size, image_size], device=device)
noise_sample = noise_scale * noise.uniform_(-1, 1).cpu()

# show noise
showimg = vutils.make_grid(noise_sample, padding=padding, normalize=True, nrow=rows, value_range=(lower,upper), scale_each=False)
fig, _ = plt.subplots()
fig.set_size_inches(width, height)
plt.axis("off")
plt.imshow(np.transpose(showimg,(1,2,0)))

In [None]:
# get corrupted
corrupted = original+noise_sample

# show corrupted
showimg = vutils.make_grid(corrupted, padding=padding, normalize=True, nrow=rows, value_range=(lower,upper), scale_each=False)
fig, _ = plt.subplots()
fig.set_size_inches(width, height)
plt.axis("off")
plt.imshow(np.transpose(showimg,(1,2,0)))

In [None]:
# preparations for optimization
corrupted_copy = corrupted.clone().to(device)
x = torch.nn.Parameter(corrupted_copy, requires_grad=True)

x_orig = original.clone().to(device)
x_orig.requires_grad = False

for p in mnist.parameters():
    p.requires_grad = False

# define loss and optimizer
loss = nn.MSELoss()
optimizer = torch.optim.Adam([x], lr=learning_rate)

# optimization
for k in range(iteration_count):
    # actual loss consists of data discrepancy and regularization term
    y = loss(x_orig, x) + regularizer * scale * mnist(x)

    # Backpropagation
    optimizer.zero_grad()
    y.backward()
    optimizer.step()

    # output
    if k % output_count == 0:
        print(y)

In [None]:
# final results
data = torch.zeros((3,1,64,64))
data[0] = original
data[1] = corrupted
data[2] = x.cpu()
showimg = vutils.make_grid(data, padding=padding, normalize=True, nrow=3*rows*cols, value_range=(lower,upper), scale_each=False)
fig, ax = plt.subplots()
fig.set_size_inches(width*3, height)
plt.axis("off")
plt.title("Original | corrupted | reconstructed")
plt.imshow(np.transpose(showimg,(1,2,0)))