In [None]:
# Imports
from dataclasses import dataclass
import torch
from diffusers import UNet2DModel, DDPMScheduler
from torchvision import transforms
import PIL.Image
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from typing import List

# Custom Scheduler
@dataclass
class CustomScheduler:
    def __init__(self, timesteps: torch.Tensor, betas: torch.Tensor):
        assert len(timesteps) == len(betas)
        self.timesteps = timesteps
        self.betas = betas
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = torch.roll(self.alphas_cumprod, 1)
        self.alphas_cumprod_prev[0] = 1.0
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)

    @classmethod
    def from_DDPMScheduler(cls, ddpm_scheduler: DDPMScheduler):
        return cls(ddpm_scheduler.timesteps, ddpm_scheduler.betas)

# Model wrapper
class Model:
    def __init__(self, model: UNet2DModel):
        self.model = model

    def to(self, device: torch.device):
        self.model.to(device)
        return self

    def __call__(self, x, t):
        return self.model(x, t)["sample"]

# Transform for outputs
sample_to_pil = transforms.Compose([
    transforms.Lambda(lambda t: t.squeeze(0)),
    transforms.Lambda(lambda t: t.permute(1, 2, 0)),
    transforms.Lambda(lambda t: (t + 1) * 127.5),
    transforms.Lambda(lambda t: torch.clamp(t, 0, 255)),
    transforms.Lambda(lambda t: t.cpu().detach().numpy().astype(np.uint8)),
    transforms.ToPILImage(),
])

# Core diffusion functions
@torch.no_grad()
def single_reverse_step(model: Model, x: torch.Tensor, t: int, S: CustomScheduler) -> torch.Tensor:
    mean = S.sqrt_recip_alphas[t] * (x - S.betas[t] * model(x, t) / S.sqrt_one_minus_alphas_cumprod[t])
    if t == 0:
        return mean
    else:
        noise = torch.randn_like(x) * torch.sqrt(S.posterior_variance[t])
        return mean + noise

@torch.no_grad()
def zero_to_t(x_0: torch.Tensor, t: int, S: CustomScheduler) -> torch.Tensor:
    if t == 0:
        return x_0
    else:
        return torch.sqrt(S.alphas_cumprod[t]) * x_0 + torch.sqrt(1.0 - S.alphas_cumprod[t]) * torch.randn_like(x_0)

@torch.no_grad()
def forward_j_steps(x_t: torch.Tensor, t: int, j: int, S: CustomScheduler) -> torch.Tensor:
    partial_alpha_cumprod = S.alphas_cumprod[t+j] / S.alphas_cumprod[t]
    return torch.sqrt(partial_alpha_cumprod) * x_t + torch.sqrt(1.0 - partial_alpha_cumprod) * torch.randn_like(x_t)

def get_jumps(timesteps, jumps_every: int = 100, r: int = 5) -> List[int]:
    jumps = []
    for i in range(0, torch.max(timesteps), jumps_every):
        jumps.extend([i] * r)
    jumps.reverse()
    return jumps

# RePaint function
@torch.no_grad()
def repaint(original_data: torch.Tensor, keep_mask: torch.Tensor,
            model: Model, scheduler: CustomScheduler, j:int=10, r:int=5) -> torch.Tensor:

    jumps = get_jumps(scheduler.timesteps, r=r)
    device = original_data.device
    sample = torch.randn_like(original_data).to(device)
    print("beginning repaint")

    for t in tqdm(scheduler.timesteps):
        while len(jumps) > 0 and jumps[0] == t:
            jumps = jumps[1:]
            sample = forward_j_steps(sample, t, j, scheduler)
            for override_t in range(t + j, t, -1):
                sample = single_reverse_step(model, sample, override_t, scheduler)

        x_known = zero_to_t(original_data, t, scheduler)
        x_unknown = single_reverse_step(model, sample, t, scheduler)
        sample = keep_mask * x_known + (1-keep_mask) * x_unknown

    return sample

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "google/ddpm-celebahq-256"

model = UNet2DModel.from_pretrained(model_id)
scheduler = DDPMScheduler.from_pretrained(model_id)
model = Model(model).to(device)
scheduler = CustomScheduler.from_DDPMScheduler(scheduler)

# Transforms for image and mask
data_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t * 2) - 1),
    transforms.Lambda(lambda t: t.unsqueeze(0))
])

mask_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: t.unsqueeze(0))
])

# Image & mask paths
image_paths = ["/content/celeba_00.jpg", "/content/celeba_01.jpg"]
mask_paths = ["/content/mask_3.png", "/content/mask_4.png"]
results = []

# Loop over images & masks
for img_path in image_paths:
    image = PIL.Image.open(img_path)
    image_tensor = data_transform(image).to(device)

    for mask_path in mask_paths:
        mask = PIL.Image.open(mask_path)
        mask_tensor = mask_transform(mask).to(device)

        out = repaint(image_tensor, mask_tensor, model, scheduler)
        results.append(sample_to_pil(out))

# Display results
fig, axes = plt.subplots(len(image_paths), len(mask_paths), figsize=(12, 6))
for i in range(len(image_paths)):
    for j in range(len(mask_paths)):
        axes[i, j].imshow(results[i * len(mask_paths) + j])
        axes[i, j].axis('off')

plt.tight_layout()
plt.show()
