<h1> CV 2023 </h1>

In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
import os
import shutil

cv_directory = '/content/drive/MyDrive/Colab_Notebooks/CVB/y.tar.gz'
cv_target_directory = '/content/drive/MyDrive/Colab_Notebooks/CVB'
shutil.unpack_archive(cv_directory, cv_target_directory)

cv_directory = '/content/drive/MyDrive/Colab_Notebooks/CVB/params.tar.gz'
cv_target_directory = '/content/drive/MyDrive/Colab_Notebooks/CVB'
shutil.unpack_archive(cv_directory, cv_target_directory)

cv_directory = '/content/drive/MyDrive/Colab_Notebooks/CVB/x.tar.gz'
cv_target_directory = '/content/drive/MyDrive/Colab_Notebooks/CVB'
shutil.unpack_archive(cv_directory, cv_target_directory)

In [None]:
print("CV2023")

CV2023




*   integrator erkl√§ren +Bilder + focal planes
*   model impainting
* Diagramm vom Model
* pre processing
*





<h1> Encoder - Decoder Implementation </h1>

To run the code:


*   Download the zip file from WeTransfer
*   Either mount drive to Google Colab and upload images there
*   or upload to content folder here (will only be for this instance)
*   change strings to files accordingly


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
%matplotlib inline
from mpl_toolkits.axes_grid1 import ImageGrid
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import shutil
import torch.nn as nn
from tqdm.notebook import tqdm

In [None]:
class InpaintingDataset(Dataset):
    def __init__(self, input_root, target_root, params_root):
        self.input_root = input_root
        self.target_root = target_root
        self.params_root = params_root
        self.input_list = os.listdir(input_root)
        self.target_list = os.listdir(target_root)

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):

        input_base_name = os.path.join(self.input_root, self.input_list[idx*6])
        target_name = os.path.join(self.target_root, self.target_list[idx])
        parts = os.path.basename(target_name).split('_')
        params_name = os.path.join(self.params_root, parts[0] + "_" + parts[1] + "_Parameters.txt")  # Assuming filenames match
        # Stack multiple input images with similar names
        input_images = []
        for suffix in ['0', '0.5', '1', '1.5', '2', '2.5']:
            input_name = os.path.join(self.input_root, f"{parts[0]}_{parts[1]}_{suffix}_integral.png")
            input_image = Image.open(input_name).convert("L")
            input_image = transforms.ToTensor()(input_image)
            input_images.append(input_image)

        # Stack images along a new dimension
        stacked_images = torch.cat(input_images, dim=0)
        target_image = Image.open(target_name).convert("L")

        target_image = transforms.ToTensor()(target_image)

        # Load and process additional parameters
        with open(params_name, 'r') as file:
            params_content = file.read()

        # Add additional parameters to the tuple
        return stacked_images, target_image, params_content

In [None]:
# This splitting assumes that you are working offline

def reset_folder(p):
    if os.path.exists(p):
        shutil.rmtree(p)
    os.makedirs(p)

def split_dataset(path, train_size, seed):
    """Takes the current folder setup:
    -notebook.ipynb
    -data
        -params
        -x
        -y
        -models
    and splits the data into a training set and a validation set with the following directory format
    -notebook.ipynb
    -data
        -params
        -x
        -y
        -models
        -train
            -x
            -y
            -params
        -val
            -x
            -y
            -params

    the function is random with a seed as specified in seed with the split being a value from 0 - 1 for how big the training dataset should be
    """
    np.random.seed(seed)

    x_images_list = os.listdir(os.path.join(path, "x"))
    y_images_list = os.listdir(os.path.join(path, "y"))
    params_list = os.listdir(os.path.join(path, "params"))

    im_numbers = [s.split('_')[1] for s in y_images_list]
    split_index = int(len(im_numbers) * train_size)

    np.random.shuffle(im_numbers)
    train_image_num = im_numbers[:split_index]
    test_image_num = im_numbers[split_index:]

    train_path = os.path.join(path, "train")
    train_x_path = os.path.join(train_path, 'x')
    train_y_path = os.path.join(train_path, 'y')
    train_p_path = os.path.join(train_path, 'params')

    val_path = os.path.join(path, "val")
    val_x_path = os.path.join(val_path, 'x')
    val_y_path = os.path.join(val_path, 'y')
    val_p_path = os.path.join(val_path, 'params')

    # checks if the folders exist, and if not creates them
    reset_folder(train_x_path)
    reset_folder(train_y_path)
    reset_folder(train_p_path)
    reset_folder(val_x_path)
    reset_folder(val_y_path)
    reset_folder(val_p_path)

    x_image_errors = []
    for image in tqdm(x_images_list, desc="x images"):
        num = image.split('_')[1]
        image = os.path.join(os.path.join(path, "x"), image)
        if num in train_image_num:
            shutil.copy(image, train_x_path)
        elif num in test_image_num:
            shutil.copy(image, val_x_path)
        else:
            x_image_errors.append(image)

    y_image_errors = []
    for image in tqdm(y_images_list, desc='y images'):
        num = image.split('_')[1]
        image = os.path.join(os.path.join(path, "y"), image)
        if num in train_image_num:
            shutil.copy(image, train_y_path)
        elif num in test_image_num:
            shutil.copy(image, val_y_path)
        else:
            y_image_errors.append(image)

    param_errors = []
    for par in tqdm(params_list, desc='Parameters'):
        num = par.split('_')[1]
        par = os.path.join(os.path.join(path, "params"), par)
        if num in train_image_num:
            shutil.copy(par, train_p_path)
        elif num in test_image_num:
            shutil.copy(par, val_p_path)
        else:
            param_errors.append(par)
    # this return is only for problem finding purposes
    #return x_image_errors, y_image_errors, param_errors, train_image_num, test_image_num

split_dataset("cnn_test_input", 0.8, 123)

In [None]:
# Specify your dataset paths for input and target images for both training and validation
data_path = 'cnn_test_input'
train_path = os.path.join(data_path, "train")
input_images_path_train = os.path.join(train_path, 'x')
target_images_path_train = os.path.join(train_path, 'y')
params_path_train = os.path.join(train_path, 'params')

val_path = os.path.join(data_path, "val")
input_images_path_val = os.path.join(val_path, 'x')
target_images_path_val = os.path.join(val_path, 'y')
params_path_val = os.path.join(val_path, 'params')

# Create datasets
train_dataset = InpaintingDataset(input_root=input_images_path_train, target_root=target_images_path_train,params_root=params_path_train)
val_dataset = InpaintingDataset(input_root=input_images_path_val, target_root=target_images_path_val,params_root=params_path_val)

In [None]:
# Create dataloaders
batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF

# Assuming you have a dataloader named 'dataloader'
for batch in dataloader:
    input_images, target_images, params = batch
    break  # Take the first batch for simplicity

In [None]:
# Convert tensors to PIL images for visualization
print(input_images[0].shape)
print(target_images[0].shape)
target_image_pil = TF.to_pil_image(target_images[0])

# Plotting the images
plt.figure(figsize=(15, 10))
plt.subplot(3, 3, 1)
plt.title("Input Image: 0")
plt.imshow(TF.to_pil_image(input_images[0][0]), cmap="gray")
plt.axis('off')

plt.subplot(3, 3, 2)
plt.title("Input Image: Focal Length 0.5")
plt.imshow(TF.to_pil_image(input_images[0][1]), cmap="gray")
plt.axis('off')

plt.subplot(3, 3, 3)
plt.title("Input Image: Focal Length 1")
plt.imshow(TF.to_pil_image(input_images[0][2]), cmap="gray")
plt.axis('off')

plt.subplot(3, 3, 4)
plt.title("Input Image: Focal Length 1.5")
plt.imshow(TF.to_pil_image(input_images[0][3]), cmap="gray")
plt.axis('off')

plt.subplot(3, 3, 5)
plt.title("Input Image: Focal Length 2")
plt.imshow(TF.to_pil_image(input_images[0][4]), cmap="gray")
plt.axis('off')

plt.subplot(3, 3, 6)
plt.title("Input Image: Focal Length 2.5")
plt.imshow(TF.to_pil_image(input_images[0][5]), cmap="gray")
plt.axis('off')

plt.subplot(3, 3, 7)
plt.title("Target Image")
plt.imshow(TF.to_pil_image(target_images[0]), cmap="gray")
plt.axis('off')

plt.show()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class UNetLike(nn.Module):
    def __init__(self):
        super(UNetLike, self).__init__()

        # Encoder
        self.down_conv1 = self.conv_block(6, 16)
        self.down_conv2 = self.conv_block(16, 32)
        self.down_conv3 = self.conv_block(32, 64)
        self.down_conv4 = self.conv_block(64, 128)
        self.down_conv5 = self.conv_block(128, 256)
        #self.down_conv6 = self.conv_block(256, 512)

        # Decoder
        #self.up_trans_1 = self.up_transpose(512, 256)
        #self.up_conv1 = self.conv_block(512, 256)
        self.up_trans_2 = self.up_transpose(256, 128)
        self.up_conv2 = self.conv_block(256, 128)
        self.up_trans_3 = self.up_transpose(128, 64)
        self.up_conv3 = self.conv_block(128, 64)
        self.up_trans_4 = self.up_transpose(64, 32)
        self.up_conv4 = self.conv_block(64, 32)
        self.up_trans_5 = self.up_transpose(32, 16)
        self.up_conv5 = self.conv_block(32, 16)

        # Final output layer
        self.out = nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1)
        self.sigmoid = nn.Sigmoid()

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1).to(device),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1).to(device),
            nn.ReLU(inplace=True),
        )

    def up_transpose(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        # Encoder path
        x1 = self.down_conv1(x)
        x2 = self.down_conv2(nn.MaxPool2d(kernel_size=2, stride=2)(x1))
        x3 = self.down_conv3(nn.MaxPool2d(kernel_size=2, stride=2)(x2))
        x4 = self.down_conv4(nn.MaxPool2d(kernel_size=2, stride=2)(x3))
        x5 = self.down_conv5(nn.MaxPool2d(kernel_size=2, stride=2)(x4))
        #x6 = self.down_conv6(nn.MaxPool2d(kernel_size=2, stride=2)(x5))

        # Decoder path
        #x = self.up_trans_1(x6)  # New layer
        #x = self.up_conv1(torch.cat([x, x5], 1))

        x = self.up_trans_2(x5)
        x = self.up_conv2(torch.cat([x, x4], 1))

        x = self.up_trans_3(x)
        x = self.up_conv3(torch.cat([x, x3], 1))

        x = self.up_trans_4(x)
        x = self.up_conv4(torch.cat([x, x2], 1))

        x = self.up_trans_5(x)
        x = self.up_conv5(torch.cat([x, x1], 1))

        x = self.out(x)
        x = self.sigmoid(x)
        return x

In [None]:
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchgeometry.image import get_gaussian_kernel2d

class SSIM(nn.Module):
    r"""Creates a criterion that measures the Structural Similarity (SSIM)
    index between each element in the input `x` and target `y`.

    The index can be described as:

    .. math::

      \text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)}
      {(\mu_x^2+\mu_y^2+c_1)(\sigma_x^2+\sigma_y^2+c_2)}

    where:
      - :math:`c_1=(k_1 L)^2` and :math:`c_2=(k_2 L)^2` are two variables to
        stabilize the division with weak denominator.
      - :math:`L` is the dynamic range of the pixel-values (typically this is
        :math:`2^{\#\text{bits per pixel}}-1`).

    the loss, or the Structural dissimilarity (DSSIM) can be finally described
    as:

    .. math::

      \text{loss}(x, y) = \frac{1 - \text{SSIM}(x, y)}{2}

    Arguments:
        window_size (int): the size of the kernel.
        max_val (float): the dynamic range of the images. Default: 1.
        reduction (str, optional): Specifies the reduction to apply to the
         output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
         'mean': the sum of the output will be divided by the number of elements
         in the output, 'sum': the output will be summed. Default: 'none'.

    Returns:
        Tensor: the ssim index.

    Shape:
        - Input: :math:`(B, C, H, W)`
        - Target :math:`(B, C, H, W)`
        - Output: scale, if reduction is 'none', then :math:`(B, C, H, W)`

    Examples::

        >>> input1 = torch.rand(1, 4, 5, 5)
        >>> input2 = torch.rand(1, 4, 5, 5)
        >>> ssim = tgm.losses.SSIM(5, reduction='none')
        >>> loss = ssim(input1, input2)  # 1x4x5x5
    """

    def __init__(
            self,
            window_size: int,
            reduction: str = 'none',
            max_val: float = 1.0) -> None:
        super(SSIM, self).__init__()
        self.window_size: int = window_size
        self.max_val: float = max_val
        self.reduction: str = reduction

        self.window: torch.Tensor = get_gaussian_kernel2d(
            (window_size, window_size), (1.5, 1.5))
        self.padding: int = self.compute_zero_padding(window_size)

        self.C1: float = (0.01 * self.max_val) ** 2
        self.C2: float = (0.03 * self.max_val) ** 2

    @staticmethod
    def compute_zero_padding(kernel_size: int) -> int:
        """Computes zero padding."""
        return (kernel_size - 1) // 2

    def filter2D(
            self,
            input: torch.Tensor,
            kernel: torch.Tensor,
            channel: int) -> torch.Tensor:
        return F.conv2d(input, kernel, padding=self.padding, groups=channel)

    def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
        if not torch.is_tensor(img1):
            raise TypeError("Input img1 type is not a torch.Tensor. Got {}"
                            .format(type(img1)))
        if not torch.is_tensor(img2):
            raise TypeError("Input img2 type is not a torch.Tensor. Got {}"
                            .format(type(img2)))
        if not len(img1.shape) == 4:
            raise ValueError("Invalid img1 shape, we expect BxCxHxW. Got: {}"
                             .format(img1.shape))
        if not len(img2.shape) == 4:
            raise ValueError("Invalid img2 shape, we expect BxCxHxW. Got: {}"
                             .format(img2.shape))
        if not img1.shape == img2.shape:
            raise ValueError("img1 and img2 shapes must be the same. Got: {}"
                             .format(img1.shape, img2.shape))
        if not img1.device == img2.device:
            raise ValueError("img1 and img2 must be in the same device. Got: {}"
                             .format(img1.device, img2.device))
        if not img1.dtype == img2.dtype:
            raise ValueError("img1 and img2 must be in the same dtype. Got: {}"
                             .format(img1.dtype, img2.dtype))
        # prepare kernel
        b, c, h, w = img1.shape
        tmp_kernel: torch.Tensor = self.window.to(img1.device).to(img1.dtype)
        kernel: torch.Tensor = tmp_kernel.repeat(c, 1, 1, 1)

        # compute local mean per channel
        mu1: torch.Tensor = self.filter2D(img1, kernel, c)
        mu2: torch.Tensor = self.filter2D(img2, kernel, c)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2

        # compute local sigma per channel
        sigma1_sq = self.filter2D(img1 * img1, kernel, c) - mu1_sq
        sigma2_sq = self.filter2D(img2 * img2, kernel, c) - mu2_sq
        sigma12 = self.filter2D(img1 * img2, kernel, c) - mu1_mu2

        ssim_map = ((2 * mu1_mu2 + self.C1) * (2 * sigma12 + self.C2)) / \
            ((mu1_sq + mu2_sq + self.C1) * (sigma1_sq + sigma2_sq + self.C2))

        loss = torch.clamp(1. - ssim_map, min=0, max=1) / 2.

        if self.reduction == 'mean':
            loss = torch.mean(loss)
        elif self.reduction == 'sum':
            loss = torch.sum(loss)
        elif self.reduction == 'none':
            pass
        return loss

def ssim(
        img1: torch.Tensor,
        img2: torch.Tensor,
        window_size: int,
        reduction: str = 'none',
        max_val: float = 1.0) -> torch.Tensor:
    r"""Function that measures the Structural Similarity (SSIM) index between
    each element in the input `x` and target `y`.

    See :class:`torchgeometry.losses.SSIM` for details.
    """
    return SSIM(window_size, reduction, max_val)(img1, img2)

In [None]:
class SSIMLoss(nn.Module):
    def __init__(self, window_size: int = 3, reduction: str = 'mean', max_val: float = 1.0):
        super(SSIMLoss, self).__init__()
        self.ssim = SSIM(window_size, reduction, max_val)

    def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
        return self.ssim(img1, img2)

In [None]:
def get_person_x_y(param):
    pose_start_idx = param.find("person pose (x,y,z,rot x, rot y, rot z) = ")
    if pose_start_idx != -1:
        # Check if the substring is found
        if pose_start_idx != -1:
            # Extract the relevant part of the string
            pose_str = param[pose_start_idx:]

            # Find the first "=" character
            equal_sign_idx = pose_str.find("=")

            # Extract the values after the "=" character
            values_str = pose_str[equal_sign_idx+1:].strip()

            # Split the values and convert them to integers
            values = [int(val) for val in values_str.split()[:2]]

            x_value = values[0]
            y_value = values[1]
    else:
        # No person in the picture = no additional loss
        x_value = 11
        y_value = 11

    return x_value, y_value

In [None]:
class SecondaryLoss(nn.Module):
    def __init__(self, window_size: int = 3, reduction: str = 'mean', max_val: float = 1.0, scale_factor: int = 10):
        super(SecondaryLoss, self).__init__()
        self.scale_factor = scale_factor
        self.ssim = SSIM(window_size, reduction, max_val)

    def forward(self, img1: torch.Tensor, img2: torch.Tensor, params) -> torch.Tensor:
        # Get the batch size
        batch_size = img1.size(0)
        x_values, y_values = zip(*[get_person_x_y(p) for p in params])

        # Identify indices of images with a person
        valid_indices = [i for i, (x_val, y_val) in enumerate(zip(x_values, y_values)) if x_val != 11 and y_val != 11]

        cropped_out = []
        cropped_gt = []
        for i in valid_indices:
            # Compute the positions for cropping for each element in the batch
            x_pos = 17 * (-x_values[i])
            y_pos = 17 * (-y_values[i])

            cropped_out.append(TF.crop(img1[i], 256 + x_pos - 25, 256 + y_pos - 25, 50, 50))
            cropped_gt.append(TF.crop(img2[i], 256 + x_pos - 25, 256 + y_pos - 25, 50, 50))

        # Stack the cropped results to form batches again
        cropped_out = torch.stack(cropped_out)
        cropped_gt = torch.stack(cropped_gt)

        # Calculate SSIM loss on the cropped batches and scale
        return self.ssim(cropped_out, cropped_gt) * self.scale_factor

In [None]:
from torch.optim.lr_scheduler import StepLR

# Create the model
model = UNetLike()

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# SSIM loss function and optimizer
ssim_loss_fn = SSIMLoss()
secondary_loss_fn = SecondaryLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00015)
scheduler = StepLR(optimizer, step_size=3, gamma=0.5)

# Training loop
num_epochs = 4
best_loss = float('inf')
loss_train = []
loss_val = []
for epoch in tqdm(range(num_epochs)):
    model.train()
    running_loss_train = 0.0
    running_loss_val = 0.0

    for inputs, targets, params in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs} Training"):
        inputs, targets = inputs.to(device), targets.to(device)
        # Forward pass
        outputs = model(inputs)

        # Calculate the SSIM loss
        loss = ssim_loss_fn(outputs, targets)
        secondary_loss = secondary_loss_fn(outputs, targets, params)
        full_loss = loss + secondary_loss #https://stackoverflow.com/questions/53994625/how-can-i-process-multi-loss-in-pytorch

        #print(f'Epoch {epoch + 1}  SSIM Loss: {loss}  Secondary Loss {secondary_loss}')
        # Backward and optimize
        optimizer.zero_grad()
        full_loss.backward()
        optimizer.step()

        running_loss_train += loss.item()

    scheduler.step()

    model.eval()
    for inputs, targets, params in tqdm(val_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs} Validation"):
        with torch.no_grad():
            inputs, targets = inputs.to(device), targets.to(device)
            # Forward pass
            outputs = model(inputs)
            # Calculate the SSIM loss
            loss = ssim_loss_fn(outputs, targets)
            #secondary_loss = secondary_loss_fn(outputs, targets, params)
            #full_loss = loss + secondary_loss #https://stackoverflow.com/questions/53994625/how-can-i-process-multi-loss-in-pytorch

            running_loss_val += loss.item()

    # Print the average loss for the epoch
    average_loss_train = running_loss_train / len(train_dataloader)
    average_loss_val = running_loss_val / len(val_dataloader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {average_loss_train}  Validation Loss {average_loss_val}")

    loss_train.append(average_loss_train)
    loss_val.append(average_loss_val)

    if average_loss_val < best_loss:
        best_loss = average_loss_val

        # Save the model with the best validation loss
        torch.save(model.state_dict(), f"cnn_test_input/models/model{epoch + 1}.pth")

epoch_x = np.arange(num_epochs) + 1
plt.figure(figsize=(10, 8))
plt.plot(epoch_x, loss_train, marker = 'o', label = "Training Loss")
plt.plot(epoch_x, loss_val, marker = 'o', label = "Validation Loss")
plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))
plt.legend()
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()

In [None]:
saved_model_path = f"cnn_test_input/models/model10.pth"
model = UNetLike()
state_dict = torch.load(saved_model_path)
model.load_state_dict(state_dict)
model.to(device)

<h3> Error in this cell, if you have time please look into printing a single image from a forward pass </h3>

In [None]:
batch_size = 64
dataloader_single = DataLoader(dataset, batch_size=batch_size, shuffle=True)

for inputs, targets, _ in dataloader_single:
    input_images = inputs
    target_images = targets
    inputs = input_images.to(device)
    outputs = model(inputs)
    break

print(input_images.shape)
print(target_images.shape)
print(outputs.shape)

In [None]:
for i in range(len(input_images)):
    input_image_pil = TF.to_pil_image(input_images[i])
    target_image_pil = TF.to_pil_image(target_images[i])
    output_iamge_pil = TF.to_pil_image(outputs[i])

    # Plotting the images
    plt.figure(figsize=(12, 4))  # Increase the width to accommodate three images

    plt.subplot(1, 3, 1)
    plt.title("Input Image")
    plt.imshow(input_image_pil)
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.title("Target Image")
    plt.imshow(target_image_pil)
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.title("Output")
    plt.imshow(output_iamge_pil)
    plt.axis('off')

    plt.show()

In [None]:
for i in range(len(input_images)):
    target_image_pil = TF.to_pil_image(target_images[i])
    output_iamge_pil = TF.to_pil_image(outputs[i])
    print(f'idx: {i}')
    # Plotting the images
    plt.figure(figsize=(6, 4))  # Increase the width to accommodate three images

    plt.subplot(1, 2, 1)
    plt.title("Target Image")
    plt.imshow(target_image_pil, cmap="gray")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.title("Output")
    plt.imshow(output_iamge_pil, cmap="gray")
    plt.axis('off')

    plt.show()

In [None]:
# Convert tensors to PIL images for visualization
def get_images(input_im, target_im, out_im, idx):

    # Plotting the images
    plt.figure(figsize=(15, 10))
    plt.subplot(3, 3, 1)
    plt.title("Input Image: 0")
    plt.imshow(TF.to_pil_image(input_images[idx][0]), cmap="gray")
    plt.axis('off')

    plt.subplot(3, 3, 2)
    plt.title("Input Image: Focal Length 0.5")
    plt.imshow(TF.to_pil_image(input_images[idx][1]), cmap="gray")
    plt.axis('off')

    plt.subplot(3, 3, 3)
    plt.title("Input Image: Focal Length 1")
    plt.imshow(TF.to_pil_image(input_images[idx][2]), cmap="gray")
    plt.axis('off')

    plt.subplot(3, 3, 4)
    plt.title("Input Image: Focal Length 1.5")
    plt.imshow(TF.to_pil_image(input_images[idx][3]), cmap="gray")
    plt.axis('off')

    plt.subplot(3, 3, 5)
    plt.title("Input Image: Focal Length 2")
    plt.imshow(TF.to_pil_image(input_images[idx][4]), cmap="gray")
    plt.axis('off')

    plt.subplot(3, 3, 6)
    plt.title("Input Image: Focal Length 2.5")
    plt.imshow(TF.to_pil_image(input_images[idx][5]), cmap="gray")
    plt.axis('off')

    plt.subplot(3, 3, 7)
    plt.title("Target Image")
    plt.imshow(TF.to_pil_image(target_images[idx]), cmap="gray")
    plt.axis('off')

    plt.subplot(3, 3, 8)
    plt.title("Output Image")
    plt.imshow(TF.to_pil_image(out_im[idx]), cmap="gray")
    plt.axis('off')

    plt.show()

get_images(input_images, target_images, outputs, 18)