In [41]:
# Core Python + PyTorch dependencies for data loading/visualization.
import os
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import random
import matplotlib.pyplot as plt

# Imaging utilities leveraged elsewhere in the notebook.
from diffusers import AutoencoderKL
from PIL import Image
import torch.nn.functional as F

# image pipiline

## image load

In [42]:
# Set dataset path
dataset_path = "/scratch/juriostegui/sun397-subset/test/"

# Image transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Normalize images to SD training resolution.
    transforms.ToTensor()  # Convert PIL images to torch tensors in [0, 1].
])

# Load dataset from folder
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
# Build an iterator that returns random single-image batches for masking.
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)


## image preprocessing

In [43]:
# Function to create random square mask
def random_square_mask(img_tensor):
    """Randomly removes a square region and returns the masked image plus binary mask."""
    _, H, W = img_tensor.shape  # Track spatial size to sample a crop.
    h = H // 2  # Use half the height to form a reasonably large hole.
    w = W // 2  # Mirror logic for width to keep the square aspect ratio.
    top = random.randint(0, H - h)  # Sample top-left corner uniformly.
    left = random.randint(0, W - w)
    mask = torch.zeros(1, H, W, device=img_tensor.device)  # Binary mask initialized to background.
    mask[:, top:top+h, left:left+w] = 1.0  # Mark the missing region with ones.

    masked_img = img_tensor.clone()  # Work on a copy to preserve the original tensor.
    masked_img[:, top:top+h, left:left+w] = 0.0  # Zero out the pixels we want the model to inpaint.
    return masked_img, mask

In [44]:
# Pick 5 random images and apply making
images = []  # Unmasked reference images for visualization.
masks = []  # Binary masks storing which pixels are missing.
masked_images = []  # Observations with central chunks removed.

for i, (img, _) in enumerate(data_loader):
    if i >= 5:
        break  # Limit to a small batch for interactive experimentation.
    img = img[0].cuda()  # Move the selected image to GPU for masking.
    masked_img, mask = random_square_mask(img)  # Apply the stochastic mask.
    images.append(img.cpu())  # Store CPU copies for plotting later.
    masked_images.append(masked_img.cpu())
    masks.append(mask.cpu())

## get latent mask

In [45]:
def _prepare_mask_latents(mask_batch: torch.Tensor, target_hw):
    """Resize binary masks to latent space so we can reinsert known pixels per Song et al."""
    if mask_batch.ndim == 3:
        mask_batch = mask_batch.unsqueeze(1)  # Ensure a channel dimension for interpolation.
    mask = mask_batch.to(device=device, dtype=sd_dtype)
    mask = F.interpolate(mask, size=target_hw, mode="nearest")  # Downsample mask without smoothing edges.
    known_mask = 1.0 - mask  # Known pixels are the complement of the missing region.
    return known_mask

## Stable Diffusion predictor-corrector inpainter

In [46]:
from diffusers import StableDiffusionPipeline, DDIMScheduler
from typing import Optional
import math

sd_model_id = "runwayml/stable-diffusion-v1-5"  # Use the base SD checkpoint for general content.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Respect accelerator availability.
sd_dtype = torch.float32  # Default precision keeps things simple across hardware.

# Instantiate the Stable Diffusion pipeline without relying on the specialized inpaint model.
sd_pipe = StableDiffusionPipeline.from_pretrained(
    sd_model_id,
    torch_dtype=sd_dtype
)
sd_pipe.safety_checker = None  # Disable safety checker to avoid extra overhead in research.
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)  # Use DDIM for Song-style updates.
sd_pipe = sd_pipe.to(device)  # Move UNet/VAE/text encoder to the working device.

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

## utils

In [47]:

@torch.no_grad()
def _encode_to_latents(images: torch.Tensor) -> torch.Tensor:
    """Project masked RGB images into the VAE latent space used during SD training."""
    images = images.to(device=device, dtype=sd_dtype)
    images = images * 2.0 - 1.0  # Map from [0, 1] to the VAE's expected [-1, 1] range.
    latents = sd_pipe.vae.encode(images).latent_dist.mean  # Use the posterior mean for determinism.
    latents = latents * sd_pipe.vae.config.get("scaling_factor", 0.18215)  # Match SD latent scaling.
    return latents



def _blend_with_observation(latents, observed_latents, known_mask, timestep, generator):
    """Inject noisy observations for known pixels, copying Song's conditioning strategy."""
    if known_mask is None:
        return latents
    noise = torch.randn_like(observed_latents, generator=generator, device=device, dtype=observed_latents.dtype)
    noised_obs = sd_pipe.scheduler.add_noise(observed_latents, noise, timestep)  # Forward-diffuse observations.
    return latents * (1.0 - known_mask) + noised_obs * known_mask  # Replace only the constrained regions.




## predictor/corrector

In [48]:
def _predict_noise(latents, timestep, text_embeddings, guidance_scale, do_cfg):
    """Run the SD UNet and optionally apply classifier-free guidance for text alignment."""
    latent_input = torch.cat([latents] * 2) if do_cfg else latents  # Duplicate for conditional/unconditional passes.
    latent_input = sd_pipe.scheduler.scale_model_input(latent_input, timestep)  # Match scheduler expected scaling.
    noise_pred = sd_pipe.unet(latent_input, timestep, encoder_hidden_states=text_embeddings).sample
    if do_cfg:
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)  # CFG blend.
    return noise_pred


def _predictor_step(latents, timestep, text_embeddings, guidance_scale, do_cfg):
    """Single DDIM predictor step that integrates the reverse SDE/ODE."""
    noise_pred = _predict_noise(latents, timestep, text_embeddings, guidance_scale, do_cfg)
    latents = sd_pipe.scheduler.step(noise_pred, timestep, latents).prev_sample  # Deterministic DDIM update.
    return latents


def _corrector_step(latents, timestep, text_embeddings, guidance_scale, snr, generator, do_cfg):
    """Langevin corrector step that injects noise according to the desired SNR."""
    score = _predict_noise(latents, timestep, text_embeddings, guidance_scale, do_cfg)
    noise = torch.randn_like(latents, generator=generator, device=device, dtype=latents.dtype)
    score_norm = torch.norm(score.reshape(score.shape[0], -1), dim=-1).mean().item()
    noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean().item()
    step_size = (snr ** 2) * ((noise_norm / (score_norm + 1e-12)) ** 2)  # Follow Song's adaptive step sizing.
    latents = latents + step_size * score + math.sqrt(2.0 * step_size) * noise  # Euler-Maruyama update.
    return latents

## inpainting pipeline

In [49]:
def get_sd_pc_inpainter(pipe: StableDiffusionPipeline,
                        snr: float = 0.15,
                        n_corrector_steps: int = 1,
                        denoise: bool = True):
    """Construct a Song-style predictor-corrector inpainting callable for Stable Diffusion."""

    def pc_inpainter(masked_images: torch.Tensor,
                    masks: torch.Tensor,
                    prompt: str = "cat",
                    negative_prompt: Optional[str] = "cat",
                    num_inference_steps: int = 30,
                    guidance_scale: float = 0,
                    generator: Optional[torch.Generator] = None):
        """Run the predictor-corrector loop conditioned on text + observed pixels."""
        pipe.scheduler.set_timesteps(num_inference_steps, device=device)  # Reset DDIM schedule each call.
        masked_images = masked_images.to(device=device, dtype=sd_dtype)
        masks = masks.to(device=device, dtype=sd_dtype)
        batch_size, _, height, width = masked_images.shape
        latent_h, latent_w = height // 8, width // 8  # Latent resolution is 1/8th of pixel space.
        if generator is None:
            generator = torch.Generator(device=device)  # Allow deterministic seeding upstream if desired.

        latents = torch.randn((batch_size, pipe.unet.config.in_channels, latent_h, latent_w),
                              generator=generator,
                              device=device,
                              dtype=sd_dtype)
        latents = latents * pipe.scheduler.init_noise_sigma  # Match scheduler noise std.
        observed_latents = _encode_to_latents(masked_images)  # Encode partial observations once upfront.
        known_mask = _prepare_mask_latents(masks, (latent_h, latent_w))
        negative_prompt_text = negative_prompt or " cat"
        prompt_text = prompt or " cat"  # Fall back to empty strings when users skip text guidance.
        do_cfg = guidance_scale > 1.0 and bool(prompt_text.strip())  # Only enable CFG when we have a prompt.
        if not do_cfg:
            text_embeddings = pipe._encode_prompt(
                prompt_text,
                device=device,
                num_images_per_prompt=1,
                do_classifier_free_guidance=False,
                negative_prompt=negative_prompt_text
            )
        else:
            text_embeddings = pipe._encode_prompt(
                prompt_text,
                device=device,
                num_images_per_prompt=1,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt_text
            )

        for timestep in pipe.scheduler.timesteps:
            for _ in range(n_corrector_steps):
                latents = _corrector_step(latents, timestep, text_embeddings, guidance_scale, snr, generator, do_cfg)
                latents = _blend_with_observation(latents, observed_latents, known_mask, timestep, generator)  # Enforce constraints.
            latents = _predictor_step(latents, timestep, text_embeddings, guidance_scale, do_cfg)
            latents = _blend_with_observation(latents, observed_latents, known_mask, timestep, generator)

        if denoise:
            latents = _predictor_step(latents, pipe.scheduler.timesteps[-1], text_embeddings, guidance_scale, do_cfg)  # Optional final clean-up.

        images = pipe.vae.decode(latents / pipe.vae.config.get("scaling_factor", 0.18215)).sample
        images = (images / 2 + 0.5).clamp(0, 1)  # Return images in displayable range.
        return images

    return pc_inpainter


pc_inpainter = get_sd_pc_inpainter(sd_pipe, snr=0.2, n_corrector_steps=0, denoise=True)  # Build the callable with default research settings.

# run

In [50]:
# Example usage (do not execute on shared cluster login nodes).
generator = torch.Generator(device=device).manual_seed(0)
masked_batch = torch.stack(masked_images).to(device)
mask_batch = torch.stack(masks).to(device)
recon = pc_inpainter(
    masked_images=masked_batch,
    masks=mask_batch,
    num_inference_steps=40,
    generator=generator
)


TypeError: expected Tensor as element 0 in argument 0, but got NoneType

## plot

In [None]:
fig, axes = plt.subplots(3, len(masked_images), figsize=(15, 6))
for idx in range(len(masked_images)):
    axes[0, idx].imshow(images[idx].permute(1, 2, 0))
    axes[0, idx].set_title("Original")
    axes[0, idx].axis("off")
    axes[1, idx].imshow(masked_images[idx].permute(1, 2, 0))
    axes[1, idx].set_title("Masked")
    axes[1, idx].axis("off")
    axes[2, idx].imshow(recon[idx].detach().cpu().permute(1, 2, 0))
    axes[2, idx].set_title("Inpainted")
    axes[2, idx].axis("off")
plt.tight_layout()