In [None]:
import os
import pickle

from datasets import load_dataset, Dataset, DatasetDict
from diffusers.pipelines import ScoreSdeVePipeline
import torch
from torchvision.io import read_image
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
from mvp_score_modelling.utils import (crop,resize, plt_img)
from mvp_score_modelling.pipelines.utils import VeTweedie
from matplotlib import pyplot as plt
from PIL.Image import Image
from mvp_score_modelling.pipelines.inpainting import MaskGenerator
from mvp_score_modelling.utils import tensor_to_PIL, plt_img

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
FIGURES_DIRECTORY = '../figures/'

lpips = LearnedPerceptualImagePatchSimilarity().to(DEVICE)
def lpips_eval(img, reference):
    return round(lpips(reference, img.to(DEVICE)).item(), 4)

ssim = StructuralSimilarityIndexMeasure().to(DEVICE)
def ssim_eval(img, reference):
    return round(ssim(reference, img.to(DEVICE)).item(), 4)

In [None]:
# PRETRAINED = "google/ncsnpp-celebahq-256"
PRETRAINED = "google/ncsnpp-church-256"
# EVAL_IMAGE = "data/test_celabhq.png"
EVAL_IMAGE = "data/church/1.png"

eval_img = crop(resize(read_image(EVAL_IMAGE))) / 256
eval_img = eval_img.to(DEVICE)

unconditional_pipeline: ScoreSdeVePipeline = ScoreSdeVePipeline.from_pretrained(PRETRAINED).to(device=DEVICE)
tweedie = VeTweedie(unconditional_pipeline.unet)
sigma = unconditional_pipeline.scheduler.sigmas[800]
z = torch.randn_like(eval_img) * sigma
noised_eval_image = eval_img + z


# Generate Qualitative Sample Images

In [None]:
# Loading ground truth

# DATASET_SOURCE = "Ryan-sjtu/celebahq-caption"
DATASET_SOURCE = "tglcourse/lsun_church_train"
#STORE_LOCATION = 'data/celab'
STORE_LOCATION = 'data/church'

dataset: Dataset = load_dataset(DATASET_SOURCE)['train']

idxs = [1,10,100,1000,10000]
for i in idxs:
    image: Image = dataset[i]['image']
    image.save(os.path.join(STORE_LOCATION, f"{i}.png"))

In [None]:
# Utils
def run_pipelines(y, pipelines, unet, scheduler, inference_steps = 1000):
    images = []
    for p_class in pipelines:
        pipeline = p_class(unet, scheduler)
        img = pipeline(
            y,
            num_inference_steps=inference_steps,
            output_type=None
        )[0]
        images.append(img)
    return images

In [None]:
# Random mask
from mvp_score_modelling.pipelines.inpainting import (
    PrYtGuidedInpaintingPipeline,
    InpaintingProjectionPipeline,
    ManifoldConstrainedGradientInpaintingPipeline,
    PseudoinverseGuidedInpaintingPipeline
)

pipelines = [
    InpaintingProjectionPipeline,
    PrYtGuidedInpaintingPipeline,
    PseudoinverseGuidedInpaintingPipeline,
    ManifoldConstrainedGradientInpaintingPipeline
]
torch.manual_seed(0)

mask_gen = MaskGenerator((1,3,256,256), device=DEVICE)
random_mask = mask_gen.generate_random_mask()

# random_mask_image = read_image('data/celab/1.png')
random_mask_image = read_image('data/church/1.png')
random_mask_image = crop(resize(random_mask_image)) / 256
random_mask_image = random_mask_image.to(DEVICE)

random_mask_results = run_pipelines(
    (random_mask_image.unsqueeze(0) , random_mask),
    pipelines,
    unconditional_pipeline.unet,
    unconditional_pipeline.scheduler,
    inference_steps=1000
)

In [None]:
with open("../output/test_images/church_random_mask_results.pkl", 'wb') as f:
    pickle.dump(random_mask_results, f)
    # random_mask_results = pickle.load(f)

In [None]:
display(*[tensor_to_PIL(img) for img in random_mask_results])

In [None]:
# Box Mask
torch.manual_seed(0)
pipelines = [
    InpaintingProjectionPipeline,
    PrYtGuidedInpaintingPipeline,
    PseudoinverseGuidedInpaintingPipeline,
    ManifoldConstrainedGradientInpaintingPipeline
]

box_mask = mask_gen.generate_box_mask(size= 80)

box_mask_image = read_image('data/church/1000.png')
box_mask_image = crop(resize(box_mask_image)) / 256
box_mask_image = box_mask_image.to(DEVICE)

box_mask_results = run_pipelines(
    (box_mask_image.unsqueeze(0) , box_mask),
    pipelines,
    unconditional_pipeline.unet,
    unconditional_pipeline.scheduler,
    inference_steps=1000
)

In [None]:
with open("../output/test_images/church_box_mask_results.pkl", 'wb') as f:
    pickle.dump(box_mask_results, f)
    #box_mask_results_2 = pickle.load(f)

In [None]:
# Box Mask
torch.manual_seed(0)
pipelines = [
    InpaintingProjectionPipeline,
    PrYtGuidedInpaintingPipeline,
    PseudoinverseGuidedInpaintingPipeline,
    ManifoldConstrainedGradientInpaintingPipeline
]

box_mask_2 = mask_gen.generate_box_mask(size= 80)
box_mask_2 = 1 - box_mask_2
box_mask_image_2 = read_image('data/church/10000.png')
box_mask_image_2 = crop(resize(box_mask_image_2)) / 256
box_mask_image_2 = box_mask_image_2.to(DEVICE)

box_mask_results_2 = run_pipelines(
    (box_mask_image_2.unsqueeze(0) , box_mask_2),
    pipelines,
    unconditional_pipeline.unet,
    unconditional_pipeline.scheduler,
    inference_steps=1000
)

In [None]:
tensor_to_PIL(box_mask_2 * box_mask_image_2)

In [None]:
display(*[tensor_to_PIL(img) for img in box_mask_results_2])

In [None]:
with open("../output/test_images/church_box_mask_results_2.pkl", 'wb') as f:
    pickle.dump(box_mask_results_2, f)
    #box_mask_results_2 = pickle.load(f)

In [None]:
# Super resolution
from mvp_score_modelling.pipelines.super_resolution import (
    SuperResolutionProjectionPipeline,
    PrYtGuidedSuperResolutionPipeline,
    PseudoinverseGuidedSuperResolutionPipeline,
    ManifoldConstrainedGradientSuperResolutionPipeline
)

pipelines = [
    SuperResolutionProjectionPipeline,
    PrYtGuidedSuperResolutionPipeline,
    PseudoinverseGuidedSuperResolutionPipeline,
    ManifoldConstrainedGradientSuperResolutionPipeline
]
torch.manual_seed(0)

KERNEL_SIZE = 16
pool = torch.nn.AvgPool2d(kernel_size=KERNEL_SIZE, stride=KERNEL_SIZE)

super_resolution_image = read_image('data/church/100.png')
super_resolution_image = crop(resize(super_resolution_image)) / 256
super_resolution_image =super_resolution_image.to(DEVICE)

super_resolution_results = run_pipelines(
    y = (pool(super_resolution_image), KERNEL_SIZE),
    pipelines = pipelines,
    unet=unconditional_pipeline.unet,
    scheduler=unconditional_pipeline.scheduler
)

In [None]:
display(*[tensor_to_PIL(img) for img in super_resolution_results])

In [None]:
with open("../output/test_images/church_super_resolution_results.pkl", 'wb') as f:
    pickle.dump(super_resolution_results, f)
    # random_mask_results = pickle.load(f)

In [None]:
# Colorisation
from mvp_score_modelling.pipelines.colorisation import (
    ColorisationProjectionPipeline,
    PrYtGuidedColorisationPipeline,
    ManifoldConstrainedGradientColorisationPipeline,
    PseudoinverseGuidedColorisationPipeline,
    greyscale
)
pipelines = [
    ColorisationProjectionPipeline,
    PrYtGuidedColorisationPipeline,
    PseudoinverseGuidedColorisationPipeline,
    ManifoldConstrainedGradientColorisationPipeline
]
torch.manual_seed(0)


colorisation_image = read_image('data/church/10.png')
colorisation_image = crop(resize(colorisation_image)) / 256
colorisation_image = colorisation_image.to(DEVICE)
grey = greyscale(colorisation_image)

colorisation_results = run_pipelines(
    y = grey,
    pipelines = pipelines,
    unet=unconditional_pipeline.unet,
    scheduler=unconditional_pipeline.scheduler
)

In [None]:
with open("../output/test_images/church_colorisation_results.pkl", 'wb') as f:
    pickle.dump(colorisation_results, f)
    # random_mask_results = pickle.load(f)

In [None]:
display(*[tensor_to_PIL(img) for img in colorisation_results])

In [None]:
# Visualising all together

fig, axs = plt.subplots(5,6)
fig.set_size_inches(14,12)

for column in axs:
    for ax in column:
        ax.axis('off')

upscale = SuperResolutionProjectionPipeline.upscale

data = [
    (random_mask_image, lambda x: x * random_mask, random_mask_results),
    (box_mask_image, lambda x: x * box_mask, box_mask_results),
    (box_mask_image_2, lambda x: x * box_mask_2, box_mask_results_2),
    (super_resolution_image, lambda x: upscale(pool(x), KERNEL_SIZE), super_resolution_results),
    (colorisation_image, lambda x: greyscale(x) , colorisation_results)
]

axs[0][0].set_title("Ground Truth")
axs[0][1].set_title("Measurement Y")
axs[0][2].set_title("Constraint")
axs[0][3].set_title("PrYt")
axs[0][4].set_title("Pseudoinverse")
axs[0][5].set_title("MCG")

for i, d in enumerate(data):
    f = d[1]
    plt_img(axs[i][0], d[0])
    plt_img(axs[i][1], d[1](d[0]))
    axs[i][1].text(0,290,"lpips / ssim / y_mse")

    for j, img in enumerate(d[2]):
        plt_img(axs[i][j+2], img)

        lpips_ = lpips_eval(img.unsqueeze(0), d[0].unsqueeze(0))
        ssim_ = ssim_eval(img.unsqueeze(0), d[0].unsqueeze(0))
        mse = ((d[1](d[0]) - d[1](img))**2).mean()

        axs[i][j+2].text(0,290, f"{lpips_:.3g}/{ssim_:.3g}/{mse:.3g}")

plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.savefig(os.path.join(FIGURES_DIRECTORY, "visualise-church.png"),bbox_inches='tight', transparent="True", pad_inches=0)

# Tweedie

In [None]:
# Tweedie

fig, axs = plt.subplots(1,3)
for ax in axs:
    ax.axis('off')

axs[0].set_title("Original")
axs[0].text(0,300,"lpips")
axs[0].text(0,340,"ssim")
plt_img(axs[0], eval_img)

axs[1].set_title(f"Noise Clipped (σ {round(sigma.item(), 3)})")
axs[1].text(0,300,lpips_eval(noised_eval_image.unsqueeze(0).clamp(0,1), eval_img.unsqueeze(0)))
axs[1].text(0,340,ssim_eval(noised_eval_image.unsqueeze(0).clamp(0,1), eval_img.unsqueeze(0)))
plt_img(axs[1], noised_eval_image)

axs[2].set_title("Tweedie")
tw = tweedie(noised_eval_image, sigma)
axs[2].text(0,300,lpips_eval(tw, eval_img.unsqueeze(0)))
axs[2].text(0,340,ssim_eval(tw, eval_img.unsqueeze(0)))
plt_img(axs[2], tw.squeeze())


plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.savefig(os.path.join(FIGURES_DIRECTORY, "tweedie.png"),bbox_inches='tight', transparent="True", pad_inches=0)
