In [None]:
import numpy as np
import torch
import torch.nn as nn
from accelerate import Accelerator
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image

In [None]:
class CNNBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        *,
        expected_shape,
        act=nn.GELU,
        kernel_size=7,
    ):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.LayerNorm((out_channels, *expected_shape)),
            act(),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
import torch.nn.init as init


class CNN(nn.Module):
    def __init__(
        self,
        in_channels,
        expected_shape=(28, 28),
        n_hidden=(64, 128, 64),
        kernel_size=7,
        last_kernel_size=3,
        time_embeddings=16,
        act=nn.GELU,
    ) -> None:
        super().__init__()
        last = in_channels

        self.blocks = nn.ModuleList()
        for hidden in n_hidden:
            self.blocks.append(
                CNNBlock(
                    last,
                    hidden,
                    expected_shape=expected_shape,
                    kernel_size=kernel_size,
                    act=act,
                )
            )
            last = hidden

        # The final layer, we use a regular Conv2d to get the
        # correct scale and shape (and avoid applying the activation)
        self.blocks.append(
            nn.Conv2d(
                last,
                in_channels,
                last_kernel_size,
                padding=last_kernel_size // 2,
            )
        )
        # final_conv = nn.Conv2d(
        #     last,
        #     in_channels,
        #     last_kernel_size,
        #     padding=last_kernel_size // 2,
        # )
        # # Use Xavier initialization for the final convolutional layer
        # init.xavier_uniform_(final_conv.weight)
        # init.constant_(final_conv.bias, 0)
        # self.blocks.append(final_conv)

        ## This part is literally just to put the single scalar "t" into the CNN
        ## in a nice, high-dimensional way:
        self.time_embed = nn.Sequential(
            nn.Linear(time_embeddings * 2, 128),
            act(),
            nn.Linear(128, 128),
            act(),
            nn.Linear(128, 128),
            act(),
            nn.Linear(128, n_hidden[0]),
        )
        frequencies = torch.tensor(
            [0] + [2 * np.pi * 1.5**i for i in range(time_embeddings - 1)]
        )
        self.register_buffer("frequencies", frequencies)

    def time_encoding(self, t: torch.Tensor) -> torch.Tensor:
        phases = torch.concat(
            (
                torch.sin(t[:, None] * self.frequencies[None, :]),
                torch.cos(t[:, None] * self.frequencies[None, :]) - 1,
            ),
            dim=1,
        )

        return self.time_embed(phases)[:, :, None, None]

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # Shapes of input:
        #    x: (batch, chan, height, width)
        #    t: (batch,)

        embed = self.blocks[0](x)
        # ^ (batch, n_hidden[0], height, width)

        # Add information about time along the diffusion process
        #  (Providing this information by superimposing in latent space)
        embed += self.time_encoding(t)
        #         ^ (batch, n_hidden[0], 1, 1) - thus, broadcasting
        #           to the entire spatial domain

        for block in self.blocks[1:]:
            embed = block(embed)

        return embed

In [None]:
def create_sigma_schedule(n_T):
    # Sigmoid schedule from 0 to 1 over n_T steps
    timesteps = torch.linspace(-4, 6, steps=n_T)  # Wide range for a gradual sigmoid
    # maybe change to 6
    sigma_schedule = torch.sigmoid(
        timesteps
    )  # Sigmoid function for non-linear scheduling
    return sigma_schedule

In [None]:
import torch
import numpy as np
import random
import string
from PIL import Image


class DDPM(nn.Module):
    def __init__(
        self,
        gt,
        n_T: int,
        criterion: nn.Module = nn.MSELoss(),
    ) -> None:
        super().__init__()

        self.gt = gt
        self.n_T = n_T

        noise_schedule = create_sigma_schedule(n_T)
        self.register_buffer("sigma_t", noise_schedule)
        self.sigma_t

        self.criterion = criterion

    def apply_qr_transformation(self, x, intensity):
        """Apply a QR code pattern onto an image with a given intensity."""
        # x is a batch of images of shape (B, C, H, W)
        batch_size, channels, height, width = x.size()
        qr_codes = self.create_batch_random_qr(batch_size, (channels, height, width))
        # Make sure intensity is broadcastable over the batch
        intensity = intensity.view(-1, 1, 1, 1).to(x.device)
        qr_codes = qr_codes.to(x.device)
        # Apply the QR code pattern with the given intensity
        transformed_x = x * (1 - intensity) + qr_codes * (intensity)

        return transformed_x

    def forward(self, x):
        # Sample a random time step t for each batch element
        t = torch.randint(0, self.n_T, (x.shape[0],), device=x.device)

        # Get the sigma for the QR code transformation at the corresponding time step t
        sigma_t = self.sigma_t[t]

        # Transform each image in the batch x with the corresponding QR code pattern
        z_t = self.apply_qr_transformation(x, sigma_t)

        # The CNN tries to predict the original image x from the QR-coded image z_t
        preds = self.gt(z_t, t / self.n_T)

        # Return the loss between the original images x and the predictions preds
        loss = self.criterion(x, preds)

        return loss

    def create_batch_random_qr(self, batch_size, image_size):
        channels, height, width = image_size
        # Initialize an empty list to store the pseudo-QR images
        pseudo_qr_batch_list = []

        for _ in range(batch_size):
            # Create a random binary pattern
            pseudo_qr_array = np.random.choice(
                [0, 255], size=(height, width), p=[0.5, 0.5]
            ).astype(np.uint8)

            # If more than one channel is needed, replicate the grayscale pattern across the channels
            if channels > 1:
                pseudo_qr_array = np.repeat(
                    pseudo_qr_array[..., np.newaxis], channels, axis=-1
                )

            pseudo_qr_batch_list.append(pseudo_qr_array)

        # Convert the list of arrays into a single NumPy array
        pseudo_qr_batch = np.stack(pseudo_qr_batch_list, axis=0)

        # Convert the NumPy array to a PyTorch tensor and normalize to [0, 1]
        pseudo_qr_batch_tensor = torch.from_numpy(pseudo_qr_batch).float() / 255.0

        return pseudo_qr_batch_tensor.unsqueeze(
            1
        )  # Add a channel dimension, resulting in shape (B, 1, H, W)

    def sample(self, n_sample, size, device):
        # Start with a fully 'QR-coded' random image
        z_t = self.create_batch_random_qr(n_sample, size).to(device)
        _one = torch.ones(n_sample, device=device)
        for t in reversed(range(0, self.n_T)):
            if t > 0:
                sigma_t = self.sigma_t[t]
                sigma_t_minus_1 = self.sigma_t[t - 1]

                # Here we generate the predictions from the model, which are presumably 'less QR-coded' than z_t
                x_0_pred = self.gt(z_t, (t / self.n_T) * _one)

                # Undo the QR code transformation by a schedule, gradually revealing the MNIST image
                z_t = (
                    z_t
                    - self.apply_qr_transformation(x_0_pred, sigma_t)
                    + self.apply_qr_transformation(x_0_pred, sigma_t_minus_1)
                )
            else:
                # The final step should be the MNIST image with no QR code pattern applied
                z_t = x_0_pred

        return z_t

In [None]:
tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0))])
dataset = MNIST("./data", train=True, download=True, transform=tf)
dataloader = DataLoader(
    dataset, batch_size=128, shuffle=True, num_workers=4, drop_last=True
)

In [None]:
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau


gt = CNN(
    in_channels=1, expected_shape=(28, 28), n_hidden=(16, 32, 64, 32, 16), act=nn.GELU
)
# For testing: (16, 32, 32, 16)
# For more capacity (for example): (64, 128, 256, 128, 64)
ddpm = DDPM(gt=gt, n_T=1000)
optim = torch.optim.Adam(ddpm.parameters(), lr=1e-4, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=30, gamma=0.1)
# Initialise the scheduler with a gamma close to 1 for slow decay:
# scheduler = ExponentialLR(optim, gamma=0.99)
scheduler = ReduceLROnPlateau(optim, mode="min", factor=0.1, patience=5, verbose=True)

In [None]:
accelerator = Accelerator()

# We wrap our model, optimizer, and dataloaders with `accelerator.prepare`,
# which lets HuggingFace's Accelerate handle the device placement and gradient accumulation.
ddpm, optim, dataloader = accelerator.prepare(ddpm, optim, dataloader)

In [None]:
accelerator.device

In [None]:
from torchvision.utils import save_image, make_grid

for x, _ in dataloader:
    break

with torch.no_grad():
    ddpm(x)

    B, C, H, W = x.shape
    eps = ddpm.create_batch_random_qr(1, (1, 28, 28))
    filled_tensor = torch.full(eps.size(), -0.5)
    eps += filled_tensor
    save_image(eps, "random_qr.png")

    t = torch.randint(0, ddpm.n_T, (x.shape[0],), device=x.device)

    sigma_t = ddpm.sigma_t[t]

    # Transform each image in the batch x with the corresponding QR code pattern
    transformed_x = ddpm.apply_qr_transformation(x, sigma_t)

    save_image(transformed_x, "./contents/transformed_image.png", nrow=16)

    xh = ddpm.sample(16, (1, 28, 28), accelerator.device)
    # Save samples to `./contents` directory
    save_image(xh, "./contents/grid_image.png", nrow=4)

In [None]:
def reformat(dataset):
    processed_imgs = []
    for img in dataset:
        resized_img = TF.resize(img, size=(299, 299))
        # Convert grayscale to RGB by repeating the grayscale channel 3 times
        rgb_img = resized_img.repeat(1, 3, 1, 1)
        processed_imgs.append(rgb_img)
    # Stack all processed images into a single tensor
    return torch.cat(processed_imgs, dim=0).cpu()

In [None]:
from torchvision.transforms import functional as TF

# 'dataloader' is the DataLoader for MNIST
real_mnist, _ = next(iter(dataloader))

# If necessary, move to the same device as your generated images
real_mnist = real_mnist.to(accelerator.device)

processed_mnist = []
for mnist in real_mnist:
    # Resize image to 299x299
    resized_mnist = TF.resize(mnist, size=(299, 299))
    # Convert grayscale to RGB by repeating the grayscale channel 3 times
    rgb_mnist = resized_mnist.repeat(1, 3, 1, 1)
    processed_mnist.append(rgb_mnist)
# Stack all processed images into a single tensor
real_images = torch.cat(processed_mnist, dim=0)

In [None]:
from torchvision.models import inception_v3, Inception_V3_Weights
from scipy.linalg import sqrtm


def calculate_fid_custom(real_images, gen_images):
    # Load the pretrained Inception v3 model
    inception_model = inception_v3(weights=Inception_V3_Weights.DEFAULT)
    inception_model.eval()
    # Modify the model to return features from an intermediate layer
    inception_model.fc = torch.nn.Identity()

    # Extract features
    with torch.no_grad():
        gen_features = inception_model(gen_images)
        gen_features = gen_features.cpu().numpy()
        real_features = inception_model(real_images)
        real_features = real_features.cpu().numpy()

    # Calculate mean and covariance of features
    mu_real, sigma_real = np.mean(real_features, axis=0), np.cov(
        real_features, rowvar=False
    )
    mu_gen, sigma_gen = np.mean(gen_features, axis=0), np.cov(
        gen_features, rowvar=False
    )

    # Calculate the FID score
    ssdiff = np.sum((mu_real - mu_gen) ** 2.0)
    covmean = sqrtm(sigma_real.dot(sigma_gen))

    # Check for imaginary numbers in covmean and eliminate them
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = ssdiff + np.trace(sigma_real + sigma_gen - 2.0 * covmean)

    return fid

In [None]:
n_epoch = 100
losses = []
fid_scores = []

for i in range(n_epoch):
    ddpm.train()

    pbar = tqdm(dataloader)  # Wrap our loop with a visual progress bar
    for x, _ in pbar:
        optim.zero_grad()

        loss = ddpm(x)

        # loss.backward()
        # ^Technically should be `accelerator.backward(loss)` but not necessary for local training
        accelerator.backward(loss)

        losses.append(loss.item())
        avg_loss = np.average(losses[min(len(losses) - 100, 0) :])
        pbar.set_description(
            f"loss: {avg_loss:.3g}"
        )  # Show running average of loss in progress bar

        optim.step()

    # Step the scheduler at the end of each epoch
    scheduler.step(avg_loss)
    ddpm.eval()
    with torch.no_grad():
        xh = ddpm.sample(
            16, (1, 28, 28), accelerator.device
        )  # Can get device explicitly with `accelerator.device`
        save_image(xh, f"./contents/ddpm_sample_{i:04d}.png")
        # Save samples to `./contents` directory

        fid = calculate_fid_custom(real_images[:16].cpu(), reformat(xh).cpu())
        fid_scores.append(fid)

        # save model
        torch.save(ddpm.state_dict(), f"./ddpm_mnist.pth")
        torch.save(losses, "./ddpm_losses.pt")
        torch.save(fid_scores, "./fid_scores.pt")

In [None]:
ddpm1 = DDPM(gt=gt, n_T=1000)
ddpm1.load_state_dict(torch.load("../pure_qr.pth", map_location=torch.device("cpu")))

In [None]:
import matplotlib.pyplot as plt

xh = ddpm.sample(16, (1, 28, 28), device=accelerator.device)
save_image(xh, "./contents/ddpm_sample_final.png")
plt.imshow(xh[0, 0].cpu().detach().float())