In [None]:
from Diffusion import Diffusion
from UNet import UNet
from params import params
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from ImageDataset import ImageDataset

In [None]:
class DiffusionModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = UNet(device=self.device)
        self.diffusion = Diffusion()

    def forward(self, x, t):
        return self.model(x, t)

    def sample_images(self, num_samples):
        return self.diffusion.sample(self.model, num_samples)

    def training_step(self, batch, batch_idx):
        x = batch
        t = self.diffusion.sample_timesteps(x.shape[0]).to(self.device)
        noisy_x, noise = self.diffusion.apply_noise(x, t)
        noisy_x = noisy_x.to(self.device)
        noise = noise.to(self.device)
        
        noise_hat = self.model(noisy_x, t)
        loss = F.mse_loss(noise_hat, noise)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


In [None]:
ds = ImageDataset('../portraits', max_image_count = 1000, img_size=params.img_size)

In [None]:
loader = torch.utils.data.DataLoader(ds, batch_size=4, shuffle=True, num_workers=0)

In [None]:
import matplotlib.pyplot as plt
from pytorch_lightning.callbacks import Callback

class DiffusionPlotCallback(Callback):
    def __init__(self, epochs_to_plot, inverse_transform):
        super().__init__()
        self.cont = 0
        self.inverse_transform = inverse_transform
        self.epochs_to_plot = epochs_to_plot
        self.num_images = 10

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch_idx % 1 == 0:
            imgs = pl_module.sample_images(self.num_images)[0]
            imgs = [self.inverse_transform(img) for img in imgs]

            fig, axs = plt.subplots(1, self.num_images, figsize=(int(self.num_images*1.7), 3))
            # plot images on intervals of len(imgs) / num_images
            for i, ax in enumerate(axs):
                ax.imshow(imgs[int(i * len(imgs) / self.num_images)])
                ax.axis("off")

            plt.savefig(f"plots/{self.cont}_{batch_idx}.png")

            #clear all plots
            plt.clf()
            plt.close()

In [None]:
trainer = pl.Trainer(gpus=1, max_epochs=1, callbacks=[DiffusionPlotCallback(epochs_to_plot=1, inverse_transform=ds.inverse_transform)])
model = DiffusionModel()
trainer.fit(model, loader)