## This notebook is for making the demo / video grid

In [None]:
from pathlib import Path
import torch
import json
from torch import Tensor
from jaxtyping import Float, UInt8
from io import BytesIO
import torchvision.transforms as tf
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import matplotlib.image as mpimg
import cv2

### Make an image grid for ground truth, pixelsplat, instructir and controlnet images

In [None]:
def make_img_grid(real_dir, gen_dir, output_dir):
    
    for key_dir in tqdm(gen_dir.iterdir(), desc='Processing directories'):

        if not key_dir.suffix.lower() == '.json':
            key = key_dir.stem
            output_path = Path(output_dir + f"/{key}")
            output_path.mkdir(exist_ok=True, parents=True)

            for i, _ in enumerate((key_dir / Path("color")).iterdir()):

                gen_img_path = key_dir / Path("color") / Path(f"{i:06d}.png")
                real_img_path = Path(real_dir) / Path(key) / Path("target") / Path(f"{i:06d}.png")
                inst_img_path = Path(real_dir) / Path(key) / Path("instructir") / Path(f"{i:06d}.png")                
                cont_img_path = Path(real_dir) / Path(key) / Path("controlnet") / Path(f"{i:06d}.png")

                fig, axes = plt.subplots(1, 4, figsize=(10, 5)) 

                # Display real image
                real_img = mpimg.imread(real_img_path)
                axes[0].imshow(real_img)
                axes[0].axis('off') 
                axes[0].set_title('Ground Truth')

                # Display generated image from pixelsplat
                gen_img = mpimg.imread(gen_img_path)
                axes[1].imshow(gen_img)
                axes[1].axis('off')  
                axes[1].set_title('pixelSplat') 

                # Display generated image from InstructIR
                inst_img = mpimg.imread(inst_img_path)
                axes[2].imshow(inst_img)
                axes[2].axis('off')  
                axes[2].set_title('InstructIR')

                # Display generated image from ControlNet
                cont_img = mpimg.imread(cont_img_path)
                axes[3].imshow(cont_img)
                axes[3].axis('off')  
                axes[3].set_title('ControlNet')

                plt.subplots_adjust(wspace=0.05)
                plt.savefig(output_path / Path(f"{str(i)}.png"), bbox_inches='tight')
                plt.close(fig)

            # break for testing for just one room
            break

In [None]:
real_dir = Path('../outputs/re10k_test_hard/re10k')
gen_root = Path('../outputs/re10k_test_hard/re10k')
output_dir = "../outputs/re10k_img_grid"

make_img_grid(gen_root, gen_root, output_dir)

### Create a video out of the image grids

In [None]:
def create_video(path_imgs, output_path, framerate=30):
    for room in Path(path_imgs).iterdir():  
        key = room.stem
        output_dir = Path(output_path)
        output_dir.mkdir(exist_ok=True, parents=True)

        img_array = []
        i = 0
        for _ in (Path(path_imgs)/Path(key)).iterdir():
            filename = path_imgs + "/" + key + f"/{i}.png"
            img = cv2.imread(filename)
            height, width, _ = img.shape
            size = (width,height)
            img_array.append(img)
            i += 1

        out = cv2.VideoWriter(f'{output_path}/{key}.avi',cv2.VideoWriter_fourcc(*'DIVX'), framerate, size)

        for i in range(len(img_array)):
            out.write(img_array[i])

        out.release()


In [None]:
path_imgs = "../outputs/re10k_img_grid"
output_path = "../outputs/re10k_video_grid"

create_video(path_imgs, output_path)

### Images from torch files

The next part is to generate grid videos for ground truth images from torch files

In [None]:
# Directories of the real data
video_index_path = "../assets/evaluation_index_re10k_video.json"
root = Path('../datasets/re10k')
data_stages = ["test"]

# Directories of the generated data such that the folders within are the room keys
gen_root = Path('../outputs/re10k_test_hard_more/re10k')

# number of videos you want to produce
num_vid = 5

In [None]:
merged_index = {}
path_to_torch = []

for data_stage in data_stages:
    with (root / data_stage / "index.json").open("r") as f:
        index = json.load(f)    
    for k, v in index.items():
        path = Path(root / data_stage / v)
        index[k] = path
        if path not in path_to_torch:
            path_to_torch.append(path)

    assert not (set(merged_index.keys()) & set(index.keys()))

    merged_index = {**merged_index, **index}

print(merged_index)
print(f"#rooms: {len(merged_index.keys())}")
print(path_to_torch)
print(f"#torch files: {len(path_to_torch)}")

In [None]:
tensor = tf.ToTensor()

def convert_images(images: list[UInt8[Tensor, "..."]],) -> Float[Tensor, "batch 3 height width"]:
    torch_images = []
    for image in images:
        image = Image.open(BytesIO(image.numpy().tobytes()))
        torch_images.append(tensor(image))
    return torch.stack(torch_images)

def center_crop(
    images: Float[Tensor, "*#batch c h w"],
    shape: tuple[int, int],
) -> tuple[
    Float[Tensor, "*#batch c h_out w_out"],  # updated images
]:
    *_, h_in, w_in = images.shape
    h_out, w_out = shape
    
    # Note that odd input dimensions induce half-pixel misalignments.
    row = (h_in - h_out) // 2
    col = (w_in - w_out) // 2

    # Center-crop the image.
    images = images[..., :, row : row + h_out, col : col + w_out]

    return images

def save_images(path_to_torch, out_dir):
    out_dir = Path(out_dir)
    out_dir.mkdir(exist_ok=True, parents=True)

    n = 0

    keys = []

    for path in path_to_torch:
        chunk = torch.load(path)
        for room in tqdm(chunk, desc="Processing rooms", unit="room"):

            print(f"n: {n}")

            if n >= num_vid:
                break

            key = room["key"]
            keys.append(key)
            print(key)

            images = room["images"]
            context_images = convert_images(images)

            room_path = out_dir / Path(key)
            room_path.mkdir(exist_ok=True, parents=True)

            for i, image in enumerate(context_images):
                # img = center_crop(image, (256, 256))
                img = center_crop(image, (360, 360))
                # break
                image_array = img.permute(1, 2, 0).numpy()
                image_array = (image_array * 255).astype('uint8')
                plt.imsave(f'{room_path}/{i}.png', image_array)
            
            n += 1

        if n >= num_vid:
            break
        # break
    # break

    return keys


In [None]:
keys = save_images(path_to_torch, "/home/rooshutter/Documents/Master/CV2/project/diffusion-augmented-pixelsplat/outputs/re10k_full")