In [9]:
!pip install streamlit pillow -q
!pip install --upgrade torch torchvision -q

os.makedirs("checkpoints", exist_ok=True)
os.makedirs("output", exist_ok=True)

def start_streamlit_server(file_path):
    print("Starting Streamlit server")
    !nohup streamlit run $file_path &

    import time
    time.sleep(5)

    print("\n--- Ngrok Tunnel Setup ---")
    !pip install pyngrok -q
    from pyngrok import ngrok

    try:
        NGROK_AUTH_TOKEN = "35t0D7y6l2yUBqAEjaI0nSNJVFk_67V8CLaVvQXr9ATPjpP5Y"
        ngrok.set_auth_token(NGROK_AUTH_TOKEN)
    except Exception as e:
        print(f"Error setting ngrok authtoken: {e}")
        return

    ngrok.kill()

    url = ngrok.connect(8501)
    print(f"Streamlit App is running on: {url}")
    print("\nClick the link above to view the interactive generator.")

    try:
        while True:
            time.sleep(1)
    except KeyboardInterrupt:
        print("\nStreamlit server stopped.")
        ngrok.kill()

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from copy import deepcopy
from PIL import Image
import os
import math
from tqdm.auto import tqdm
from google.colab import files
import numpy as np

class Config:
    LATENT_CHANNELS = 4
    LATENT_SIZE = 7
    VAE_PATH = "checkpoints/fashion_vae.pt"
    LATENT_EMA_PATH = "checkpoints/latent_model.ema.pt"
    N_STEPS = 1000
    NUM_CLASSES = 10
    IMAGE_CHANNELS = 1

    BASE_CHANNELS = 64
    CHANNEL_MULT = [1, 2, 2]
    NUM_RES_BLOCKS = 2
    TIME_EMBED_DIM = BASE_CHANNELS*4
    CONTEXT_DIM = TIME_EMBED_DIM

config = Config()

def get_timestep_embedding(timesteps, dim):
    half_dim = dim // 2
    exponent = torch.exp(-math.log(10000) * torch.arange(half_dim, device=timesteps.device) / half_dim)
    timesteps = timesteps.float().unsqueeze(-1)
    emb = timesteps * exponent.unsqueeze(0)
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
    return emb.to(dtype=torch.float32)

def load_checkpoint(model, path, device="cpu"):
    if os.path.exists(path):
        model.load_state_dict(torch.load(path, map_location=device))
        print(f"Checkpoint loaded from {path}")
        return True
    return False

class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class Encoder(nn.Module):
    def __init__(self, latent_channels=config.LATENT_CHANNELS):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(config.IMAGE_CHANNELS, 32, 3, padding=1), nn.GroupNorm(8, 32), nn.SiLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1), nn.GroupNorm(8, 64), nn.SiLU(), ChannelAttention(64),
            nn.Conv2d(64, 128, 4, stride=2, padding=1), nn.GroupNorm(8, 128), nn.SiLU(), ChannelAttention(128)
        )
        self.to_mu = nn.Conv2d(128, latent_channels, 1)
        self.to_logvar = nn.Conv2d(128, latent_channels, 1)
    def forward(self, x):
        h = self.net(x)
        return self.to_mu(h), self.to_logvar(h)

class Decoder(nn.Module):
    def __init__(self, latent_channels=config.LATENT_CHANNELS):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(latent_channels, 128, 3, padding=1), nn.GroupNorm(8, 128), nn.SiLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.GroupNorm(8, 64), nn.SiLU(), ChannelAttention(64),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), nn.GroupNorm(8, 32), nn.SiLU(),
            nn.Conv2d(32, config.IMAGE_CHANNELS, 3, padding=1), nn.Tanh()
        )
    def forward(self, z):
        return self.net(z)

class VAE(nn.Module):
    def __init__(self, latent_channels=config.LATENT_CHANNELS):
        super().__init__()
        self.encoder = Encoder(latent_channels)
        self.decoder = Decoder(latent_channels)
    def encode(self, x):
        return self.encoder(x)
    def decode(self, z):
        return self.decoder(z)
    def forward(self, x):
        return self.decode(self.encode(x)[0])

class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim, num_heads=4):
        super().__init__()
        self.q_proj = nn.Linear(query_dim, query_dim)
        self.k_proj = nn.Linear(context_dim, query_dim)
        self.v_proj = nn.Linear(context_dim, query_dim)
        self.attn = nn.MultiheadAttention(query_dim, num_heads, batch_first=True)
        self.proj_out = nn.Linear(query_dim, query_dim)
    def forward(self, x, context):
        q = self.q_proj(x)
        k = self.k_proj(context)
        v = self.v_proj(context)
        attn_output, _ = self.attn(q, k, v)
        return x + self.proj_out(attn_output)

class LatentResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, context_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_ch))
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
        self.norm_ca = nn.GroupNorm(8, out_ch)
        self.ca = CrossAttention(out_ch, context_dim)
    def forward(self, x, t_emb, context):
        b, _, h, w = x.shape
        out_ch = self.conv1.out_channels
        h_temp = F.group_norm(x, min(8, x.shape[1]))
        h_temp = self.conv1(F.silu(h_temp))
        h_temp = h_temp + self.time_mlp(t_emb)[:, :, None, None]
        h_ca = F.silu(self.norm_ca(h_temp))
        new_hw = int(h) * int(w)
        h_ca = h_ca.permute(0, 2, 3, 1).reshape(b, new_hw, out_ch)
        h_ca = self.ca(h_ca, context)
        h_temp = h_temp + h_ca.transpose(1, 2).reshape(b, out_ch, h, w)
        h_temp = self.conv2(F.silu(F.group_norm(h_temp, min(8, h_temp.shape[1]))))
        return h_temp + self.shortcut(x)

class LatentEpsModel(nn.Module):
    def __init__(self):
        super().__init__()
        time_dim = config.TIME_EMBED_DIM
        context_dim = config.CONTEXT_DIM
        self.time_mlp = nn.Sequential(nn.Linear(time_dim, time_dim), nn.SiLU(), nn.Linear(time_dim, time_dim))
        self.class_emb = nn.Embedding(config.NUM_CLASSES, context_dim)
        self.conv_in = nn.Conv2d(config.LATENT_CHANNELS, config.BASE_CHANNELS, 3, padding=1)

        self.downs = nn.ModuleList()
        ch = config.BASE_CHANNELS
        for mult in config.CHANNEL_MULT:
            out = config.BASE_CHANNELS * mult
            blocks = nn.ModuleList()
            blocks.append(LatentResBlock(ch, out, time_dim, context_dim))
            for _ in range(config.NUM_RES_BLOCKS - 1):
                blocks.append(LatentResBlock(out, out, time_dim, context_dim))
            downsample = nn.Conv2d(out, out, 3, 1, 1)
            self.downs.append(nn.ModuleDict({"blocks": blocks, "down": downsample}))
            ch = out

        self.bot1 = LatentResBlock(ch, ch, time_dim, context_dim)
        self.bot2 = LatentResBlock(ch, ch, time_dim, context_dim)

        self.ups = nn.ModuleList()
        for mult in reversed(config.CHANNEL_MULT):
            skip_ch = config.BASE_CHANNELS * mult
            out = skip_ch
            in_ch_res = out + skip_ch
            blocks = nn.ModuleList()
            blocks.append(LatentResBlock(in_ch_res, out, time_dim, context_dim))
            for _ in range(config.NUM_RES_BLOCKS - 1):
                blocks.append(LatentResBlock(out, out, time_dim, context_dim))
            upsample = nn.ConvTranspose2d(ch, out, 3, 1, 1)
            self.ups.append(nn.ModuleDict({"blocks": blocks, "up": upsample}))
            ch = out

        self.conv_out = nn.Conv2d(ch, config.LATENT_CHANNELS, 3, padding=1)

    def forward(self, x, t, y):
        t_emb = self.time_mlp(get_timestep_embedding(t, config.TIME_EMBED_DIM))
        context = self.class_emb(y).unsqueeze(1)
        hs = []
        h = self.conv_in(x)
        for module in self.downs:
            for block in module["blocks"]:
                h = block(h, t_emb, context)
            hs.append(h)
            h = module["down"](h)

        h = self.bot1(h, t_emb, context)
        h = self.bot2(h, t_emb, context)

        for module in self.ups:
            skip = hs.pop()
            h = module["up"](h)
            h = torch.cat([h, skip], dim=1)
            for block in module["blocks"]:
                h = block(h, t_emb, context)

        return self.conv_out(h)

class ConditionalDenoiseDiffusion():
    def __init__(self, eps_model, n_steps=config.N_STEPS, device=None):
        super().__init__()
        self.eps_model = eps_model
        self.device = device if device is not None else torch.device("cpu")
        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(self.device)
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.n_steps = n_steps
        self.sqrt_alpha_bar = torch.sqrt(self.alpha_bar)
        self.sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar)
        self.alphas_cumprod_prev = F.pad(self.alpha_bar[:-1], (1, 0), value=1.0)
        self.post_variance = self.beta * (1. - self.alphas_cumprod_prev) / (1. - self.alpha_bar)

    def p_sample(self, xt, t, c=None):
        if not isinstance(t, torch.Tensor):
            t = torch.tensor([t] * xt.shape[0], device=xt.device, dtype=torch.long)

        eps_theta = self.eps_model(xt, t, c)

        alpha_t = self.alpha[t].reshape(-1, 1, 1, 1)
        alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
        alpha_bar_t_prev = self.alphas_cumprod_prev[t].reshape(-1, 1, 1, 1)

        x0_pred = (xt - self.sqrt_one_minus_alpha_bar[t].reshape(-1, 1, 1, 1) * eps_theta) / self.sqrt_alpha_bar[t].reshape(-1, 1, 1, 1)
        x0_pred = torch.clamp(x0_pred, -1., 1.)

        mean = (alpha_bar_t_prev.sqrt() * self.beta[t].reshape(-1, 1, 1, 1) / (1. - alpha_bar_t)) * x0_pred + \
               (alpha_t.sqrt() * (1. - alpha_bar_t_prev) / (1. - alpha_bar_t)) * xt

        variance = self.post_variance[t].reshape(-1, 1, 1, 1)

        if t[0] > 0:
            noise = torch.randn_like(xt)
            return mean + torch.sqrt(variance) * noise
        else:
            return mean

    def sample(self, shape, device, c=None):
        x = torch.randn(shape, device=device)
        for t in tqdm(reversed(range(self.n_steps)), desc="Denoising"):
            x = self.p_sample(x, t, c)
        return x

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

print("Please upload the 'fashion_vae.pt' and 'latent_model.ema.pt' files:")
uploaded = files.upload()

if 'fashion_vae.pt' in uploaded:
    !mv fashion_vae.pt checkpoints/
    print("fashion_vae.pt moved to checkpoints/")
else:
    print("VAE file not found. Please upload 'fashion_vae.pt'.")

if 'latent_model.ema.pt' in uploaded:
    !mv latent_model.ema.pt checkpoints/
    print("latent_model.ema.pt moved to checkpoints/")
else:
    print("EMA DDPM file not found. Please upload 'latent_model.ema.pt'.")


try:
    vae = VAE().to(device)
    load_checkpoint(vae, config.VAE_PATH, device=device)
    vae.eval()
    for p in vae.parameters(): p.requires_grad_(False)

    latent_model = LatentEpsModel().to(device)
    load_checkpoint(latent_model, config.LATENT_EMA_PATH, device=device)
    latent_model.eval()
    for p in latent_model.parameters(): p.requires_grad_(False)

    ddpm_scheduler = ConditionalDenoiseDiffusion(latent_model, device=device)
    print("\nModels initialized and ready for sampling.")

except Exception as e:
    print(f"\nERROR initializing models: {e}")
    print("Please ensure your uploaded files are correct and match the architecture.")

Using device: cuda
Please upload the 'fashion_vae.pt' and 'latent_model.ema.pt' files:


Saving fashion_vae.pt to fashion_vae.pt
Saving latent_model.ema.pt to latent_model.ema.pt
fashion_vae.pt moved to checkpoints/
latent_model.ema.pt moved to checkpoints/
Checkpoint loaded from checkpoints/fashion_vae.pt
Checkpoint loaded from checkpoints/latent_model.ema.pt

Models initialized and ready for sampling.


In [19]:
%%writefile app.py
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image
from PIL import Image
import io
import time
import os
import math
from tqdm.auto import tqdm

class Config:
    LATENT_CHANNELS = 4
    LATENT_SIZE = 7
    VAE_PATH = "checkpoints/fashion_vae.pt"
    LATENT_EMA_PATH = "checkpoints/latent_model.ema.pt"
    N_STEPS = 1000
    NUM_CLASSES = 10
    IMAGE_CHANNELS = 1
    BASE_CHANNELS = 64
    CHANNEL_MULT = [1, 2, 2]
    NUM_RES_BLOCKS = 2
    TIME_EMBED_DIM = BASE_CHANNELS*4
    CONTEXT_DIM = BASE_CHANNELS*4

config = Config()
CLASS_MAP = {
    0: "T-shirt/top", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat",
    5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle boot"
}

def get_timestep_embedding(timesteps, dim):
    half_dim = dim // 2
    exponent = torch.exp(-math.log(10000) * torch.arange(half_dim, device=timesteps.device) / half_dim)

    timesteps = timesteps.float().unsqueeze(-1)
    emb = timesteps * exponent.unsqueeze(0)
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
    return emb.to(dtype=torch.float32)

def load_checkpoint(model, path, device="cpu"):
    if os.path.exists(path):
        model.load_state_dict(torch.load(path, map_location=device))
        return True
    return False

class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class Encoder(nn.Module):
    def __init__(self, latent_channels=config.LATENT_CHANNELS):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(config.IMAGE_CHANNELS, 32, 3, padding=1),
            nn.GroupNorm(8, 32), nn.SiLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.GroupNorm(8, 64), nn.SiLU(), ChannelAttention(64),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.GroupNorm(8, 128), nn.SiLU(), ChannelAttention(128)
        )
        self.to_mu = nn.Conv2d(128, latent_channels, 1)
        self.to_logvar = nn.Conv2d(128, latent_channels, 1)
    def forward(self, x):
        h = self.net(x)
        return self.to_mu(h), self.to_logvar(h)

class Decoder(nn.Module):
    def __init__(self, latent_channels=config.LATENT_CHANNELS):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(latent_channels, 128, 3, padding=1),
            nn.GroupNorm(8, 128), nn.SiLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.GroupNorm(8, 64), nn.SiLU(), ChannelAttention(64),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.GroupNorm(8, 32), nn.SiLU(),
            nn.Conv2d(32, config.IMAGE_CHANNELS, 3, padding=1),
            nn.Tanh()
        )
    def forward(self, z):
        return self.net(z)

class VAE(nn.Module):
    def __init__(self, latent_channels=config.LATENT_CHANNELS):
        super().__init__()
        self.encoder = Encoder(latent_channels)
        self.decoder = Decoder(latent_channels)
    def decode(self, z):
        return self.decoder(z)

class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim, num_heads=4):
        super().__init__()
        self.q_proj = nn.Linear(query_dim, query_dim)
        self.k_proj = nn.Linear(context_dim, query_dim)
        self.v_proj = nn.Linear(context_dim, query_dim)
        self.attn = nn.MultiheadAttention(query_dim, num_heads, batch_first=True)
        self.proj_out = nn.Linear(query_dim, query_dim)
    def forward(self, x, context):
        q = self.q_proj(x)
        k = self.k_proj(context)
        v = self.v_proj(context)
        attn_output, _ = self.attn(q, k, v)
        return x + self.proj_out(attn_output)

class LatentResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, context_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_ch))
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
        self.norm_ca = nn.GroupNorm(8, out_ch)
        self.ca = CrossAttention(out_ch, context_dim)
    def forward(self, x, t_emb, context):
        b, _, h, w = x.shape
        out_ch = self.conv1.out_channels
        h_temp = F.group_norm(x, min(8, x.shape[1]))
        h_temp = self.conv1(F.silu(h_temp))
        h_temp = h_temp + self.time_mlp(t_emb)[:, :, None, None]
        h_ca = F.silu(self.norm_ca(h_temp))
        new_hw = int(h) * int(w)
        h_ca = h_ca.permute(0, 2, 3, 1).reshape(b, new_hw, out_ch)
        h_ca = self.ca(h_ca, context)
        h_temp = h_temp + h_ca.transpose(1, 2).reshape(b, out_ch, h, w)
        h_temp = self.conv2(F.silu(F.group_norm(h_temp, min(8, h_temp.shape[1]))))
        return h_temp + self.shortcut(x)

class LatentEpsModel(nn.Module):
    def __init__(self):
        super().__init__()
        time_dim = config.TIME_EMBED_DIM
        context_dim = config.CONTEXT_DIM
        self.time_mlp = nn.Sequential(nn.Linear(time_dim, time_dim), nn.SiLU(), nn.Linear(time_dim, time_dim))
        self.class_emb = nn.Embedding(config.NUM_CLASSES, context_dim)
        self.conv_in = nn.Conv2d(config.LATENT_CHANNELS, config.BASE_CHANNELS, 3, padding=1)

        self.downs = nn.ModuleList()
        ch = config.BASE_CHANNELS
        for mult in config.CHANNEL_MULT:
            out = config.BASE_CHANNELS * mult
            blocks = nn.ModuleList()
            blocks.append(LatentResBlock(ch, out, time_dim, context_dim))
            for _ in range(config.NUM_RES_BLOCKS - 1):
                blocks.append(LatentResBlock(out, out, time_dim, context_dim))
            downsample = nn.Conv2d(out, out, 3, 1, 1)
            self.downs.append(nn.ModuleDict({"blocks": blocks, "down": downsample}))
            ch = out

        self.bot1 = LatentResBlock(ch, ch, time_dim, context_dim)
        self.bot2 = LatentResBlock(ch, ch, time_dim, context_dim)

        self.ups = nn.ModuleList()
        for mult in reversed(config.CHANNEL_MULT):
            skip_ch = config.BASE_CHANNELS * mult
            out = skip_ch
            in_ch_res = out + skip_ch
            blocks = nn.ModuleList()
            blocks.append(LatentResBlock(in_ch_res, out, time_dim, context_dim))
            for _ in range(config.NUM_RES_BLOCKS - 1):
                blocks.append(LatentResBlock(out, out, time_dim, context_dim))
            upsample = nn.ConvTranspose2d(ch, out, 3, 1, 1)
            self.ups.append(nn.ModuleDict({"blocks": blocks, "up": upsample}))
            ch = out

        self.conv_out = nn.Conv2d(ch, config.LATENT_CHANNELS, 3, padding=1)

    def forward(self, x, t, y):
        t_emb = self.time_mlp(get_timestep_embedding(t, config.TIME_EMBED_DIM))
        context = self.class_emb(y).unsqueeze(1)

        hs = []
        h = self.conv_in(x)
        for module in self.downs:
            for block in module["blocks"]:
                h = block(h, t_emb, context)
            hs.append(h)
            h = module["down"](h)

        h = self.bot1(h, t_emb, context)
        h = self.bot2(h, t_emb, context)

        for module in self.ups:
            skip = hs.pop()
            h = module["up"](h)
            h = torch.cat([h, skip], dim=1)
            for block in module["blocks"]:
                h = block(h, t_emb, context)

        return self.conv_out(h)

class ConditionalDenoiseDiffusion():
    def __init__(self, eps_model, n_steps=config.N_STEPS, device=None):
        self.eps_model = eps_model
        self.device = device if device is not None else torch.device("cpu")
        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(self.device)
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.n_steps = n_steps
        self.sqrt_alpha_bar = torch.sqrt(self.alpha_bar)
        self.sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar)
        self.alphas_cumprod_prev = F.pad(self.alpha_bar[:-1], (1, 0), value=1.0)
        self.post_variance = self.beta * (1. - self.alphas_cumprod_prev) / (1. - self.alpha_bar)

    def p_sample(self, xt, t, c=None):
        if not isinstance(t, torch.Tensor):
            t = torch.tensor([t] * xt.shape[0], device=xt.device, dtype=torch.long)

        eps_theta = self.eps_model(xt, t, c)

        alpha_t = self.alpha[t].reshape(-1, 1, 1, 1)
        alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
        alpha_bar_t_prev = self.alphas_cumprod_prev[t].reshape(-1, 1, 1, 1)

        x0_pred = (xt - self.sqrt_one_minus_alpha_bar[t].reshape(-1, 1, 1, 1) * eps_theta) / self.sqrt_alpha_bar[t].reshape(-1, 1, 1, 1)
        x0_pred = torch.clamp(x0_pred, -1., 1.)

        mean = (alpha_bar_t_prev.sqrt() * self.beta[t].reshape(-1, 1, 1, 1) / (1. - alpha_bar_t)) * x0_pred + \
               (alpha_t.sqrt() * (1. - alpha_bar_t_prev) / (1. - alpha_bar_t)) * xt

        variance = self.post_variance[t].reshape(-1, 1, 1, 1)

        if t[0] > 0:
            noise = torch.randn_like(xt)
            return mean + torch.sqrt(variance) * noise
        else:
            return mean

    def sample(self, shape, device, c=None):
        x = torch.randn(shape, device=device)
        for t in tqdm(reversed(range(self.n_steps)), desc="Denoising", leave=False):
            x = self.p_sample(x, t, c)
        return x


@st.cache_resource
def load_models():
    try:
        DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        vae = VAE().to(DEVICE)
        if not load_checkpoint(vae, config.VAE_PATH, device=DEVICE):
            st.error(f"FATAL: VAE checkpoint not found at {config.VAE_PATH}. Did you upload it?")
            st.stop()
        vae.eval()
        for p in vae.parameters(): p.requires_grad_(False)

        latent_model = LatentEpsModel().to(DEVICE)
        if not load_checkpoint(latent_model, config.LATENT_EMA_PATH, device=DEVICE):
            st.error(f"FATAL: DDPM EMA checkpoint not found at {config.LATENT_EMA_PATH}. Did you upload it?")
            st.stop()
        latent_model.eval()
        for p in latent_model.parameters(): p.requires_grad_(False)

        ddpm_scheduler = ConditionalDenoiseDiffusion(latent_model, device=DEVICE)

        st.info(f"Models loaded successfully on device: {DEVICE}")
        return vae, ddpm_scheduler, DEVICE

    except Exception as e:
        st.error(f"Critical error during model loading: {e}")
        st.stop()

VAE, DDPM_SCHEDULER, DEVICE = load_models()


def generate_image_and_display(class_label_id, n_images=1):
    """Samples from the DDPM and decodes the latent output."""
    start_time = time.time()

    latent_shape = (n_images, config.LATENT_CHANNELS, config.LATENT_SIZE, config.LATENT_SIZE)
    target_labels = torch.tensor([class_label_id] * n_images, device=DEVICE)

    with torch.no_grad():
        z_samples = DDPM_SCHEDULER.sample(
            shape=latent_shape,
            device=DEVICE,
            c=target_labels
        )

        x_samples = VAE.decode(z_samples).clamp(-1, 1)

    x_samples = (x_samples + 1) * 0.5

    output_buffer = io.BytesIO()
    save_image(x_samples, output_buffer, nrow=n_images, format='png')
    output_buffer.seek(0)

    end_time = time.time()
    st.success(f"Generation finished in {end_time - start_time:.2f} seconds.")
    return Image.open(output_buffer)


st.title("ðŸ‘— Fashion CLDM Generator")
st.subheader("Conditional Latent Diffusion Model for Fashion-MNIST")
st.markdown("Select a class and click 'Generate' to create a new image.")

selected_class_name = st.selectbox(
    "Choose a Fashion Class:",
    options=list(CLASS_MAP.values())
)
selected_class_id = [k for k, v in CLASS_MAP.items() if v == selected_class_name][0]

n_images = st.slider("Number of Images to Generate (Batch Size):", min_value=1, max_value=8, value=4)

if st.button("âœ¨ Generate Image(s)"):
    st.info(f"Generating {n_images} image(s) for **{selected_class_name}** (Class ID: {selected_class_id})...")
    with st.spinner('Denoising latent space...'):
        generated_image = generate_image_and_display(selected_class_id, n_images=n_images)
        st.image(generated_image, caption=f'Generated Image(s) for {selected_class_name}', width=200)

Overwriting app.py


In [None]:
start_streamlit_server("app.py")

Starting Streamlit server...
nohup: appending output to 'nohup.out'

--- Ngrok Tunnel Setup ---
Streamlit App is running on: NgrokTunnel: "https://crew-unphosphatized-mirella.ngrok-free.dev" -> "http://localhost:8501"

Click the link above to view the interactive generator.
