In [4]:
import sys
sys.path.append('../')

from utils import make_new_folder, plot_norm_losses, save_input_args, \
sample_z, class_loss_fn, plot_losses, corrupt, prep_data, plot_log_losses # one_hot


import torch
from torch import optim
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn.functional import binary_cross_entropy as bce

from torchvision import transforms, datasets
from torchvision.utils import make_grid, save_image

import numpy as np

import os
from os.path import join

import argparse

from PIL import Image

import matplotlib 
matplotlib.use('Agg')
from matplotlib import pyplot as plt

from time import time

EPSILON = 1e-6

In [5]:
class generator(nn.Module):
    # initializers
    def __init__(self, d=128):
        super(generator, self).__init__()
        self.deconv1_1 = nn.ConvTranspose2d(100, d*2, 4, 1, 0)
        self.deconv1_1_bn = nn.BatchNorm2d(d*2)
        self.deconv1_2 = nn.ConvTranspose2d(10, d*2, 4, 1, 0)
        self.deconv1_2_bn = nn.BatchNorm2d(d*2)
        self.deconv2 = nn.ConvTranspose2d(d*4, d*2, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm2d(d*2)
        self.deconv3 = nn.ConvTranspose2d(d*2, d, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(d)
        self.deconv4 = nn.ConvTranspose2d(d, 1, 4, 2, 1)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input, label):
        x = F.relu(self.deconv1_1_bn(self.deconv1_1(input)))
        y = F.relu(self.deconv1_2_bn(self.deconv1_2(label)))
        x = torch.cat([x, y], 1)
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = F.tanh(self.deconv4(x))
        # x = F.relu(self.deconv4_bn(self.deconv4(x)))
        # x = F.tanh(self.deconv5(x))

        return x

def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()
        
def find_batch_z(gen, x, nz, lr, exDir, maxEpochs=100, alpha=1e-6, batchNo=0):

    #generator in eval mode
    gen.eval()

    #save the "original" images
    save_image(x.data, join(exDir, 'original_batch'+str(batchNo)+'.png'), normalize=True, nrow=10)

    #Assume the prior is Standard Normal
    pdf = torch.distributions.Normal(0, 1)

    Zinit = Variable(torch.randn(x.size(0), nz).view(-1, nz, 1, 1).cuda(), requires_grad=True)
    Yinit = torch.zeros(x.size(0), 10)
    fixed_y_ = torch.ones(x.size(0), 1)
    Yinit.scatter_(1, fixed_y_.type(torch.LongTensor), 1)
    Yinit = Yinit.view(-1, 10, 1, 1)
    Yinit = Variable(Yinit.cuda(), requires_grad=False)

    #optimizer
    optZ = torch.optim.RMSprop([Zinit], lr=lr)

    losses = {'rec': [], 'logProb': []}
    for e in range(maxEpochs):

        #reconstruction loss
        xHAT = gen.forward(Zinit, Yinit)
        recLoss = F.mse_loss(xHAT, x)

        #loss to make sure z's are Guassian
        logProb = pdf.log_prob(Zinit).mean(dim=1)  #each element of Z is independant, so likelihood is a sum of log of elements
        loss = recLoss - (alpha * logProb.mean())
        

        optZ.zero_grad()
        loss.backward()
        optZ.step()

        losses['rec'].append(recLoss.data)
        losses['logProb'].append(logProb.mean().data)

        if e%100==0:
            print('[%d] loss: %0.5f, recLoss: %0.5f, regMean: %0.5f' % (e, loss.data, recLoss.data, logProb.mean().data))
            # save_image(xHAT.data, join(exDir, 'rec'+str(e)+'.png'), normalize=True)

        #plot training losses
        if e>0:
            plot_losses(losses, exDir, e+1)
            #plot_norm_losses(losses, exDir, e+1)

    #visualise the final output
    xHAT = gen.forward(Zinit, Yinit)
    save_image(xHAT.data, join(exDir, 'rec_batch'+str(batchNo)+'.png'), normalize=True, nrow=10)

    return Zinit, recLoss.data, xHAT

In [6]:
data_dir = 'untilted_1'
batchSize = 128
maxEpochs = 200
nz = 100
imSize = 64
lr = 0.01
fSize = 64
alpha = 1e-2
#Create new subfolder for saving results and training params
exDir = 'InversionExperiments_untilted'
try:
    os.mkdir(exDir)
except:
    print('already exists')

print('Outputs will be saved to:',exDir)

im_size = 32

print('Prepare data loaders...')
transform = transforms.Compose([
        transforms.Scale(im_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])    

testLoader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=False, download=False, transform=transform),
        batch_size=batchSize, shuffle=False)

###### Create model and load parameters #####
G = generator(128)
G.weight_init(mean=0.0, std=0.02)
print('Setting cuda device')
torch.cuda.set_device(0)
G.cuda()
G.load_state_dict(torch.load('MNIST_cDCGAN_results/data/MNIST_cDCGAN_generator_param.pkl'))
print('params loaded')

#testLoader = torch.utils.data.DataLoader(testDataset, batch_size=opts.batchSize, shuffle=False)
print('Data loaders ready.')

#Find each z individually for each x
allRec = []
allX = []
sumLoss = 0
for i, data in enumerate(testLoader):
    x, y = prep_data(data, useCUDA=True)
    z, recLoss, xRec = find_batch_z(gen=G, x=x, nz=100, lr=lr, exDir=exDir, maxEpochs=maxEpochs, alpha=alpha, batchNo=i)

    allRec.append(xRec.cpu().data)
    allX.append(x.cpu().data) #incase the loader shuffles samples

allRec = np.concatenate(allRec)
allX = np.concatenate(allX)
print('allRec:', np.shape(allRec))
print('allX:', np.shape(allX))


mseLoss = np.mean((allRec - allX)**2, axis=(1,2,3))  # mean over colour channels and pixels
np.save(join(exDir, 'mseLosses_per_sample.npy'), mseLoss)
meanLoss = np.mean(mseLoss) # mean over samples
stdLoss = np.std(mseLoss)  #std over samples

f = open(join(exDir,'recError.txt'), 'w')
f.write('mean loss %0.5f' % (meanLoss))
f.write('std of loss %0.5f' % (stdLoss))
f.close()

already exists
Outputs will be saved to: InversionExperiments_untilted
Prepare data loaders...
Setting cuda device
params loaded
Data loaders ready.
[0] loss: 0.17246, recLoss: 0.15825, regMean: -1.42061


  "please use transforms.Resize instead.")


[100] loss: 0.02143, recLoss: 0.00791, regMean: -1.35170
[0] loss: 0.19403, recLoss: 0.17985, regMean: -1.41891
[100] loss: 0.02314, recLoss: 0.00949, regMean: -1.36482
[0] loss: 0.17711, recLoss: 0.16291, regMean: -1.42050
[100] loss: 0.02348, recLoss: 0.00989, regMean: -1.35870
[0] loss: 0.18917, recLoss: 0.17496, regMean: -1.42072
[100] loss: 0.02319, recLoss: 0.00960, regMean: -1.35907
[0] loss: 0.17182, recLoss: 0.15765, regMean: -1.41747
[100] loss: 0.02094, recLoss: 0.00749, regMean: -1.34496
[0] loss: 0.17878, recLoss: 0.16455, regMean: -1.42235
[100] loss: 0.02021, recLoss: 0.00670, regMean: -1.35109
[0] loss: 0.17033, recLoss: 0.15621, regMean: -1.41206
[100] loss: 0.02158, recLoss: 0.00811, regMean: -1.34724
[0] loss: 0.16644, recLoss: 0.15222, regMean: -1.42189
[100] loss: 0.02007, recLoss: 0.00657, regMean: -1.35052
allRec: (991, 1, 32, 32)
allX: (991, 1, 32, 32)
