In [None]:
#-----------------------------------------
# MINI STABLE DIFFUSION (LATENT DIFFUSION)
#-----------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

# =====================================================
# CONFIG
# =====================================================
IMG_SIZE = 32
LATENT_SIZE = 8          # VAE compresses 32x32 -> 8x8
LATENT_CHANNELS = 4
BATCH = 128
EPOCHS = 100
TIMESTEPS = 1000
DATA_DIR = r"D:\pk\dataset\real"

VAE_PATH = "sd_vae.pth"
UNET_PATH = "sd_unet.pth"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# =====================================================
# DATASET
# =====================================================
class ImageDataset(Dataset):
    def __init__(self, folder):
        self.files = [
            os.path.join(folder, f)
            for f in os.listdir(folder)
            if f.lower().endswith(("jpg", "png", "jpeg"))
        ]
        self.transform = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        return self.transform(img)

loader = DataLoader(ImageDataset(DATA_DIR), batch_size=BATCH, shuffle=True)

# =====================================================
# VAE (LATENT SPACE)
# =====================================================
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(128, LATENT_CHANNELS, 3, 1, 1)
        )

    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(LATENT_CHANNELS, 128, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = Encoder()
        self.dec = Decoder()

    def forward(self, x):
        z = self.enc(x)
        return self.dec(z), z

vae = VAE().to(DEVICE)
vae_opt = torch.optim.Adam(vae.parameters(), lr=1e-3)

# =====================================================
# DIFFUSION SCHEDULE (LATENT SPACE)
# =====================================================
beta = torch.linspace(1e-4, 0.02, TIMESTEPS).to(DEVICE)
alpha = 1 - beta
alpha_cum = torch.cumprod(alpha, 0)

# =====================================================
# TIME EMBEDDING
# =====================================================
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        half = self.dim // 2
        emb = torch.exp(
            torch.arange(half, device=t.device) *
            -(np.log(10000) / (half - 1))
        )
        emb = t[:, None].float() * emb[None]
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)

# =====================================================
# LATENT UNET (STABLE DIFFUSION CORE)
# =====================================================
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.time = TimeEmbedding(128)
        self.time_fc = nn.Linear(128, 128)

        self.conv1 = nn.Conv2d(LATENT_CHANNELS, 128, 3, 1, 1)
        self.conv2 = nn.Conv2d(128, 128, 3, 1, 1)
        self.conv3 = nn.Conv2d(128, LATENT_CHANNELS, 3, 1, 1)

    def forward(self, x, t):
        t = self.time_fc(self.time(t))[:, :, None, None]
        h = F.relu(self.conv1(x))
        h = h + t
        h = F.relu(self.conv2(h))
        return self.conv3(h)

unet = UNet().to(DEVICE)
unet_opt = torch.optim.Adam(unet.parameters(), lr=1e-4)

# =====================================================
# MODEL SUMMARY
# =====================================================
print("\nVAE SUMMARY\n", vae)
print("\nUNET SUMMARY\n", unet)

# =====================================================
# TRAINING
# =====================================================
for epoch in range(EPOCHS):
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for x in pbar:
        x = x.to(DEVICE)

        # -------------------
        # Train VAE
        # -------------------
        recon, z = vae(x)
        vae_loss = F.mse_loss(recon, x)

        vae_opt.zero_grad()
        vae_loss.backward()
        vae_opt.step()

        # ðŸ”¥ Detach latent before diffusion
        z = z.detach()

        # -------------------
        # Diffusion in latent space
        # -------------------
        t = torch.randint(0, TIMESTEPS, (z.size(0),), device=DEVICE)
        noise = torch.randn_like(z)

        a = alpha_cum[t][:, None, None, None]
        z_noisy = torch.sqrt(a) * z + torch.sqrt(1 - a) * noise

        pred = unet(z_noisy, t)
        diff_loss = F.mse_loss(pred, noise)

        unet_opt.zero_grad()
        diff_loss.backward()
        unet_opt.step()

        pbar.set_postfix(VAE=vae_loss.item(), Diff=diff_loss.item())


print("âœ” Training Complete")

# =====================================================
# SAVE MODELS
# =====================================================
torch.save(vae.state_dict(), VAE_PATH)
torch.save(unet.state_dict(), UNET_PATH)
print("âœ” Models Saved")

# =====================================================
# INFERENCE (LATENT DIFFUSION SAMPLING)
# =====================================================
@torch.no_grad()
def sample(n=16):
    z = torch.randn(n, LATENT_CHANNELS, LATENT_SIZE, LATENT_SIZE, device=DEVICE)

    for t in reversed(range(TIMESTEPS)):
        tt = torch.full((n,), t, device=DEVICE)
        eps = unet(z, tt)
        a = alpha_cum[tt][:, None, None, None]
        z = (z - torch.sqrt(1 - a) * eps) / torch.sqrt(a)

    imgs = vae.dec(z)
    return torch.clamp((imgs + 1) / 2, 0, 1).cpu()

# =====================================================
# SHOW GENERATED IMAGES
# =====================================================
samples = sample(16)

plt.figure(figsize=(5,5))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(samples[i].permute(1,2,0))
    plt.axis("off")
plt.show()



VAE SUMMARY
 VAE(
  (enc): Encoder(
    (net): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ReLU()
      (4): Conv2d(128, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (dec): Decoder(
    (net): Sequential(
      (0): ConvTranspose2d(4, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ReLU()
      (4): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): Tanh()
    )
  )
)

UNET SUMMARY
 UNet(
  (time): TimeEmbedding()
  (time_fc): Linear(in_features=128, out_features=128, bias=True)
  (conv1): Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(128, 4, kernel_size=

Epoch 1/100:  49%|â–ˆâ–ˆâ–ˆâ–ˆâ–‰     | 39/79 [00:28<00:25,  1.54it/s, Diff=0.633, VAE=0.0597]