In [None]:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image

# -------- Model stubs (replace with actual implementations / loaded weights) --------
class SDVAE(nn.Module):
    def __init__(self): super().__init__()
    def encode(self, x): 
        # x: [B,3,H,W] in [-1,1]
        # return latent z: [B,C,H',W']
        raise NotImplementedError
    def decode(self, z):
        # z: [B,C,H',W']
        # return img: [B,3,H,W] in [-1,1]
        raise NotImplementedError

class SDUNet(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, z_t, t_emb):
        # Predict color residual at timestep t
        raise NotImplementedError

# -------- Utilities --------
def preprocess_gray(img_gray_pil, size=512):
    img = img_gray_pil.resize((size, size), Image.BICUBIC)
    arr = np.array(img).astype(np.float32) / 255.0
    arr = (arr - 0.5) / 0.5
    # replicate to 3 channels to feed VAE encoder
    return torch.from_numpy(arr)[None, ...].repeat(3, 1, 1)  # [3,H,W]

def timestep_embedding(t, dim=128, device='cpu'):
    # Simple sinusoidal embedding for scalar t in [0,1]
    # t: [B]
    freq = torch.linspace(1.0, 1000.0, dim//2, device=device)
    angles = t[:, None] * freq[None, :]
    emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
    return emb

def replace_luma(out_rgb_norm, in_gray_norm):
    # out_rgb_norm: [3,H,W] in [-1,1]
    # in_gray_norm: [3,H,W] replicated gray in [-1,1]
    out = out_rgb_norm.permute(1,2,0).cpu().numpy()*0.5+0.5  # [H,W,3] in [0,1]
    gray = in_gray_norm[0].cpu().numpy()*0.5+0.5             # [H,W] in [0,1]
    import cv2
    out_bgr = cv2.cvtColor((out*255).astype(np.uint8), cv2.COLOR_RGB2BGR)
    lab = cv2.cvtColor(out_bgr, cv2.COLOR_BGR2Lab)
    lab[:,:,0] = (gray*255).astype(np.uint8)  # replace L
    bgr = cv2.cvtColor(lab, cv2.COLOR_Lab2BGR)
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    return Image.fromarray(rgb)

# -------- Inference pipeline (no text) --------
class Colorizer:
    def __init__(self, vae: SDVAE, unet: SDUNet, device='cuda'):
        self.vae = vae.eval().to(device)
        self.unet = unet.eval().to(device)
        self.device = device

    @torch.no_grad()
    def colorize(self, gray_img_pil: Image.Image, steps=50, size=512, color_scale=1.0):
        # 1) Preprocess and encode grayscale
        gray = preprocess_gray(gray_img_pil, size).unsqueeze(0).to(self.device)  # [1,3,H,W]
        z_gray = self.vae.encode(gray)  # [1,C,h,w]

        # 2) Iterative cold-diffusion refinement (no text)
        ts = torch.linspace(1.0, 0.0, steps, device=self.device)
        z_t = z_gray.clone()

        for t in ts:
            t_emb = timestep_embedding(t[None], device=self.device)  # [1,D]
            delta_t = self.unet(z_t, t_emb)                          # [1,C,h,w]
            z_t = z_t + (1.0/steps) * delta_t

        # 3) Apply global color scale for saturation control and decode
        z_col = z_gray + color_scale * (z_t - z_gray)
        out_rgb = self.vae.decode(z_col).squeeze(0)  # [3,H,W] in [-1,1]

        # 4) Luma replacement to reduce artifacts
        final_img = replace_luma(out_rgb, gray.squeeze(0))
        return final_img

# -------- Usage example --------
# vae = load_sd15_vae()       # implement loading
# unet = load_finetuned_unet()# implement loading (trained for color residuals)
# colorizer = Colorizer(vae, unet, device='cuda')
# gray = Image.open('input_grayscale.png').convert('L')
# colorized = colorizer.colorize(gray, steps=50, size=512, color_scale=1.0)
# colorized.save('colorized.png')
