In [None]:
from pair_dataset import PairDataset

from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np

DATASET_PATH = "dataloader/data/full/SAMPLE_001"


rtstruct_path1 = (
    DATASET_PATH + "/RS.1.2.246.352.221.53086809173815688567595866456863246500.dcm"
)
rtstruct_path2 = (
    DATASET_PATH + "/RS.1.2.246.352.221.46272062591570509005209218152822185346.dcm"
)

dataset = PairDataset(rtstruct_path1, rtstruct_path2)

size = 128
size = 1024
dataset = PairDataset(rtstruct_path2, rtstruct_path1, (size, size))

In [None]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)

# Example of iterating through the dataloader
for batch in dataloader:
    #  display(batch)
    #  display(batch  ["item1" ][   'ct' ])

    print(f"1 Batch shape: {batch['item1']['ct'].shape}")
    display(batch["item2"]["ct"])
    print(f"2 Batch shape: {batch['item1']['ct'].shape}")

    ct_batch1 = batch["item1"]["ct"]
    ct_batch2 = batch["item2"]["ct"]
    z_positions = batch["item1"]["z_position"]

    y_cutoff = 87 * int(size / 128)

    print(f"CT batch shape: {ct_batch1.shape}")
    for i in range(ct_batch1.shape[0]):
        ct_img1 = ct_batch1[i, 0].numpy()  # Remove batch and channel dimensions
        ct_img2 = ct_batch2[i, 0].numpy()
        z_pos = z_positions[i]

        ct_img1[y_cutoff:, :] = 0
        ct_img2[y_cutoff:, :] = 0

        ct_img1 = np.where(ct_img1 > 0.2, 1, 0)
        ct_img2 = np.where(ct_img2 > 0.2, 1, 0)

        diff_mask = (ct_img1 != ct_img2).astype(np.uint8)

        overlay = np.stack(
            [diff_mask * 255, np.zeros_like(diff_mask), np.zeros_like(diff_mask)],
            axis=-1,
        )  # red mask

        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

        axs[0].imshow(ct_img1, cmap="gray")
        axs[0].set_title(f"Item1 - Z: {z_pos} {y_cutoff}")
        axs[0].axis("off")

        axs[1].imshow(ct_img2, cmap="gray")
        axs[1].set_title(f"Item2 - Z: {z_pos} {y_cutoff}")
        axs[1].axis("off")

        num_diff_pixels = np.sum(diff_mask)
        num_white_pixels_img1 = np.sum(ct_img1 == 1)
        num_white_pixels_img2 = np.sum(ct_img2 == 1)

        if (num_white_pixels_img1 + num_white_pixels_img2) > 0:
            diff_ratio = num_diff_pixels / (
                (num_white_pixels_img1 + num_white_pixels_img2) / 2
            )
        else:
            diff_ratio = 0.0

        print(
            f"Item {i} - Z: {z_pos} | Diff Pixels: {num_diff_pixels}, White Pixels (img2): {num_white_pixels_img2}, Ratio: {diff_ratio:.4f}"
        )

        axs[2].imshow(ct_img1, cmap="gray")
        axs[2].imshow(overlay, alpha=0.6)  # semi-transparent red overlay
        axs[2].set_title(f"Differences (red) Ratio: {diff_ratio:.4f} ")
        axs[2].axis("off")

        plt.tight_layout()
        plt.show()

    """
    masks1 = batch["item1"]["masks"]
    masks2 = batch["item2"]["masks"]
    z_positions1 = batch["item1"]["z_position"]
    z_positions2 = batch["item2"]["z_position"]
    print(f"1 Batch shape: {masks1.shape}, Z positions: {z_positions1}")
    print(f"2 Batch shape: {masks2.shape}, Z positions: {z_positions2}")

"""