In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
from PIL import Image
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim import AdamW
from diffusers import AutoencoderKL, UNet2DConditionModel
import numpy as np
import cv2
from PIL import Image
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel


## Train

### Load data

In [None]:
# -----------------------------
# Dataset: devuelve pares RGB / Gray
# -----------------------------
class ColorizationDataset(Dataset):
    def __init__(self, img_dir, size=512):
        self.img_dir = img_dir
        self.files = [f for f in os.listdir(img_dir) if f.endswith((".jpg",".png"))]
        self.size = size
        self.to_tensor = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # [-1,1]
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = os.path.join(self.img_dir, self.files[idx])
        img_rgb = Image.open(path).convert("RGB")
        img_gray = img_rgb.convert("L").convert("RGB")  # 3 canales para VAE

        rgb_tensor = self.to_tensor(img_rgb)   # [3,H,W]
        gray_tensor = self.to_tensor(img_gray) # [3,H,W]

        return {"rgb": rgb_tensor, "gray": gray_tensor}


### Cargar modelos e hiperparámetros

In [None]:
# -----------------------------
# Hiperparámetros
# -----------------------------
NUM_EPOCHS = 100
EPOCHS_VISUALIZATION = 5
DEVICE = "cpu"
BATCH_SIZE = 32
LR = 1e-5

# -----------------------------
# Modelos: VAE congelado y UNet entrenable
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

vae = AutoencoderKL.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="vae"
).to(device)
vae.requires_grad_(False)

unet = UNet2DConditionModel.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="unet"
).to(device)

# -----------------------------
# Dataloader
# -----------------------------
dataset = ColorizationDataset("/ruta/a/imagenes", size=512)
# Dataset de validación
val_dataset = ColorizationDataset("/ruta/a/imagenes_validacion", size=512)

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

# -----------------------------
# Optimizador
# -----------------------------
opt = AdamW(unet.parameters(), lr=LR, weight_decay=0.01)

### Bucle entrenamiento

In [None]:
def visualize_validation(vae, unet, val_dataset, device=DEVICE, steps=50, color_scale=1.0, num_samples=5):
    # Selecciona aleatoriamente num_samples imágenes de validación
    indices = np.random.choice(len(val_dataset), num_samples, replace=False)
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 3*num_samples))

    for i, idx in enumerate(indices):
        sample = val_dataset[idx]
        # rgb = sample["rgb"].unsqueeze(0).to(device)
        gray = sample["gray"].unsqueeze(0).to(device)

        with torch.no_grad():
            # Codificar gris
            z_gray = vae.encode(gray).latent_dist.sample() * 0.18215
            z_t = z_gray.clone()

            # Refinamiento paso a paso
            ts = torch.linspace(1.0, 0.0, steps, device=device)
            for t in ts:
                t_int = torch.tensor([int(t.item()*999)], device=device)
                text_emb = torch.zeros((1,77,768), device=device)
                delta_t = unet(sample=z_t, timestep=t_int, encoder_hidden_states=text_emb).sample
                z_t = z_t + (1.0/steps) * delta_t

            z_col = z_gray + color_scale * (z_t - z_gray)
            out_rgb = vae.decode(z_col / 0.18215).sample.squeeze(0)

        # Conversión a imágenes para mostrar
        def tensor_to_img(t):
            arr = t.detach().cpu().permute(1,2,0).numpy()
            arr = (arr*0.5+0.5).clip(0,1)
            return arr

        rgb_img = tensor_to_img(sample["rgb"])
        gray_img = tensor_to_img(sample["gray"])
        colorized_img = tensor_to_img(out_rgb)

        axes[i,0].imshow(gray_img)
        axes[i,0].set_title("Grayscale input")
        axes[i,0].axis("off")

        axes[i,1].imshow(colorized_img)
        axes[i,1].set_title("Colorized output")
        axes[i,1].axis("off")

        axes[i,2].imshow(rgb_img)
        axes[i,2].set_title("Original RGB")
        axes[i,2].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
for epoch in range(NUM_EPOCHS):
    for batch in loader:
        rgb_batch = batch["rgb"].to(DEVICE)   # [B,3,H,W]
        gray_batch = batch["gray"].to(DEVICE) # [B,3,H,W]

        with torch.no_grad():
            # Encode latentes
            z_rgb = vae.encode(rgb_batch).latent_dist.sample() * 0.18215
            z_gray = vae.encode(gray_batch).latent_dist.sample() * 0.18215

            # Mezcla temporal
            t = torch.rand(z_rgb.size(0), device=DEVICE)
            z_t = (1 - t.view(-1,1,1,1)) * z_gray + t.view(-1,1,1,1) * z_rgb
            target = z_rgb

        # Diffusers espera timesteps enteros (0-999)
        t_int = torch.randint(0, 1000, (rgb_batch.size(0),), device=DEVICE)

        # Dummy text conditioning (sin texto)
        text_emb = torch.zeros((rgb_batch.size(0), 77, 768), device=DEVICE)

        # Predicción del residuo
        delta_hat = unet(
            sample=z_t,
            timestep=t_int,
            encoder_hidden_states=text_emb
        ).sample

        # Reconstrucción y loss
        recon = z_t + delta_hat
        loss = F.mse_loss(recon, target)

        opt.zero_grad()
        loss.backward()
        opt.step()

    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Loss: {loss.item():.4f}")

    if (epoch + 1) % EPOCHS_VISUALIZATION == 0:
        visualize_validation(vae, unet, val_dataset, device=DEVICE, steps=50, num_samples=5)

## Inferencia

In [None]:
class Colorizer:
    def __init__(self, vae: AutoencoderKL, unet: UNet2DConditionModel, device='cuda'):
        self.vae = vae.eval().to(device)
        self.unet = unet.eval().to(device)
        self.device = device

    def _preprocess_gray(self, img_gray_pil, size=512):
        img = img_gray_pil.resize((size, size), Image.BICUBIC)
        arr = np.array(img.convert("L").convert("RGB")).astype(np.float32) / 255.0
        arr = (arr - 0.5) / 0.5
        return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(self.device)  # [1,3,H,W]

    def _replace_luma(self, out_rgb_norm, in_gray_norm):
        out = out_rgb_norm.permute(1,2,0).cpu().numpy()*0.5+0.5
        gray = in_gray_norm[0].cpu().numpy()*0.5+0.5
        out_bgr = cv2.cvtColor((out*255).astype(np.uint8), cv2.COLOR_RGB2BGR)
        lab = cv2.cvtColor(out_bgr, cv2.COLOR_BGR2Lab)
        lab[:,:,0] = (gray*255).astype(np.uint8)
        bgr = cv2.cvtColor(lab, cv2.COLOR_Lab2BGR)
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        return Image.fromarray(rgb)

    @torch.no_grad()
    def colorize(self, gray_img_pil: Image.Image, steps=50, size=512, color_scale=1.0):
        # 1) Preprocesar y codificar
        gray = self._preprocess_gray(gray_img_pil, size)  # [1,3,H,W]
        z_gray = self.vae.encode(gray).latent_dist.sample() * 0.18215
        z_t = z_gray.clone()

        # 2) Iteración de refinamiento
        ts = torch.linspace(1.0, 0.0, steps, device=self.device)
        for t in ts:
            t_int = torch.tensor([int(t.item()*999)], device=self.device)
            text_emb = torch.zeros((1,77,768), device=self.device)  # dummy text conditioning
            delta_t = self.unet(sample=z_t, timestep=t_int, encoder_hidden_states=text_emb).sample
            z_t = z_t + (1.0/steps) * delta_t

        # 3) Escalado y decodificación
        z_col = z_gray + color_scale * (z_t - z_gray)
        out_rgb = self.vae.decode(z_col / 0.18215).sample.squeeze(0)

        # 4) Reemplazo de luminancia
        final_img = self._replace_luma(out_rgb, gray.squeeze(0))
        return final_img


In [None]:
# Cargar modelos de diffusers
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")

colorizer = Colorizer(vae, unet, device=DEVICE)

gray_img = Image.open("foto_gris.png").convert("L")
colorized = colorizer.colorize(gray_img, steps=50, size=512, color_scale=1.0)
colorized.save("foto_color.png")
