In [1]:
from models.models import DiffusionNet, LadderVAE
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image

from typing import Tuple
import math

device: cpu


In [2]:
gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if gpu else "cpu")
print("device:", device)

device: cpu


In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# specify inputs for model
input_dim = 784
hidden_dims = [512, 256, 128, 64, 32]
latent_dims = [64, 32, 16, 8, 4]

num_epochs = 200
lr = 1e-3

model_lvae = LadderVAE(input_dim, hidden_dims, latent_dims).to(device)
print("Number of model parameters: ", count_parameters(model_lvae))
print(model_lvae)

Number of model parameters:  1157560
LadderVAE(
  (encoder): ModuleList(
    (0): LadderEncoder(
      (linear): Linear(in_features=784, out_features=512, bias=True)
      (batchnorm): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (mu): Linear(in_features=512, out_features=64, bias=True)
      (var): Linear(in_features=512, out_features=64, bias=True)
    )
    (1): LadderEncoder(
      (linear): Linear(in_features=512, out_features=256, bias=True)
      (batchnorm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (mu): Linear(in_features=256, out_features=32, bias=True)
      (var): Linear(in_features=256, out_features=32, bias=True)
    )
    (2): LadderEncoder(
      (linear): Linear(in_features=256, out_features=128, bias=True)
      (batchnorm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (mu): Linear(in_features=128, out_features=16, bias=True)
      (var): L

In [4]:
n_layers = 8
hidden_dim = 256
hidden_dims = [hidden_dim for _ in range(n_layers)]

num_epochs = 200
lr = 5e-5

model_diff = DiffusionNet(hidden_dims=hidden_dims).to(device)
print("Number of model parameters: ", count_parameters(model_diff))

Number of model parameters:  4870913


In [5]:
model_diff.load_state_dict(torch.load('./trained_models/diffusion_model', map_location=device))
model_lvae.load_state_dict(torch.load('./trained_models/LadderVAE', map_location=device))

<All keys matched successfully>

In [6]:
root = './data'
if not os.path.exists(root):
    os.mkdir(root)

if not os.path.exists('results'):
    os.mkdir('results')

batch_size = 128

kwargs = {'num_workers': 1, 'pin_memory': True} if device.type == "cuda:0" else {}
trainset = datasets.MNIST(root='./data', train=True,
                          download=True, transform=transforms.ToTensor())
train_loader = DataLoader(trainset, batch_size=batch_size,
                          shuffle=True, **kwargs)

testset = datasets.MNIST(root='./data', train=False,
                          download=True, transform=transforms.ToTensor())
test_loader = DataLoader(testset, batch_size=batch_size,
                         shuffle=False, **kwargs)

In [7]:
model_lvae.eval()
model_diff.eval()

lst = []
for batch_idx, (x, _) in enumerate(train_loader):
    lst.append(x)
    break


idx = 30
img = lst[0][idx].view(28, 28)

model_lvae_img, _ = model_lvae(lst[0][idx].view(-1, 784).to(device))
# model_diff_img, _, _ = model_diff(lst[0][idx].to(device))


lst = []
lst.append(img)
for t in range(model_diff.T):
    ts = torch.randint(low=t, high=t+1, size=(img.shape[0], )).to(device)
    temp = model_diff.add_noise(lst[t].to(device), ts)
    temp = model_diff.zero_one(temp)
    lst.append(temp)

model_diff.sample(1, img=lst[-1])


# imm = img.detach().cpu().numpy()
# imm2 = model_lvae_img.view(28, 28).detach().cpu().numpy()
# imm3 = model_diff_img.view(28, 28).detach().cpu().numpy()

# print(ssim(imm, imm2, data_range=imm.max() - imm.min()))
# print(ssim(imm, imm3, data_range=imm.max() - imm.min()))

# f, axarr = plt.subplots(1,3)
# axarr[0].imshow(img.detach().cpu().numpy(), cmap='gray')
# axarr[0].set_title('Ground-truth image')
# axarr[1].imshow(model_lvae_img.view(28, 28).detach().cpu().numpy(), cmap='gray')
# axarr[1].set_title('Bernoulli VAE Decoded Image')
# axarr[2].imshow(model_diff_img.view(28, 28).detach().cpu().numpy(), cmap='gray')
# axarr[2].set_title('Continuous Bernoulli VAE Decoded Image')

plt.show()

: 