# Super resolution with Latent Diffusion Models

Let's also check what type of GPU we've got.

In [None]:
!nvidia-smi

Load it.

In [None]:
#@title loading utils
import torch
from omegaconf import OmegaConf

from ldm.util import instantiate_from_config


def load_model_from_config(config, ckpt=None):
    model = instantiate_from_config(config.model)
    
    if ckpt is not None:
        print(f"Loading model from {ckpt}")
        pl_sd = torch.load(ckpt)#, map_location="cpu")
        sd = pl_sd["state_dict"]
        m, u = model.load_state_dict(sd, strict=False)
    else:
        print("Instantiated model from config")
        
    model.cuda()
    model.eval()
    return model


def get_model():
    config = OmegaConf.load("/home/alban/ImSeqCond/latent-diffusion/models/ldm/bsr_sr/config.yaml")
    model = load_model_from_config(config, "/home/alban/ImSeqCond/latent-diffusion/models/ldm/bsr_sr/model.ckpt")
    return model

In [None]:
from ldm.models.diffusion.ddim import DDIMSampler

model = get_model()
sampler = DDIMSampler(model)

# count model parameters
params = sum([p.numel() for p in model.parameters() if p.requires_grad])
print(f"Model has {params/1e6:.2f}M parameters")

In [None]:
# Load some custom data
from ldm.data.siar import SIAR

dataset = SIAR("/home/alban/ImSeqCond/data/SIARmini", set_type='val', resolution=256)

And go. Quality, sampling speed and diversity are best controlled via the `scale`, `ddim_steps` and `ddim_eta` variables. As a rule of thumb, higher values of `scale` produce better samples at the cost of a reduced output diversity. Furthermore, increasing `ddim_steps` generally also gives higher quality samples, but returns are diminishing for values > 250. Fast sampling (i e. low values of `ddim_steps`) while retaining good quality can be achieved by using `ddim_eta = 0.0`.

In [None]:
import numpy as np
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid
from torchvision.transforms.functional import resize

i = 1

images_indexes = [i]
n_samples_per_image = 6

ddim_steps = 500
ddim_eta = 0.0
scale = 1   # for unconditional guidance


all_samples = list()

with torch.no_grad():
    with model.ema_scope():
        
        uc = model.get_learned_conditioning(
            {model.cond_stage_key: torch.zeros(n_samples_per_image, 3, 64, 64).cuda().to(model.device)}
            )

        for image_index in images_indexes:
            print(f"rendering {n_samples_per_image} examples of images '{image_index}' in {ddim_steps} steps and using s={scale:.2f}.")
            
            resized = resize(rearrange(torch.tensor(dataset[image_index]['data']), 'h w c -> c h w'), (64, 64))
            xc = resized.unsqueeze(0).repeat(n_samples_per_image, 1, 1, 1)
            
            c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})

            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                             conditioning=c[model.cond_stage_key],
                                             batch_size=n_samples_per_image,
                                             shape=[3, 64, 64],
                                             verbose=False,
                                             unconditional_guidance_scale=scale,
                                             unconditional_conditioning=uc[model.cond_stage_key],
                                             eta=ddim_eta)

            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0,
                                         min=0.0, max=1.0)
            all_samples.append(x_samples_ddim)


# display as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_samples_per_image)

# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8))

In [None]:
original = (dataset[i]['data'] + 1) * 127.5

Image.fromarray(original.astype(np.uint8))