In [None]:
import os
from diffusers.pipelines import ScoreSdeVePipeline
import torch
from torchvision.io import read_image
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
from mvp_score_modelling.utils import (crop,resize, plt_img)
from mvp_score_modelling.pipelines.utils import VeTweedie
from matplotlib import pyplot as plt

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
FIGURES_DIRECTORY = '../figures/'

lpips = LearnedPerceptualImagePatchSimilarity().to(DEVICE)
def lpips_eval(img, reference):
    return round(lpips(reference, img.to(DEVICE)).item(), 4)

ssim = StructuralSimilarityIndexMeasure().to(DEVICE)
def ssim_eval(img, reference):
    return round(ssim(reference, img.to(DEVICE)).item(), 4)

In [None]:
PRETRAINED = "google/ncsnpp-celebahq-256"
EVAL_IMAGE = "data/test_celabhq.png"

eval_img = crop(resize(read_image(EVAL_IMAGE))) / 256
eval_img = eval_img.to(DEVICE)

unconditional_pipeline: ScoreSdeVePipeline = ScoreSdeVePipeline.from_pretrained(PRETRAINED).to(device=DEVICE)
tweedie = VeTweedie(unconditional_pipeline.unet)
sigma = unconditional_pipeline.scheduler.sigmas[800]
z = torch.randn_like(eval_img) * sigma
noised_eval_image = eval_img + z


In [None]:
# Tweedie

fig, axs = plt.subplots(1,3)
for ax in axs:
    ax.axis('off')

axs[0].set_title("Original")
axs[0].text(0,300,"lpips")
axs[0].text(0,340,"ssim")
plt_img(axs[0], eval_img)

axs[1].set_title(f"Noise Clipped (σ {round(sigma.item(), 3)})")
axs[1].text(0,300,lpips_eval(noised_eval_image.unsqueeze(0).clamp(0,1), eval_img.unsqueeze(0)))
axs[1].text(0,340,ssim_eval(noised_eval_image.unsqueeze(0).clamp(0,1), eval_img.unsqueeze(0)))
plt_img(axs[1], noised_eval_image)

axs[2].set_title("Tweedie")
tw = tweedie(noised_eval_image, sigma)
axs[2].text(0,300,lpips_eval(tw, eval_img.unsqueeze(0)))
axs[2].text(0,340,ssim_eval(tw, eval_img.unsqueeze(0)))
plt_img(axs[2], tw.squeeze())


plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.savefig(os.path.join(FIGURES_DIRECTORY, "tweedie.png"),bbox_inches='tight', transparent="True", pad_inches=0)
