In [10]:
import os
import torch
import torchvision
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np

from model import UNet
from diffusion_discrete import DiscreteDiffusion, generate_betas

In [2]:
gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if gpu else "cpu")
print("device:", device)

device: cpu


In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def imshow(img, w_shape=False):
    if w_shape:
        pimg = img.permute(1, 2, 0)
        npimg = pimg.numpy()
    else:
        npimg = img.numpy()
    plt.imshow(npimg, cmap='gray')
    plt.show()

In [4]:
# seed for reproducability
torch.manual_seed(50)

# training parameters
num_epochs = 10
batch_size = 128
lr = 2e-4
model = UNet(image_channels=1, model_output='logits').to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [5]:
print("Number of model parameters: ", count_parameters(model))
print(model)

Number of model parameters:  168897282
UNet(
  (image_proj): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_emb): TimeEmbedding(
    (lin1): Linear(in_features=64, out_features=256, bias=True)
    (act): Swish()
    (lin2): Linear(in_features=256, out_features=256, bias=True)
  )
  (down): ModuleList(
    (0-1): 2 x DownBlock(
      (res): ResidualBlock(
        (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (act1): Swish()
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
        (act2): Swish()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (shortcut): Identity()
        (time_emb): Linear(in_features=256, out_features=64, bias=True)
        (time_act): Swish()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (attn): Identity()
    )
    (2): Downsample(
      (conv): Conv2d(64, 64, kernel_size=(3, 

In [6]:
# change image size from 28 to 32 so that it is power of 2
img_size = 32

transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
])

trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

In [7]:
betas = generate_betas(type='cosine', start=0.02, stop=1., num_steps=1000).to(device)

In [8]:
diffusion = DiscreteDiffusion(betas=betas, transition_mat_type='uniform',
                              num_bits=8, transition_bands=None, model_prediction='x_start',
                              model_output='logistic_pars', loss_type='hybrid',
                              hybrid_coeff=0.001, device=device)

  self.betas = betas = torch.tensor(betas, dtype=torch.float64)


In [9]:
model.train()
for e in range(1, num_epochs+1):
    train_loss = 0
    train_loss_vals = []
    train_prior_bpd = 0
    for batch_idx, (x, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
        x = x.to(device, dtype=torch.long)
        # frame work expects data shape (B, H, W, C)
        x = x.permute(0, 2, 3, 1)

        optimizer.zero_grad()
        loss = diffusion.training_losses(model, x_start=x, rng=25).mean()
        prior_bpd = diffusion.prior_bpd(x).mean()
        loss.backward()
        train_loss += loss.item()
        train_prior_bpd += prior_bpd
        train_loss_vals.append(train_loss)
        optimizer.step()
        

    print("\tEpoch,", e, "complete!", "\tLoss: ", train_loss / batch_idx,
          "\tPrior bpd: ", train_prior_bpd / batch_idx)
        

  log_probs2 = F.log_softmax(logits)


old:  tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  .

  0%|          | 0/469 [00:47<?, ?it/s]

new:  tensor([[[[-5.8113e-03, -5.2292e-03, -1.3778e+01,  ..., -1.2163e+01,
           -1.1793e+01, -5.3368e-03],
          [-1.3477e+01, -5.8644e-03, -4.4800e-03,  ..., -1.1723e+01,
           -5.7536e-03, -5.3542e-03],
          [-5.8030e-03, -5.8776e-03, -3.8478e-03,  ..., -5.8398e-03,
           -5.7616e-03, -1.2496e+01],
          ...,
          [-5.8431e-03, -4.8669e-03, -3.8464e-03,  ..., -1.1689e+01,
           -5.8878e-03, -1.3262e+01],
          [-1.1718e+01, -1.1798e+01, -5.9075e-03,  ..., -3.4863e-03,
           -1.3765e+01, -5.6350e-03],
          [-1.3494e+01, -1.1921e+01, -1.2384e+01,  ..., -5.1816e-03,
           -5.6923e-03, -5.0018e-03]]],


        [[[-1.3981e+01, -6.5723e-03, -5.3988e-03,  ..., -5.2158e-03,
           -5.4280e-03, -5.8444e-03],
          [-6.3950e-03, -1.2953e+01, -1.4233e+01,  ..., -5.3330e-03,
           -1.2056e+01, -1.2090e+01],
          [-6.4549e-03, -4.9385e-03, -6.1179e-03,  ..., -4.1715e-03,
           -4.6118e-03, -6.3582e-03],
          ..




SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
