In [1]:
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from IPython.display import HTML
from matplotlib.animation import FuncAnimation, PillowWriter
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm

In [2]:
class CustomDataset(Dataset):
    def __init__(self, images: str, labels: str, transforms):
        self.images = np.load(images, allow_pickle=True)
        self.labels = np.load(labels, allow_pickle=True)
        print("Images' shape: (%d, %d, %d, %d)" % self.images.shape)
        print("Labels' shape: (%d, %d)" % self.labels.shape)
        self.transforms = transforms

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, index):
        image = self.images[index]
        label = torch.tensor(self.labels[index], dtype=torch.int64)
        if self.transforms:
            image = self.transforms(image)
        return (image, label)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [3]:
class ResidualConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, skip: bool = False) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.skip = skip

        if self.in_channels != self.out_channels:
            self.shortcut = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
            )
        else:
            self.shortcut = nn.Identity()

        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
            ),  # same size
            nn.BatchNorm2d(num_features=out_channels),
            nn.GELU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
            ),  # same size
            nn.BatchNorm2d(num_features=out_channels),
            nn.GELU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.conv1(x)
        out = self.conv2(x1)
        if self.skip:
            out = self.shortcut(x) + out
        return out


class UnetDown(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            ResidualConvBlock(in_channels=in_channels, out_channels=out_channels),
            ResidualConvBlock(in_channels=out_channels, out_channels=out_channels),
            nn.MaxPool2d(2),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)


class UnetUp(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
            ),
            ResidualConvBlock(in_channels=out_channels, out_channels=out_channels),
            ResidualConvBlock(in_channels=out_channels, out_channels=out_channels),
        )

    def forward(self, x: torch.Tensor, skip_tensor: torch.Tensor) -> torch.Tensor:
        x = torch.cat([x, skip_tensor], dim=1)

        x = self.layers(x)
        return x


class EmbedFC(nn.Module):
    def __init__(self, input_dim: int, embed_dim: int) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.embed_dim = embed_dim

        self.layers = nn.Sequential(
            nn.Linear(input_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(-1, self.input_dim)
        return self.layers(x)

In [4]:
class ContextUnet(nn.Module):
    def __init__(
        self,
        in_channels: int,
        hidden_dim: int = 256,
        context_feature_dim: int = 10,
        size: int = 28,
    ) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.hidden_dim = hidden_dim
        self.context_feature_dim = context_feature_dim
        self.size = size  #assume h == w. must be divisible by 4, so 28,24,20,16...

        self.init_conv = ResidualConvBlock(
            in_channels=in_channels,
            out_channels=hidden_dim,
            skip=True
        )

        self.down1 = UnetDown(hidden_dim, hidden_dim)   # down1 #[10, 256, 8, 8]
        self.down2 = UnetDown(hidden_dim, 2 * hidden_dim)  # down2 #[10, 256, 4,  4]

        # original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
        self.to_vec = nn.Sequential(
            nn.AvgPool2d(4),
            nn.GELU(),
        )

        # embed timestep
        self.time_embed_1 = EmbedFC(1, 2 * hidden_dim)
        self.time_embed_2 = EmbedFC(1, 1 * hidden_dim)

        # embed context label
        self.context_embed_1 = EmbedFC(context_feature_dim, 2 * hidden_dim)
        self.context_embed_2 = EmbedFC(context_feature_dim, hidden_dim)

        self.up0 = nn.Sequential(
            nn.Upsample(scale_factor=4),
            nn.Conv2d(
                in_channels=2 * hidden_dim,
                out_channels=2 * hidden_dim,
                kernel_size=1,
                stride=1,
                padding=0,
            ),
            nn.GroupNorm(
                num_groups=8,
                num_channels=2 * hidden_dim
            ),
            nn.ReLU(),
        )

        self.up1 = UnetUp(
            in_channels=4 * hidden_dim,
            out_channels=hidden_dim
        )
        self.up2 = UnetUp(
            in_channels=2 * hidden_dim,
            out_channels=hidden_dim
        )

        self.out = nn.Sequential(
            nn.Conv2d(
                in_channels=2 * hidden_dim,
                out_channels=hidden_dim,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.GroupNorm(
                num_groups=8,
                num_channels=hidden_dim
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=hidden_dim,
                out_channels=self.in_channels,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
        )

    def forward(
        self,
        image: torch.Tensor,
        timesteps: torch.Tensor,
        context_label: Optional[None | torch.Tensor] = None,
    ) -> torch.Tensor:
        x = self.init_conv(image)

        down1 = self.down1(x)
        down2 = self.down2(down1)

        latent = self.to_vec(down2)

        if context_label is None:
            context_label = torch.zeros(x.shape[0], self.context_feature_dim).to(x.device)
        # context
        context_embed_1 = self.context_embed_1(context_label).view(-1, self.hidden_dim * 2, 1, 1)

        context_embed_2 = self.context_embed_2(context_label).view(-1, self.hidden_dim, 1, 1)
        # timesteps
        time_embed_1 = self.time_embed_1(timesteps).view(-1, self.hidden_dim * 2, 1, 1)
        time_embed_2 = self.time_embed_2(timesteps).view(-1, self.hidden_dim, 1, 1)

        up1 = self.up0(latent)
        up2 = self.up1(context_embed_1 * up1 + time_embed_1, down2)
        up3 = self.up2(context_embed_2 * up2 + time_embed_2, down1)

        out = self.out(torch.cat([up3, x], dim=1))
        return out


In [5]:
timesteps = 500
beta_start = 1e-4
beta_end = 0.02

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
hidden_dim = 64
context_dim = 5
size = 16
batch_size = 100
n_epoch = 32
lr = 1e-3

In [6]:
def denormalize_and_clip(image: torch.Tensor) -> torch.Tensor:
    image *= 0.5
    image += 0.5
    return image.clip(0, 1)


class NoiseScheduler(nn.Module):
    def __init__(
        self,
        timesteps: int = 24,
        beta_start: float = 1e-4,
        beta_end: float = 0.6
    ):
        super().__init__()
        self.timesteps = timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end

        beta = torch.linspace(beta_start, beta_end, timesteps)
        alpha = 1. - beta
        alpha_bar = torch.cumprod(alpha, 0)

        self.register_buffer('beta', beta)
        self.register_buffer('alpha', alpha)
        self.register_buffer('alpha_bar', alpha_bar)

    def add_noise(self, x: torch.Tensor, t: torch.Tensor, noise: torch.Tensor):
        """
        Adds a single step of noise
        :param x: image we are adding noise to
        :param t: step number, 0 indexed (0 <= t < steps)
        :return: image with noise added
        """
        alpha_bar = self.alpha_bar[t].view(-1, 1, 1, 1)
        return torch.sqrt(alpha_bar) * x + torch.sqrt(1 - alpha_bar) * noise

    def sample_prev_step(self, xt, t, pred_noise):
        z = torch.randn_like(xt)
        z[t.expand_as(z) == 0] = 0

        mean = (1 / torch.sqrt(self.alpha[t])) * (xt - (self.beta[t] / torch.sqrt(1 - self.alpha_bar[t])) * pred_noise)
        var = ((1 - self.alpha_bar[t - 1])  / (1 - self.alpha_bar[t])) * self.beta[t]
        sigma = torch.sqrt(var)

        x = mean + sigma * z
        return x

In [7]:
noise_scheduler = NoiseScheduler(timesteps, beta_start, beta_end).to(device)

nn_model = ContextUnet(
    in_channels=3,
    hidden_dim=hidden_dim,
    context_feature_dim=context_dim,
    size=size
).to(device)

In [12]:
dataset = CustomDataset("sprites_1788_16x16.npy", "sprite_labels_nc_1788_16x16.npy", transforms=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
optim = torch.optim.AdamW(nn_model.parameters(), lr=lr)

Images' shape: (89400, 16, 16, 3)
Labels' shape: (89400, 5)


In [None]:
criterion = nn.MSELoss()
nn_model.train()

for epoch in range(n_epoch):

    # Decay learning rate
    optim.param_groups[0]["lr"] = lr * (1 - epoch / n_epoch)

    pbar = tqdm(dataloader, mininterval=2)
    losses = []
    for x, _ in pbar:
        optim.zero_grad()

        x = x.to(device)

        noise = torch.randn_like(x)
        t = torch.randint(1, timesteps, (x.shape[0],), device=device)
        x_pert = noise_scheduler.add_noise(x, t, noise)

        pred_noise = nn_model(x_pert, t / timesteps)

        loss = criterion(pred_noise, noise)
        loss.backward()

        optim.step()

        losses.append(loss.detach().cpu().item())
    if epoch % 4 == 0 or epoch == n_epoch - 1:
        print(f"Epoch {epoch} - Loss {np.mean(losses)}")

In [14]:
@torch.no_grad()
def sample_ddpm(n_sample: int, channels: int = 3, size: int = 16, save_rate: int = 20):
    # x_T ~ N(0, 1), sample initial noise
    samples = torch.randn(n_sample, channels, size, size).to(device)

    # array to keep track of generated steps for plotting
    intermediate = []
    for i in range(timesteps - 1, 0, -1):
        print(f'Sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

        eps = nn_model(samples, t)
        tensor_i = torch.tensor(i, device=device).view(1,)
        samples = noise_scheduler.sample_prev_step(samples, tensor_i, eps)
        if i % save_rate ==0 or i == timesteps or i<8:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate

In [15]:
def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn):
    ncols = n_sample // nrows
    sx_gen_store = np.moveaxis(x_gen_store, 2, 4)  # change to Numpy image format (h,w,channels) vs (channels,h,w)
    nsx_gen_store = denormalize_and_clip(sx_gen_store)
    # create gif of images evolving over time, based on x_gen_store
    fig, axs = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        sharex=True,
        sharey=True,
        figsize=(ncols, nrows)
    )
    def animate_diff(i, store):
        print(f'gif animating frame {i} of {store.shape[0]}', end='\r')
        plots = []
        for row in range(nrows):
            for col in range(ncols):
                axs[row, col].clear()
                axs[row, col].set_xticks([])
                axs[row, col].set_yticks([])
                plots.append(axs[row, col].imshow(store[i, (row * ncols) + col]))
        return plots
    ani = FuncAnimation(
        fig,
        animate_diff,
        fargs=[nsx_gen_store],
        interval=200,
        blit=False,
        repeat=True,
        frames=nsx_gen_store.shape[0]
    )
    plt.close()
    ani.save(save_dir + f"{fn}.gif", dpi=100, writer=PillowWriter(fps=5))
    return ani

In [None]:
nn_model.eval()

plt.clf()
samples, intermediate_ddpm = sample_ddpm(32)
save_dir = "/content/"
animation_ddpm = plot_sample(intermediate_ddpm, 32, 4, save_dir, "ani_run")
HTML(animation_ddpm.to_jshtml())