In [None]:
from tqdm import tqdm
from PIL import Image

import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
import numpy as np

import torchvision
import itertools
import shutil
import torch
import time
import math
import os

In [None]:
T = 1000

class SimpleCNN(nn.Module):
    
    def __init__(self):
        super(SimpleCNN, self).__init__()
        
        self.act = nn.ReLU()
        self.pool = nn.AvgPool2d(2, 2)
        self.conv1 = nn.Conv2d(3, 32, 3, 1, 1) # Size 16
        self.conv2 = nn.Conv2d(32, 64, 3, 1, 1) # Size 8
        self.conv3 = nn.Conv2d(64, 128, 3, 1, 1) # Size 4
        self.conv4 = nn.Conv2d(128, 256, 3, 1, 1) # Size 2
        self.fc = nn.Linear(1024, T)
    
    def forward(self, x):
        x = self.pool(self.act(self.conv1(x)))
        x = self.pool(self.act(self.conv2(x)))
        x = self.pool(self.act(self.conv3(x)))
        x = self.pool(self.act(self.conv4(x)))
        x = x.flatten(start_dim=1)
        x = self.fc(x)
        return x

class SimpleCNN256(nn.Module):
    
    def __init__(self):
        super(SimpleCNN256, self).__init__()
        
        self.act = nn.ReLU()
        self.pool2 = nn.AvgPool2d(2, 2)
        self.pool4 = nn.AvgPool2d(2, 4)
        self.conv1 = nn.Conv2d(3, 32, 3, 1, 1) # Size 64
        self.conv2 = nn.Conv2d(32, 64, 3, 1, 1) # Size 16
        self.conv3 = nn.Conv2d(64, 128, 3, 1, 1) # Size 4
        self.conv4 = nn.Conv2d(128, 256, 3, 1, 1) # Size 2
        self.fc = nn.Linear(1024, T)
    
    def forward(self, x):
        x = self.pool4(self.act(self.conv1(x)))
        x = self.pool4(self.act(self.conv2(x)))
        x = self.pool4(self.act(self.conv3(x)))
        x = self.pool2(self.act(self.conv4(x)))
        x = x.flatten(start_dim=1)
        x = self.fc(x)
        return x

# CIFAR10 : sigma_max = 50
# CELEBAHQ : sigma_max = 348
def sample_joint(x, sigma_max, sigma_min=0.01):
    sigma_min = torch.tensor([sigma_min]).to(x.device)
    sigma_max = torch.tensor([sigma_max]).to(x.device)
    ts = torch.linspace(1.0, 1e-3, T).to(x.device)
    ss = sigma_min * (sigma_max / sigma_min).to(x.device) ** ts
    y = torch.randint(0, T, size=[x.shape[0]]).to(x.device)
    x = x + ss[y].reshape(-1,1,1,1) * torch.randn_like(x).to(x.device)
    return x, y

In [None]:
load = 0
ckpt_dir = './logs'
s_epoch = 5
n_epoch = 100
bs = 100

# train_X = (code to load your dataset)

n_iter = math.ceil(train_X.shape[0] / bs)

with torch.cuda.device(0):
    
    net = SimpleCNN().cuda()
    opt = optim.Adam(net.parameters(), lr=0.001)
    crit = nn.CrossEntropyLoss()

    if load:
        ckpt = torch.load(ckpt_dir + '/{}.pt'.format(load))
        epoch = ckpt['epoch']
        opt.load_state_dict(ckpt['opt'])
        net.load_state_dict(ckpt['net'])
    else:
        epoch = 0
        if os.path.isdir(ckpt_dir):
            for f in os.listdir(ckpt_dir):
                os.remove(os.path.join(ckpt_dir,f))
        else:
            os.makedirs(ckpt_dir)
        
    for i in range(n_epoch):
        epoch += 1
        for j in range(n_iter):
            X, y = sample_joint(train_X[j*bs:(j+1)*bs].cuda())
            loss = crit(net(X), y)
            opt.zero_grad()
            loss.backward()
            opt.step()    
            
        if epoch % s_epoch == 0:
            
            ckpt = {
                'epoch': epoch,
                'opt': opt.state_dict(),
                'net': net.state_dict(),
            }
            
            torch.save(ckpt, ckpt_dir + '/{}.pt'.format(epoch))