# Dataset

source: image with masked-out object

mask: binary mask where object was removed

ref: object crop, resized

target: original full image

Download COCO 2017

In [None]:
import os
from torchvision.datasets import CocoDetection
from torchvision import transforms
from pycocotools.coco import COCO
import requests, zipfile, io

def download_coco_val(root="coco2017"):
    os.makedirs(root, exist_ok=True)

    url_images = "http://images.cocodataset.org/zips/val2017.zip"
    url_annotations = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"

    # Download val2017 images
    if not os.path.exists(os.path.join(root, "val2017")):
        print("Downloading val2017...")
        r = requests.get(url_images)
        z = zipfile.ZipFile(io.BytesIO(r.content))
        z.extractall(root)

    # Download annotations (contains val annotations too)
    if not os.path.exists(os.path.join(root, "annotations")):
        print("Downloading annotations...")
        r = requests.get(url_annotations)
        z = zipfile.ZipFile(io.BytesIO(r.content))
        z.extractall(root)

    print("COCO val2017 ready at", root)



In [None]:
download_coco_val()

Create Inpainting Triplets

In [None]:
def get_valid_mask_and_ref(coco, anns, img_size):
    H, W = img_size
    min_ratio, max_ratio = 0.05, 0.5  # thresholds

    for ann in anns:
        mask = coco.annToMask(ann)
        mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_NEAREST)
        ratio = mask.sum() / (H * W)
        if min_ratio <= ratio <= max_ratio:  # thresholds
            return mask, ann
    return None, None

In [None]:
from PIL import Image, ImageDraw
import random
import numpy as np
import torch
import cv2

class CocoInpaintingDataset(torch.utils.data.Dataset):
    def __init__(self, coco_root="coco2017", split="val2017", image_size=256, max_objects=1):
        super().__init__()
        ann_file = os.path.join(coco_root, "annotations", f"instances_{split}.json")
        self.coco = COCO(ann_file)
        self.img_ids = list(self.coco.imgs.keys())
        self.root = os.path.join(coco_root, split)
        self.max_objects = max_objects
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            # transforms.Normalize([0.5], [0.5]),
        ])

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        path = os.path.join(self.root, img_info["file_name"])
        image = Image.open(path).convert("RGB")

        # Load annotations
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        if len(anns) == 0:
            return self[(idx + 1) % len(self)]  # skip empty

        # Pick an object
        # ann = random.choice(anns)
        # mask = self.coco.annToMask(ann)
        mask, ann = get_valid_mask_and_ref(self.coco, anns, image.size[::-1])
        if mask is None:
          # skip image only if ALL annotations invalid
          return self[(idx + 1) % len(self)]

        mask_img = Image.fromarray((mask * 255).astype(np.uint8))

        # Create source (masked image)
        source = image.copy()
        source.paste((0,0,0), mask=mask_img)

        # Create reference (cropped object)
        bbox = ann["bbox"]  # [x,y,w,h]
        x,y,w,h = map(int, bbox)
        ref = image.crop((x,y,x+w,y+h))

        # Target = original image
        target = image

        # Apply transforms
        image = self.transform(image)
        source = self.transform(source)
        ref = self.transform(ref.resize((target.size[0], target.size[1])))  # match size
        mask_tensor = self.transform(mask_img)

        return {
            "source": source,     # masked image
            "mask": mask_tensor,  # binary mask
            "ref": ref,           # reference image
            "target": image       # ground truth
        }


DataLoader

In [None]:
def get_coco_inpainting_dataloader(root="coco2017", image_size=256, batch_size=8):
    dataset = CocoInpaintingDataset(coco_root=root, split="val2017", image_size=image_size)
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)


In [None]:
dataloader = get_coco_inpainting_dataloader(image_size=128, batch_size=1)

In [None]:
test_dataloader = get_coco_inpainting_dataloader(image_size=512, batch_size=1)

In [None]:
batch = next(iter(dataloader))
print(batch["source"].shape)  # (B, 3, 256, 256)
print(batch["mask"].shape)    # (B, 1, 256, 256)
print(batch["ref"].shape)     # (B, 3, 256, 256)
print(batch["target"].shape)  # (B, 3, 256, 256)

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 4, figsize=(16, 4))

for i, (key, tensor) in enumerate(batch.items()):
    img = tensor[0].permute(1, 2, 0).cpu().numpy()  # take first sample, convert to HWC
    axs[i].imshow(img)
    axs[i].set_title(key)
    axs[i].axis("off")

plt.show()


# Code

**Guided inpainting** on top of a pretrained Stable Diffusion Inpainting model, using an extra reference image as guidance (no class/text).

- Start from runwayml/stable-diffusion-inpainting (SD‑1.5 inpaint; UNet has in_channels=9: 4 noisy latents + 4 masked latents + 1 mask).

- Add +4 channels for the reference image latents → new in_channels=13.

- Initialize the new input conv by copying the 9 pretrained channels and zero‑init the extra 4.

- Keep SD’s text pathway fixed to a single null prompt embedding internally so you don’t have to deal with text at all.

In [None]:
!pip install xformers

In [None]:
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW

from diffusers import StableDiffusionInpaintPipeline
from diffusers import AutoencoderKL

In [None]:
def _expand_unet_in_channels_for_ref(unet, extra_latent_ch=4):
    """
    Expand inpainting UNet input from 9 → 9 + extra_latent_ch.
    Copies the pretrained weights for the original 9 channels,
    zero-inits the new channels.
    """
    old = unet.conv_in
    assert isinstance(old, nn.Conv2d)
    old_in, out_c, k, p, s, d = old.in_channels, old.out_channels, old.kernel_size, old.padding, old.stride, old.dilation
    new_in = old_in + extra_latent_ch
    new = nn.Conv2d(new_in, out_c, kernel_size=k, stride=s, padding=p, dilation=d, bias=old.bias is not None)

    with torch.no_grad():
        new.weight.zero_()
        new.weight[:, :old_in] = old.weight.clone()
        if old.bias is not None:
            new.bias.copy_(old.bias)

    unet.conv_in = new
    # Update config so schedulers/savers know the new channel count
    if hasattr(unet, "config"):
        unet.config.in_channels = new_in
    return unet

In [None]:
def _prep_latent_mask(mask, latent_h, latent_w, device, dtype):
    """
    mask: (B,1,H,W) in {0,1} where 1==region to be repainted.
    Downsample to latent size, keep single channel.
    """
    mask = F.interpolate(mask, size=(latent_h, latent_w), mode="nearest")
    return mask.to(device=device, dtype=dtype)

In [None]:
LATENT_SCALE = 0.18215  # SD convention

In [None]:
def _to_latents(vae, images):
    """
    images in [-1,1], return latents scaled by LATENT_SCALE.
    """
    posterior = vae.encode(images).latent_dist
    z = posterior.sample() * LATENT_SCALE
    return z

In [None]:
class Unet(nn.Module):
    """
    Guided inpainting model:
      - Loads SD inpainting pipeline
      - Expands UNet to accept reference latents (extra 4 channels)
      - Hides text: uses a fixed 'null' prompt embedding internally
      - forward() predicts noise ε given:
          * noisy image latents
          * masked-image latents
          * mask (latent size, 1ch)
          * reference-image latents
          * timestep t
    """
    def __init__(self, im_channels, model_config):
        super().__init__()

        model_id = "runwayml/stable-diffusion-inpainting"

        device = "cuda" if torch.cuda.is_available() else "cpu"

        torch_dtype = "float16"

        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            model_id, torch_dtype=torch_dtype, safety_checker=None, feature_extractor=None
        ).to(device)

        # Expand UNet to accept +4 ref-latent channels (9 → 13)
        _expand_unet_in_channels_for_ref(pipe.unet, extra_latent_ch=4)

        # Prepare a fixed null-text embedding
        with torch.no_grad():
            prompt = [""]  # single null prompt
            text_inputs = pipe.tokenizer(
                prompt,
                padding="max_length",
                max_length=pipe.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )
            text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
            null_emb = pipe.text_encoder(**text_inputs)[0]  # (1,77,768)
            self.register_buffer("null_prompt_embeds", null_emb, persistent=False)

        self.pipe = pipe
        self.unet = pipe.unet
        self.vae = pipe.vae
        self.scheduler = pipe.scheduler
        self.device = device
        self.dtype = torch_dtype

    @torch.no_grad()
    def encode_images_to_latents(self, images):
        """
        images: (B,3,H,W) in [-1,1]
        returns latents: (B,4,h,w)
        """
        return _to_latents(self.vae, images.to(self.device, dtype=self.vae.dtype))

    @torch.no_grad()
    def decode_latents_to_images(self, latents):
        """
        latents: (B,4,h,w), returns images in [-1,1]
        """
        latents = latents.to(self.device, dtype=self.vae.dtype) / LATENT_SCALE
        return self.vae.decode(latents).sample

    # Prep mask to latent size
    def prepare_mask(self, mask, latent_h, latent_w):
        return _prep_latent_mask(mask, latent_h, latent_w, self.device, self.unet.dtype)

    # Noise Prediction
    def forward(
        self,
        noisy_latents,            # (B,4,h,w)
        t,                        # (B,) or scalar timestep
        masked_image_latents,     # (B,4,h,w)
        mask_latent,              # (B,1,h,w) in {0,1}
        ref_latents,              # (B,4,h,w)
    ):
        """
        Returns ε prediction with no text conditioning (fixed null embed).
        UNet input is concatenation along channel dim:
            [ noisy_latents, masked_image_latents, mask_latent, ref_latents ]  -> (B, 13, h, w)
        """
        # Build model input like SD inpainting + ref latents
        latent_model_input = torch.cat([noisy_latents, masked_image_latents, mask_latent, ref_latents], dim=1)

        noise_pred = self.unet(
            latent_model_input,
            t,
            encoder_hidden_states=self.null_prompt_embeds.expand(latent_model_input.shape[0], -1, -1),
        ).sample
        return noise_pred


In [None]:
# Pseudocode for the trainer loop (not the full file)
model = Unet(im_channels=3, model_config={"pretrained_model_name_or_path": "runwayml/stable-diffusion-inpainting"})
scheduler = model.scheduler  # DDPMScheduler already configured from the pipeline

# Freeze VAE & text encoder (we don't train them)
for p in model.vae.parameters():
    p.requires_grad = False
if hasattr(model, "pipe") and hasattr(model.pipe, "text_encoder"):
    for p in model.pipe.text_encoder.parameters():
        p.requires_grad = False

# Optimizer: train the UNet (including the new conv_in channels we added)
optimizer = AdamW(
    model.unet.parameters(),
    lr=1e-4,              # tweak as needed
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-8,
)

model.unet.enable_gradient_checkpointing()
model.unet.enable_xformers_memory_efficient_attention()

model.unet.to(model.device, dtype=torch.float16)
model.vae.to(model.device, dtype=torch.float16)
model.unet.enable_gradient_checkpointing()

In [None]:
import torch
import torch.nn.functional as F

model.unet.train()
scaler = torch.cuda.amp.GradScaler()
global_step = 0
num_epochs = 1

for epoch in range(num_epochs):
    for batch in dataloader:
        images = batch["target"].to(model.device)
        ref_images = batch["ref"].to(model.device)
        masks = batch["mask"].to(model.device)
        source_images = batch["source"].to(model.device)

        # -----------------------
        # Encode to latents
        # -----------------------
        with torch.no_grad():
            latents = model.encode_images_to_latents(images)            # (B,4,h,w)
            ref_latents = model.encode_images_to_latents(ref_images)    # (B,4,h,w)
            masked_image_latents = model.encode_images_to_latents(source_images)

            _, _, h, w = latents.shape
            mask_latent = F.interpolate(masks, size=(h, w), mode="nearest").to(dtype=latents.dtype)

        # -----------------------
        # Noise + timestep
        # -----------------------
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, scheduler.num_train_timesteps,
                                  (latents.size(0),), device=latents.device).long()
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)

        noisy_latents = noisy_latents.to(model.device, dtype=model.unet.dtype)
        timesteps = timesteps.to(model.device)
        masked_image_latents = masked_image_latents.to(model.device, dtype=model.unet.dtype)
        mask_latent = mask_latent.to(model.device, dtype=model.unet.dtype)
        ref_latents = ref_latents.to(model.device, dtype=model.unet.dtype)

        # -----------------------
        # Forward
        # -----------------------
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            noise_pred = model(
                noisy_latents=noisy_latents,
                t=timesteps,
                masked_image_latents=masked_image_latents,
                mask_latent=mask_latent,
                ref_latents=ref_latents,
            )
            loss = F.mse_loss(noise_pred, noise)

        # -----------------------
        # Safety checks
        # -----------------------
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"[Step {global_step}] 🚨 Loss is NaN/Inf, skipping update")
            optimizer.zero_grad(set_to_none=True)
            continue

        # Backward + step
        scaler.scale(loss).backward()

        # Check gradients
        grad_nan = False
        for n, p in model.unet.named_parameters():
            if p.grad is not None:
                if torch.isnan(p.grad).any() or torch.isinf(p.grad).any():
                    print(f"[Step {global_step}] 🚨 NaN/Inf in gradient of {n}")
                    grad_nan = True
                    break
        if grad_nan:
            optimizer.zero_grad(set_to_none=True)
            continue

        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

        # Check weights after step
        weight_nan = False
        for n, p in model.unet.named_parameters():
            if torch.isnan(p).any() or torch.isinf(p).any():
                print(f"[Step {global_step}] 🚨 NaN/Inf in weights of {n}")
                weight_nan = True
                break
        if weight_nan:
            break  # stop training immediately

        # -----------------------
        # Logging
        # -----------------------
        if global_step % 50 == 0:
            print(f"[Epoch {epoch} | Step {global_step}] loss={loss.item():.6f}")

        global_step += 1


In [None]:
print("UNet device:", next(model.unet.parameters()).device)
print("UNet dtype:", next(model.unet.parameters()).dtype)

print("noisy_latents:", noisy_latents.device, noisy_latents.dtype)
print("masked_image_latents:", masked_image_latents.device, masked_image_latents.dtype)
print("mask_latent:", mask_latent.device, mask_latent.dtype)
print("ref_latents:", ref_latents.device, ref_latents.dtype)


# Save Checkpoints

In [None]:
import os

epoch = 0
os.makedirs("checkpoints", exist_ok=True)
save_path = f"checkpoints/unet_epoch{epoch+1}.pt"
torch.save(model.unet.state_dict(), save_path)
print(f"Saved checkpoint to {save_path}")


Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Save model weights

In [None]:
save_path = "/content/drive/MyDrive/inpainting_unet.pt"
torch.save(model.unet.state_dict(), save_path)
print(f"UNet saved to {save_path}")

To reload later

In [None]:
# from diffusers import StableDiffusionInpaintPipeline
# import torch

# model_id = "runwayml/stable-diffusion-inpainting"
# pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None).to("cuda")

# # Load trained weights
# pipe.unet.load_state_dict(torch.load("/content/drive/MyDrive/inpainting_unet.pt"))


model.unet.load_state_dict(torch.load("/content/drive/MyDrive/inpainting_unet.pt"))

Save the whole pipeline (recommended if you’ll use it later for inference)

In [None]:
save_dir = "/content/drive/MyDrive/inpainting_pipeline"
model.pipe.save_pretrained(save_dir)
print(f"Pipeline saved to {save_dir}")

To reload later

In [None]:
from diffusers import StableDiffusionInpaintPipeline

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "/content/drive/MyDrive/inpainting_pipeline",
    safety_checker=None
).to("cuda")


Save optimizer & scheduler (optional, if you want to resume training)

In [None]:
checkpoint = {
    "unet": model.unet.state_dict(),
    "optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, "/content/drive/MyDrive/inpainting_checkpoint.pt")


To reload later

In [None]:
ckpt = torch.load("/content/drive/MyDrive/inpainting_checkpoint.pt", map_location="cuda")
model.unet.load_state_dict(ckpt["unet"])
optimizer.load_state_dict(ckpt["optimizer"])

# Scheduler: just re-initialize from the pipeline
from diffusers import DDIMScheduler

model.scheduler = DDIMScheduler.from_pretrained("runwayml/stable-diffusion-inpainting", subfolder="scheduler")


# Test / Evaluation

In [None]:
def show_images(images, titles=None):
    n = len(images)
    fig, axs = plt.subplots(1, n, figsize=(4*n, 4))

    if n == 1:
        axs = [axs]   # wrap single Axes into a list

    for i, img in enumerate(images):
        if isinstance(img, torch.Tensor):
            img = img.detach().cpu()
            if img.ndim == 3 and img.shape[0] in [1,3]:  # C,H,W
                img = img.permute(1,2,0)
            img = img.numpy()

        # --- cast to float32 for imshow ---
        if img.dtype == "float16":
            img = img.astype("float32")

        axs[i].imshow(img)
        axs[i].axis("off")
        if titles:
            axs[i].set_title(titles[i])
    plt.show()


little test

In [None]:
z = torch.randn(1, 4, 64, 64).to("cuda")  # latent size for 512x512
out = model.decode_latents_to_images(z)
show_images([out[0]], ["Random Latent Decode"])


In [None]:
test_latents = model.encode_images_to_latents(batch["target"].to(model.device))
test_images = model.decode_latents_to_images(test_latents)
show_images([batch["target"][0], test_images[0]], ["Original", "Reconstruction"])

In [None]:
def evaluate_one_batch(batch, model, num_inference_steps=50):
    model.unet.eval()
    with torch.no_grad():
        images = batch["target"].to(model.device)
        ref_images = batch["ref"].to(model.device)
        masks = batch["mask"].to(model.device)
        source_images = batch["source"].to(model.device)

        latents = model.encode_images_to_latents(images)
        ref_latents = model.encode_images_to_latents(ref_images)
        masked_image_latents = model.encode_images_to_latents(source_images)

        _, _, h, w = latents.shape
        mask_latent = F.interpolate(masks, size=(h, w), mode="nearest").to(dtype=latents.dtype)

        # --- Start from scaled noise ---
        # model.scheduler.set_timesteps(num_inference_steps, device=model.device)
        # noisy_latents = torch.randn_like(latents) * model.scheduler.init_noise_sigma
        noisy_latents = masked_image_latents * (1 - mask_latent) + torch.randn_like(masked_image_latents) * mask_latent

        # Example: 50 inference steps
        num_inference_steps = 50
        model.scheduler.set_timesteps(num_inference_steps)

        # --- Denoising loop ---
        for t in model.scheduler.timesteps:
          timesteps = torch.full((latents.shape[0],), t, device=model.device, dtype=torch.long)
          with torch.autocast("cuda", dtype=torch.float16):
              noise_pred = model(
                  noisy_latents,
                  t=timesteps,
                  masked_image_latents=masked_image_latents,
                  mask_latent=mask_latent,
                  ref_latents=ref_latents
              )

          latents = model.scheduler.step(noise_pred, t, latents).prev_sample

        # --- Decode (with scaling back) ---
        recon = model.decode_latents_to_images(latents)
        recon = (recon * 0.5 + 0.5).clamp(0,1).to(torch.float32)

    # recon: [B, C, H_model, W_model]
    # source_images: [B, C, H_img, W_img]
    # mask_latent: [B, 1, h_latent, w_latent]

    # 1. Upsample mask to match source_images
    mask_resized = F.interpolate(mask_latent, size=source_images.shape[2:], mode="nearest")

    # 2. If mask has 1 channel but images have 3, repeat channels
    if mask_resized.shape[1] == 1 and source_images.shape[1] > 1:
        mask_resized = mask_resized.repeat(1, source_images.shape[1], 1, 1)

    # 3. Ensure recon matches source_images size (if decoder returns slightly different size)
    if recon.shape[2:] != source_images.shape[2:]:
        recon = F.interpolate(recon, size=source_images.shape[2:], mode="bilinear", align_corners=False)

    # 4. Blend
    final_output = recon * mask_resized + source_images * (1 - mask_resized)


    show_images([
        source_images[0], ref_images[0], images[0], recon[0], final_output[0]
    ], ["Masked Source", "Reference", "Target", "Reconstruction", "Blended Output"])


In [None]:
batch = next(iter(test_dataloader))
evaluate_one_batch(batch, model)

# Quantitative Metrics

In [None]:
!pip install torchmetrics

In [None]:
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

psnr = PeakSignalNoiseRatio().to(model.device)
ssim = StructuralSimilarityIndexMeasure().to(model.device)

with torch.no_grad():
    score_psnr = psnr(recon, images).item()
    score_ssim = ssim(recon, images).item()
print(f"PSNR: {score_psnr:.2f}, SSIM: {score_ssim:.3f}")