# Fine tuning

In [1]:
# !pip install lpips
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
from PIL import Image

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim import AdamW

from diffusers import AutoencoderKL, UNet2DConditionModel, UNet2DModel

from datetime import datetime

import lpips
from skimage.metrics import structural_similarity
from skimage.color import rgb2lab, deltaE_ciede2000

# Carpeta donde guardar
SAVE_DIR = "checkpoints"
os.makedirs(SAVE_DIR, exist_ok=True)

def save_model(unet, opt, name="unet"):
    # Fecha y hora actual
    timestamp = datetime.now().strftime("%m%d_%H%M")
    filename = f"{name}_base_{timestamp}.pt"
    path = os.path.join(SAVE_DIR, filename)

    # Guardar estado del modelo y optimizador
    torch.save({
        "model_state_dict": unet.state_dict(),
        "optimizer_state_dict": opt.state_dict(),
    }, path)
    print(f"Modelo guardado en {path}")

# -----------------------------
# Cargar modelo y optimizador
# -----------------------------
def load_model(unet, opt, checkpoint_path):
    # Cargar el diccionario guardado
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)

    # Restaurar pesos del modelo
    unet.load_state_dict(checkpoint["model_state_dict"])

    # Restaurar estado del optimizador
    opt.load_state_dict(checkpoint["optimizer_state_dict"])

    print(f"Modelo y optimizador cargados desde {checkpoint_path}")
    return unet, opt


def visualize_validation(
    vae,
    unet,
    val_dataset,
    unet_device="cuda",
    vae_device="cpu",
    step_list=[15,30,50],
    num_samples=2,
    color_scale=1.0,
    needs_text=False
):
    indices = np.random.choice(len(val_dataset), num_samples, replace=False)
    fig, axes = plt.subplots(num_samples, len(step_list)+2, figsize=(15, 5*num_samples), squeeze=False)
    
    TEXT_LEN = 77
    TEXT_DIM = 768
    
    for i, idx in enumerate(indices):
        sample = val_dataset[idx]
        # Entrada al VAE en vae_device
        gray = sample["gray"].unsqueeze(0).to(vae_device)
        rgb_img = (sample["rgb"].detach().cpu().permute(1,2,0).numpy()*0.5+0.5).clip(0,1)
        gray_img = (sample["gray"].detach().cpu().permute(1,2,0).numpy()*0.5+0.5).clip(0,1)

        axes[i,0].imshow(gray_img)
        axes[i,0].set_title("Grayscale input")
        axes[i,0].axis("off")

        with torch.no_grad():
            # Codificación con VAE en vae_device
            z_gray_cpu = vae.encode(gray).latent_dist.sample() * 0.18215
            # Mover latente al dispositivo del UNet
            z_gray = z_gray_cpu.to(unet_device)
            
            if needs_text:
                text_zeros = torch.zeros((1, TEXT_LEN, TEXT_DIM), device=unet_device)

            for j, steps in enumerate(step_list):
                z_t = z_gray.clone()
                ts = torch.linspace(1.0, 0.0, steps, device=unet_device)
                for t in ts:
                    t_int = torch.tensor([int(t.item()*999)], device=unet_device)
                    if needs_text:
                        delta_t = unet(sample=z_t, timestep=t_int, encoder_hidden_states=text_zeros).sample
                    else:
                        delta_t = unet(sample=z_t, timestep=t_int).sample
                    z_t = z_t + (1.0/steps) * delta_t

                # Volver al vae_device para decodificar
                z_col_cpu = (z_gray + color_scale * (z_t - z_gray)).to(vae_device)
                out_rgb = vae.decode(z_col_cpu / 0.18215).sample.squeeze(0)
                out_img = (out_rgb.detach().cpu().permute(1,2,0).numpy()*0.5+0.5).clip(0,1)

                axes[i,j+1].imshow(out_img)
                axes[i,j+1].set_title(f"Colorized ({steps} steps)")
                axes[i,j+1].axis("off")

        axes[i,len(step_list)+1].imshow(rgb_img)
        axes[i,len(step_list)+1].set_title("Original RGB")
        axes[i,len(step_list)+1].axis("off")

    plt.tight_layout()
    plt.show()



2025-12-12 14:13:15.390232: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765548795.411122     490 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765548795.417466     490 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

In [5]:
def train_unet(
    vae,
    unet,
    loader,
    val_dataset,
    opt,
    device="cuda",
    num_epochs=1000,
    epochs_visualization=30,
    needs_text=False
):
    # Stable Diffusion CLIP text dims
    TEXT_LEN = 77
    TEXT_DIM = 768

    unet.train()
    vae.eval()

    scaler = torch.cuda.amp.GradScaler(enabled=(device.startswith("cuda") and torch.cuda.is_available()))

    for epoch in range(num_epochs):
        for batch in loader:
            z_rgb = batch["rgb"].to(device)
            z_gray = batch["gray"].to(device)

            # Mezcla temporal
            t = torch.rand(z_rgb.size(0), device=device)
            z_t = (1 - t.view(-1,1,1,1)) * z_gray + t.view(-1,1,1,1) * z_rgb
            target = z_rgb

            # Timesteps enteros
            t_int = torch.randint(0, 1000, (z_rgb.size(0),), device=device)

            with torch.cuda.amp.autocast(enabled=(device.startswith("cuda") and torch.cuda.is_available())):
                # Llamada al UNet
                if needs_text:
                    text_zeros = torch.zeros((z_rgb.size(0), TEXT_LEN, TEXT_DIM), device=device)
                    delta_hat = unet(sample=z_t, timestep=t_int, encoder_hidden_states=text_zeros).sample
                else:
                    delta_hat = unet(sample=z_t, timestep=t_int).sample

                # Reconstrucción y loss
                recon = z_t + delta_hat
                loss = F.mse_loss(recon, target)

            opt.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

        # Logs parciales
        if (epoch + 1) % max(1, (epochs_visualization // 2)) == 0:
            print(f"Epoch {epoch+1}/{num_epochs} - Loss: {loss.item():.4f}")

        # Validación cada cierto número de épocas
        if (epoch + 1) % epochs_visualization == 0:
            for cs in [1.2, 1.4, 1.6, 1.8]:
                print(f"color_scale={cs}")
                visualize_validation(
                    vae=vae,
                    unet=unet,
                    val_dataset=val_dataset,
                    unet_device=device, vae_device="cpu",
                    step_list=[2, 15, 50, 100],
                    num_samples=1,
                    color_scale=cs,
                    needs_text=needs_text
                )


## Train

### Load data

In [None]:
# -----------------------------
# Dataset: devuelve pares RGB / Gray
# -----------------------------
class ColorizationDataset(Dataset):
    def __init__(self, img_dir, size=128):
        self.img_dir = img_dir
        self.files = [f for f in os.listdir(img_dir) if f.endswith((".jpg",".png"))]
        self.size = size
        self.to_tensor = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # [-1,1]
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = os.path.join(self.img_dir, self.files[idx])
        img_rgb = Image.open(path).convert("RGB")
        img_gray = img_rgb.convert("L").convert("RGB")  # 3 canales para VAE

        rgb_tensor = self.to_tensor(img_rgb)   # [3,H,W]
        gray_tensor = self.to_tensor(img_gray) # [3,H,W]

        return {"rgb": rgb_tensor, "gray": gray_tensor}

class LatentDataset(Dataset):
    def __init__(self, latent_file):
        data = torch.load(latent_file)
        self.rgb = data["rgb"]
        self.gray = data["gray"]

    def __len__(self):
        return self.rgb.size(0)

    def __getitem__(self, idx):
        return {"rgb": self.rgb[idx], "gray": self.gray[idx]}


In [None]:
# Dataset original con imágenes
# Limitar a un máximo de imágenes
def generar_latente(max_images = 100, path = None):
    dataset = ColorizationDataset("/kaggle/input/stl10/unlabeled_images", size=128)
    if len(dataset) > max_images:
        dataset.files = dataset.files[:max_images]
    
    loader = DataLoader(dataset, batch_size=16, shuffle=False)
    
    
    vae = AutoencoderKL.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        subfolder="vae"
    ).to("cuda")
    vae.requires_grad_(False)
    
    latents_rgb, latents_gray = [], []
    
    with torch.no_grad():
        for batch in loader:
            rgb = batch["rgb"].to("cuda")
            gray = batch["gray"].to("cuda")
    
            z_rgb = vae.encode(rgb).latent_dist.sample() * 0.18215
            z_gray = vae.encode(gray).latent_dist.sample() * 0.18215
    
            latents_rgb.append(z_rgb.cpu())
            latents_gray.append(z_gray.cpu())
    
    latents_rgb = torch.cat(latents_rgb)
    latents_gray = torch.cat(latents_gray)
    if not path:
        path = f"stl10_latents_{max_images}.pt"
    torch.save({"rgb": latents_rgb, "gray": latents_gray}, path)


## Cargar hiperparámetros

In [None]:
# -----------------------------
# Hiperparámetros
# -----------------------------
NUM_EPOCHS = 200
EPOCHS_VISUALIZATION = 20
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
LR = 1e-5
SIZE = 128

# -----------------------------
# Modelos: VAE congelado y UNet entrenable
# -----------------------------
device = DEVICE

vae = AutoencoderKL.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        subfolder="vae"
    ).to("cpu")

## Cargamos modelo y datos

In [None]:
stableDifussion = UNet2DConditionModel.from_pretrained(
    "runwayml/stable-diffusion-v1-5",  # o el checkpoint que quieras
    subfolder="unet"
).to(DEVICE)
for name, param in stableDifussion.named_parameters():
    if "up_blocks.2" in name or "mid_block" in name:
        param.requires_grad = True   # entrenar últimas capas
    else:
        param.requires_grad = False  # congelar resto


# -----------------------------
# Dataloader
# -----------------------------
dataset = LatentDataset("/kaggle/working/stl10_latents_10000.pt")

# Dataset de validación (aún con imágenes, para visualizar resultados)
val_dataset = ColorizationDataset("/kaggle/input/stl10/train_images", size=SIZE)

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

# -----------------------------
# Optimizador
# -----------------------------
opt = AdamW(stableDifussion.parameters(), lr=LR, weight_decay=0.01)

# 2. Cargar el checkpoint
checkpoint_path = "/kaggle/working/stableDifussion_base_1212_1047.pt"  # tu archivo guardado
stableDifussion, opt = load_model(stableDifussion, opt, checkpoint_path)

In [None]:
EPOCHS_VISUALIZATION = 4
train_unet(vae, stableDifussion, loader, val_dataset, opt, device=DEVICE, num_epochs=NUM_EPOCHS, needs_text=True, epochs_visualization=EPOCHS_VISUALIZATION)

## Guardado

In [None]:
visualize_validation(vae, stableDifussion, val_dataset, step_list=[2,100, 200, 500], num_samples=1, color_scale=1.5, needs_text=True)
save_model(stableDifussion, opt, name="stableDifussion")

## Inferencia

In [2]:
class Colorizer:
    def __init__(self, vae: AutoencoderKL, unet: UNet2DConditionModel, device='cuda'):
        self.vae = vae.eval().to(device)
        self.unet = unet.eval().to(device)
        self.device = device

    def _preprocess_gray(self, img_gray_pil, size=512):
        img = img_gray_pil.resize((size, size), Image.BICUBIC)
        arr = np.array(img.convert("L").convert("RGB")).astype(np.float32) / 255.0
        arr = (arr - 0.5) / 0.5
        return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(self.device)  # [1,3,H,W]

    def _replace_luma(self, out_rgb_norm, in_gray_norm):
        out = out_rgb_norm.permute(1,2,0).cpu().numpy()*0.5+0.5
        gray = in_gray_norm[0].cpu().numpy()*0.5+0.5
        out_bgr = cv2.cvtColor((out*255).astype(np.uint8), cv2.COLOR_RGB2BGR)
        lab = cv2.cvtColor(out_bgr, cv2.COLOR_BGR2Lab)
        lab[:,:,0] = (gray*255).astype(np.uint8)
        bgr = cv2.cvtColor(lab, cv2.COLOR_Lab2BGR)
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        return Image.fromarray(rgb)

    @torch.no_grad()
    def colorize(self, gray_img_pil: Image.Image, steps=50, size=512, color_scale=1.0):
        # 1) Preprocesar y codificar
        gray = self._preprocess_gray(gray_img_pil, size)  # [1,3,H,W]
        z_gray = self.vae.encode(gray).latent_dist.sample() * 0.18215
        z_t = z_gray.clone()

        # 2) Iteración de refinamiento
        ts = torch.linspace(1.0, 0.0, steps, device=self.device)
        for t in ts:
            t_int = torch.tensor([int(t.item()*999)], device=self.device)
            text_emb = torch.zeros((1,77,768), device=self.device)  # dummy text conditioning
            delta_t = self.unet(sample=z_t, timestep=t_int, encoder_hidden_states=text_emb).sample
            z_t = z_t + (1.0/steps) * delta_t

        # 3) Escalado y decodificación
        z_col = z_gray + color_scale * (z_t - z_gray)
        out_rgb = self.vae.decode(z_col / 0.18215).sample.squeeze(0)

        # 4) Reemplazo de luminancia
        final_img = self._replace_luma(out_rgb, gray.squeeze(0))
        return final_img


In [3]:
import os
os.listdir("/kaggle/working/checkpoints")

['unet_base_1210_1625.pt', 'stableDifussion_base_1212_1328.pt']

In [7]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- Cargar VAE y UNet entrenado ---
vae = AutoencoderKL.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="vae"
).to("cpu")
vae.requires_grad_(False)

stableDifussion = UNet2DConditionModel.from_pretrained(
    "runwayml/stable-diffusion-v1-5",  # o el checkpoint que quieras
    subfolder="unet"
).to(DEVICE)
opt = AdamW(stableDifussion.parameters(), lr=1e-5, weight_decay=0.01)

checkpoint_path = "/kaggle/working/checkpoints/stableDifussion_base_1212_1328.pt"  # tu archivo guardado
stableDifussion, opt = load_model(stableDifussion, opt, checkpoint_path)

colorizer = Colorizer(vae, stableDifussion, device=DEVICE)

# --- Transformaciones ---
transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

# --- Métricas ---
lpips_fn = lpips.LPIPS(net='alex').to(DEVICE)

def evaluate_folder(folder_path, num_samples=100):
    mse_scores, ssim_scores, lpips_scores, ciede_scores = [], [], [], []
    files = os.listdir(folder_path)[:num_samples]

    for fname in files:
        img = Image.open(os.path.join(folder_path, fname)).convert("RGB")
        gray = img.convert("L")

        # Colorizar (usa UNet en GPU y VAE en CPU)
        colorized = colorizer.colorize(gray, steps=10, size=256, color_scale=1.5)

        # Tensores para LPIPS (en GPU)
        img_t = transform(img).unsqueeze(0).to(DEVICE)
        col_t = transform(colorized).unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            lp = lpips_fn(img_t, col_t).item()
        lpips_scores.append(lp)

        # Convertir a numpy para métricas en CPU
        img_np = np.array(img.resize((256,256))).astype(np.float32)
        col_np = np.array(colorized.resize((256,256))).astype(np.float32)

        # --- MSE ---
        mse = np.mean((img_np - col_np) ** 2)
        mse_scores.append(mse)

        # --- SSIM ---
        ssim_scores.append(structural_similarity(img_np, col_np, channel_axis=2, data_range=255))

        # --- CIEDE2000 ---
        img_lab = rgb2lab(img_np / 255.0)
        col_lab = rgb2lab(col_np / 255.0)
        delta_e = deltaE_ciede2000(img_lab, col_lab)
        ciede_scores.append(np.mean(delta_e))

        # Liberar memoria GPU en cada iteración
        del img_t, col_t
        torch.cuda.empty_cache()

    print(f"LPIPS: {np.mean(lpips_scores):.4f}")
    print(f"MSE: {np.mean(mse_scores):.2f}")
    print(f"SSIM: {np.mean(ssim_scores):.4f}")
    print(f"CIEDE2000: {np.mean(ciede_scores):.2f}")



# --- Ejecutar evaluación ---
evaluate_folder("/kaggle/input/stl10/test_images", num_samples=4000)


Modelo y optimizador cargados desde /kaggle/working/checkpoints/stableDifussion_base_1212_1328.pt
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.11/dist-packages/lpips/weights/v0.1/alex.pth
LPIPS: 0.1416
MSE: 554.04
SSIM: 0.8636
CIEDE2000: 13.76


In [8]:
import shutil
#shutil.copy("/kaggle/working/checkpoints/stableDifussion_base_1212_1328.pt","stableDifussion_base_1212_1328.pt")

'stableDifussion_base_1212_1328.pt'