# Inverse Modeling Exploration

In [None]:
# Imports
from typing import List

import torch
import torch.nn as nn
from torchvision.io import read_image
from diffusers.pipelines import ScoreSdeVePipeline
from matplotlib import pyplot as plt
from IPython.display import display
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure

from mvp_score_modelling.utils import (
    crop,
    resize,
    tensor_to_PIL,
    plt_img
)
from mvp_score_modelling.pipelines.utils import (
    VeTweedie,
    CustomConditionalScoreVePipeline,
)
from mvp_score_modelling.pipelines.superresolution import (
    PseudoinverseGuidedSuperResolutionPipeline,
    SuperResolutionProjectionPipeline,
    PrYtGuidedSuperResolutionPipeline
)
from mvp_score_modelling.pipelines.colorisation import (
    greyscale,
    ColorisationProjectionPipeline,
    PrYtGuidedColorisationPipeline
)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
PRETRAINED = "google/ncsnpp-celebahq-256"
EVAL_IMAGE = "data/test_celabhq.png"
INFERENCE_STEPS = 1000

eval_img = crop(resize(read_image(EVAL_IMAGE))) / 256
eval_img = eval_img.to(DEVICE)
display(tensor_to_PIL(eval_img))

In [None]:
# Metrics
lpips = LearnedPerceptualImagePatchSimilarity().to(DEVICE)
def lpips_eval(img, reference = None):
    if reference == None:
        reference = eval_img.unsqueeze(0)
    return round(lpips(reference, img.to(DEVICE)).item(), 4)

ssim = StructuralSimilarityIndexMeasure().to(DEVICE)
def ssim_eval(img, reference = None):
    if reference == None:
        reference = eval_img.unsqueeze(0)
    return round(ssim(reference, img.to(DEVICE)).item(), 4)

# Utils
def run_pipelines(y, pipelines, unet, scheduler):
    images = []
    for p_class in pipelines:
        pipeline = p_class(unet, scheduler)
        img = pipeline(
            y,
            num_inference_steps=INFERENCE_STEPS,
            output_type=None
        )[0]
        images.append(img)
    return images

def show_pipeline_results(
        reference,
        y,
        images: List[torch.Tensor],
        measurement_function,
        denoiser=None
    ):
    fig, axs = plt.subplots(1,len(images) + 1)
    for ax in axs:
        ax.axis('off')
    plt_img(axs[0], reference)
    axs[0].text(0,300, "y error")
    for i, img in enumerate(images):
        ax = axs[i+1]
        if not denoiser is None:
            img = denoiser(img)
        y_s = measurement_function(img)
        mse = torch.mean((y_s-y)**2).item()     
        plt_img(ax, img.squeeze())
        ax.text(0, 300, f"{mse:.3g}")

## Score-Based Generative Modeling Through Stochastic Differential Equations
Yang Song

In [None]:
unconditional_pipeline: ScoreSdeVePipeline = ScoreSdeVePipeline.from_pretrained(PRETRAINED).to(device=DEVICE)
tweedie = VeTweedie(unconditional_pipeline.unet)
images = unconditional_pipeline(num_inference_steps=INFERENCE_STEPS).images
display(images[0])
denoiser = lambda x: tweedie(x, unconditional_pipeline.scheduler.sigmas[-1])

In [None]:
INITIAL_START = int(INFERENCE_STEPS * 0.5)

sigma = unconditional_pipeline.scheduler.sigmas[INFERENCE_STEPS-INITIAL_START]
z = torch.randn_like(eval_img) * sigma

noised_eval_image = eval_img + z
pipeline = CustomConditionalScoreVePipeline(
    unconditional_pipeline.unet,
    unconditional_pipeline.scheduler
).to(DEVICE)
images = pipeline(
    num_inference_steps=INFERENCE_STEPS,
    initial_sample=(noised_eval_image.unsqueeze(0), INITIAL_START),
    output_type = None
)

In [None]:
fig, axs = plt.subplots(1,3)
for ax in axs:
    ax.axis('off')

axs[0].set_title("Original")
axs[0].text(0,300,"lpips")
axs[0].text(0,340,"ssim")
plt_img(axs[0], eval_img)
axs[1].set_title("One-step denoised")
axs[1].text(0,300,lpips_eval(tweedie(noised_eval_image, sigma)))
axs[1].text(0,340,ssim_eval(tweedie(noised_eval_image, sigma)))
plt_img(axs[1], tweedie(noised_eval_image,sigma).squeeze())
axs[2].set_title("Generated")
axs[2].text(0,300,lpips_eval(images[0].unsqueeze(0)))
axs[2].text(0,340,ssim_eval(images[0].unsqueeze(0)))
plt_img(axs[2], images[0])

# Super-resolution

In [None]:
KERNEL_SIZE = 16
pool = nn.AvgPool2d(kernel_size=KERNEL_SIZE, stride=KERNEL_SIZE)
downsampled_img = pool(eval_img).to(device=DEVICE)

pipelines = [
    SuperResolutionProjectionPipeline,
    PrYtGuidedSuperResolutionPipeline,
    PseudoinverseGuidedSuperResolutionPipeline
]

images = run_pipelines(
    y = (downsampled_img, KERNEL_SIZE),
    pipelines = pipelines,
    unet=unconditional_pipeline.unet,
    scheduler=unconditional_pipeline.scheduler
)

show_pipeline_results(
    SuperResolutionProjectionPipeline.upscale(downsampled_img, KERNEL_SIZE),
    downsampled_img,
    images,
    pool
)

# Colorisation

In [None]:
grey = greyscale(eval_img)

pipelines = [
    ColorisationProjectionPipeline,
    PrYtGuidedColorisationPipeline
]
images = run_pipelines(
    grey,
    pipelines,
    unconditional_pipeline.unet,
    unconditional_pipeline.scheduler
)
show_pipeline_results(grey.squeeze(), grey, images, greyscale)