# Inverse Modeling Exploration

In [None]:
# Imports
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torchvision.transforms import Resize, ToPILImage, GaussianBlur, CenterCrop
from torchvision.io import read_image
from diffusers.pipelines import ScoreSdeVePipeline
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from matplotlib import pyplot as plt
from IPython.display import display
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure

from mvp_score_modelling.utils import (
    crop,
    resize,
    tensor_to_PIL,
    plt_img
)
from mvp_score_modelling.pipelines import (
    VeTweedie,
    CustomConditionalScoreVePipeline
)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
PRETRAINED = "google/ncsnpp-celebahq-256"
EVAL_IMAGE = "test_celabhq.png"
INFERENCE_STEPS = 1000

eval_img = crop(resize(read_image(EVAL_IMAGE))) / 256
eval_img = eval_img.to(DEVICE)
display(tensor_to_PIL(eval_img))


In [None]:
# Metrics
lpips = LearnedPerceptualImagePatchSimilarity().to(DEVICE)
def lpips_eval(img, reference = None):
    if reference == None:
        reference = eval_img.unsqueeze(0)
    return round(lpips(reference, img.to(DEVICE)).item(), 4)

ssim = StructuralSimilarityIndexMeasure().to(DEVICE)
def ssim_eval(img, reference = None):
    if reference == None:
        reference = eval_img.unsqueeze(0)
    return round(ssim(reference, img.to(DEVICE)).item(), 4)


## Score-Based Generative Modeling Through Stochastic Differential Equations
Yang Song

In [None]:
unconditional_pipeline: ScoreSdeVePipeline = ScoreSdeVePipeline.from_pretrained(PRETRAINED).to(device=DEVICE)
tweedie = VeTweedie(unconditional_pipeline.unet)
images = unconditional_pipeline(num_inference_steps=INFERENCE_STEPS).images
display(images[0])

In [None]:
INITIAL_START = int(INFERENCE_STEPS * 0.5)

sigma = unconditional_pipeline.scheduler.sigmas[INFERENCE_STEPS-INITIAL_START]
z = torch.randn_like(eval_img) * sigma

noised_eval_image = eval_img + z
pipeline = CustomConditionalScoreVePipeline(
    unconditional_pipeline.unet,
    unconditional_pipeline.scheduler
).to(DEVICE)
images = pipeline(
    num_inference_steps=INFERENCE_STEPS,
    initial_sample=(noised_eval_image.unsqueeze(0), INITIAL_START),
    output_type = None
)

In [None]:
fig, axs = plt.subplots(1,3)
for ax in axs:
    ax.axis('off')

axs[0].set_title("Original")
axs[0].text(0,300,"lpips")
axs[0].text(0,340,"ssim")
plt_img(axs[0], eval_img)
axs[1].set_title("One-step denoised")
axs[1].text(0,300,lpips_eval(tweedie(noised_eval_image, sigma)))
axs[1].text(0,340,ssim_eval(tweedie(noised_eval_image, sigma)))
plt_img(axs[1], tweedie(noised_eval_image,sigma).squeeze())
axs[2].set_title("Generated")
axs[2].text(0,300,lpips_eval(images[0].unsqueeze(0)))
axs[2].text(0,340,ssim_eval(images[0].unsqueeze(0)))
plt_img(axs[2], images[0])

In [None]:
round(lpips(eval_img.unsqueeze(0), images[0].unsqueeze(0)).item(),3)