In [5]:
import sys
sys.path.append("..")

import torch
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange
from tqdm import tqdm, trange
from torch.nn import functional as F
from utils.download_mnist import mnist_dataloader_test
from assembler import get_config, get_config_ebm, make_energy_model
from utils.config import show 
from torchvision.utils import make_grid

path = !cd .. && pwd
path = path[0]

def plotable(img):
    return rearrange(img, "b c h w -> (b c h) w ").cpu().detach().numpy()

def get_model_config(model_name):
    dataset, model, sampling, task = model_name.split("/")
    name = f"{sampling}/{task}"
    config = get_config(get_config_ebm, dataset, model, name, path=path)
    return config

def reconstruction_error(x_hat, x, reduction="mean"):
    return F.mse_loss(x_hat, x, reduction=reduction)

In [6]:
model_name = "mnist/vae/langevin/inpainting"
config = get_model_config(model_name)

In [7]:
config["operator_params"]["operator"] = "CompressedSensing"
config["operator_params"]["num_measurements"] = 100
config["exp_params"]["batch_size"] = 100

In [8]:
dm = mnist_dataloader_test(config, path=path)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [19]:
from torch import nn, Tensor

class Discriminator(nn.Module):

    def __init__(self, config) -> None:
        """
        Args:
            feature_maps: Number of feature maps to use
            image_channels: Number of channels of the images from the dataset
        """
        super().__init__()
        feature_maps = 64
        image_channels = 1

        self.disc = nn.Sequential(
            self._make_disc_block(image_channels, feature_maps, batch_norm=False),
            self._make_disc_block(feature_maps, feature_maps * 2),
            self._make_disc_block(feature_maps * 2, feature_maps * 4),
            self._make_disc_block(feature_maps * 4, feature_maps * 8),
            self._make_disc_block(feature_maps * 8, 1, kernel_size=2, stride=1, padding=0, last_block=True),
        )

    @staticmethod
    def _make_disc_block(
        in_channels: int,
        out_channels: int,
        kernel_size: int = 4,
        stride: int = 2,
        padding: int = 1,
        bias: bool = False,
        batch_norm: bool = True,
        last_block: bool = False,
    ) -> nn.Sequential:
        if not last_block:
            disc_block = nn.Sequential(
                torch.nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)),
                nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            disc_block = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
            )

        return disc_block

    def forward(self, x: Tensor) -> Tensor:
        return torch.sigmoid(self.disc(x)/2).view(-1, 1).squeeze(1)
    
    def logit(self, x: Tensor) -> Tensor:
        return self.disc(x).view(-1, 1).squeeze(1)

In [20]:
config['estimator_params']["initalisation"] = "random"
config['estimator_params']['potential'] = "mse"
config['estimator_params']['lambda'] = 1
config["estimator_params"]["noise_factor"] = 0.5
config["estimator_params"]["burn_in"] = 120
config['estimator_params']['initalisation'] = "map_posterior"
config['estimator_params']['num_steps_map_initaliser'] = 100
config['estimator_params']['step_size_map_initaliser'] = 0.1

In [21]:
ebm = make_energy_model(config, path=path)
ebm = ebm.to("cuda")

In [26]:
discriminator = Discriminator(config)
discriminator.to("cuda")

Discriminator(
  (disc): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): Identity()
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.

In [27]:
from torch import optim 

criterion = nn.BCELoss()
optimizer = optim.SGD(discriminator.parameters(), lr=0.01, momentum=0.9)
for real, y in tqdm(dm):
    optimizer.zero_grad()
    real = real.to("cuda")

    real_pred = discriminator(real)
    real_gt = torch.ones_like(real_pred)
    real_loss = criterion(real_pred, real_gt)

    # Train with fake
    fake_pred = discriminator(ebm.model.model.get_samples(100))
    fake_gt = torch.zeros_like(fake_pred)
    fake_loss = criterion(fake_pred, fake_gt)

    disc_loss = real_loss + fake_loss
    disc_loss.backward()
    optimizer.step()

100%|██████████| 100/100 [00:02<00:00, 39.87it/s]


In [28]:
discriminator(ebm.model.model.get_samples(100))

tensor([0.1417, 0.3926, 0.0368, 0.3020, 0.0080, 0.1095, 0.1031, 0.0538, 0.0308,
        0.0041, 0.0348, 0.2504, 0.1530, 0.0821, 0.0271, 0.0609, 0.0253, 0.1048,
        0.0735, 0.0243, 0.1014, 0.2693, 0.0857, 0.0310, 0.0235, 0.0479, 0.1118,
        0.1361, 0.0270, 0.0423, 0.0643, 0.0183, 0.1400, 0.2467, 0.0820, 0.1654,
        0.0231, 0.0880, 0.2092, 0.0033, 0.0810, 0.1615, 0.3496, 0.0727, 0.1876,
        0.0381, 0.2964, 0.0403, 0.1021, 0.1061, 0.0423, 0.2941, 0.1242, 0.0523,
        0.0120, 0.0899, 0.0772, 0.0680, 0.2059, 0.1652, 0.1421, 0.0964, 0.0092,
        0.0828, 0.0895, 0.1778, 0.3277, 0.1969, 0.1697, 0.1660, 0.0059, 0.0613,
        0.1269, 0.2362, 0.0867, 0.0210, 0.1762, 0.0082, 0.0957, 0.0427, 0.0429,
        0.0372, 0.1573, 0.0100, 0.0547, 0.0361, 0.0975, 0.0229, 0.0450, 0.1761,
        0.0260, 0.0714, 0.2770, 0.0647, 0.1170, 0.0757, 0.1751, 0.0325, 0.0396,
        0.2429], device='cuda:0', grad_fn=<SqueezeBackward1>)

In [29]:
discriminator(real)

tensor([0.9938, 0.9948, 0.9992, 0.9979, 0.9995, 0.9987, 0.9979, 0.9966, 0.9955,
        0.9968, 0.9987, 0.9975, 0.9989, 0.9987, 0.9938, 0.9978, 0.9979, 0.9946,
        0.9987, 0.9986, 0.9976, 0.9984, 0.9977, 0.9982, 0.9981, 0.9941, 0.9982,
        0.9992, 0.9990, 0.9991, 0.9814, 0.9973, 0.9990, 0.9991, 0.9935, 0.9978,
        0.9938, 0.9985, 0.9953, 0.9966, 0.9963, 0.9955, 0.9971, 0.9840, 0.9965,
        0.9984, 0.9979, 0.9991, 0.9835, 0.9944, 0.9989, 0.9977, 0.9728, 0.9916,
        0.9947, 0.9981, 0.9970, 0.9977, 0.9992, 0.9969, 0.9986, 0.9973, 0.9909,
        0.9962, 0.9944, 0.9946, 0.9983, 0.9927, 0.9987, 0.9987, 0.9884, 0.9990,
        0.9985, 0.9976, 0.9987, 0.9982, 0.9967, 0.9973, 0.9979, 0.9965, 0.9929,
        0.9944, 0.9816, 0.9702, 0.9977, 0.9930, 0.9983, 0.9976, 0.9916, 0.9948,
        0.9986, 0.9919, 0.9980, 0.9804, 0.9966, 0.9944, 0.9919, 0.9980, 0.9951,
        0.9979], device='cuda:0', grad_fn=<SqueezeBackward1>)

In [26]:
torch.save(discriminator.state_dict(), 'non_gan_discriminator.ckpt')