<a href="https://colab.research.google.com/github/Yang-star-source/Latent_Diffusion_In_Image_Inpainting/blob/main/Image_Inpainting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Define VAE

In [None]:
import torch
import torch.nn as nn

class ResNetBlock(nn.Module):
    def __init__(self,in_channels,out_channels,dropout_prob = 0.0):
        super(ResNetBlock, self).__init__()
        self.norm1 = nn.GroupNorm(32,in_channels)
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
        self.norm2 = nn.GroupNorm(32,out_channels)
        self.drop = nn.Dropout(dropout_prob)
        self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
        self.silu = nn.SiLU()

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels,out_channels,kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def forward(self,x):
        x1 = x

        x = self.norm1(x)
        x = self.silu(x)
        x = self.conv1(x)

        x = self.norm2(x)
        x = self.silu(x)
        x = self.drop(x)
        x = self.conv2(x)
        return x + self.shortcut(x1)

class AttentionBlock(nn.Module):
    def __init__(self,in_channels):
        super(AttentionBlock, self).__init__()
        self.norm = nn.GroupNorm(32,in_channels)

        self.to_q = nn.Linear(in_channels,in_channels)
        self.to_k = nn.Linear(in_channels,in_channels)
        self.to_v = nn.Linear(in_channels,in_channels)

        self.to_out = nn.Linear(in_channels,in_channels)

    def forward(self,x):
        residual = x
        B,C,H,W = x.shape
        x = self.norm(x)
        x = x.view(B,C,-1).permute(0,2,1)

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        attn = torch.bmm(q,k.permute(0,2,1)) # batch matrix multiplication
        attn = attn * (C**(-0.5))  # sqrt(dk)
        attn = attn.softmax(dim=-1)
        attn = torch.bmm(attn,v)

        out = self.to_out(attn)
        out = out.permute(0,2,1).view(B,C,H,W)

        return out + residual

class MidBlock(nn.Module):
    def __init__(self,in_channels):
        super(MidBlock, self).__init__()
        self.res1 = ResNetBlock(in_channels,in_channels)
        self.attn1 = AttentionBlock(in_channels)
        self.res2 = ResNetBlock(in_channels,in_channels)

    def forward(self,x):
        x = self.res1(x)
        x = self.attn1(x)
        x = self.res2(x)
        return x

class DownBlock(nn.Module):
    def __init__(self,in_channels,out_channels,has_attn=False):
        super(DownBlock, self).__init__()
        self.res1 = ResNetBlock(in_channels,out_channels)
        self.res2 = ResNetBlock(out_channels,out_channels)
        self.down = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=2,padding=1)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x):
        x = self.res1(x)
        x = self.res2(x)
        x = self.attn(x)
        x = self.down(x)
        return x

class Encoder(nn.Module):
    def __init__(self,in_channels=3,out_channels=4):
        super(Encoder, self).__init__()
        self.inp = nn.Conv2d(in_channels,64,kernel_size=3,padding=1)
        self.down_block = nn.Sequential(
            DownBlock(64,128),
            DownBlock(128,256),
            DownBlock(256,512,has_attn = True)
        )
        self.bottle = MidBlock(512)
        self.out = nn.Sequential(
            nn.GroupNorm(32,512),
            nn.SiLU(),
            nn.Conv2d(512,out_channels*2,kernel_size=1)
        )

    def reparameterize(self,x):
        mean , log_var = torch.chunk(x,2,dim=1)
        log_var = torch.clamp(log_var, -30.0, 20.0)
        D_kl = 0.5 *(torch.exp(log_var) + mean**2 - log_var - 1)
        D_kl = torch.sum(D_kl,dim=[1,2,3]).mean() # Mean is for batch dimention
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        z = mean + eps*std
        return z,D_kl

    def forward(self,x):
        x = self.inp(x)
        x = self.down_block(x)
        x = self.bottle(x)
        x = self.out(x)
        z,D_kl = self.reparameterize(x)
        return z,D_kl

class UpBlock(nn.Module):
    def __init__(self,in_channels,out_channels,has_attn=False):
        super().__init__()
        self.res1 = ResNetBlock(in_channels,out_channels)
        self.res2 = ResNetBlock(out_channels,out_channels)
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2), # nearest mode by default
            nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
        )
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x):
        x = self.res1(x)
        x = self.res2(x)
        x = self.attn(x)
        x = self.up(x)
        return x

class Decoder(nn.Module):
    def __init__(self,in_channels=4,out_channels=3):
        super(Decoder, self).__init__()
        self.inp = nn.Conv2d(in_channels,512,kernel_size=3,padding=1)
        self.bottle = MidBlock(512)
        self.up_block = nn.Sequential(
            UpBlock(512,256,has_attn=True),
            UpBlock(256,128),
            UpBlock(128,64)
        )
        self.out = nn.Sequential(
            nn.GroupNorm(32,64),
            nn.SiLU(),
            nn.Conv2d(64,out_channels,kernel_size=3,padding=1)
        )

    def forward(self,x):
        x = self.inp(x)
        x = self.bottle(x)
        x = self.up_block(x)
        x = self.out(x)
        return x

class PatchGan(nn.Module):
    def __init__(self,in_channels=3):
        super(PatchGan, self).__init__()
        self.model=nn.Sequential(
            nn.Conv2d(in_channels,64,kernel_size=3,stride=2,padding=1),
            nn.SiLU(),

            nn.Conv2d(64,128,kernel_size=3,stride=2,padding=1),
            nn.GroupNorm(32,128),
            nn.SiLU(),

            nn.Conv2d(128,256,kernel_size=3,stride=2,padding=1),
            nn.GroupNorm(32,256),
            nn.SiLU(),

            nn.Conv2d(256,512,kernel_size=3,stride=2,padding=1),
            nn.GroupNorm(32,512),
            nn.SiLU(),

            nn.Conv2d(512,1,kernel_size=3,stride=1,padding=1)
        )

    def forward(self,x):
        x = self.model(x)
        return x

class VAE(nn.Module):
    def __init__(self,in_channels=3,out_channels=4):
        super(VAE, self).__init__()
        self.encoder = Encoder(in_channels,out_channels)
        self.decoder = Decoder(out_channels,in_channels)

    def forward(self,x):
        z,D_kl = self.encoder(x)
        x = self.decoder(z)
        return x,D_kl



In [None]:
def raw_time_embedding(time , dim):
    if not torch.is_tensor(time):
        time = torch.tensor(time)

    device=time.device
    if time.ndim == 0:
        time = time.unsqueeze(0).unsqueeze(1)
    else: # This will be execute in training since t shape is (B)
        time = time.unsqueeze(1)
        # (B) -> (B,1)

    # important to specify device
    i=torch.arange(dim//2,device=device).float()
    obj = (time)/(10000**(2*i/dim))
    return torch.cat([torch.sin(obj),torch.cos(obj)],dim=1)

class time_embedding(nn.Module):
    def __init__(self,dim):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(dim,dim),
            nn.SiLU(),
            nn.Linear(dim,dim)
        )

    def forward(self,X):
        return self.net(X)

class DiffusionResNetBlock(nn.Module):
    def __init__(self,in_channels,out_channels,time_emb_dim,dropout_prob = 0.0 ):
        super().__init__()
        self.norm1 = nn.GroupNorm(32,in_channels)
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
        self.norm2 = nn.GroupNorm(32,out_channels)
        self.drop = nn.Dropout(dropout_prob)
        self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
        self.silu = nn.SiLU()

        # purpose of this projection is to match channel dim , before adding to x
        self.time_proj = nn.Linear(time_emb_dim,out_channels)

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels,out_channels,kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def forward(self,x,time_emb):
        x1 = x

        x = self.norm1(x)
        x = self.silu(x)
        x = self.conv1(x)

        # (B,C) -> (B,C,1,1)
        # why silu ?
        emb = self.time_proj(self.silu(time_emb))
        x = x + emb[:, :, None, None]

        x = self.norm2(x)
        x = self.silu(x)
        x = self.drop(x)
        x = self.conv2(x)
        return x + self.shortcut(x1)

class DiffusionAttentionBlock(nn.Module):
    def __init__(self,in_channels):
        super().__init__()
        self.norm = nn.GroupNorm(32,in_channels)

        self.to_q = nn.Linear(in_channels,in_channels)
        self.to_k = nn.Linear(in_channels,in_channels)
        self.to_v = nn.Linear(in_channels,in_channels)

        self.to_out = nn.Linear(in_channels,in_channels)

    def forward(self,x):
        residual = x
        B,C,H,W = x.shape
        x = self.norm(x)
        x = x.view(B,C,-1).permute(0,2,1)

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        attn = torch.bmm(q,k.permute(0,2,1)) # batch matrix multiplication
        attn = attn * (C**(-0.5))  # sqrt(dk)
        attn = attn.softmax(dim=-1)
        attn = torch.bmm(attn,v)

        out = self.to_out(attn)
        out = out.permute(0,2,1).view(B,C,H,W)

        return out + residual

class DiffusionDownBlock(nn.Module):
    def __init__(self,in_channels,out_channels,time_emb_dim,has_attn=False):
        super().__init__()
        self.res1 = DiffusionResNetBlock(in_channels,out_channels,time_emb_dim)
        self.res2 = DiffusionResNetBlock(out_channels,out_channels,time_emb_dim)
        self.down = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=2,padding=1)
        if has_attn:
            self.attn = DiffusionAttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x,time_emb):
        x = self.res1(x,time_emb)
        x = self.res2(x,time_emb)
        x = self.attn(x)

        skip_connection = x

        x = self.down(x)
        return x , skip_connection

class DiffusionMidBlock(nn.Module):
    def __init__(self,in_channels,time_emb_dim):
        super().__init__()

        self.res1 = DiffusionResNetBlock(in_channels,in_channels,time_emb_dim)
        self.attn = DiffusionAttentionBlock(in_channels)
        self.res2 = DiffusionResNetBlock(in_channels,in_channels,time_emb_dim)

    def forward(self,x,time_emb):
        x = self.res1(x,time_emb)
        x = self.attn(x)
        x = self.res2(x,time_emb)
        return x

class DiffusionUpBlock(nn.Module):
    def __init__(self,in_channels,out_channels,time_emb_dim,has_attn=False):
        super().__init__()

        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels,in_channels,kernel_size=3,padding=1)
        )
        # Why input *2 ? Because we have to concatenate the channels
        self.res1 = DiffusionResNetBlock(in_channels*2,out_channels,time_emb_dim)
        self.res2 = DiffusionResNetBlock(out_channels,out_channels,time_emb_dim)
        if has_attn:
            self.attn = DiffusionAttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x,skip_connection,time_emb):
        x = self.up(x)

        # cancatenate at channels dimension
        x = torch.cat([x,skip_connection],dim=1)
        x = self.res1(x,time_emb)
        x = self.res2(x,time_emb)
        x = self.attn(x)
        return x

class DiffusionUnet(nn.Module):
    def __init__(self,in_channels=4,out_channels=4,time_dim=256):
        super().__init__()

        self.time_dim = time_dim
        self.time_embedding = time_embedding(time_dim)

        self.init_conv = nn.Conv2d(in_channels,64,kernel_size=3,padding=1)

        self.down1 = DiffusionDownBlock(64,64,time_dim)
        self.down2 = DiffusionDownBlock(64,128,time_dim)
        self.down3 = DiffusionDownBlock(128,128,time_dim,has_attn=True)
        self.down4 = DiffusionDownBlock(128,256,time_dim,has_attn=True)

        self.mid = DiffusionMidBlock(256,time_dim)

        self.up1 = DiffusionUpBlock(256,128,time_dim,has_attn=True)
        self.up2 = DiffusionUpBlock(128,128,time_dim,has_attn=True)
        self.up3 = DiffusionUpBlock(128,64,time_dim)
        self.up4 = DiffusionUpBlock(64,64,time_dim)

        self.out = nn.Sequential(
            nn.GroupNorm(32,64),
            nn.SiLU(),
            nn.Conv2d(64,out_channels,kernel_size=3,padding=1)
        )

    def forward(self,x,t):
        t = raw_time_embedding(t,self.time_dim)
        emb = self.time_embedding(t)

        x = self.init_conv(x)

        x1 , skip1 = self.down1(x,emb)
        x2 , skip2 = self.down2(x1,emb)
        x3 , skip3 = self.down3(x2,emb)
        x4 , skip4 = self.down4(x3,emb)

        x = self.mid(x4,emb)

        x = self.up1(x,skip4,emb)
        x = self.up2(x,skip3,emb)
        x = self.up3(x,skip2,emb)
        x = self.up4(x,skip1,emb)

        x = self.out(x)
        return x




# Simple mask function for testing purpose

In [None]:
def generate_mask(batch_size, channels, height, width,device):
    """
    Generates a random box mask.
    1 = Keep the pixel
    0 = Drop the pixel (The Hole)
    """

    mask = torch.ones((batch_size, 1, height, width), device=device)

    for i in range(batch_size):

        mask[i, :, 12:22,12:22] = 0.0

    return mask

# Experiment 1 , mask on latent

In [None]:
from huggingface_hub import hf_hub_download
import os

repo_id = "ziyang06315/latent_diffusion_from_scratch"
VAE_CHECKPOINT = hf_hub_download(repo_id, filename="checkpoint.pth")
UNET_PATH = hf_hub_download(repo_id, filename="unet_epoch_500.pth")

repo_id2 = "ziyang06315/cats_images_dataset"
target_dir = "/content/LATENTS"
os.makedirs(target_dir, exist_ok=True)
LATENTS_PATH = hf_hub_download(repo_id2,
                               filename = "LATENTS.zip",
                               repo_type = "dataset",
                               local_dir = target_dir)

os.system(f"unzip -q {LATENTS_PATH} -d {target_dir}")
print("Finished Unzipped")

In [None]:
import torch
import torch.nn
from torchvision.utils import save_image
import os
from torch.utils.data import Dataset,DataLoader,Subset

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

OUTPUT_FOLDER = "/content/VAE_Training/Diffusion_Model/image_inpainting"
LATENTS_PATH = "/content/LATENTS/LATENTS"
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

vae = VAE().to(DEVICE) # Ensure your VAE class is defined above!
checkpoint = torch.load(VAE_CHECKPOINT, map_location=DEVICE)
vae.load_state_dict(checkpoint["vae_state_dict"])
vae.eval()

unet = DiffusionUnet().to(DEVICE)
unet.load_state_dict(torch.load(UNET_PATH))
unet.eval()

beta_start = 0.0001
beta_end = 0.02
num_timesteps = 1000

betas = torch.linspace(beta_start, beta_end, num_timesteps,device=DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

class LatentDataset(Dataset):
    def __init__(self,root_dir):
        self.root_dir = root_dir
        self.latent_files = [f for f in os.listdir(self.root_dir) if f.endswith('.pt')]

    def __len__(self):
        return len(self.latent_files)

    def __getitem__(self,idx):
        latent_path = os.path.join(self.root_dir,self.latent_files[idx])
        latent = torch.load(latent_path)
        return latent

dataset = LatentDataset(LATENTS_PATH)
print(len(dataset))
dataset = Subset(dataset,range(1))
dataloader = DataLoader(dataset,batch_size=1,shuffle=True)
orig_latent = next(iter(dataloader)).to(DEVICE)

mask = generate_mask(1,4,32,32,DEVICE)

with torch.no_grad():

    xt = torch.randn_like(orig_latent)

    for t in reversed(range(1000)):
        t_tensor = torch.ones(1,device=DEVICE).long() * t

        alpha = alphas[t]
        alpha_cumprod = alphas_cumprod[t]
        beta = betas[t]
        sqrt_alpha_cumprod = sqrt_alphas_cumprod[t]
        sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alphas_cumprod[t]

        if t >0:
            noise1 = torch.randn_like(orig_latent)
            noise2 = torch.randn_like(orig_latent)
        else:
            noise1 = torch.zeros_like(orig_latent)
            noise2 = torch.zeros_like(orig_latent)

        noise_pred = unet(xt,t_tensor)

        xt_minus_one = (1 / torch.sqrt(alpha)) * (xt - ((1 - alpha) / (torch.sqrt(1 - alpha_cumprod))) * noise_pred) + torch.sqrt(beta) * noise1

        if t > 0:
            xt_minus_one = xt_minus_one * (1-mask) + (sqrt_alphas_cumprod[t-1]*orig_latent + sqrt_one_minus_alphas_cumprod[t-1]*noise2) * (mask)


        else:
            xt_minus_one = xt_minus_one * (1-mask) + orig_latent * (mask)

        xt = xt_minus_one

with torch.no_grad():
    original = vae.decoder(orig_latent)
    reconstructed = vae.decoder(xt)

    # Un-normalize (-1,1 -> 0,1)
    reconstructed = reconstructed * 0.5 + 0.5
    reconstructed = torch.clamp(reconstructed, 0, 1)

    original = original * 0.5 + 0.5
    original = torch.clamp(original, 0, 1)

    # (2,3,256,256)
    comebine = torch.cat([original,reconstructed],dim=0)
    save_image(comebine, f"{OUTPUT_FOLDER}/sample_inpainting4.png", nrow=2)
    print(f"Saved to {OUTPUT_FOLDER}/sample_inpainting4.png")




# Simple User Interface (UI)
The above result is sucessful , but we face 2 limitation

- Our mask is on latent , not image itself
- Our mask is square , not brush

Solution:
- we need a solution to resize the mask from 256 to 32
- we need a UI brusher to build mask (Ask LLM)

In [None]:
!pip install gradio==3.48.0

In [None]:
import gradio as gr
from torchvision import transforms
import numpy as np
from PIL import Image
import torch.nn.functional as F
import os
from huggingface_hub import hf_hub_download

repo_id = "ziyang06315/latent_diffusion_from_scratch"
UNET_PATH = hf_hub_download(repo_id, filename="unet_epoch_500.pth")
VAE_CHECKPOINT = hf_hub_download(repo_id, filename="checkpoint.pth")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

vae = VAE().to(DEVICE) # Ensure your VAE class is defined above!
checkpoint = torch.load(VAE_CHECKPOINT, map_location=DEVICE)
vae.load_state_dict(checkpoint["vae_state_dict"])
vae.eval()

unet = DiffusionUnet().to(DEVICE)
unet.load_state_dict(torch.load(UNET_PATH))
unet.eval()

beta_start = 0.0001
beta_end = 0.02
num_timesteps = 1000

betas = torch.linspace(beta_start, beta_end, num_timesteps,device=DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

img_transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

mask_transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor()
])


#                  |      SHAPE    |  TYPE    | RANGE
# -------------------------------------------------
#  Gradio Return   |   (256,256,3) |  uint8   | 0-255
#  To Tensor       |   (3,256,256) |  float32 | 0.0-1.0
def inpainting(input_dictionary):
    image = input_dictionary["image"]
    mask = input_dictionary["mask"]

    img_pil = Image.fromarray(image).convert("RGB")
    mask_pil = Image.fromarray(mask)

    # Batch dimention
    img_tensor = img_transform(img_pil).unsqueeze(0).to(DEVICE)
    mask_tensor = mask_transform(mask_pil).unsqueeze(0).to(DEVICE)

    mask = F.interpolate(mask_tensor,size=(32,32),mode="nearest")

    # Adjust if terbalik
    mask =( mask < 0.5).float()
    mask = mask[:,0,:,:]
    print(f"Mask shape : {mask.shape}")

    orig_latent = vae.encoder(img_tensor)[0]
    print(f"Latent shape : {orig_latent.shape}")

    print("Start Diffusion Process")
    with torch.no_grad():

        xt = torch.randn_like(orig_latent)

        for t in reversed(range(1000)):
            t_tensor = torch.ones(1,device=DEVICE).long() * t

            alpha = alphas[t]
            alpha_cumprod = alphas_cumprod[t]
            beta = betas[t]
            sqrt_alpha_cumprod = sqrt_alphas_cumprod[t]
            sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alphas_cumprod[t]

            if t >0:
                noise1 = torch.randn_like(orig_latent)
                noise2 = torch.randn_like(orig_latent)
            else:
                noise1 = torch.zeros_like(orig_latent)
                noise2 = torch.zeros_like(orig_latent)

            noise_pred = unet(xt,t_tensor)

            xt_minus_one = (1 / torch.sqrt(alpha)) * (xt - ((1 - alpha) / (torch.sqrt(1 - alpha_cumprod))) * noise_pred) + torch.sqrt(beta) * noise1

            if t > 0:
                xt_minus_one = xt_minus_one * (1-mask) + (sqrt_alphas_cumprod[t-1]*orig_latent + sqrt_one_minus_alphas_cumprod[t-1]*noise2) * (mask)


            else:
                xt_minus_one = xt_minus_one * (1-mask) + orig_latent * (mask)

            xt = xt_minus_one

    with torch.no_grad():
        reconstructed = vae.decoder(xt)

        # Un-normalize (-1,1 -> 0,1)
        reconstructed = reconstructed * 0.5 + 0.5
        reconstructed = torch.clamp(reconstructed, 0, 1)

        # Convert Tensor [1, C, H, W] -> Numpy [H, W, C] [0, 255] for Gradio
        out_img = reconstructed.squeeze(0).permute(1,2,0).cpu().numpy()
        out_img = (out_img * 255).astype(np.uint8)

    return out_img


demo = gr.Interface(
    fn=inpainting,
    inputs = gr.Image(tool="sketch" , type="numpy", label="Upload & Draw Mask"),
    outputs= gr.Image(label="Inpainting Result"),
    title = "Diffusion Inpainting Demo",
    description="Upload an image, draw over the area you want to remove, and click Submit."
)

demo.launch(debug=True)


