In [None]:
!pip install lpips

# Define VAE

In [None]:
import torch
import torch.nn as nn

class ResNetBlock(nn.Module):
    def __init__(self,in_channels,out_channels,dropout_prob = 0.0):
        super(ResNetBlock, self).__init__()
        self.norm1 = nn.GroupNorm(32,in_channels)
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
        self.norm2 = nn.GroupNorm(32,out_channels)
        self.drop = nn.Dropout(dropout_prob)
        self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
        self.silu = nn.SiLU()

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels,out_channels,kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def forward(self,x):
        x1 = x

        x = self.norm1(x)
        x = self.silu(x)
        x = self.conv1(x)

        x = self.norm2(x)
        x = self.silu(x)
        x = self.drop(x)
        x = self.conv2(x)
        return x + self.shortcut(x1)

class AttentionBlock(nn.Module):
    def __init__(self,in_channels):
        super(AttentionBlock, self).__init__()
        self.norm = nn.GroupNorm(32,in_channels)

        self.to_q = nn.Linear(in_channels,in_channels)
        self.to_k = nn.Linear(in_channels,in_channels)
        self.to_v = nn.Linear(in_channels,in_channels)

        self.to_out = nn.Linear(in_channels,in_channels)

    def forward(self,x):
        residual = x
        B,C,H,W = x.shape
        x = self.norm(x)
        x = x.view(B,C,-1).permute(0,2,1)

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        attn = torch.bmm(q,k.permute(0,2,1)) # batch matrix multiplication
        attn = attn * (C**(-0.5))  # sqrt(dk)
        attn = attn.softmax(dim=-1)
        attn = torch.bmm(attn,v)

        out = self.to_out(attn)
        out = out.permute(0,2,1).view(B,C,H,W)

        return out + residual

class MidBlock(nn.Module):
    def __init__(self,in_channels):
        super(MidBlock, self).__init__()
        self.res1 = ResNetBlock(in_channels,in_channels)
        self.attn1 = AttentionBlock(in_channels)
        self.res2 = ResNetBlock(in_channels,in_channels)

    def forward(self,x):
        x = self.res1(x)
        x = self.attn1(x)
        x = self.res2(x)
        return x

class DownBlock(nn.Module):
    def __init__(self,in_channels,out_channels,has_attn=False):
        super(DownBlock, self).__init__()
        self.res1 = ResNetBlock(in_channels,out_channels)
        self.res2 = ResNetBlock(out_channels,out_channels)
        self.down = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=2,padding=1)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x):
        x = self.res1(x)
        x = self.res2(x)
        x = self.attn(x)
        x = self.down(x)
        return x

class Encoder(nn.Module):
    def __init__(self,in_channels=3,out_channels=4):
        super(Encoder, self).__init__()
        self.inp = nn.Conv2d(in_channels,64,kernel_size=3,padding=1)
        self.down_block = nn.Sequential(
            DownBlock(64,128),
            DownBlock(128,256),
            DownBlock(256,512,has_attn = True)
        )
        self.bottle = MidBlock(512)
        self.out = nn.Sequential(
            nn.GroupNorm(32,512),
            nn.SiLU(),
            nn.Conv2d(512,out_channels*2,kernel_size=1)
        )

    def reparameterize(self,x):
        mean , log_var = torch.chunk(x,2,dim=1)
        log_var = torch.clamp(log_var, -30.0, 20.0)
        D_kl = 0.5 *(torch.exp(log_var) + mean**2 - log_var - 1)
        D_kl = torch.sum(D_kl,dim=[1,2,3]).mean() # Mean is for batch dimention
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        z = mean + eps*std
        return z,D_kl

    def forward(self,x):
        x = self.inp(x)
        x = self.down_block(x)
        x = self.bottle(x)
        x = self.out(x)
        z,D_kl = self.reparameterize(x)
        return z,D_kl

class UpBlock(nn.Module):
    def __init__(self,in_channels,out_channels,has_attn=False):
        super().__init__()
        self.res1 = ResNetBlock(in_channels,out_channels)
        self.res2 = ResNetBlock(out_channels,out_channels)
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2), # nearest mode by default
            nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
        )
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x):
        x = self.res1(x)
        x = self.res2(x)
        x = self.attn(x)
        x = self.up(x)
        return x

class Decoder(nn.Module):
    def __init__(self,in_channels=4,out_channels=3):
        super(Decoder, self).__init__()
        self.inp = nn.Conv2d(in_channels,512,kernel_size=3,padding=1)
        self.bottle = MidBlock(512)
        self.up_block = nn.Sequential(
            UpBlock(512,256,has_attn=True),
            UpBlock(256,128),
            UpBlock(128,64)
        )
        self.out = nn.Sequential(
            nn.GroupNorm(32,64),
            nn.SiLU(),
            nn.Conv2d(64,out_channels,kernel_size=3,padding=1)

        )

    def forward(self,x):
        x = self.inp(x)
        x = self.bottle(x)
        x = self.up_block(x)
        x = self.out(x)
        return x

class PatchGan(nn.Module):
    def __init__(self,in_channels=3):
        super(PatchGan, self).__init__()
        self.model=nn.Sequential(
            nn.Conv2d(in_channels,64,kernel_size=3,stride=2,padding=1),
            nn.SiLU(),

            nn.Conv2d(64,128,kernel_size=3,stride=2,padding=1),
            nn.GroupNorm(32,128),
            nn.SiLU(),

            nn.Conv2d(128,256,kernel_size=3,stride=2,padding=1),
            nn.GroupNorm(32,256),
            nn.SiLU(),

            nn.Conv2d(256,512,kernel_size=3,stride=2,padding=1),
            nn.GroupNorm(32,512),
            nn.SiLU(),

            nn.Conv2d(512,1,kernel_size=3,stride=1,padding=1)
        )

    def forward(self,x):
        x = self.model(x)
        return x

class VAE(nn.Module):
    def __init__(self,in_channels=3,out_channels=4):
        super(VAE, self).__init__()
        self.encoder = Encoder(in_channels,out_channels)
        self.decoder = Decoder(out_channels,in_channels)

    def forward(self,x):
        z,D_kl = self.encoder(x)
        x = self.decoder(z)
        return x,D_kl



# Training VAE

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets , transforms
from torchvision.utils import save_image
import lpips
import os

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# PERCEPTUAL LOSS used pretrained , for SHARPNESS
lpips_loss_fn = lpips.LPIPS(net='vgg').to(DEVICE)

IMAGES_PATH = "/content/drive/MyDrive/dataset"
MODEL_PATH = "/content/drive/MyDrive/VAE_Training"
LATENTS_PATH = "/content/drive/MyDrive/VAE_Training/LATENTS"
os.makedirs(LATENTS_PATH,exist_ok=True)

transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

dataset = datasets.ImageFolder(IMAGES_PATH,transform=transform)
dataloader = DataLoader(dataset,batch_size=16,shuffle=True)

EPOCHS = 200

VAE_MODEL = VAE().to(DEVICE)
PATCH_GAN = PatchGan().to(DEVICE)

VAE_OPTIMIZER = torch.optim.Adam(VAE_MODEL.parameters(),lr=0.0002)
PATCH_GAN_OPTIMIZER = torch.optim.Adam(PATCH_GAN.parameters(),lr=0.0001)

L1_LOSS = nn.L1Loss()

# Logits -> Sigmoid -> BCE Loss
BCE_LOSS = nn.BCEWithLogitsLoss()

# Weights from the Paper
# High-Resolution Image Synthesis with Latent Diffusion Models
# Page 29
weight_kl = 0.000001
weight_lpips = 1.0

DISC_START = 10

# ==============================================
# LLM Generated HELPER FUNCTION: ADAPTIVE WEIGHT
# ==============================================
def calculate_adaptive_weight(recon_loss, g_loss, last_layer, weight_limit=1.0):
    # gradient <- torch.autograd.grad (loss,weights)
    # why [0] ? since obj returned is tuple , we need extract the tensor inside tuple
    # why retain_graph = True ? It's map of how data flow
    # keep for g_grads and VAE_LOSS.backward()
    recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
    g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]

    # apply L2 Norm to get the magnitude
    d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
    d_weight = torch.clamp(d_weight, 0.0, 1.0).detach()
    return d_weight

def sampling_images(model,real_image,epoch,RECON_IMAGES_PATH = "/content/drive/MyDrive/VAE_Training/RECON_IMAGES_SAMPLING"):
    os.makedirs(RECON_IMAGES_PATH,exist_ok=True)
    model.eval()
    with torch.no_grad():
        real_image = real_image.unsqueeze(0).to(DEVICE)
        reconstructed_image , _ = model(real_image)

        reconstructed_image = reconstructed_image*0.5 + 0.5
        real_image = real_image*0.5 + 0.5

        reconstructed_image = torch.clamp(reconstructed_image,0.0,1.0)
        real_image = torch.clamp(real_image,0.0,1.0)

        comparison_image = torch.cat([real_image,reconstructed_image],dim=3)
        save_image(comparison_image,f"{RECON_IMAGES_PATH}/comparison_image{epoch}.png")

    model.train()

for epoch in range(EPOCHS):
    for idx , (real_images,_) in enumerate(dataloader):
        real_images = real_images.to(DEVICE)

        '''
        Train VAE
        We want to let GAN predict on reconstructed images to be 1 ,
        so GAN fooled , means loss on VAE lesser
        '''

        reconstructed_images , D_kl = VAE_MODEL(real_images)

        l1_loss = L1_LOSS(reconstructed_images,real_images)
        lpips_loss = lpips_loss_fn(reconstructed_images,real_images).mean()
        rec_loss = l1_loss + (weight_lpips * lpips_loss)

        # GAN predict 1 on 0 , means GAN fooled , VAE strong
        gan_fake_images_pred = PATCH_GAN(reconstructed_images)
        all_ones = torch.ones_like(gan_fake_images_pred)
        gan_fake_images_pred_loss = BCE_LOSS(gan_fake_images_pred,all_ones)

        # warmup , dont tell the VAE how GAN think on it's generated images first
        # for 10 epochs
        if epoch < DISC_START:
            adaptive_weight = 0.0
        else:
            # Get the weight of the last layer of decoder
            # self.decoder -> self.out -> last layer in nn.Sequential(...) -> get weight
            last_layer = VAE_MODEL.decoder.out[-1].weight
            adaptive_weight = calculate_adaptive_weight(rec_loss,
                                                        gan_fake_images_pred_loss,
                                                        last_layer)

        VAE_LOSS = (weight_lpips * lpips_loss) + \
                   (weight_kl * D_kl) + \
                   (l1_loss) + \
                   (adaptive_weight* gan_fake_images_pred_loss)

        VAE_OPTIMIZER.zero_grad()
        VAE_LOSS.backward()
        VAE_OPTIMIZER.step()

        if epoch >= DISC_START:
            gan_real_images_pred = PATCH_GAN(real_images)
            all_ones = torch.ones_like(gan_real_images_pred)
            gan_real_images_pred_loss = BCE_LOSS(gan_real_images_pred,all_ones)

            gan_fake_images_pred = PATCH_GAN(reconstructed_images.detach())
            all_zeros = torch.zeros_like(gan_fake_images_pred)
            gan_fake_images_pred_loss = BCE_LOSS(gan_fake_images_pred,all_zeros)

            total_gan_loss = (gan_real_images_pred_loss + gan_fake_images_pred_loss)

            PATCH_GAN_OPTIMIZER.zero_grad()
            total_gan_loss.backward()
            PATCH_GAN_OPTIMIZER.step()

    if epoch % 5 == 0:

        if epoch >= DISC_START:
            disc_loss_val = total_gan_loss.item()
        else:
            disc_loss_val = 0.0

        checkpoint = {
            "epoch": epoch,
            "vae_state_dict": VAE_MODEL.state_dict(),
            "gan_state_dict": PATCH_GAN.state_dict(),
            "vae_optimizer": VAE_OPTIMIZER.state_dict(),
            "gan_optimizer": PATCH_GAN_OPTIMIZER.state_dict(),
        }

        torch.save(checkpoint,os.path.join(MODEL_PATH,"checkpoint.pth"))

        print(f"Epoch {epoch} | VAE Loss: {VAE_LOSS.item():.4f} | Disc Loss: {disc_loss_val:.4f}")
        sampling_images(VAE_MODEL,real_images[0],epoch)


# Generate Latents

In [None]:
from torch._prims_common import check
import torch
import os
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision import datasets

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_PATH = "/content/drive/MyDrive/VAE_Training/checkpoint.pth"
IMAGES_PATH = "/content/drive/MyDrive/VAE_Training/cat_dataset"
LATENTS_PATH = "/content/drive/MyDrive/VAE_Training/LATENTS"
os.makedirs(LATENTS_PATH,exist_ok=True)

checkpoint = torch.load(CHECKPOINT_PATH)
VAE_MODEL = VAE().to(DEVICE)
VAE_MODEL.load_state_dict(checkpoint["vae_state_dict"])

VAE_MODEL.eval()

transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

dataset = datasets.ImageFolder(IMAGES_PATH,transform=transform)
dataloader = DataLoader(dataset,batch_size=1,shuffle=False)

print("Generating Latents")
with torch.no_grad():
    for idx , (real_images,_) in enumerate(dataloader):
        real_images = real_images.to(DEVICE)
        latent, _ = VAE_MODEL.encoder(real_images)

        save_path = os.path.join(LATENTS_PATH,f"latent_{idx}.pt")

        # save 3D tensor , remove batch dimention
        torch.save(latent.squeeze(0).cpu(),save_path)

        if idx % 100 == 0:
            print(f"Processed {idx} images")

print("All images processed")



# Define Diffusion Unet

In [None]:
def raw_time_embedding(time , dim):
    if not torch.is_tensor(time):
        time = torch.tensor(time)

    device=time.device
    if time.ndim == 0:
        time = time.unsqueeze(0).unsqueeze(1)
    else: # This will be execute in training since t shape is (B)
        time = time.unsqueeze(1)
        # (B) -> (B,1)

    # important to specify device
    i=torch.arange(dim//2,device=device).float()
    obj = (time)/(10000**(2*i/dim))
    return torch.cat([torch.sin(obj),torch.cos(obj)],dim=1)

class time_embedding(nn.Module):
    def __init__(self,dim):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(dim,dim),
            nn.SiLU(),
            nn.Linear(dim,dim)
        )

    def forward(self,X):
        return self.net(X)

class DiffusionResNetBlock(nn.Module):
    def __init__(self,in_channels,out_channels,time_emb_dim,dropout_prob = 0.0 ):
        super().__init__()
        self.norm1 = nn.GroupNorm(32,in_channels)
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
        self.norm2 = nn.GroupNorm(32,out_channels)
        self.drop = nn.Dropout(dropout_prob)
        self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
        self.silu = nn.SiLU()

        # purpose of this projection is to match channel dim , before adding to x
        self.time_proj = nn.Linear(time_emb_dim,out_channels)

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels,out_channels,kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def forward(self,x,time_emb):
        x1 = x

        x = self.norm1(x)
        x = self.silu(x)
        x = self.conv1(x)

        # (B,C) -> (B,C,1,1)
        # why silu ?
        emb = self.time_proj(self.silu(time_emb))
        x = x + emb[:, :, None, None]

        x = self.norm2(x)
        x = self.silu(x)
        x = self.drop(x)
        x = self.conv2(x)
        return x + self.shortcut(x1)

class DiffusionAttentionBlock(nn.Module):
    def __init__(self,in_channels):
        super().__init__()
        self.norm = nn.GroupNorm(32,in_channels)

        self.to_q = nn.Linear(in_channels,in_channels)
        self.to_k = nn.Linear(in_channels,in_channels)
        self.to_v = nn.Linear(in_channels,in_channels)

        self.to_out = nn.Linear(in_channels,in_channels)

    def forward(self,x):
        residual = x
        B,C,H,W = x.shape
        x = self.norm(x)
        x = x.view(B,C,-1).permute(0,2,1)

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        attn = torch.bmm(q,k.permute(0,2,1)) # batch matrix multiplication
        attn = attn * (C**(-0.5))  # sqrt(dk)
        attn = attn.softmax(dim=-1)
        attn = torch.bmm(attn,v)

        out = self.to_out(attn)
        out = out.permute(0,2,1).view(B,C,H,W)

        return out + residual

class DiffusionDownBlock(nn.Module):
    def __init__(self,in_channels,out_channels,time_emb_dim,has_attn=False):
        super().__init__()
        self.res1 = DiffusionResNetBlock(in_channels,out_channels,time_emb_dim)
        self.res2 = DiffusionResNetBlock(out_channels,out_channels,time_emb_dim)
        self.down = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=2,padding=1)
        if has_attn:
            self.attn = DiffusionAttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x,time_emb):
        x = self.res1(x,time_emb)
        x = self.res2(x,time_emb)
        x = self.attn(x)

        skip_connection = x

        x = self.down(x)
        return x , skip_connection

class DiffusionMidBlock(nn.Module):
    def __init__(self,in_channels,time_emb_dim):
        super().__init__()

        self.res1 = DiffusionResNetBlock(in_channels,in_channels,time_emb_dim)
        self.attn = DiffusionAttentionBlock(in_channels)
        self.res2 = DiffusionResNetBlock(in_channels,in_channels,time_emb_dim)

    def forward(self,x,time_emb):
        x = self.res1(x,time_emb)
        x = self.attn(x)
        x = self.res2(x,time_emb)
        return x

class DiffusionUpBlock(nn.Module):
    def __init__(self,in_channels,out_channels,time_emb_dim,has_attn=False):
        super().__init__()

        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels,in_channels,kernel_size=3,padding=1)
        )
        # Why input *2 ? Because we have to concatenate the channels
        self.res1 = DiffusionResNetBlock(in_channels*2,out_channels,time_emb_dim)
        self.res2 = DiffusionResNetBlock(out_channels,out_channels,time_emb_dim)
        if has_attn:
            self.attn = DiffusionAttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x,skip_connection,time_emb):
        x = self.up(x)

        # cancatenate at channels dimension
        x = torch.cat([x,skip_connection],dim=1)
        x = self.res1(x,time_emb)
        x = self.res2(x,time_emb)
        x = self.attn(x)
        return x

class DiffusionUnet(nn.Module):
    def __init__(self,in_channels=4,out_channels=4,time_dim=256):
        super().__init__()

        self.time_dim = time_dim
        self.time_embedding = time_embedding(time_dim)

        self.init_conv = nn.Conv2d(in_channels,64,kernel_size=3,padding=1)

        self.down1 = DiffusionDownBlock(64,64,time_dim)
        self.down2 = DiffusionDownBlock(64,128,time_dim)
        self.down3 = DiffusionDownBlock(128,128,time_dim,has_attn=True)
        self.down4 = DiffusionDownBlock(128,256,time_dim,has_attn=True)

        self.mid = DiffusionMidBlock(256,time_dim)

        self.up1 = DiffusionUpBlock(256,128,time_dim,has_attn=True)
        self.up2 = DiffusionUpBlock(128,128,time_dim,has_attn=True)
        self.up3 = DiffusionUpBlock(128,64,time_dim)
        self.up4 = DiffusionUpBlock(64,64,time_dim)

        self.out = nn.Sequential(
            nn.GroupNorm(32,64),
            nn.SiLU(),
            nn.Conv2d(64,out_channels,kernel_size=3,padding=1)
        )

    def forward(self,x,t):
        t = raw_time_embedding(t,self.time_dim)
        emb = self.time_embedding(t)

        x = self.init_conv(x)

        x1 , skip1 = self.down1(x,emb)
        x2 , skip2 = self.down2(x1,emb)
        x3 , skip3 = self.down3(x2,emb)
        x4 , skip4 = self.down4(x3,emb)

        x = self.mid(x4,emb)

        x = self.up1(x,skip4,emb)
        x = self.up2(x,skip3,emb)
        x = self.up3(x,skip2,emb)
        x = self.up4(x,skip1,emb)

        x = self.out(x)
        return x




# Training Diffusion Unet

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
import os
from tqdm import tqdm
import torch.nn.functional as F

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LATENTS_PATH = "/content/drive/MyDrive/VAE_Training/LATENTS"
CHECKPOINT_PATH = "/content/drive/MyDrive/VAE_Training/checkpoint.pth"
MODEL_SAVE_PATH = "/content/drive/MyDrive/VAE_Training/Diffusion_Model/VERSION4"
OUTPUT_FOLDER = "/content/drive/MyDrive/VAE_Training/Diffusion_Model/VERSION4/ImageSampling"
# version 1 is Diffusion absolute from pretrained
# version 2 is add dataset , train until unet_epoch_249
# version 3 is use own defined add noice and denoice step
# version 4 is use own defined diffusion unet and all add noice,denoice process
os.makedirs(MODEL_SAVE_PATH,exist_ok=True)
os.makedirs(OUTPUT_FOLDER,exist_ok=True)

BATCH_SIZE = 32
LEARNING_RATE = 1e-4
EPOCHS = 500

class LatentDataset(Dataset):
    def __init__(self,root_dir):
        self.root_dir = root_dir
        self.latent_files = [f for f in os.listdir(self.root_dir) if f.endswith('.pt')]

    def __len__(self):
        return len(self.latent_files)

    def __getitem__(self,idx):
        latent_path = os.path.join(self.root_dir,self.latent_files[idx])
        latent = torch.load(latent_path)
        return latent

dataset = LatentDataset(LATENTS_PATH)
print(len(dataset))
dataloader = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True)

model = DiffusionUnet().to(DEVICE)
vae = VAE().to(DEVICE)
checkpoint = torch.load(CHECKPOINT_PATH)
vae.load_state_dict(checkpoint["vae_state_dict"])
vae.eval()

beta_start = 0.0001
beta_end = 0.02
num_timesteps = 1000

betas = torch.linspace(beta_start, beta_end, num_timesteps,device=DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

model.train()

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0.0

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}")

    for step, latents in enumerate(progress_bar):
        latents = latents.to(DEVICE)

        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, 1000, (latents.shape[0],), device=DEVICE)

        sqrt_alpha_cumprod = sqrt_alphas_cumprod[timesteps][:,None,None,None]
        sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alphas_cumprod[timesteps][:,None,None,None]

        noisy_latents = sqrt_alpha_cumprod * latents + sqrt_one_minus_alpha_cumprod * noise

        noise_pred = model(noisy_latents, timesteps)

        loss = F.mse_loss(noise_pred, noise)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss += loss.item()
        progress_bar.set_postfix({"Loss": loss.item()})

    # Save Checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch} | Loss: {epoch_loss / len(dataloader)}")
    if (epoch + 1) % 50 == 0:
        torch_path = os.path.join(MODEL_SAVE_PATH,f"unet_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), torch_path)
        print(f"Saved model to epoch_{epoch}")

        with torch.no_grad():
            latents = torch.randn(
                (8, 4, 32, 32),
                device=DEVICE
            )

            for t in reversed(range(1000)):
                t_tensor = torch.ones(8,device=DEVICE).long() * t

                noise_pred = model(latents, t_tensor)

                alpha = alphas[t]
                alpha_cumprod = alphas_cumprod[t]
                beta = betas[t]
                sqrt_alpha_cumprod = sqrt_alphas_cumprod[t]
                sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alphas_cumprod[t]

                if t >0:
                    noise = torch.randn_like(latents)
                else:
                    noise = torch.zeros_like(latents)

                # Subtract noise
                latents = (1 / torch.sqrt(alpha)) * (latents - ((1 - alpha) / (torch.sqrt(1 - alpha_cumprod))) * noise_pred) + torch.sqrt(beta) * noise

            reconstructed = vae.decoder(latents)

            # Un-normalize (-1,1 -> 0,1)
            reconstructed = reconstructed * 0.5 + 0.5
            reconstructed = torch.clamp(reconstructed, 0, 1)

            #Save
            save_image(reconstructed, f"{OUTPUT_FOLDER}/generated_cats_v4{epoch+1}.png", nrow=4)
            print(f"Saved to {OUTPUT_FOLDER}/generated_cats_v4{epoch+1}.png")

print("Training Complete!")




# Sampling Generated Images

In [None]:
import torch
import torch.nn
from torchvision.utils import save_image
import os

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

VAE_CHECKPOINT = "/content/drive/MyDrive/VAE_Training/checkpoint.pth"
UNET_PATH = "/content/drive/MyDrive/VAE_Training/Diffusion_Model/VERSION4/unet_epoch_500.pth"
OUTPUT_FOLDER = "/content/drive/MyDrive/VAE_Training/Diffusion_Model/VERSION4/ImageSampling"
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

vae = VAE().to(DEVICE) # Ensure your VAE class is defined above!
checkpoint = torch.load(VAE_CHECKPOINT, map_location=DEVICE)
vae.load_state_dict(checkpoint["vae_state_dict"])
vae.eval()

unet = DiffusionUnet().to(DEVICE)
unet.load_state_dict(torch.load(UNET_PATH))
unet.eval()

beta_start = 0.0001
beta_end = 0.02
num_timesteps = 1000

betas = torch.linspace(beta_start, beta_end, num_timesteps,device=DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
with torch.no_grad():
    latents = torch.randn(
        (8, 4, 32, 32),
        device=DEVICE
    )

    for t in reversed(range(1000)):
        t_tensor = torch.ones(8,device=DEVICE).long() * t

        noise_pred = unet(latents, t_tensor)

        alpha = alphas[t]
        alpha_cumprod = alphas_cumprod[t]
        beta = betas[t]
        sqrt_alpha_cumprod = sqrt_alphas_cumprod[t]
        sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alphas_cumprod[t]

        if t >0:
            noise = torch.randn_like(latents)
        else:
            noise = torch.zeros_like(latents)

        # Subtract noise
        latents = (1 / torch.sqrt(alpha)) * (latents - ((1 - alpha) / (torch.sqrt(1 - alpha_cumprod))) * noise_pred) + torch.sqrt(beta) * noise

    reconstructed = vae.decoder(latents)

    # Un-normalize (-1,1 -> 0,1)
    reconstructed = reconstructed * 0.5 + 0.5
    reconstructed = torch.clamp(reconstructed, 0, 1)

    #Save image , nrow=4 means in 4 column , weird design
    # save_image save number in integer [0,255] , but expect input [0.0,1.0]
    save_image(reconstructed, f"{OUTPUT_FOLDER}/sample_cat9.png", nrow=4)
    print(f"Saved to {OUTPUT_FOLDER}/sample_cat9.png")

