In [1]:
import sys
sys.path.append("..")

import torch
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange
from tqdm import tqdm, trange
from torch.nn import functional as F
from utils.download_mnist import mnist_dataloader_test
from assembler import get_config, get_config_ebm, make_energy_model
from utils.config import show 

path = !cd .. && pwd
path = path[0]

def plotable(img):
    return rearrange(img, "b c h w -> (b c h) w ").cpu().detach().numpy()

def get_model_config(model_name):
    dataset, model, sampling, task = model_name.split("/")
    name = f"{sampling}/{task}"
    config = get_config(get_config_ebm, dataset, model, name, path=path)
    return config

def reconstruction_error(x_hat, x, reduction="mean"):
    return F.mse_loss(x_hat, x, reduction=reduction)

def experiment(config, x):
    ebm = make_energy_model(config, path=path)
    
    # The math
    x_tilde = ebm.operator(x)
    x_hat = ebm(x_tilde)
    
    # The result
    fig, axs = plt.subplots(nrows = 1, ncols = 2)
    axs[0].set_title("original")
    axs[0].imshow(plotable(x))
    axs[1].set_title("estimation")
    axs[1].imshow(plotable(x_hat))
    plt.show()

In [2]:
model_name = "mnist/vae/langevin/inpainting2"
config = get_model_config(model_name)
dm = mnist_dataloader_test(config, path=path)
gen = iter(dm)
x, y = next(gen)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [34]:
config["operator_params"]["operator"] = "CenterOcclude"
config["operator_params"]["num_measurements"] = 200
config["estimator_params"]["initalisation"] = "map_posterior"
config["estimator_params"]["num_steps_map_initaliser"] = 1
config["estimator_params"]["step_size_map_initaliser"] = 0.1
config["estimator_params"]["noise_factor"] = 0.5
config["estimator_params"]["lambda"] = 1

In [37]:
experiment(config, x)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [38]:
ebm = make_energy_model(config, path=path)
    
# The math
x_tilde = ebm.operator(x)
x_hat = ebm(x_tilde)

In [39]:
x_tilde.shape

torch.Size([1, 1, 32, 32])

In [40]:
from numpy import exp,arange, meshgrid
from pylab import meshgrid,cm,imshow,contour,clabel,colorbar,axis,title,show

# the function that I'm going to plot
@np.vectorize
def z_func(x,y):
 return ebm.energy_fn(x_tilde)(torch.tensor([[x,y]], dtype=torch.float32)).detach().numpy()

In [41]:
x_ = arange(-5.0,5.0,0.1)
y_ = arange(-5.0,5.0,0.1)
X,Y = meshgrid(x_, y_) # grid of point
Z = z_func(X, Y) # evaluation of the function on the grid

In [42]:
%matplotlib widget

In [43]:
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111, projection='3d')


# Plot a 3D surface
ax.plot_surface(X, Y, Z)

title('Importance Weight')
show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [50]:
config['estimator_params']['potential'] = "discriminator_weighted"
config['estimator_params']['lambda'] = 1
config['estimator_params']['lambda_score'] = 0.5
config['estimator_params']['discriminator_base_model'] = "mnist/gan/dcgan"
config["estimator_params"]["num_steps_finetune"] = 100
config["estimator_params"]["num_samples_for_finetune"] = 100

In [51]:
ebm2 = make_energy_model(config, path=path)
ebm2.finetune_discrimiantor()
# The math
x_tilde = ebm2.operator(x)
x_hat = ebm2(x_tilde)

In [52]:
from numpy import exp,arange, meshgrid
from pylab import meshgrid,cm,imshow,contour,clabel,colorbar,axis,title,show

# the function that I'm going to plot
@np.vectorize
def z2_func(x,y):
 return ebm2.energy_fn(x_tilde)(torch.tensor([[x,y]], dtype=torch.float32)).detach().numpy()

In [53]:
x2_ = arange(-5.0,5.0,0.1)
y2_ = arange(-5.0,5.0,0.1)
X2,Y2 = meshgrid(x_, y_) # grid of point
Z2 = z2_func(X2, Y2) # evaluation of the function on the grid

In [54]:
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111, projection='3d')


# Plot a 3D surface
ax.plot_surface(X2, Y2, Z2)

title('Importance Weight')
show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [55]:
fig, axs = plt.subplots(nrows = 1, ncols = 2)
axs[0].set_title("original")
axs[0].imshow(-Z)
axs[1].set_title("discriminator weights")
axs[1].imshow(-Z2)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …