In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.chdir("..")
print("New Working Directory ", os.getcwd())

In [None]:
import torch

# from configs.ve.cifar10_ncsnpp_deep_continuous import get_config
from configs.ve.celebahq_256_ncsnpp_continuous import get_config
from score_inverse.models.utils import create_model
from score_inverse.models.ema import ExponentialMovingAverage

# ckpt_path = "checkpoints/ve/cifar10_ncsnpp_deep_continuous/checkpoint_12.pth"
ckpt_path = "checkpoints/ve/celebahq_256_ncsnpp_continuous/checkpoint_48.pth"

config = get_config()
config.model.num_scales = 100  # Number of discretisation steps
config.eval.batch_size = 1  # Number of samples per generation

loaded_state = torch.load(ckpt_path, map_location=config.device)

score_model = create_model(config)
# Still need to load the base model state since non-trainable params aren't covered by EMA
score_model.load_state_dict(loaded_state["model"], strict=False)

# Replace trainable model params with EMA params
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
ema.load_state_dict(loaded_state["ema"])
ema.copy_to(score_model.parameters())

In [None]:
from PIL import Image
import numpy as np


def display_img(im, size):
    display(im.resize(size, Image.NEAREST))


def display_sample(sample, size=(128, 128)):
    samples = np.clip(sample.permute(0, 2, 3, 1).cpu().numpy() * 255.0, 0, 255).astype(
        np.uint8
    )
    if samples.shape[-1] == 1:
        im = Image.fromarray(samples[0, :, :, 0])
    else:
        im = Image.fromarray(samples[0])

    display_img(im, size)

In [None]:
from score_inverse.tasks.deblur import DeblurTask
from score_inverse.datasets import CelebA, CIFAR10

dataset = CelebA(img_size=config.data.image_size)
inverse_task = DeblurTask(dataset.img_size, kernel_type="gaussian", kernel_size=5).to(
    config.device
)

In [None]:
x = dataset[1]
x = x[None, :]

x = x.to(config.device)
y = inverse_task.A(x)

display_sample(x)
display_sample(y)

In [None]:
from score_inverse.sde import get_sde
from score_inverse.datasets.scalers import get_data_inverse_scaler, get_data_scaler
from score_inverse.sampling import get_corrector, get_predictor
from score_inverse.sampling.inverse import get_pc_inverse_solver

scaler = get_data_scaler(config)

inverse_scaler = get_data_inverse_scaler(config)
sde, sampling_eps = get_sde(config)

sampling_shape = (config.eval.batch_size, *dataset.img_size)
predictor = get_predictor(config.sampling.predictor.lower())
corrector = get_corrector(config.sampling.corrector.lower())
sampling_fn = get_pc_inverse_solver(
    sde=sde,
    shape=sampling_shape,
    predictor=predictor,
    corrector=corrector,
    inverse_scaler=inverse_scaler,
    snr=config.sampling.snr,
    n_steps=config.sampling.n_steps_each,
    probability_flow=config.sampling.probability_flow,
    continuous=config.training.continuous,
    denoise=config.sampling.noise_removal,
    eps=sampling_eps,
    device=config.device,
    inverse_task=inverse_task,
    lambda_=0.1,
)

In [None]:
# with torch.autocast("cuda", torch.float16):
sample, n = sampling_fn(score_model, y.to(config.device))

In [None]:
display_sample(y)
display_sample(sample)
display_sample(x)