# Sequence restoration with Latent Diffusion Models

In [12]:
import matplotlib.pyplot as plt
import matplotlib
#%matplotlib inline

import os

import numpy as np
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid
from torchvision.transforms.functional import resize

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Let's also check what type of GPU we've got.

In [None]:
import os

ROOT_PATH = "C:/Users/Shadow/Documents/ImSeqCond/"
#os.chdir("/home/alban/ImSeqCond/latent-diffusion/")
os.chdir(os.path.join(ROOT_PATH, "latent-diffusion/"))

Load it.

In [None]:
#@title loading utils
import torch
from omegaconf import OmegaConf

from ldm.util import instantiate_from_config


def load_model_from_config(config, ckpt=None):
    model = instantiate_from_config(config.model)
    
    if ckpt is not None:
        print(f"Loading model from {ckpt}")
        pl_sd = torch.load(ckpt)#, map_location="cpu")
        sd = pl_sd["state_dict"]
        m, u = model.load_state_dict(sd, strict=False)
    else:
        print("Instantiated model from config")
        
    model.cuda()
    model.eval()
    return model

cond_key = 'label'
#cond_key = 'LR_image'

model_folder = os.path.join(ROOT_PATH, "latent-diffusion/logs_saved/2024-01-18T22-34-24_config_siar_recon")
#model_folder = "/home/alban/ImSeqCond/latent-diffusion/logs_saved/2023-12-21T15-15-42_config_siar_recon"
checkpoint = "epoch=000048.ckpt"

files = os.listdir(os.path.join(model_folder, "configs"))
config_file = ""
for file in files:
    if file.endswith("project.yaml"):
        config_file = file
        break

if config_file == "":
    raise ValueError("No config file found")

def get_model(model_folder, config_file, checkpoint):
    config = OmegaConf.load(os.path.join(model_folder, 'configs', config_file))
    model = load_model_from_config(config, os.path.join(model_folder, "checkpoints", checkpoint))
    return model

In [None]:
from ldm.models.diffusion.ddim import DDIMSampler

model = get_model(model_folder, config_file, checkpoint)
sampler = DDIMSampler(model)

# count model parameters
params = sum([p.numel() for p in model.parameters() if p.requires_grad])
print(f"Model has {params/1e6:.2f}M parameters")

In [13]:
# Load some custom data
from ldm.data.siar import SIAR

dataset = SIAR(os.path.join(ROOT_PATH, "data/SIAR"), set_type='val', resolution=256, max_sequence_size=10) #, downscale_f=4)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)

ModuleNotFoundError: No module named 'ldm'

In [None]:
# Find the index of an image in the dataset

""" for i in range(len(dataset)):
    if dataset[i]['name'] == "7139":
        print(i)
        break """

And go. Quality, sampling speed and diversity are best controlled via the `scale`, `ddim_steps` and `ddim_eta` variables. As a rule of thumb, higher values of `scale` produce better samples at the cost of a reduced output diversity. Furthermore, increasing `ddim_steps` generally also gives higher quality samples, but returns are diminishing for values > 250. Fast sampling (i e. low values of `ddim_steps`) while retaining good quality can be achieved by using `ddim_eta = 0.0`.

In [None]:
i = 10

images_indexes = [i]
n_samples_per_image = 6

ddim_steps = 20
ddim_eta = 1.0
scale = 1# for unconditional guidance


all_samples = list()

with torch.no_grad():
    with model.ema_scope():
        
        """ uc = model.get_learned_conditioning(
            {model.cond_stage_key: torch.zeros(n_samples_per_image, 3, 64, 64).cuda().to(model.device)}
            ) """
        
        uc = model.get_learned_conditioning(
            torch.zeros(n_samples_per_image, 3, 4, 256, 256).cuda().to(model.device)
            )

        for image_index in images_indexes:
            print(f"rendering {n_samples_per_image} examples of images '{dataset[image_index]['name']}' in {ddim_steps} steps and using s={scale:.2f}.")
            
            if cond_key == 'LR_image':
                xc = rearrange(torch.tensor(dataset[image_index]['LR_image']), 'h w c -> c h w').unsqueeze(0).repeat(n_samples_per_image, 1, 1, 1)
            elif cond_key == 'label':
                xc = rearrange(torch.tensor(dataset[image_index][cond_key]), 's h w c -> c s h w').unsqueeze(0).repeat(n_samples_per_image, 1, 1, 1, 1)
            else:
                raise ValueError(f"Unknown cond_key '{cond_key}'")

            c = model.get_learned_conditioning(xc.to(model.device))

            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                             conditioning=c,
                                             batch_size=n_samples_per_image,
                                             shape=[3, 64, 64],
                                             verbose=False,
                                             unconditional_guidance_scale=scale,
                                             unconditional_conditioning=uc,
                                             eta=ddim_eta)

            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0,
                                         min=0.0, max=1.0)
            all_samples.append(x_samples_ddim)
 

# display as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_samples_per_image)

# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8))

In [None]:
def plot_image(data, predict=None):
    """ For a single data point, plot the ground truth image, the input images and the predicted image
    Args:
        gt (torch.tensor): ground truth image
        input (torch.tensor): input images
        predict (torch.tensor): predicted image
    """
    gt, input = data['data'], data['label']
    
    label_images = data['label'].shape[0]
    
    fig, axes = plt.subplots(2, 6, figsize=(20, 10))

    axes[0, 0].imshow(gt)
    axes[0, 0].set_title("Ground truth")
    
    for i in range(label_images):
        axes[i//5, i%5 + 1].imshow(input[i])
        axes[i//5, i%5 + 1].set_title("Input " + str(i+1))
        
    if predict is not None:
        axes[1, 0].imshow(predict)
        axes[1, 0].set_title("Predicted")
        
    plt.show()

In [None]:
def prepare_for_plot(data, all_samples=None):
    
    data_prepared = dict()
    for key, value in data.items():
        if key in ['data', 'label']:
            data_prepared[key] = (value + 1) / 2
    
    predict_prepared = rearrange(all_samples[0][0], 'c h w -> h w c')
    #predict_prepared = (predict_prepared + 1) / 2
    predict_prepared = predict_prepared.cpu().detach().numpy()
    
    return data_prepared, predict_prepared

In [None]:
plot_image(*prepare_for_plot(dataset[i], all_samples))

In [None]:
# STUDY OF THE LATENT SPACE

""" cond = c[0]

# convert cond in 0, 1
cond = (cond - cond.min()) / (cond.max() - cond.min())
cond = rearrange(cond, 'c h w -> h w c')
cond = cond.detach().cpu().numpy()

cond_decode = model.decode_first_stage(c[0].unsqueeze(0))

cond_decode = torch.clamp((cond_decode+1.0)/2.0,
                                        min=0.0, max=1.0)
cond_decode = rearrange(cond_decode.squeeze(), 'c h w -> h w c')
cond_decode = cond_decode.detach().cpu().numpy()


fig, axes = plt.subplots(1, 2, figsize=(20, 10))

axes[0].imshow(cond)
axes[0].set_title("Cond in latent space")

axes[1].imshow(cond_decode)
axes[1].set_title("Cond in pixel space") """

In [None]:
from benchmark import Benchmark

class BenchmarkLDM(Benchmark):
    
    def __init__(self, model, dataloader, mse=True, clip=False, lpips=False, cond_key='label'):
        super().__init__(model, dataloader, mse, clip, lpips, cond_key)
    
    def sample(self, data, ddim_steps=20, ddim_eta=1.0, scale=1):
        """ Method used to sample from the model with the data as conditionning 
            Args:
                data (torch.tensor): conditionning data. size: (batch_size, 3, W, H) or (batch_size, 10, 3, W, H)
            Output:
                torch.tensor: restored image. size: (batch_size, 3, W, H)
        """

        uc = self.model.get_learned_conditioning(
            torch.zeros(data.shape[0], 3, 4, 256, 256).cuda().to(self.model.device)
            )
        
        if self.cond_key == 'LR_image':
            xc = rearrange(torch.tensor(data), 'b h w c -> b c h w')
        elif self.cond_key == 'label':
            xc = rearrange(torch.tensor(data), 'b s h w c -> b c s h w')
        else:
            raise ValueError(f"Unknown cond_key '{cond_key}'")

        c = self.model.get_learned_conditioning(xc.to(self.model.device))

        samples_ddim, _ = sampler.sample(S=ddim_steps,
                                            conditioning=c,
                                            batch_size=data.shape[0],
                                            shape=[3, 64, 64],
                                            verbose=False,
                                            unconditional_guidance_scale=scale,
                                            unconditional_conditioning=uc,
                                            eta=ddim_eta)

        x_samples_ddim = self.model.decode_first_stage(samples_ddim)
        x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0,
                                        min=0.0, max=1.0)
    
        return x_samples_ddim
    

In [None]:
benchmark = BenchmarkLDM(model, dataloader, mse=True, clip=True, lpips=True, cond_key=cond_key)

In [None]:
results = benchmark.evaluate()

In [None]:
print(benchmark.results)

import json

print(json.dumps(benchmark.results, indent=4))

In [None]:
def rescale(data):
    """ Rescale data between 0 and 1 from -1 and 1
        Args:
            data (torch.tensor): data to rescale
        Output:
            torch.tensor: rescaled data
    """
    return {
        'data': (data['data'] + 1) / 2,
        'label': (data['label'] + 1) / 2,
        'name': data['name'],
    }

In [None]:
# GENERATE SAMPLES

from PIL import Image
import numpy as np

output_folder = os.path.join(model_folder, 'test_predictions_random')

print(output_folder)

for i in range(min(40, len(dataset))):
    
    j = np.random.randint(len(dataset))
    data = dataset[j]
    
    y = data[cond_key]

    predict = benchmark.sample(y[None,...],20)
    #predict = model.predict(y.unsqueeze(0).to(device))
    out = predict.detach().cpu()
    
    out = out[0].transpose(0,1).transpose(1,2)
    
    if os.path.exists(output_folder) == False:
        os.makedirs(output_folder)
        
    # rescale data
    data = rescale(data) # scale between 0 and 1
    
    plot_image(data, out)
    
    out_im = (out.numpy()* 255).astype(np.uint8) # rescale to 0-255
    
    im_pil = Image.fromarray(out_im)
    im_pil.save(os.path.join(output_folder, f'{data["name"]}.png'))

In [21]:
import torch.nn as nn
from benchmark import Benchmark
import os
import json

class Baseline(nn.Module):
    """ 
        Naive approach to solve the restoration problem
    """
    
    def __init__(self):
        super(Baseline, self).__init__()
    
    def forward(self, x):
        """ 
            Forward pass
            Args:
                x (torch.tensor): input image sequence. size: (batch_size, 10, 3, 256, 256)
            Returns:
                torch.tensor: restored image
        """
        out = torch.mean(x, dim=1)
        
        return out.permute(0, 3, 1, 2)
    
class BaselineBenchmark(Benchmark):
    
    def __init__(self, model, dataloader, mse=True, clip=False, lpips=False, cond_key='label'):
        super().__init__(model, dataloader, mse, clip, lpips, cond_key)
        
    def sample(self, data):
        """ Overwrite this method to sample from the model with data as conditioning """
        
        data = data.to(device)
        
        return self.model(data)

    def sample_repeat(self, data, n_rep=4):
        
        samples = []

        for i in range(n_rep):
            samples.append(self.sample(data))
        
        return rearrange(torch.stack(samples), 'n b c h w -> b n c h w')

import torch
from torchvision.transforms import transforms
import os
import numpy as np
from PIL import Image
import random

class SIAR(torch.utils.data.Dataset):
    
    """
    Dataset class for SIAR dataset
    
    The dataset is constitued of N folders, N is the number of images of the dataset
    Each folder contains 11 images, gt.png is the ground truth image, the others are distorted images (from 1 to 10)
    """
    SPLIT_FILE = "split.csv"
    WRONG_FILE = "wrongs.txt"
    
    def __init__(self, root, set_type='train', transform=None, generate_split=False, resolution=64, downscale_f=None, max_sequence_size=10, random_label=False):
        """
        Args:   
            path (str): path to the dataset
            set (str): train, val or test
            transform (torchvision.transforms): transform to apply to the images
            generate_split (bool): if True, generate a new split file
            resolution (int): resolution to resize the images to
            downscale_f (int): if not None, downscale the ground truth to create a Lower Resolution image (LR_image)
            max_sequence_size (int): maximum number of label images in a sequence
            random_label (bool): if False, returns the first max_sequence_size images as label, 
                                 if True, returns a random sequence of [1, max_sequence_size] images as label
        """
        
        assert set_type in ['train', 'val', 'test'], "set_type must be train, val or test"
        
        self.root = root
        
        if not self._split_exists():
            if generate_split:
                self._generate_split()
            else:
                raise RuntimeError("Split file does not exist, please generate it with generate_split=True")

        self.wrong_images = self._get_wrong_images_list()

        self.set_type = set_type
        self.split = self._load_split()
        self.images = self._load_images()
        
        self.len = len(self.images)

        self.transform = transform
        
        self.resolution = resolution
        self.downscale_f = downscale_f
        
        self.max_sequence_size = max_sequence_size
        self.random_label = random_label
        
    def __getitem__(self, index):
        """
        Args:
            index (int or slice): index of item
        returns:
            dict: {'gt': gt, 'input': inputs} if single index
            list: [{'gt': gt, 'input': inputs}, ...] if slice
        """
        if isinstance(index, slice):
            return [self._getitem(i) for i in range(*index.indices(self.len))]
        else:
            return self._getitem(index)
        
    def _getitem(self, index):
        """ Get an item
        Args:
            index (int): index of item
        Returns:
            dict: {'gt': gt, 'input': inputs}
                gt: np.array of shape (resolution, resolution, 3), ground truth image
                label: np.array of shape (max_sequence_size, resolution, resolution, 3), input images
        """

        if index >= self.len:
            raise IndexError("Index out of range")
        if index < 0:
            index += self.__len__()
        
        # read grund truth image
        gt = Image.open(os.path.join(self.root, self.images[index], "gt.png"))
        gt = gt.resize((self.resolution, self.resolution))
        
        # read input images
        input = self.__get_label_images(index)
        
        # apply transform if any
        if self.transform:
            # CHANGES MADE THIS PART NOT
            gt = self.transform(gt)
            input = [self.transform(im) for im in input]
            
            input = torch.stack(input)
            
            if self.downscale_f is not None:
                raise NotImplementedError("Downscale not implemented yet with transforms")
        else:
            """ to_tensor = transforms.ToTensor()
            gt = to_tensor(gt)
            input = [to_tensor(im) for im in input]
            
            input = torch.stack(input) """
            
            if self.downscale_f is not None:
                lr_image = gt.resize((self.resolution // self.downscale_f, self.resolution // self.downscale_f))
            else:
                lr_image = gt
            lr_image = np.array(lr_image).astype(np.uint8)
            
            gt = np.array(gt).astype(np.uint8)
                        
            if len(input) == 0:
                input = gt
            else:
                input = np.stack(input)
                
                input = np.pad(input, ((0, self.max_sequence_size - input.shape[0]), (0, 0), (0, 0), (0, 0)), 'constant', constant_values=0)

        return {
            'data': (gt/127.5 - 1).astype(np.float32), 
            'label': (input/127.5 - 1).astype(np.float32),
            'LR_image': (lr_image/127.5 - 1).astype(np.float32),
            'name': self.images[index]            
        }
        
    def __get_label_images(self, index):
        """ Read the label images from the disk 
            if self.random_label is True, returns a random sequence of [1, max_sequence_size] images as label
            else, returns the first max_sequence_size images as label   
        """
        label = []
        
        if self.random_label:
            label_size = random.randint(1, self.max_sequence_size)
            indexes = random.sample(range(1, self.max_sequence_size + 1), label_size)
        else:
            indexes = range(1, self.max_sequence_size + 1)
        
        for i in indexes:
            im = Image.open(os.path.join(self.root, self.images[index], str(i) + ".png"))
            im = im.resize((self.resolution, self.resolution))
            label.append(np.array(im).astype(np.uint8))
        
        return label
        
    def __len__(self):
        """ Get the length of the dataset """
        return self.len
    
    def _split_exists(self):
        return os.path.exists(os.path.join(self.root, self.SPLIT_FILE))
    
    def _wrong_exists(self):
        return os.path.exists(os.path.join(self.root, self.WRONG_FILE))
    
    def _generate_split(self):
        """ Generate a split for the dataset 
            If the split already exists, please delete the file before generating a new one
        """
        print("Generating split...")
        images = os.listdir(self.root)
        
        split = []
        count = [0, 0, 0]
    
        for im in images:
            r = random.random()
            if r <= 0.8:
                split.append(im + ',train')
                count[0] += 1
            elif r <= 0.9:
                split.append(im + ',val')
                count[1] += 1
            else:
                split.append(im + ',test')
                count[2] += 1
                
        with open(os.path.join(self.root, self.SPLIT_FILE), 'w') as f:
            f.write('\n'.join(split))
            
        print("Split generated")
        print("Train: {}, Val: {}, Test: {}".format(count[0], count[1], count[2]))

    def _load_split(self):
        """ Load the split file """
        
        with open(os.path.join(self.root, self.SPLIT_FILE), 'r') as f:
            data = f.read().split("\n")
        
        split = {}
        for line in data:
            if line == '':
                continue
            name, set_type = line.split(',')
            split[name] = set_type
            
        return split
    
    def _load_images(self):
        """ Load the images according to the split """
        im = []
        
        for name, set_type in self.split.items():
            if set_type == self.set_type and name not in self.wrong_images:
                im.append(name)
                
        return im
    
    def _get_wrong_images_list(self):
        """ Some images are misslabelled in the dataset, this function returns the list of those images """
        
        if self._wrong_exists():
            with open(os.path.join(self.root, self.WRONG_FILE), 'r') as f:
                data = f.read().split("\n")
                print("Wrong images excluded")
                return data

        print("No wrong images file found")
        return None

In [None]:
model_folder = "d:\\Cours\\Master_Arbeit\\code\\"
ROOT_PATH = "d:\\Cours\\Master_Arbeit\\code\\"
model = Baseline()
cond_key = 'label'

In [22]:
# generate samples with different sequence length
max_images = 10

#model = get_model(model_folder, config_file, checkpoint)

output_folder = os.path.join(model_folder, 'test_predictions_sequence')
os.makedirs(output_folder, exist_ok=True)

for i in range(1, 11):
    
    dataset = SIAR(os.path.join(ROOT_PATH, "data/SIAR"), set_type='val', resolution=256, max_sequence_size=i, random_label=True)[:max_images] #, downscale_f=4)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)
    
    print(f"Evaluating the model with sequence length {i}")
    
    benchmark = BaselineBenchmark(model, dataloader, mse=True, clip=False, lpips=False, cond_key=cond_key)
    
    results = benchmark.evaluate()
    
    # save results as json
    with open(os.path.join(output_folder, f"results_{i}.json"), 'w') as f:
        json.dump(results, f)
    

Wrong images excluded
Evaluating the model with sequence length 1
Evaluating model on metrics mse


  im = torch.tensor(data['data']).permute(0, 3, 1, 2).to(device)
100%|██████████| 3/3 [00:00<00:00, 152.02it/s]

Metric mse score: 0.7434548139572144, lowest score: 0.19783706963062286 for image 10, highest score: 1.2263375520706177 for image 10
Find results in 'results' attribute
Wrong images excluded





Evaluating the model with sequence length 2
Evaluating model on metrics mse


100%|██████████| 3/3 [00:00<00:00, 78.02it/s]

Metric mse score: 0.873989462852478, lowest score: 0.2751019299030304 for image 10, highest score: 1.4087573289871216 for image 10
Find results in 'results' attribute
Wrong images excluded





Evaluating the model with sequence length 3
Evaluating model on metrics mse


100%|██████████| 3/3 [00:00<00:00, 98.70it/s]

Metric mse score: 0.9964872598648071, lowest score: 0.4447501003742218 for image 10, highest score: 2.1954457759857178 for image 10034
Find results in 'results' attribute
Wrong images excluded





Evaluating the model with sequence length 4
Evaluating model on metrics mse


100%|██████████| 3/3 [00:00<00:00, 114.65it/s]

Metric mse score: 0.9247682690620422, lowest score: 0.2682325839996338 for image 10, highest score: 1.810653805732727 for image 10034
Find results in 'results' attribute
Wrong images excluded





Evaluating the model with sequence length 5
Evaluating model on metrics mse


100%|██████████| 3/3 [00:00<00:00, 90.17it/s]

Metric mse score: 1.0989768505096436, lowest score: 0.13379880785942078 for image 10, highest score: 2.4827871322631836 for image 10
Find results in 'results' attribute
Wrong images excluded





Evaluating the model with sequence length 6
Evaluating model on metrics mse


100%|██████████| 3/3 [00:00<00:00, 59.73it/s]

Metric mse score: 1.2641005516052246, lowest score: 0.4447110891342163 for image 10034, highest score: 2.2697014808654785 for image 10
Find results in 'results' attribute
Wrong images excluded





Evaluating the model with sequence length 7
Evaluating model on metrics mse


100%|██████████| 3/3 [00:00<00:00, 85.56it/s]

Metric mse score: 1.1160790920257568, lowest score: 0.29216137528419495 for image 10, highest score: 2.3366546630859375 for image 10
Find results in 'results' attribute
Wrong images excluded





Evaluating the model with sequence length 8
Evaluating model on metrics mse


100%|██████████| 3/3 [00:00<00:00, 67.93it/s]

Metric mse score: 0.8373936414718628, lowest score: 0.20230530202388763 for image 10, highest score: 1.6802958250045776 for image 10034
Find results in 'results' attribute
Wrong images excluded





Evaluating the model with sequence length 9
Evaluating model on metrics mse


100%|██████████| 3/3 [00:00<00:00, 59.58it/s]

Metric mse score: 1.3749839067459106, lowest score: 0.4753604829311371 for image 10034, highest score: 2.5789358615875244 for image 10
Find results in 'results' attribute
Wrong images excluded





Evaluating the model with sequence length 10
Evaluating model on metrics mse


100%|██████████| 3/3 [00:00<00:00, 61.19it/s]

Metric mse score: 0.7398250699043274, lowest score: 0.13284368813037872 for image 10, highest score: 1.498531699180603 for image 10078
Find results in 'results' attribute



