In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ----------------------------
# Text encoder (placeholder)
# ----------------------------
class TinyTextEncoder(nn.Module):
    def __init__(self, vocab_size=10000, dim=768, n_layers=4, n_heads=12, max_len=77):
        super().__init__()
        self.token = nn.Embedding(vocab_size, dim)
        self.pos = nn.Embedding(max_len, dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=n_heads, dim_feedforward=dim*4, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.dim = dim
        self.max_len = max_len

    def forward(self, token_ids):
        # token_ids: (B, L)
        B, L = token_ids.shape
        pos_ids = torch.arange(L, device=token_ids.device).unsqueeze(0).expand(B, L)
        x = self.token(token_ids) + self.pos(pos_ids)
        x = self.encoder(x)  # (B, L, dim)
        return x  # sequence of text embeddings for cross-attention


# ----------------------------
# VAE for latent space
# ----------------------------
class EncoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch, downsample=True):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.down = nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1) if downsample else nn.Identity()

    def forward(self, x):
        x = F.silu(self.norm1(self.conv1(x)))
        x = F.silu(self.norm2(self.conv2(x)))
        x = self.down(x)
        return x

class DecoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch, upsample=True):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.up = nn.ConvTranspose2d(out_ch, out_ch, 4, stride=2, padding=1) if upsample else nn.Identity()

    def forward(self, x):
        x = F.silu(self.norm1(self.conv1(x)))
        x = F.silu(self.norm2(self.conv2(x)))
        x = self.up(x)
        return x

class SimpleVAE(nn.Module):
    def __init__(self, in_ch=3, latent_ch=4):
        super().__init__()
        # Encoder to latent
        self.enc1 = EncoderBlock(in_ch, 64)
        self.enc2 = EncoderBlock(64, 128)
        self.enc3 = EncoderBlock(128, 256)
        self.enc4 = EncoderBlock(256, 256, downsample=False)
        self.to_mu = nn.Conv2d(256, latent_ch, 1)
        self.to_logvar = nn.Conv2d(256, latent_ch, 1)
        # Decoder from latent
        self.dec1 = DecoderBlock(latent_ch, 256, upsample=False)
        self.dec2 = DecoderBlock(256, 256)
        self.dec3 = DecoderBlock(256, 128)
        self.dec4 = DecoderBlock(128, 64)
        self.to_img = nn.Conv2d(64, in_ch, 1)
        self.scaling = 0.18215  # typical SD latent scaling

    def encode(self, x):
        h = self.enc4(self.enc3(self.enc2(self.enc1(x))))
        mu, logvar = self.to_mu(h), self.to_logvar(h)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        return z * self.scaling, mu, logvar

    def decode(self, z):
        z = z / self.scaling
        h = self.dec4(self.dec3(self.dec2(self.dec1(z))))
        x_rec = self.to_img(h)
        return x_rec

    def forward(self, x):
        z, mu, logvar = self.encode(x)
        x_rec = self.decode(z)
        return x_rec, mu, logvar


# ----------------------------
# Cross-attention block
# ----------------------------
class CrossAttention(nn.Module):
    def __init__(self, dim_q, dim_kv, n_heads=8):
        super().__init__()
        self.n_heads = n_heads
        self.to_q = nn.Linear(dim_q, dim_q)
        self.to_k = nn.Linear(dim_kv, dim_q)
        self.to_v = nn.Linear(dim_kv, dim_q)
        self.proj = nn.Linear(dim_q, dim_q)

    def forward(self, q, kv):
        # q: (B, Nq, Dq) image features as sequence
        # kv: (B, Nk, Dkv) text embeddings
        B, Nq, Dq = q.shape
        Nk = kv.shape[1]
        qh = self.to_q(q).view(B, Nq, self.n_heads, Dq // self.n_heads).transpose(1, 2)  # (B,H,Nq,Dh)
        kh = self.to_k(kv).view(B, Nk, self.n_heads, Dq // self.n_heads).transpose(1, 2) # (B,H,Nk,Dh)
        vh = self.to_v(kv).view(B, Nk, self.n_heads, Dq // self.n_heads).transpose(1, 2) # (B,H,Nk,Dh)
        attn = torch.softmax((qh @ kh.transpose(-2, -1)) / (Dq // self.n_heads) ** 0.5, dim=-1)  # (B,H,Nq,Nk)
        out = attn @ vh  # (B,H,Nq,Dh)
        out = out.transpose(1, 2).contiguous().view(B, Nq, Dq)
        return self.proj(out)


# ----------------------------
# UNet block with cross-attention
# ----------------------------
class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv1 = nn.Conv2d(ch, ch, 3, padding=1)
        self.conv2 = nn.Conv2d(ch, ch, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, ch)
        self.norm2 = nn.GroupNorm(8, ch)

    def forward(self, x):
        h = F.silu(self.norm1(self.conv1(x)))
        h = self.norm2(self.conv2(h))
        return F.silu(h + x)

class UNetWithText(nn.Module):
    def __init__(self, in_ch=4, base_ch=256, text_dim=768):
        super().__init__()
        # Down path
        self.in_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1)
        self.down1 = nn.Conv2d(base_ch, base_ch, 3, stride=2, padding=1)
        self.rb1 = ResBlock(base_ch)
        self.down2 = nn.Conv2d(base_ch, base_ch, 3, stride=2, padding=1)
        self.rb2 = ResBlock(base_ch)

        # Bottleneck + cross-attention
        self.rb_mid = ResBlock(base_ch)
        self.txt_proj = nn.Linear(text_dim, base_ch)
        self.cross = CrossAttention(dim_q=base_ch, dim_kv=text_dim, n_heads=8)

        # Up path
        self.up1 = nn.ConvTranspose2d(base_ch, base_ch, 4, stride=2, padding=1)
        self.rb3 = ResBlock(base_ch)
        self.up2 = nn.ConvTranspose2d(base_ch, base_ch, 4, stride=2, padding=1)
        self.rb4 = ResBlock(base_ch)

        # Output head (predict epsilon in latent space)
        self.out = nn.Conv2d(base_ch, in_ch, 3, padding=1)

        # Timestep embedding
        self.t_embed = nn.Sequential(
            nn.Linear(1, base_ch), nn.SiLU(), nn.Linear(base_ch, base_ch)
        )

    def forward(self, x_t, t, text_emb):
        # x_t: (B,4,H,W), t: (B,), text_emb: (B,L,768)
        B, _, H, W = x_t.shape
        h = F.silu(self.in_conv(x_t))
        h = self.down1(h); h = self.rb1(h)
        h = self.down2(h); h = self.rb2(h)

        # Add timestep conditioning (FiLM-like bias)
        t_inp = t.view(B, 1)
        gamma = self.t_embed(t_inp).view(B, -1, 1, 1)
        h = h + gamma

        # Bottleneck + cross-attention: flatten spatial to sequence
        h = self.rb_mid(h)
        seq = h.view(B, -1, H//4 * W//4).transpose(1, 2)  # (B, Nq, C)
        seq = seq + self.cross(seq, text_emb)             # conditioned on text
        h = seq.transpose(1, 2).view(B, -1, H//4, W//4)

        # Up path
        h = self.up1(h); h = self.rb3(h)
        h = self.up2(h); h = self.rb4(h)

        eps = self.out(h)  # predicted noise in latent space
        return eps


# ----------------------------
# Scheduler helper (DDPM-like)
# ----------------------------
class SimpleDDPMSchedule:
    def __init__(self, timesteps=1000):
        betas = torch.linspace(1e-4, 0.02, timesteps)
        alphas = 1.0 - betas
        self.alphas_cumprod = torch.cumprod(alphas, dim=0)
        self.timesteps = timesteps

    def get_alpha_sigma(self, t_cont):
        # t_cont in [0,1], map to discrete index
        idx = (t_cont * (self.timesteps - 1)).long().clamp(0, self.timesteps - 1)
        alpha = torch.sqrt(self.alphas_cumprod[idx])
        sigma = torch.sqrt(1.0 - self.alphas_cumprod[idx])
        return alpha.view(-1, 1, 1, 1), sigma.view(-1, 1, 1, 1)


# ----------------------------
# Training step (noise prediction loss with text conditioning)
# ----------------------------
def training_step(vae, unet, schedule, images, token_ids):
    # Encode images to latents
    latents, mu, logvar = vae.encode(images)  # (B,4,H/8,W/8) roughly

    # Sample timesteps and noise
    B = latents.size(0)
    t = torch.rand(B, device=latents.device)
    alpha, sigma = schedule.get_alpha_sigma(t)
    noise = torch.randn_like(latents)
    x_t = alpha * latents + sigma * noise

    # Encode text
    text_encoder = TinyTextEncoder()
    text_encoder.to(images.device)
    text_emb = text_encoder(token_ids)  # (B,L,768)

    # Predict noise
    eps_pred = unet(x_t, t, text_emb)

    # Loss: MSE between true noise and predicted noise (Îµ-prediction)
    loss = F.mse_loss(eps_pred, noise)

    # VAE KL term (optional, if training VAE jointly)
    kl = (-0.5 * (1 + logvar - mu.pow(2) - logvar.exp())).mean()
    return loss + 1e-4 * kl


# ----------------------------
# Inference (classifier-free guidance, minimal)
# ----------------------------
@torch.no_grad()
def sample_images(vae, unet, schedule, token_ids, steps=50, guidance_scale=7.5, shape=(4, 64, 64)):
    device = next(unet.parameters()).device
    B = token_ids.size(0)
    x = torch.randn((B, *shape), device=device)

    # Build text encoder once
    text_encoder = TinyTextEncoder().to(device)
    text_emb = text_encoder(token_ids)
    empty_tokens = torch.zeros_like(token_ids)  # simplistic "empty" prompt
    empty_emb = text_encoder(empty_tokens)

    ts = torch.linspace(1.0, 0.0, steps+1, device=device)
    for i in range(steps):
        t = ts[i].expand(B)
        alpha, sigma = schedule.get_alpha_sigma(t)
        # Predict eps (conditional & unconditional)
        eps_uncond = unet(x, t, empty_emb)
        eps_cond = unet(x, t, text_emb)
        eps_guided = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
        # DDIM-like update (simplified)
        x0_pred = (x - sigma * eps_guided) / (alpha + 1e-8)
        if i < steps - 1:
            t_next = ts[i+1].expand(B)
            alpha_next, sigma_next = schedule.get_alpha_sigma(t_next)
            x = alpha_next * x0_pred + sigma_next * torch.randn_like(x)
        else:
            x = x0_pred

    # Decode latents to images
    imgs = vae.decode(x)
    return imgs.clamp(-1, 1)


# ----------------------------
# Example wiring (toy run)
# ----------------------------
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"

    vae = SimpleVAE(in_ch=3, latent_ch=4).to(device)
    unet = UNetWithText(in_ch=4, base_ch=256, text_dim=768).to(device)
    schedule = SimpleDDPMSchedule(timesteps=1000)

    # Dummy batch: images in [-1,1], tokens as ints
    images = torch.randn(4, 3, 256, 256, device=device)
    token_ids = torch.randint(0, 10000, (4, 77), device=device)

    # One training step
    loss = training_step(vae, unet, schedule, images, token_ids)
    print("Training loss:", float(loss))

    # One sampling run
    imgs = sample_images(vae, unet, schedule, token_ids, steps=10, guidance_scale=5.0, shape=(4, 32, 32))
    print("Generated images shape:", imgs.shape)
