In [None]:
import os
import argparse

import numpy as np
import torch as th
import torch.nn.functional as F
import time
from RePaint import conf_mgt
from RePaint.utils import yamlread
from RePaint.guided_diffusion import dist_util

In [None]:
print(os.getcwd())
os.chdir("..")
print(os.getcwd())

In [None]:
from RePaint.guided_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    classifier_defaults,
    create_model_and_diffusion,
    create_classifier,
    select_args,
)  # noqa: E402

In [None]:
def toU8(sample):
    if sample is None:
        return sample

    sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous()
    sample = sample.detach().cpu().numpy()
    return sample


## conf

In [None]:
conf_path = "./Repaint/confs/test.yml"
conf = conf_mgt.conf_base.Default_Conf()
conf.update(yamlread(conf_path))

## main

In [None]:
device = dist_util.dev(conf.get('device'))
print(device)

In [None]:
model, diffusion = create_model_and_diffusion(
    **select_args(conf, model_and_diffusion_defaults().keys()), conf=conf
)
model.load_state_dict(
    dist_util.load_state_dict(os.path.expanduser(
        conf.model_path), map_location="cpu")
)
model.to(device)
print(conf.use_fp16)
if conf.use_fp16:
    model.convert_to_fp16()
model.eval()
show_progress = conf.show_progress

In [None]:
def model_fn(x, t, y=None, gt=None, **kwargs):
    assert y is not None
    return model(x, t, y if conf.class_cond else None, gt=gt)


In [None]:
dset = 'eval'
eval_name = conf.get_default_eval_name()
print(f"eval_name = {eval_name}")

dl = conf.get_dataloader(dset=dset, dsName=eval_name)

### one loop

In [None]:
batch = next(iter(dl))
print(batch.keys())

In [None]:
print(batch['GT'].shape)
print(batch['GT'].device)

In [None]:
from PIL import Image
import numpy as np

def tensor_to_image(tensor: th.Tensor):
    img_arr = tensor.detach().cpu().squeeze().numpy() * 0.5 + 0.5  # remap to 0 to 1
    pil_image = Image.fromarray((img_arr * 255).astype('uint8').transpose((1, 2, 0)))
    return pil_image

tensor_to_image(batch['GT'])

In [None]:
for k in batch.keys():
    if isinstance(batch[k], th.Tensor):
        batch[k] = batch[k].to(device)


In [None]:
print(batch['GT'].device)

In [None]:
model_kwargs = {}

model_kwargs["gt"] = batch['GT']

gt_keep_mask = batch.get('gt_keep_mask')
if gt_keep_mask is not None:
    model_kwargs['gt_keep_mask'] = gt_keep_mask

batch_size = model_kwargs["gt"].shape[0]

if conf.cond_y is not None:
    print(f"conf cond_y is not None")
    classes = th.ones(batch_size, dtype=th.long, device=device)
    model_kwargs["y"] = classes * conf.cond_y
else:
    print(f"conf cond_y is None")
    classes = th.randint(
        low=0, high=NUM_CLASSES, size=(batch_size,), device=device
    )
    print(classes)
    model_kwargs["y"] = classes
model_kwargs.keys()

In [None]:
print(f"use_ddim = {conf.use_ddim}")
sample_fn = (
    diffusion.p_sample_loop if not conf.use_ddim else diffusion.ddim_sample_loop
)


In [None]:
# full loop
result = sample_fn(
    model_fn,
    (batch_size, 3, conf.image_size, conf.image_size),
    clip_denoised=conf.clip_denoised,
    model_kwargs=model_kwargs,
    cond_fn=None,
    device=device,
    progress=show_progress,
    return_all=True,
    conf=conf
)

In [None]:
result.keys()

### sample_fn 拆解

In [None]:
shape = (batch_size, 3, conf.image_size, conf.image_size)
print(shape)

In [None]:
image_after_step = th.randn(*shape, device=device)
print(th.min(image_after_step))
print(th.max(image_after_step))
tensor_to_image(image_after_step)

In [None]:
pred_xstart = None

In [None]:
from RePaint.guided_diffusion.scheduler import get_schedule_jump
times = get_schedule_jump(**conf.schedule_jump_params)
time_pairs = list(zip(times[:-1], times[1: ]))
print(len(time_pairs))

In [None]:
t_last, t_cur = time_pairs[0]
print(t_last, t_cur)

In [None]:
t_last_t = th.tensor([t_last] * shape[0], device=device)
t_last_t

In [None]:

with th.no_grad():
    out = diffusion.p_sample(
        model,
        image_after_step,
        t_last_t,
        clip_denoised=True,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=model_kwargs,
        conf=conf,
        pred_xstart=pred_xstart
    )


In [None]:
tensor_to_image(out["pred_xstart"])

In [None]:
tensor_to_image(out['sample'])

In [None]:
image_after_step = out["sample"]
pred_xstart = out["pred_xstart"]

## 继续分解

In [None]:
noise = th.randn_like(image_after_step)
tensor_to_image(noise)

In [None]:
gt_keep_mask = model_kwargs.get('gt_keep_mask')
tensor_to_image(gt_keep_mask)

In [None]:
gt = model_kwargs['gt']
tensor_to_image(gt)

In [None]:
from RePaint.guided_diffusion.gaussian_diffusion import _extract_into_tensor
print(diffusion.alphas_cumprod.shape)
print(t_last_t)
alpha_cumprod = _extract_into_tensor(diffusion.alphas_cumprod, t_last_t, image_after_step.shape)
tensor_to_image(alpha_cumprod)
print(th.min(alpha_cumprod))
print(th.max(alpha_cumprod))

In [None]:
print(conf.inpa_inj_sched_prev_cumnoise)
if conf.inpa_inj_sched_prev_cumnoise:
    
    weighed_gt = diffusion.get_gt_noised(gt, int(t_last_t[0].item()))
else:
    gt_weight = th.sqrt(alpha_cumprod)
    print(f"gt_weight = {th.min(gt_weight)}")
    gt_part = gt_weight * gt

    noise_weight = th.sqrt((1 - alpha_cumprod))
    print(f"noise_weight = {th.min(noise_weight)}")
    noise_part = noise_weight * th.randn_like(image_after_step)

    weighed_gt = gt_part + noise_part
tensor_to_image(weighed_gt)


In [None]:

x = (
    gt_keep_mask * (
        weighed_gt
    )
    +
    (1 - gt_keep_mask) * (
        image_after_step
    )
)
tensor_to_image(x)

In [None]:
out = diffusion.p_mean_variance(
            model,
            x,
            t_last_t,
            clip_denoised=None,
            denoised_fn=None,
            model_kwargs=model_kwargs,
        )
print(out.keys())

In [None]:
tensor_to_image(out['pred_xstart'])

In [None]:
nonzero_mask = (
    (t_last_t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) 


In [None]:
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
tensor_to_image(sample)