In [11]:
import os
import torch
import torchvision
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# from torch.utils.tensorboard import summary
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np

from model import UNet
from diffusion_discrete import DiscreteDiffusion, generate_betas

In [12]:
gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if gpu else "cpu")
print("device:", device)

device: cpu


In [13]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def imshow(img, w_shape=False):
    if w_shape:
        pimg = img.permute(1, 2, 0)
        npimg = pimg.numpy()
    else:
        npimg = img.numpy()
    plt.imshow(npimg, cmap='gray')
    plt.show()

In [14]:
# seed for reproducability
torch.manual_seed(50)

# training parameters
num_epochs = 10
batch_size = 128
lr = 2e-4
model = UNet(image_channels=1, model_output='logistic_pars').to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [15]:
print("Number of model parameters: ", count_parameters(model))
print(model)

Number of model parameters:  168897282
UNet(
  (image_proj): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_emb): TimeEmbedding(
    (lin1): Linear(in_features=64, out_features=256, bias=True)
    (act): Swish()
    (lin2): Linear(in_features=256, out_features=256, bias=True)
  )
  (down): ModuleList(
    (0-1): 2 x DownBlock(
      (res): ResidualBlock(
        (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (act1): Swish()
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
        (act2): Swish()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (shortcut): Identity()
        (time_emb): Linear(in_features=256, out_features=64, bias=True)
        (time_act): Swish()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (attn): Identity()
    )
    (2): Downsample(
      (conv): Conv2d(64, 64, kernel_size=(3, 

In [16]:
# change image size from 28 to 32 so that it is power of 2
img_size = 32

transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.PILToTensor()
])

trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

In [17]:
betas = generate_betas(type='linear', start=1.e-4, stop=0.02, num_steps=1000).to(device)

In [18]:
diffusion = DiscreteDiffusion(betas=betas, transition_mat_type='gaussian',
                              num_bits=8, transition_bands=None, model_prediction='x_start',
                              model_output='logistic_pars', loss_type='hybrid',
                              hybrid_coeff=0.001, device=device)

  self.betas = betas = torch.tensor(betas, dtype=torch.float64)


In [20]:
train_losses = []
test_losses = []

for e in range(1, num_epochs+1):
    model.train()
    train_loss = 0
    train_loss_vals = []
    train_prior_bpd = 0
    for batch_idx, (x, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
        x = x.to(device, dtype=torch.int32)
        # frame work expects data shape (B, H, W, C)
        x = x.permute(0, 2, 3, 1)

        optimizer.zero_grad()
        loss = diffusion.training_losses(model, x_start=x, rng=25).mean()
        prior_bpd = diffusion.prior_bpd(x).mean()

        loss.backward()
        train_loss += loss.item()
        train_loss_vals.append(loss.item())
        optimizer.step()

    train_loss /= batch_idx

    # evaluation
    model.eval()
    test_loss = 0
    test_loss_vals = []
    test_prior_bpd = 0
    test_total_bpd = 0
    with torch.no_grad():
        for batch_idx, (x_test, _) in enumerate(test_loader):
            x_test = x_test.to(device, dtype=torch.int32)
            x_test = x_test.permute(0, 2, 3, 1)

            l = diffusion.training_losses(model, x_start=x_test, rng=25).mean()
            test_loss += l.item()
            test_loss_vals.append(l.item())

            # loss_dict = diffusion.calc_bpd_loop(model, x_start=x_test, rng=25)
            # total_bpd = torch.mean(loss_dict['total'], dim=0)
            # prior_bpd = torch.mean(loss_dict['prior'], axis=0)
    
    test_loss /= batch_idx

    train_losses.append(train_loss_vals)
    test_losses.append(test_loss_vals)

    samples = diffusion.p_sample_loop(model_fn=model, shape=(1, 32, 32, 1), rng=25)
    imshow(samples[0].detach().cpu())

    print("\tEpoch,", e, "complete!", "\tTrain Loss: ", train_loss,
          "\tTest Loss: ", test_loss)
        

  0%|          | 0/469 [00:00<?, ?it/s]

torch.Size([128, 1, 32, 32])
tensor(255, dtype=torch.uint8)
tensor(0, dtype=torch.uint8)
torch.Size([128, 32, 32, 1])


  0%|          | 0/469 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [None]:
plt.plot(np.arange(1, num_epochs+1), [np.mean(ls) for ls in train_losses], lw=2.5, label='train')
plt.plot(np.arange(1, num_epochs+1), [np.mean(ls) for ls in test_losses], lw=2.5, label='test')
plt.xlabel('epoch')
plt.ylabel('loss')
# plt.yscale('log')
plt.title('Average training/test loss per epoch')
plt.legend()
plt.show()