In [None]:
import os
import torch
import matplotlib.pyplot as plt
import repaint_lib
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from diffusers import DiffusionPipeline
from dataset import get_dataset
from tqdm import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
dataset = get_dataset(
    "val",
    mu=torch.tensor([0.0, 0.0, 0.0]),
    std=torch.tensor([1.0, 1.0, 1.0]),
    return_feature=True,
    return_idx=True,
)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)

model_id = "google/ncsnpp-celebahq-256"
sde_ve = DiffusionPipeline.from_pretrained(model_id)
scheduler, unet = sde_ve.scheduler, sde_ve.unet

sigma_min, sigma_max = scheduler.config.sigma_min, scheduler.config.sigma_max
repaint_steps = 2000
repaint_jump = 10
repaint_n_samples = 7
repaint_schedule = repaint_lib.get_schedule(
    1,
    repaint_steps,
    repaint_jump,
    repaint_n_samples,
    eps=scheduler.config.sampling_eps,
)
sigmas = sigma_min * (sigma_max / sigma_min) ** repaint_schedule
sigmas = sigmas.to(device)

_, ax = plt.subplots(figsize=(16 / 2, 9 / 2))
ax.plot(sigmas.cpu())
ax.set_ylabel(r"$\sigma$")
ax.set_xlabel("Step")
plt.show()

In [None]:
unet = unet.to(device)
unet.eval()
torch.set_grad_enabled(False)

data = next(iter(dataloader))
idx, x0, _, feature = data
idx = idx.item()

_s = [0, 1]
m = repaint_lib.get_mask(feature, _s)
x = m * x0

_, axes = plt.subplots(1, 2, figsize=(16 / 2, 9 / 2))
ax = axes[0]
ax.imshow(x0.squeeze().permute(1, 2, 0))
ax.axis("off")

ax = axes[1]
ax.imshow(x.squeeze().permute(1, 2, 0))
ax.axis("off")
plt.show()

In [None]:
m = m.to(device)
x0 = x0.to(device)

sigma0 = sigmas[0]
z = torch.randn_like(x0)
x = sigma0[:, None, None, None] * z

fig, ax = plt.subplots(figsize=(5, 5))
ax.axis("off")
ax.imshow(x.cpu().squeeze().permute(1, 2, 0))
plt.savefig(os.path.join("test", "0.jpg"), bbox_inches="tight")
plt.close()


def _reverse_step(x, sigma, sigma_next):
    # add noise to known part
    z = torch.randn_like(x)
    x_known = x0 + sigma_next[:, None, None, None] * z

    # denoise unknown part
    model_output = unet(x, sigma.repeat(x.size(0))).sample

    diffusion = torch.sqrt(sigma**2 - sigma_next**2)
    diffusion = diffusion.flatten()
    drift = -(diffusion**2) * model_output

    z = torch.randn_like(x)
    x_unknown = x - drift + diffusion * z

    x = m * x_known + (1 - m) * x_unknown
    return x


def _forward_step(x, sigma, sigma_next):
    z = torch.randn_like(x)
    diffusion = torch.sqrt(sigma_next**2 - sigma**2)
    x = x + diffusion[:, None, None, None] * z
    return x


for i, sigma in enumerate(tqdm(sigmas[:-1])):
    sigma_next = sigmas[i + 1]
    if sigma_next < sigma:
        x = _reverse_step(x, sigma, sigma_next)
    else:
        x = _forward_step(x, sigma, sigma_next)

    if (i + 1) % 100 == 0:
        fig, ax = plt.subplots(figsize=(5, 5))
        ax.axis("off")
        ax.imshow(x.cpu().squeeze().permute(1, 2, 0))
        plt.savefig(os.path.join("test", f"{i+1}.jpg"), bbox_inches="tight")
        plt.close()

In [None]:
s = "23"
steps = range(500, int(5e03), 500)
for idx in range(6):
    idx_dir = os.path.join(repaint_dir, f"{idx}")
    s_dir = os.path.join(idx_dir, s)
    checkpoint_dir = os.path.join(s_dir, "checkpoints")

    x0 = torch.load(os.path.join(idx_dir, "original.pt"))
    m = torch.load(os.path.join(s_dir, "mask.pt"))
    x = m * x0

    checkpoints = torch.stack(
        [torch.load(os.path.join(checkpoint_dir, f"{c}.pt")) for c in steps],
        dim=1,
    )
    repainted = torch.load(os.path.join(s_dir, "repainted.pt"))

    x0 = x0.repeat(repainted.size(0), 1, 1, 1)
    x = x.repeat(repainted.size(0), 1, 1, 1)

    x0 = x0.unsqueeze(1)
    x = x.unsqueeze(1)
    repainted = repainted.unsqueeze(1)

    data = torch.cat([x0, x, checkpoints, repainted], dim=1)
    data = torch.clamp(data, 0, 1)
    l = data.size(1)
    data = data.flatten(0, 1)

    _, ax = plt.subplots(figsize=(16, 9))
    data = data[: 4 * l]
    im = make_grid(data, nrow=l)
    ax.imshow(im.permute(1, 2, 0))
    ax.axis("off")
    plt.show()