In [204]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("using", device)

torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed(777)

using cpu


In [205]:
# Models.py

import torch
import math

class ContractingBlock(torch.nn.Module):
    def __init__(self, in_channels:int, out_channels:int, latent_size:list, embed_channels:int):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.max_pool = torch.nn.MaxPool2d(2)
        self.activation = torch.nn.ReLU(inplace=True)
        self.time_embed_fc = torch.nn.Linear(embed_channels, out_channels)
        self.layernorm1 = torch.nn.LayerNorm([in_channels] + latent_size, eps = 1e-5)
        self.layernorm2 = torch.nn.LayerNorm([out_channels] + latent_size, eps = 1e-5)

    def forward(self, x, time_embed):
        x = self.conv1(self.layernorm1(x)) + self.time_embed_fc(time_embed)[:, :, None, None]
        x = self.activation(x)
        x = self.conv2(self.layernorm2(x))
        x = self.activation(x)
        return self.max_pool(x), x

class ExpansiveBlock(torch.nn.Module):
    def __init__(self, in_channels:int, out_channels:int, latent_size:list, embed_channels:int):
        super().__init__()
        self.upconv = torch.nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size = 2, stride = 2)
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.activation = torch.nn.ReLU(inplace=True)
        self.time_embed_fc = torch.nn.Linear(embed_channels, out_channels)
        self.layernorm1 = torch.nn.LayerNorm([in_channels] + latent_size, eps = 1e-5)
        self.layernorm2 = torch.nn.LayerNorm([out_channels] + latent_size, eps = 1e-5)

    def forward(self, x, x_skip, time_embed):
        x = self.upconv(x)
        # batch_size and channel_input should be same size
        assert x.size()[0] == x_skip.size()[0]
        assert x.size()[1] == x_skip.size()[1]
        if x.size() != x_skip.size():
            # size conflict -> pad to align size
            # this is only required if pad = 0 at Cont block and Exp block
            assert x.size()[2] < x_skip.size()[2]
            assert x.size()[3] < x_skip.size()[3]
            x_dif = x_skip.size()[2] - x.size()[2]
            y_dif = x_skip.size()[3] - x.size()[3]
            # size will be aligned to x_skip
            x = torch.nn.functional.pad(x, [x_dif // 2, x_dif - x_dif // 2, y_dif // 2, y_dif - y_dif // 2])
        x = torch.cat((x, x_skip), dim = 1)
        x = self.conv1(self.layernorm1(x)) + self.time_embed_fc(time_embed)[:, :, None, None]
        x = self.activation(x)
        x = self.conv2(self.layernorm2(x))
        x = self.activation(x)
        return x

class MiddleBlock(torch.nn.Module):
    def __init__(self, in_channels:int, out_channels:int, latent_size:list, embed_channels:int):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.activation = torch.nn.ReLU(inplace=True)
        self.time_embed_fc = torch.nn.Linear(embed_channels, out_channels)
        self.layernorm1 = torch.nn.LayerNorm([in_channels] + latent_size, eps = 1e-5)
        self.layernorm2 = torch.nn.LayerNorm([out_channels] + latent_size, eps = 1e-5)
    
    def forward(self, x, time_embed):
        x = self.conv1(self.layernorm1(x)) + self.time_embed_fc(time_embed)[:, :, None, None]
        x = self.activation(x)
        x = self.conv2(self.layernorm2(x))
        x = self.activation(x)
        return x

class PositionalEncoding(torch.nn.Module):
    # copy from transformer
    def __init__(self, embed_len, steps) -> None:
        super(PositionalEncoding, self).__init__()
        encoding = torch.zeros(embed_len, steps)
        encoding.requires_grad = False
        position = torch.arange(0, embed_len).float().unsqueeze(1)
        div_term = torch.exp(torch.arange(0, steps, 2) * -(math.log(10000.0) / steps))
        encoding[:, 0::2] = torch.sin(position * div_term)
        encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = torch.nn.Parameter(data = encoding, requires_grad=False)
    
    # input size : 
    # (batch_size) << this should be integer(time)
    # output size :
    # (batch_size, embed_size)
    def forward(self, x):
        pos_embed = self.encoding[x, :]
        return pos_embed

class UnetForDiffusion(torch.nn.Module):
    # Model of Unet with time embedding
    # this differs from the model described in the DDPM paper
    def __init__(self, in_channels:int, out_channels:int, latent_size:list, steps:int, embed_channels:int = 64, mid_start_channels:int = 64, path_len:int = 4):
        # channel size inc/decrease like
        # input image size should be larger than 2 ^ path_len
        # in_channels -> mid_start_channels -> mid_start_channels * 2 ... -> mid_start_channels * 2 ^ path_len -> ... mid_start_channels -> out_channels
        super().__init__()
        cont_blocks = []
        up_blocks = []
        self.path_len = path_len
        cont_blocks.append(ContractingBlock(in_channels, mid_start_channels, latent_size, embed_channels))
        up_blocks.append(ExpansiveBlock(mid_start_channels * 2, mid_start_channels, latent_size, embed_channels))
        self.time_encoding = PositionalEncoding(steps, embed_channels)
        self.classifier = torch.nn.Conv2d(mid_start_channels, out_channels, kernel_size=1)
        for _ in range(path_len - 1):
            latent_size = [s // 2 for s in latent_size]
            cont_blocks.append(ContractingBlock(mid_start_channels, mid_start_channels * 2, latent_size, embed_channels))
            up_blocks.append(ExpansiveBlock(mid_start_channels * 4, mid_start_channels * 2, latent_size, embed_channels))
            mid_start_channels *= 2
        self.cont_blocks = torch.nn.ModuleList(cont_blocks)
        self.exp_blocks = torch.nn.ModuleList(up_blocks)
        latent_size = [s // 2 for s in latent_size]
        self.mid_block = MiddleBlock(mid_start_channels, mid_start_channels * 2, latent_size, embed_channels)
    

    # input size : 
    # x : (batch_size, width, height)
    # time : (batch_size) << this should be integer(time)
    # output size :
    # (batch_size, width, height)
    def forward(self, x, time):
        skip = []
        time_embed = self.time_encoding(time)
        for i in range(self.path_len):
            x, x_skip = self.cont_blocks[i](x, time_embed)
            skip.append(x_skip)
        x = self.mid_block(x, time_embed)
        for i in range(self.path_len - 1, -1, -1):
            x = self.exp_blocks[i](x, skip[i], time_embed)
        return self.classifier(x)

In [206]:
# import torch

# class PixelCNNpp(torch.nn.Module):
#     # Model of PixelCNN++
#     def __init__(self, in_channels:int, out_channels:int):
#         pass

In [207]:
# # checking output size

# net = Unet(in_channels = 3, out_channels = 10, steps = 1000)
# inp = torch.randn(5, 3, 32, 32)
# time = torch.arange(5)
# out = net(inp, time)
# print(out.size())

In [208]:
# Config.py

class config:
    # class to handle configs in ddpm
    def __init__(self,
                 in_channels = 1,
                 out_channels = 1,
                 latent_size = [28, 28],
                 lr = 0.1, 
                 epoch = 1000, 
                 eval_per_epoch = 10, 
                 batch_size = 256,
                 criterion = torch.nn.MSELoss(), 
                 step = 1000, 
                 beta_1 = 0.0001,
                 beta_t = 0.02, 
                 device = device
                 ):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.latent_size = latent_size
        self.lr = lr
        self.epoch = epoch
        self.eval_per_epoch = eval_per_epoch
        self.batch_size = batch_size
        self.criterion = criterion
        self.step = step
        self.beta = [0] + [i / (step - 1) * (beta_t - beta_1) + beta_1 for i in range(0, step)]
        assert len(self.beta) == step + 1

        self.alpha = []
        alpha = 1
        for b_t in self.beta:
            alpha *= (1 - b_t)
            self.alpha.append(alpha)
        self.device = device
        self.beta = torch.tensor(self.beta).to(self.device)     # beta_t
        self.alpha = torch.tensor(self.alpha).to(self.device)   # alpha_t bar
        self.beta.requires_grad = False                         # Do not update alpha & beta
        self.alpha.requires_grad = False

In [209]:
# DDPM.py

class DDPM:
    def __init__(self, model, train_data, eval_data, test_data, config:config):
        # model should be image to image model has input size == output size
        # data should be dataset, not dataloader
        self.model = model
        self.config = config

        # init dataloaders
        self.train_loader = torch.utils.data.DataLoader(
            dataset = train_data,
            shuffle = True,
            batch_size = self.config.batch_size,
            drop_last = True,
        )
        self.eval_loader = torch.utils.data.DataLoader(
            dataset = eval_data,
            shuffle = False,
            batch_size = self.config.batch_size,
            drop_last = True,
        )
        self.test_loader = torch.utils.data.DataLoader(
            dataset = test_data,
            shuffle = False,
            batch_size = self.config.batch_size,
            drop_last = True,
        )
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr = self.config.lr)
    
    def train(self):
        for i in range(1, self.config.epoch + 1):
            # train here
            self.train_one_epoch(i)
            
            # eval here
            if i % self.config.eval_per_epoch == 0:
                self.evaluate(i)

    def train_one_epoch(self, epoch):
        self.model.train()
        loss_sum = 0
        cnt = 0
        for x, _ in self.train_loader:
            x = x.to(self.config.device)
            sampled_steps = self.sample_steps()
            x_t = self.sample_forward_t(x, sampled_steps)         # sampled x_t
            x_tp1 = self.sample_forward_1(x_t, sampled_steps)     # sampled x_t+1
            pred_residual = self.model(x_tp1, sampled_steps)    # predict residual between x_t and x_t+1 using x_t+1
            loss = self.loss(pred_residual, x_t, x_tp1, sampled_steps)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            loss_sum += loss.item()
            cnt += x.size(0)
        print('[EPOCH' + str(epoch) + '] TRAIN avg loss :', loss_sum / cnt)

    def evaluate(self, epoch):
        self.model.eval()
        loss_sum = 0
        cnt = 0
        for x, _ in self.train_loader:
            x = x.to(self.config.device)
            sampled_steps = self.sample_steps()
            x_t = self.sample_forward_t(x, sampled_steps)         # sampled x_t
            x_tp1 = self.sample_forward_1(x_t, sampled_steps)     # sampled x_t+1
            pred_residual = self.model(x_tp1, sampled_steps)    # predict residual between x_t and x_t+1 using x_t+1
            loss = self.loss(pred_residual, x_t, x_tp1, sampled_steps)
            loss_sum += loss.item()
            cnt += x.size(0)
        print('[EPOCH' + str(epoch) + '] EVAL avg loss :', loss_sum / cnt)
        return loss_sum / cnt

    def inference(self, latent):
        # inference from the latent
        # latent : tensor of size (1, channels, width, height)
        for step in range(self.config.step - 1, -1, -1):
            latent = self.sample_reverse_1(latent, step)
        
        return latent

    def sample_latent(self, size):
        return torch.randn(size).to(self.config.device)
    
    def sample_reverse_1(self, x_t, step):
        step_torch = torch.tensor([step])
        z = torch.randn(x_t.size()).to(self.config.device)
        sigma_t = torch.sqrt(self.config.beta[step_torch])      # page 3 of paper says that (1 - alpha_t-1) / (1 - alpha_t) * beta_t and beta_t had similar results
        return (x_t - self.model(x_t, step_torch) * ((self.config.beta[step_torch]) / torch.sqrt(1 - self.config.alpha[step_torch]))[:, None, None, None]) / torch.sqrt(1 - self.config.beta[step_torch])[:, None, None, None] + z * sigma_t[:, None, None, None]
    
    def sample_steps(self):
        return torch.randint(self.config.step, (self.config.batch_size, ))
    
    def sample_forward_t(self, x_0, step):
        # sample using q(x_t|x_0)
        # input : x_0, step(t)
        # step must be 0 ~ max_step - 1
        # output : sampled x_t
        return torch.randn(x_0.size()).to(self.config.device) * (1 - self.config.alpha[step])[:, None, None, None] + x_0 * torch.sqrt(self.config.alpha[step])[:, None, None, None]

    def sample_forward_1(self, x_t, step):
        # sample using q(x_t+1|x_t)
        # input : x_0, step(t)
        # step must be 0 ~ max_step - 1
        # output : sampled x_t+1
        return torch.randn(x_t.size()).to(self.config.device) * (self.config.beta[step + 1])[:, None, None, None] + x_t * torch.sqrt(1 - self.config.beta[step + 1])[:, None, None, None]

    def loss(self, pred_x_t, gt_x_t, gt_x_tp1, step):
        # mseloss
        # input : predicted x_t, gt x_t, gt x_t+1, step(t)
        # step must be 0 ~ max_step - 1
        # output : mseloss with prediction and gt, 
        gt_residual = (gt_x_t - gt_x_tp1) / self.config.beta[step + 1][:, None, None, None]
        return torch.sum(torch.mean((pred_x_t - gt_residual).pow(2), 0, True))


In [210]:
DDPMConfig = config()
model = UnetForDiffusion(in_channels = DDPMConfig.in_channels, out_channels = DDPMConfig.out_channels, latent_size = DDPMConfig.latent_size, steps = 1000).to(DDPMConfig.device)
mnist_train = torchvision.datasets.MNIST(
    root = '../MNIST_data',
    train = True, 
    transform = torchvision.transforms.ToTensor(), 
    download = True
)
mnist_eval = torchvision.datasets.MNIST(
    root = '../MNIST_data',
    train = False, 
    transform = torchvision.transforms.ToTensor(), 
    download = True
)
DDPM_obj = DDPM(model, mnist_train, mnist_eval, mnist_eval, DDPMConfig)

In [211]:
DDPM_obj.train()

[EPOCH1] TRAIN avg loss : 3.6297121047973633


KeyboardInterrupt: 

In [None]:
example_latent = DDPM_obj.sample_latent((1, 1, 32, 32))
inferenced_image = DDPM_obj.inference(example_latent)

In [None]:
# Utils.py
import matplotlib.pyplot as plt
import numpy

def print_image(img_torch):
    img = img_torch.numpy()
    plt.imshow(numpy.transpose(img, (1, 2, 0)))
    plt.show()

In [None]:
print_image(inferenced_image)