In [None]:
# Imports
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torchvision.transforms import Resize, ToPILImage, GaussianBlur
from torchvision.transforms._functional_tensor import _get_gaussian_kernel2d
from torchvision.io import read_image
from diffusers.pipelines import ScoreSdeVePipeline
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from matplotlib import pyplot as plt
from IPython.display import display

In [None]:
# Hyperparameters
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'     # Update this line if you want to use a different device such as TPU or Macbook's MPS
PRETRAINED = "google/ncsnpp-celebahq-256"
TEST_IMAGE = "test_celabhq.png"
N_INFERENCE_STEPS = 2000

In [None]:
# Load data
unconditional_pipeline = ScoreSdeVePipeline.from_pretrained(PRETRAINED).to(device=DEVICE)
resize = Resize(256)
test_img = read_image(TEST_IMAGE)
test_img = resize(test_img) / 256

In [None]:
# Utilities

def show_img(img, size = 3) -> None:
    plt.figure(figsize=(size, size))
    plt.imshow(img.squeeze().permute(-2,-1,-3))
    plt.axis('off')
    plt.show()
tensor_to_PIL = ToPILImage()


class TractableInversePipeline(ScoreSdeVePipeline):

    def add_conditional_gradient(self, sample, output, y, sigma):
        raise NotImplementedError

    @torch.no_grad()
    def __call__(
        self,
        y,
        batch_size: int = 1,
        num_inference_steps: int = 2000,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        debug: bool = False,
        **kwargs,
    ) -> Union[ImagePipelineOutput, Tuple]:
        img_size = self.unet.config.sample_size
        shape = (batch_size, 3, img_size, img_size)

        model = self.unet

        sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma
        sample = sample.to(self.device)

        self.scheduler.set_timesteps(num_inference_steps)
        self.scheduler.set_sigmas(num_inference_steps)

        for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
            sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)

            # correction step
            for _ in range(self.scheduler.config.correct_steps):
                model_output = self.unet(sample, sigma_t).sample
                model_output = self.add_conditional_gradient(sample, model_output, y, sigma_t)
                sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample

            # prediction step
            model_output = model(sample, sigma_t).sample
            model_output =  self.add_conditional_gradient(sample, model_output, y, sigma_t)
            output = self.scheduler.step_pred(model_output, t, sample, generator=generator)

            sample, sample_mean = output.prev_sample, output.prev_sample_mean

            if i % 500 == 0 and debug:
                show_img(sample[0].cpu())

        sample = sample_mean.clamp(0, 1)
        sample = sample.cpu().permute(0, 2, 3, 1).numpy()
        if output_type == "pil":
            sample = self.numpy_to_pil(sample)

        if not return_dict:
            return (sample,)

        return ImagePipelineOutput(images=sample)

# Super Resolution

In [None]:
class SuperResolutionPipeline(TractableInversePipeline):
    """ Super resolution from a smaller image

    Recall from controllable generation we need to find grad(log p(y | x(t))) to add as a conditional gradient
    This is directly tractable by pooling x(t), and considering a normal distribution N' from the reference image.

    The normal distribution has a variance that can be obtained by considering sums of gaussians, and multiplying
    a gaussian by a constant. Both are proportional to the area covered by the downsampling kernel.

    N' = (reference, (sigma * KERNEL**2) / (KERNEL**4))
    """

    def add_conditional_gradient(self, sample, output, y, sigma):
        y, kernel_size = y
        pool = nn.AvgPool2d(kernel_size=kernel_size, stride=kernel_size)
        with torch.enable_grad():
            sample.requires_grad = True
            d_sample = pool(sample)
            diff = d_sample - y
            sigma = sigma[0] # sigma is the same across batches
            sigma = sigma / kernel_size # STD, Not Variance!!
            dist = torch.distributions.Normal(0,sigma)
            l_p_x_y = dist.log_prob(diff)
            l_p_x_y.sum().backward()
            output += sample.grad
        return output

In [None]:
# Execution
KERNEL_SIZE = 16

pool = nn.AvgPool2d(kernel_size=KERNEL_SIZE, stride=KERNEL_SIZE)
downsampled_img = pool(test_img).to(device=DEVICE)
plt.axis('off')
plt.gcf().set_size_inches(3,3)
plt.imshow(downsampled_img.cpu().permute(-2,-1,-3))

pipeline = SuperResolutionPipeline(unconditional_pipeline.unet, unconditional_pipeline.scheduler)
images = pipeline((downsampled_img, KERNEL_SIZE), batch_size=3, target=False, num_inference_steps=1000).images
display("Original", tensor_to_PIL(test_img), "Generated", *images)

# Colorisation

In [None]:
def grayscale(image_tensor: torch.Tensor) -> torch.Tensor:
    return torch.mean(image_tensor, dim=1, keepdims=True).repeat(1, 3, 1, 1)

class ColorisationPipeline(TractableInversePipeline):
    """ Colorisation
    """
    def add_conditional_gradient(self, sample, output, y, sigma):
        with torch.enable_grad():
            sample.requires_grad = True
            gray_sample = grayscale(sample)
            diff = gray_sample[:, 0] - y[0][0].expand(3, 256, 256)
            sigma = sigma[0] # sigma is the same across batches
            sigma = sigma / 3.0**0.5 # STD, Not Variance!!
            dist = torch.distributions.Normal(0,sigma)
            l_p_x_y = dist.log_prob(diff)
            l_p_x_y.sum().backward()
            output += sample.grad
        return output

In [None]:
gray = grayscale(test_img.unsqueeze(0)).to(device=DEVICE)
pipeline = ColorisationPipeline(unconditional_pipeline.unet, unconditional_pipeline.scheduler)
images = pipeline(gray, batch_size=3, target=False, num_inference_steps=1000).images
display("Original", tensor_to_PIL(test_img), "Generated", *images, tensor_to_PIL(gray.squeeze()))

# Blur

In [None]:
BLUR_KERNEL_SIZE = 27
gaussian_blur = GaussianBlur(BLUR_KERNEL_SIZE, sigma=2)

class DeblurringPipeline(TractableInversePipeline):
    """ Deblurring
    """
    def add_conditional_gradient(self, sample, output, y, sigma):
        with torch.enable_grad():
            sample.requires_grad = True
            blurred_sample = gaussian_blur(sample)
            diff = y - blurred_sample
            sigma = sigma[0]
            dist = torch.distributions.Normal(0,sigma)
            l_p_x_y = dist.log_prob(diff)
            l_p_x_y.sum().backward()
            output += sample.grad
        return output

In [None]:
blurred = gaussian_blur(test_img).to(device=DEVICE)
pipeline = DeblurringPipeline(unconditional_pipeline.unet, unconditional_pipeline.scheduler)
images = pipeline(blurred, batch_size=1, target=False, num_inference_steps=1000).images
display("Original", tensor_to_PIL(test_img), "Generated", *images, tensor_to_PIL(blurred))