In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
from utils import *
from modules import UNet
import logging
from torch.utils.tensorboard import SummaryWriter
import argparse

In [None]:
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

In [None]:
class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample(self, model, n):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x


def train(args, model=None):
    setup_logging(args.run_name)
    device = args.device
    dataloader = get_data(args)
    if not model:
        model = UNet().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=args.image_size, device=device)
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    l = len(dataloader)

    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, (images, _) in enumerate(pbar):
            images = images.to(device)
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            x_t, noise = diffusion.noise_images(images, t)
            predicted_noise = model(x_t, t)
            loss = mse(noise, predicted_noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_postfix(MSE=loss.item())
            logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)

        sampled_images = diffusion.sample(model, n=images.shape[0])
        save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))
        torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt.pt"))
        print(f'epoch {epoch} end')


def launch(model=None):
    parser = argparse.ArgumentParser()
    args, _ = parser.parse_known_args()
    args.run_name = "DDPM_Uncondtional"
    args.epochs = 30
    args.batch_size = 8
    args.image_size = 64
    args.dataset_path = r"dataset"
    args.device = "cuda"
    args.lr = 3e-4
    train(args, model)

In [None]:
!pip install kaggle



In [None]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"ufo137","key":"fffeec136023c4c69d81ed65c4e494d2"}'}

In [None]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 /root/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d arnaud58/landscape-pictures

Downloading landscape-pictures.zip to /content
 98% 610M/620M [00:05<00:00, 223MB/s]
100% 620M/620M [00:05<00:00, 123MB/s]


In [None]:
!mkdir dataset/landscape-pictures -p

In [None]:
from zipfile import ZipFile
import os

zip_path = '/content/landscape-pictures.zip'

extract_path = '/content/dataset/landscape-pictures'

In [None]:
with ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

In [None]:
os.makedirs("models", exist_ok=True)
os.makedirs("results", exist_ok=True)
os.makedirs(os.path.join("models", "DDPM_Uncondtional"), exist_ok=True)
os.makedirs(os.path.join("results", "DDPM_Uncondtional"), exist_ok=True)

In [None]:
LOAD = True
if __name__ == '__main__':
    if LOAD:
        device = "cuda"
        ckpt = torch.load("models/DDPM_Uncondtional/ckpt.pt")
        model = UNet().to(device)
        model.load_state_dict(ckpt)
        launch(model)
    else:
        launch()

100%|██████████| 540/540 [03:29<00:00,  2.58it/s, MSE=0.0249]
999it [01:34, 10.58it/s]


epoch 0 end


100%|██████████| 540/540 [03:29<00:00,  2.58it/s, MSE=0.0119]
999it [01:34, 10.59it/s]


epoch 1 end


100%|██████████| 540/540 [03:28<00:00,  2.59it/s, MSE=0.0268]
999it [01:34, 10.58it/s]


epoch 2 end


100%|██████████| 540/540 [03:28<00:00,  2.59it/s, MSE=0.0121]
999it [01:34, 10.58it/s]


epoch 3 end


100%|██████████| 540/540 [03:29<00:00,  2.58it/s, MSE=0.0172]
999it [01:34, 10.59it/s]


epoch 4 end


100%|██████████| 540/540 [03:29<00:00,  2.58it/s, MSE=0.00738]
999it [01:34, 10.58it/s]


epoch 5 end


100%|██████████| 540/540 [03:28<00:00,  2.59it/s, MSE=0.0277]
999it [01:34, 10.58it/s]


epoch 6 end


100%|██████████| 540/540 [03:28<00:00,  2.59it/s, MSE=0.0141]
999it [01:34, 10.58it/s]


epoch 7 end


100%|██████████| 540/540 [03:28<00:00,  2.59it/s, MSE=0.00808]
999it [01:34, 10.58it/s]


epoch 8 end


100%|██████████| 540/540 [03:28<00:00,  2.59it/s, MSE=0.0178]
999it [01:34, 10.59it/s]


epoch 9 end


100%|██████████| 540/540 [03:29<00:00,  2.58it/s, MSE=0.00176]
999it [01:34, 10.58it/s]


epoch 10 end


100%|██████████| 540/540 [03:28<00:00,  2.59it/s, MSE=0.00386]
999it [01:34, 10.59it/s]


epoch 11 end


100%|██████████| 540/540 [03:29<00:00,  2.58it/s, MSE=0.0119]
999it [01:34, 10.59it/s]


epoch 12 end


100%|██████████| 540/540 [03:28<00:00,  2.59it/s, MSE=0.0103]
999it [01:34, 10.60it/s]


epoch 13 end


100%|██████████| 540/540 [03:28<00:00,  2.60it/s, MSE=0.00889]
999it [01:34, 10.59it/s]


epoch 14 end


100%|██████████| 540/540 [03:30<00:00,  2.57it/s, MSE=0.0148]
999it [01:34, 10.60it/s]


epoch 15 end


100%|██████████| 540/540 [03:28<00:00,  2.58it/s, MSE=0.0133]
999it [01:34, 10.59it/s]


epoch 16 end


100%|██████████| 540/540 [03:34<00:00,  2.52it/s, MSE=0.0214]
999it [01:34, 10.58it/s]


epoch 17 end


100%|██████████| 540/540 [03:29<00:00,  2.57it/s, MSE=0.0127]
999it [01:34, 10.59it/s]


epoch 18 end


100%|██████████| 540/540 [03:28<00:00,  2.59it/s, MSE=0.00362]
999it [01:34, 10.58it/s]


epoch 19 end


100%|██████████| 540/540 [03:29<00:00,  2.58it/s, MSE=0.00905]
999it [01:34, 10.58it/s]


epoch 20 end


100%|██████████| 540/540 [03:29<00:00,  2.58it/s, MSE=0.0177]
999it [01:34, 10.59it/s]


epoch 21 end


100%|██████████| 540/540 [03:29<00:00,  2.58it/s, MSE=0.0146]
999it [01:34, 10.58it/s]


epoch 22 end


100%|██████████| 540/540 [03:29<00:00,  2.58it/s, MSE=0.00765]
999it [01:34, 10.59it/s]


epoch 23 end


100%|██████████| 540/540 [03:28<00:00,  2.58it/s, MSE=0.015]
999it [01:34, 10.59it/s]


epoch 24 end


100%|██████████| 540/540 [03:29<00:00,  2.58it/s, MSE=0.012]
999it [01:34, 10.60it/s]


epoch 25 end


100%|██████████| 540/540 [03:28<00:00,  2.59it/s, MSE=0.0155]
999it [01:34, 10.58it/s]


epoch 26 end


100%|██████████| 540/540 [03:29<00:00,  2.57it/s, MSE=0.00778]
999it [01:34, 10.59it/s]


epoch 27 end


100%|██████████| 540/540 [03:30<00:00,  2.57it/s, MSE=0.0208]
999it [01:34, 10.58it/s]


epoch 28 end


100%|██████████| 540/540 [03:28<00:00,  2.59it/s, MSE=0.00364]
999it [01:34, 10.58it/s]


epoch 29 end


In [None]:
!zip -r /content/results/DDPM_Uncondtional.zip /content/results/DDPM_Uncondtional

In [None]:
files.download('/content/models/DDPM_Uncondtional/ckpt.pt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
files.download('/content/results/DDPM_Uncondtional.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>