In [1]:
import torch
from torchvision import datasets, transforms

tensor_transform = transforms.Compose([
    transforms.ToTensor()
])

batch_size = 512
train_dataset = datasets.MNIST(root = "/home/zhh/data",
									train = True,
									download = True,
									transform = tensor_transform)
test_dataset = datasets.MNIST(root = "/home/zhh/data",
									train = False,
									download = True,
									transform = tensor_transform)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
							   batch_size = batch_size,
								 shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
							   batch_size = batch_size,
								 shuffle = False)


In [2]:
class Avger(list):
    def __str__(self):
        return f'{sum(self) / len(self):.4f}' if len(self) > 0 else 'N/A'
    
from tqdm.notebook import tqdm,trange
import torchvision.utils as vutils
import torch.nn as nn
import torch.nn.functional as F

# Continuous PixelCNN Sanity

In [3]:
from pixelcnn import PixelCNNContinous

In [4]:
from tqdm.notebook import trange
import torchvision.utils as vutils

@torch.no_grad()
def sample_pixelcnn_continous(model, num=64,ep=0):
    x = torch.zeros(num, 1, 28, 28).cuda()
    for i in trange(28,desc=f'Epoch {ep} sampling'):
        for j in range(28):
            # mu = model(x)
            mu, logvar = model(x)
            # x[:, :, i, j] = torch.normal(mu[:, :, i, j], torch.exp(torch.tensor(-4.0/2).cuda())).clamp(0,1)
            x[:, :, i, j] = torch.normal(mu[:, :, i, j], torch.exp(logvar[:,:,i,j].cuda()/2)).clamp(0,1)
    grid = vutils.make_grid(x, nrow=8)
    vutils.save_image(grid, f'samples/{ep}_pixelcnn_c.png')
    return x

In [5]:
def gaussian_log_pdf(x, mu, logvar=torch.tensor(-4.0)):
    var = torch.exp(logvar)
    return -0.5 * (torch.log(2 * torch.tensor(torch.pi)) + logvar + ((x - mu) ** 2) / var)

def train_pixelCNN_continous_sanity(model,epochs, sample_ep=4):
    opt = torch.optim.Adam(model.parameters(), lr=1e-4)
    # opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(epochs):
        # opt.param_groups[0]['lr'] = 1e-3 * (epochs - epoch) / epochs
        model.train()
        train_losses = Avger()
        with tqdm(train_loader) as bar:
            for x,_ in bar:
                # def get_mixture(data):
                x = x.cuda()
                # return x.cuda()
                pred = model(x)
                # use log likelihood
                # loss = -torch.mean(gaussian_log_pdf(x, pred[0]))
                loss = -torch.mean(gaussian_log_pdf(x, pred[0], pred[1]))
                
                opt.zero_grad()
                loss.backward()
                opt.step()
                train_losses.append(loss.item())
                bar.set_description(f'Epoch {epoch} loss {train_losses}')
        
        if epoch == 0 or (epoch+1) % sample_ep == 0 or epoch == epochs-1:
            # save model
            torch.save(model.state_dict(), f'checkpoints/{epoch}_pixelcnn_c.pth')
            model.eval()
            sample_pixelcnn_continous(model,num=64,ep=epoch)
            print(f'Epoch {epoch} sample saved')

In [6]:
pixelcnn_model = PixelCNNContinous().cuda()
print('model params:', sum(p.numel() for p in pixelcnn_model.parameters()))
train_pixelCNN_continous_sanity(pixelcnn_model, 50)

model params: 2490402


  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 0 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 0 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 3 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 3 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 7 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 7 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 11 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 11 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 15 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 15 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 19 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 19 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 23 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 23 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 27 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 27 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 31 sampling:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 31 sample saved


  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

KeyboardInterrupt: 