# Utilities

In [1]:
import torch
from torchvision import datasets, transforms

tensor_transform = transforms.Compose([
    transforms.ToTensor()
])

batch_size = 512
train_dataset = datasets.MNIST(root = "/home/zhh/data",
									train = True,
									download = True,
									transform = tensor_transform)
test_dataset = datasets.MNIST(root = "/home/zhh/data",
									train = False,
									download = True,
									transform = tensor_transform)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
							   batch_size = batch_size,
								 shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
							   batch_size = batch_size,
								 shuffle = False)


In [2]:
class Avger(list):
    def __str__(self):
        return f'{sum(self) / len(self):.4f}' if len(self) > 0 else 'N/A'
    
from tqdm.notebook import tqdm,trange
import torchvision.utils as vutils
import torch.nn as nn
import torch.nn.functional as F

## DDPM Hyperparameters

Schedules.

In [3]:
# cosine schedule
T = 1000
angles = torch.linspace(0.1, torch.pi-0.1, T)
alpha_bars = (1 + torch.cos(angles)) / 2
alphas = torch.ones_like(alpha_bars,dtype=torch.float)
alphas[1:] = alpha_bars[1:] / alpha_bars[:-1]
sigmas = torch.sqrt(1-alphas)
alpha_bars[::100]

tensor([0.9975, 0.9616, 0.8860, 0.7771, 0.6444, 0.4993, 0.3542, 0.2217, 0.1131,
        0.0378])

Model

In [4]:
# from context_unet import ModelCls
from vanilla_unet import UNet as ModelCls

# Full DDPM (sanity check)

In [5]:
assert False

AssertionError: 

In [None]:
def prepare_batch_base(x):
    ts = torch.randint(0, T, (x.shape[0],1,1,1))
    noise = torch.randn_like(x)
    inputs = torch.sqrt(alpha_bars[ts]) * x + torch.sqrt(1 - alpha_bars[ts]) * noise
    targets = noise
    return inputs.cuda(), ts.reshape(-1).cuda(), targets.cuda()

In [None]:
import os
import torchvision.utils as vutils
from tqdm.notebook import trange

os.makedirs('samples', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)

@torch.no_grad()
def sample(model,num=64,ep=0):
    x = torch.randn(num,1,28,28).cuda()
    for t in trange(T-1,-1,-1,desc=f'Epoch {ep} sampling'):
        w1 = 1/torch.sqrt(alphas[t]).cuda()
        w2 = (1-alphas[t])/torch.sqrt(1-alpha_bars[t]).cuda()
        x = w1 * (x - w2 * model(x,torch.tensor([t]).cuda().repeat(x.shape[0],))) + sigmas[t].cuda().reshape(-1,1,1,1) * torch.randn_like(x)
    # make grid
    grid = vutils.make_grid(x, nrow=8)
    vutils.save_image(grid, f'samples/{ep}_sample.png')
    return x

In [None]:
import torch.nn.functional as F
from tqdm.notebook import tqdm
def train(model,epochs, sample_ep=4):
    opt = torch.optim.Adam(model.parameters(), lr=1e-4)
    for epoch in range(epochs):
        model.train()
        train_losses = Avger()
        with tqdm(train_loader) as bar:
            for x,_ in bar:
                inputs,ts,targets = prepare_batch_base(x)
                pred = model(inputs,ts)
                loss = F.mse_loss(pred,targets)
                
                opt.zero_grad()
                loss.backward()
                opt.step()
                train_losses.append(loss.item())
                bar.set_description(f'Epoch {epoch} loss {train_losses}')
                
        
        if epoch == 0 or (epoch+1) % sample_ep == 0 or epoch == epochs-1:
            # save model
            torch.save(model.state_dict(), f'checkpoints/{epoch}.pth')
            model.eval()
            sample(model,num=64,ep=epoch)
            print(f'Epoch {epoch} sample saved')

## run

In [None]:
model = ModelCls().cuda()
print('model params:', sum(p.numel() for p in model.parameters()))
train(model, 50)

model params: 2282673


  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 0 sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 0 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 3 sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 3 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 7 sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 7 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 11 sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 11 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 15 sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 15 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 19 sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 19 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 23 sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 23 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 27 sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 27 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 31 sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 31 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 35 sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 35 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 39 sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 39 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 43 sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 43 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 47 sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 47 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

# Half DDPM

In [None]:
assert False

In [None]:
def prepare_batch_halfddpm(x):
    ts = torch.randint(0, T//2, (x.shape[0],1,1,1))
    noise = torch.randn_like(x)
    inputs = torch.sqrt(alpha_bars[ts]) * x + torch.sqrt(1 - alpha_bars[ts]) * noise
    targets = noise
    return inputs.cuda(), ts.reshape(-1).cuda(), targets.cuda()

In [None]:
import os
import torchvision.utils as vutils
from tqdm.notebook import trange

os.makedirs('samples', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)

@torch.no_grad()
def sample_halfddpm(model,num=64,ep=0):
    data = next(iter(test_loader))[0][:num].cuda()
    x = data * torch.sqrt(alpha_bars[T//2]) + torch.randn_like(data) * torch.sqrt(1-alpha_bars[T//2])
    grid = vutils.make_grid(x, nrow=8)
    vutils.save_image(grid, f'samples/{ep}_init.png')
    
    for t in trange(T//2-1,-1,-1,desc=f'Epoch {ep} sampling'):
        w1 = 1/torch.sqrt(alphas[t]).cuda()
        w2 = (1-alphas[t])/torch.sqrt(1-alpha_bars[t]).cuda()
        x = w1 * (x - w2 * model(x,torch.tensor([t]).cuda().repeat(x.shape[0],))) + sigmas[t].cuda().reshape(-1,1,1,1) * torch.randn_like(x)
    # make grid
    grid = vutils.make_grid(x, nrow=8)
    vutils.save_image(grid, f'samples/{ep}_sample.png')
    return x

In [None]:
import torch.nn.functional as F
from tqdm.notebook import tqdm
def train_halfddpm(model,epochs, sample_ep=4):
    opt = torch.optim.Adam(model.parameters(), lr=1e-4)
    for epoch in range(epochs):
        model.train()
        train_losses = Avger()
        with tqdm(train_loader) as bar:
            for x,_ in bar:
                inputs,ts,targets = prepare_batch_halfddpm(x)
                pred = model(inputs,ts)
                loss = F.mse_loss(pred,targets)
                
                opt.zero_grad()
                loss.backward()
                opt.step()
                train_losses.append(loss.item())
                bar.set_description(f'Epoch {epoch} loss {train_losses}')
        
        if epoch == 0 or (epoch+1) % sample_ep == 0 or epoch == epochs-1:
            # save model
            torch.save(model.state_dict(), f'checkpoints/{epoch}.pth')
            model.eval()
            sample_halfddpm(model,num=64,ep=epoch)
            print(f'Epoch {epoch} sample saved')

## run

In [None]:
model = ModelCls().cuda()
print('model params:', sum(p.numel() for p in model.parameters()))
train_halfddpm(model, 50)

model params: 2282673


  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 0 sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 0 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 3 sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 7 sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 7 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 11 sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 11 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 15 sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 15 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 19 sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 19 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 23 sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 23 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 27 sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 27 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 31 sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 31 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 35 sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 35 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 39 sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 39 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 43 sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 43 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 47 sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 47 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 49 sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 49 sample saved


# PixelCNN Sanity Check

In [None]:
assert False

In [None]:
from pixelcnn import PixelCNNDiscrete

In [None]:
from tqdm.notebook import trange
import torchvision.utils as vutils

@torch.no_grad()
def sample_pixelcnn_discrete(model, num=64,ep=0):
    x = torch.zeros(num, 1, 28, 28).cuda()
    for i in trange(28,desc=f'Epoch {ep} sampling'):
        for j in range(28):
            logits = model(x)
            probs = F.softmax(logits[:, :, i, j],dim=1)
            x[:, :, i, j] = (torch.multinomial(probs, 1).float()) / model.classes
    grid = vutils.make_grid(x, nrow=8)
    vutils.save_image(grid, f'samples/{ep}_pixelcnn.png')
    return x

In [None]:
def train_pixelCNN_sanity(model,epochs, sample_ep=4):
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(epochs):
        model.train()
        train_losses = Avger()
        with tqdm(train_loader) as bar:
            for x,_ in bar:
                x = x.cuda()
                quantized = torch.floor(x.clamp(0.001,0.999) * model.classes).long()
                
                pred = model(x)
                loss = F.cross_entropy(pred.permute(0,2,3,1).reshape(-1,model.classes), quantized.squeeze(1).reshape(-1))
                
                opt.zero_grad()
                loss.backward()
                opt.step()
                train_losses.append(loss.item())
                bar.set_description(f'Epoch {epoch} loss {train_losses}')
        
        if epoch == 0 or (epoch+1) % sample_ep == 0 or epoch == epochs-1:
            # save model
            torch.save(model.state_dict(), f'checkpoints/{epoch}_pixelcnn.pth')
            model.eval()
            sample_pixelcnn_discrete(model,num=64,ep=epoch)
            print(f'Epoch {epoch} sample saved')

In [None]:
pixelcnn_model = PixelCNNDiscrete(256).cuda()
print('model params:', sum(p.numel() for p in pixelcnn_model.parameters()))
train_pixelCNN_sanity(pixelcnn_model, 30)

model params: 2519232


  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 0 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 0 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 3 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 3 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 7 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 7 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 11 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 11 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 15 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 15 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 19 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 19 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 23 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 23 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 27 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 27 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

KeyboardInterrupt: 

# PixelCNN Job

In [None]:
def get_mixture(data):
    x = data * torch.sqrt(alpha_bars[T//2]) + torch.randn_like(data) * torch.sqrt(1-alpha_bars[T//2])
    return x.cuda()

In [None]:
def train_pixelCNN(model,epochs, sample_ep=4):
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(epochs):
        model.train()
        train_losses = Avger()
        with tqdm(train_loader) as bar:
            for x,_ in bar:
                x = get_mixture(x)
                quantized = torch.floor(x.clamp(0.001,0.999) * model.classes).long()
                
                pred = model(x)
                loss = F.cross_entropy(pred.permute(0,2,3,1).reshape(-1,model.classes), quantized.squeeze(1).reshape(-1))
                
                opt.zero_grad()
                loss.backward()
                opt.step()
                train_losses.append(loss.item())
                bar.set_description(f'Epoch {epoch} loss {train_losses}')
        
        if epoch == 0 or (epoch+1) % sample_ep == 0 or epoch == epochs-1:
            # save model
            torch.save(model.state_dict(), f'checkpoints/{epoch}_pixelcnn.pth')
            model.eval()
            sample_pixelcnn_discrete(model,num=64,ep=epoch)
            print(f'Epoch {epoch} sample saved')

In [None]:
pixelcnn_model = PixelCNNDiscrete(256).cuda()
print('model params:', sum(p.numel() for p in pixelcnn_model.parameters()))
train_pixelCNN(pixelcnn_model, 30)

model params: 2519232


  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 0 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 0 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 3 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 3 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 7 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 7 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 11 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 11 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 15 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 15 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Continuous PixelCNN Sanity

In [None]:
assert False

In [6]:
from pixelcnn import PixelCNNContinous

In [12]:
from tqdm.notebook import trange
import torchvision.utils as vutils

@torch.no_grad()
def sample_pixelcnn_continous(model, num=64,ep=0):
    x = torch.zeros(num, 1, 28, 28).cuda()
    for i in trange(28,desc=f'Epoch {ep} sampling'):
        for j in range(28):
            mu = model(x)
            # mu, logvar = model(x)
            x[:, :, i, j] = torch.normal(mu[:, :, i, j], torch.exp(torch.tensor(-4.0/2).cuda())).clamp(0,1)
    grid = vutils.make_grid(x, nrow=8)
    vutils.save_image(grid, f'samples/{ep}_pixelcnn_c.png')
    return x

In [15]:
def gaussian_log_pdf(x, mu, logvar=torch.tensor(-4.0)):
    var = torch.exp(logvar)
    return -0.5 * (torch.log(2 * torch.tensor(torch.pi)) + logvar + (x - mu) ** 2 / var)

def train_pixelCNN_continous_sanity(model,epochs, sample_ep=4):
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(epochs):
        # opt.param_groups[0]['lr'] = 1e-3 * (epochs - epoch) / epochs
        model.train()
        train_losses = Avger()
        with tqdm(train_loader) as bar:
            for x,_ in bar:
                # def get_mixture(data):
                x = x.cuda()
                # return x.cuda()
                pred = model(x)
                # use log likelihood
                loss = -torch.mean(gaussian_log_pdf(x, pred[0]))
                
                opt.zero_grad()
                loss.backward()
                opt.step()
                train_losses.append(loss.item())
                bar.set_description(f'Epoch {epoch} loss {train_losses}')
        
        if epoch == 0 or (epoch+1) % sample_ep == 0 or epoch == epochs-1:
            # save model
            torch.save(model.state_dict(), f'checkpoints/{epoch}_pixelcnn_c.pth')
            model.eval()
            sample_pixelcnn_continous(model,num=64,ep=epoch)
            print(f'Epoch {epoch} sample saved')

In [16]:
pixelcnn_model = PixelCNNContinous().cuda()
print('model params:', sum(p.numel() for p in pixelcnn_model.parameters()))
train_pixelCNN_continous_sanity(pixelcnn_model, 10)

model params: 2486337


  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 0 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 0 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 3 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 3 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 7 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 7 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 9 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 9 sample saved


# Continuous PixelCNN

In [None]:
assert False

In [None]:
def get_mixture(data):
    x = data * torch.sqrt(alpha_bars[T//2]) + torch.randn_like(data) * torch.sqrt(1-alpha_bars[T//2])
    return x.cuda()

In [None]:
def gaussian_log_pdf(x, mu, logvar):
    var = torch.exp(logvar)
    return -0.5 * (torch.log(2 * torch.tensor(torch.pi)) + logvar + (x - mu) ** 2 / var)

def train_pixelCNN_continous(model,epochs, sample_ep=4):
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(epochs):
        opt.param_groups[0]['lr'] = 1e-3 * (epochs - epoch) / epochs
        model.train()
        train_losses = Avger()
        with tqdm(train_loader) as bar:
            for x,_ in bar:
                x = get_mixture(x)
                
                pred = model(x)
                # use log likelihood
                loss = -torch.mean(gaussian_log_pdf(x, pred[0], pred[1]))
                
                opt.zero_grad()
                loss.backward()
                opt.step()
                train_losses.append(loss.item())
                bar.set_description(f'Epoch {epoch} loss {train_losses}')
        
        if epoch == 0 or (epoch+1) % sample_ep == 0 or epoch == epochs-1:
            # save model
            torch.save(model.state_dict(), f'checkpoints/{epoch}_pixelcnn_c.pth')
            model.eval()
            sample_pixelcnn_continous(model,num=64,ep=epoch)
            print(f'Epoch {epoch} sample saved')

In [None]:
pixelcnn_model = PixelCNNContinous().cuda()
print('model params:', sum(p.numel() for p in pixelcnn_model.parameters()))
train_pixelCNN_continous(pixelcnn_model, 10)

model params: 2486466


  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 0 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 0 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 3 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 3 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 7 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 7 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 9 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 9 sample saved
