In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision
import pytorch_lightning as pl
from PIL import Image

import numpy as np
import uuid
import matplotlib.pyplot as plt

from diffusion_model import DiffusionModel
from forward_diffusion import ForwardDiffusionModule
from reverse_diffusion import UNet
from utils import Transforms

#### Learner Module

In [2]:
class Learner(pl.LightningModule):
    def __init__(self, diffusion_model: DiffusionModel, lr: float, model_filepath: str, sampled_dir: str):
        super().__init__()

        self.save_hyperparameters(ignore=["diffusion_model"])

        self.diffusion_model = diffusion_model
        self.num_steps = self.diffusion_model.forward_diffusion_model.num_steps

        self.transforms = Transforms()

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        return self.diffusion_model.reverse_diffusion_model(x, t)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.diffusion_model.reverse_diffusion_model.parameters(), lr=self.hparams["lr"])

        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        batch_size = x.shape[0]
        t = torch.randint(1, self.num_steps, size=(batch_size,), device=self.device, dtype=torch.long)

        x_noised, noise = self.diffusion_model.forward_diffusion(x, t)
        predicted_noise = self.diffusion_model.reverse_diffusion(x_noised, t)

        loss = F.mse_loss(noise, predicted_noise)

        self.log_dict(
            {
                "loss": loss
            }
        )

        return loss

    def training_epoch_end(self, outputs):
        # save sampled images
        noise = torch.randn((4, 3, 64, 64), device=self.device)

        sampled_images = self.diffusion_model.sample(noise)
        grid = torchvision.utils.make_grid(sampled_images)

        # transformed_grid = self.transforms.t2i(grid)
        torchvision.utils.save_image(grid, fp=f"{self.hparams['sampled_dir']}/sampled_{self.current_epoch}.jpg")

        # save model
        self.diffusion_model.save(self.hparams["model_filepath"])

#### Hyperparameters

In [3]:
model_name = f"diffusion_{str(uuid.uuid1())[:8]}"

hparams = dict(
    lr=3e-4,
    model_filepath=f"models/{model_name}",
    sampled_dir="sampled",
    betas=[0.0001, 0.2],
    num_steps=256,
    
)

#### Initializing Models

In [4]:
forward_diffusion_model = ForwardDiffusionModule(num_steps=hparams["num_steps"], beta_start=hparams["betas"][0], beta_end=hparams["betas"][1], schedule_type="cosine")
reverse_diffusion_model = UNet()
diffusion_model = DiffusionModel(forward_diffusion_model, reverse_diffusion_model)
learner = Learner(diffusion_model, lr=hparams["lr"], model_filepath=hparams["model_filepath"], sampled_dir=hparams["sampled_dir"])

#### Dataset

In [5]:
transforms = Transforms()

subset_idx = range(10)

cifar = torchvision.datasets.CIFAR10(download=True, root="~/Projekty/datasets/cifar", transform=transforms.i2t)
dataset = Subset(cifar, subset_idx)
loader = DataLoader(dataset, batch_size=4, num_workers=8, shuffle=True)

Files already downloaded and verified


In [8]:
trainer = pl.Trainer(max_epochs=1, accelerator="cpu")
trainer.fit(learner, loader)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name            | Type           | Params
---------------------------------------------------
0 | diffusion_model | DiffusionModel | 23.3 M
---------------------------------------------------
23.3 M    Trainable params
0         Non-trainable params
23.3 M    Total params
93.331    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]