# Stable Diffusion MNIST By Andrew Huang

Load Package.

In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
from diffusers.models import UNet2DModel
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

Define VAE Model.

In [None]:
class VAE(torch.nn.Module):
    def __init__(self, latent_dim = 4) -> None:
        super(VAE, self).__init__()
        self.encoder_layer = torch.nn.Sequential(
            torch.nn.Conv2d(1, latent_dim, 3, padding=1),
            torch.nn.BatchNorm2d(latent_dim),
            torch.nn.LeakyReLU(),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(latent_dim, latent_dim, 3, padding=1),
            torch.nn.BatchNorm2d(latent_dim),
            torch.nn.LeakyReLU(),
            torch.nn.MaxPool2d(2))
        self.decoder_layer = torch.nn.Sequential(
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(latent_dim, latent_dim, 3, padding=1),
            torch.nn.BatchNorm2d(latent_dim),
            torch.nn.LeakyReLU(),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(latent_dim, 1, 3, padding=1),
            torch.nn.Tanh())
        
    def forward(self, inputs):
        encoded = self.encoder_layer(inputs)
        decoded = self.decoder_layer(encoded)
        return decoded

Define StableDiffusion Model.

In [None]:
class StableDiffusion(torch.nn.Module):
    def __init__(self, t = 50, vae_model = "./vae_model.pth", latent_dim = 4) -> None:
        super(StableDiffusion, self).__init__()
        self.t = t
        self.vae_layer = VAE(latent_dim=latent_dim)
        self.vae_layer.load_state_dict(torch.load(vae_model))
        self.unet_layer = UNet2DModel(
            sample_size=8,
            in_channels=16, 
            out_channels=16,
            down_block_types=("DownBlock2D", "AttnDownBlock2D"),
            up_block_types=("AttnUpBlock2D", "UpBlock2D"),
            block_out_channels=(64, 64),
            downsample_type="conv",
            upsample_type="conv",
            norm_num_groups=32)
    
    def encoder(self, inputs):
        encoded = self.vae_layer.encoder_layer(inputs)
        return encoded

    def forward(self, inputs, t):
        outputs = self.unet_layer(inputs, t).sample
        return outputs
    
    def decoder(self, inputs):
        outputs = self.vae_layer.decoder_layer(inputs)
        return outputs
    
    def sample(self, sample_size = 25, use_cuda = True):
        xt = torch.randn(sample_size, 16, 8, 8).float()
        betas = torch.linspace(0.01, 0.2, steps=self.t).float()

        if use_cuda:
            xt = xt.cuda()
            betas = betas.cuda()

        alpha = 1 - betas
        alpha_hat = torch.cumprod(alpha, dim=0)
        sigma = betas.sqrt()
        for i in reversed(range(self.t)):
            z = torch.randn_like(xt) if i > 0 else torch.zeros_like(xt)
            t = torch.full((sample_size,), i).long()
            if use_cuda:
                t = t.cuda()

            lambdas = (1 - alpha[i]) / torch.sqrt(1 - alpha_hat[i])
            xt = (1 / torch.sqrt(alpha[i])) * (xt - (lambdas * self(xt, t))) + (z * sigma[i])

        xt = self.decoder(xt)
        return xt

Define VAE Hyper-Parameters.

In [None]:
epochs = 20
lr = 0.0002
batch_size = 128
sample_size = 25
iters = 0
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=(0.5,), std=(0.5,))])

dataset = torchvision.datasets.MNIST(root="./", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)

vae_model = VAE(latent_dim=16)
vae_model = vae_model.cuda()

vae_optim = torch.optim.Adam(vae_model.parameters(), lr=lr)
mse_loss = torch.nn.MSELoss()

vae_summary = SummaryWriter(log_dir="./vae")

Training VAE Model.

In [None]:
for epoch in range(epochs):
    for inputs, _ in dataloader:
        inputs = inputs.float().cuda()
        
        outputs = vae_model(inputs)
        loss = mse_loss(outputs, inputs)

        vae_optim.zero_grad()
        loss.backward()
        vae_optim.step()
        
        if iters % 20 == 0:
            print("[+] Epoch [%d/%d] Loss: %.4f" % (epoch+1, epochs, loss))
            with torch.no_grad():
                inputs = inputs[0:sample_size]
                outputs = vae_model(inputs)
                fig = plt.figure()
                for i in range(sample_size):
                    image = 0.5 + outputs[i] * 0.5
                    image = image * 255.0
                    image = image.squeeze().byte().cpu().numpy()
                    plt.subplot(5, 5, i+1)
                    plt.imshow(image, cmap="gray")
                    plt.axis("off")
                #plt.show()
                vae_summary.add_scalar("Loss", loss, iters)
                vae_summary.add_figure("Image", fig, iters)
        
        iters += 1
        
vae_summary.close()

vae_model = vae_model.cpu()

torch.save(vae_model.state_dict(), "vae_model.pth")

Define StableDiffusion Hyper-Parameters.

In [None]:
epochs = 20
lr = 0.001
batch_size = 128
sample_size = 25
iters = 0
t_step = 50
betas = torch.linspace(0.01, 0.2, steps=t_step).float().cuda()
alpha = 1 - betas
alpha_hat = torch.cumprod(alpha, dim=0)
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=(0.5,), std=(0.5,))])

dataset = torchvision.datasets.MNIST(root="./", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)

model = StableDiffusion(t_step, latent_dim=16)
model = model.cuda()

optim = torch.optim.AdamW(model.parameters(), lr=lr)
mse_loss = torch.nn.MSELoss()

summary = SummaryWriter(log_dir="./stablediffusion")

Training StableDiffusion Model.

In [None]:
for epoch in range(epochs):
    for inputs, _ in dataloader:
        inputs = inputs.float().cuda()
        
        t = torch.randint(low=0, high=t_step, size=(batch_size,)).long().cuda()
        
        alpha_hats = alpha_hat[t].reshape((-1, 1, 1, 1))

        with torch.no_grad():
            encoded = model.encoder(inputs)
        eps = torch.randn_like(encoded)
        xt_encoded = encoded * torch.sqrt(alpha_hats) + eps * torch.sqrt(1 - alpha_hats)
        
        eps_outputs = model(xt_encoded, t)
        loss = mse_loss(eps_outputs, eps)

        optim.zero_grad()
        loss.backward()
        optim.step()
        
        if iters % 20 == 0:
            print("[+] Epoch [%d/%d] Loss: %.4f" % (epoch+1, epochs, loss))
            with torch.no_grad():
                outputs = model.sample(sample_size=sample_size, use_cuda=True)
                fig = plt.figure()
                for i in range(sample_size):
                    image = 0.5 + outputs[i] * 0.5
                    image = image * 255.0
                    image = image.squeeze().byte().cpu().numpy()
                    plt.subplot(5, 5, i+1)
                    plt.imshow(image, cmap="gray")
                    plt.axis("off")
                #plt.show()
                summary.add_scalar("Loss", loss, iters)
                summary.add_figure("Image", fig, iters)
        
        iters += 1
        
summary.close()

model = model.cpu()

torch.save(model.state_dict(), "stablediffusion_model.pth")

Sample Image.

In [None]:
model = StableDiffusion(t=50, latent_dim=16)
model.load_state_dict(torch.load("./stablediffusion_model.pth"))
model.eval()
with torch.no_grad():
    images = model.sample(36, use_cuda=False)
    for i in range(36):
        image = (images[i].clamp(-1., 1.) + 1) / 2
        image = image * 255.0
        image = image.byte().squeeze().cpu().numpy()
        plt.subplot(6, 6, i + 1)
        plt.axis("off")
        plt.imshow(image, cmap="gray")
    plt.show()