# IoU

In [1]:
import os
import numpy as np
import pickle
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.patches as mpatches
import torch

from torchmetrics import JaccardIndex
from matplotlib.colors import ListedColormap
from tqdm import tqdm

In [2]:
# Load Split Indexes
train_indexes = np.load("./train_indexes.npy")
test_indexes = np.load("./test_indexes.npy")

print(f"train_indexes ({len(train_indexes)}): {train_indexes}")
print(f"test_indexes ({len(test_indexes)}): {test_indexes}")

train_indexes (120): [123  89 103  10 157   2 158 121  72 109  19  65  91  66  33  25 102  64
  59 118 141  60   7  14  70  79 130  53 115 122  96  63  97 136  36 135
  42  56  15  27 129  93 107  24   4  31  40  95  18 144  67  92 119 154
  78  84  62  87 142   6  76 146  51 155 124 156 143  20  61 101 106 131
  43  90 153  44 127 151  45  21  46 140  16 149 108 152 120  11 125  13
  28  82  86  69  88  39 132 114 150  98  54  30 134 116   9   8  37  38
 111 139  81  71  57  50  26  49   1  75  41 105]
test_indexes (39): [148  58  55  22 110  68 147 100  34  32 133  48  80  52  94   5 126 117
  73  83  17 104 113  29  85 138  99  12 128  74   0 112  35  77  23 137
  47   3 145]


In [3]:
EPOCHS = [50]

In [4]:
jaccard = JaccardIndex(task="multiclass", num_classes=2)
jaccard_ignore = JaccardIndex(task="multiclass", num_classes=2, ignore_index=0)

if not os.path.isdir("iou"):
    os.mkdir("iou")

epochs_iou = []
epochs_iou_ignore = []
epochs_predictions = []

with open(f"./targets.p", "rb") as targets_file:
    targets = pickle.load(targets_file)

test_length = len(test_indexes)

for epoch in EPOCHS:

    if not os.path.isdir(f"iou/{epoch}"):
        os.mkdir(f"iou/{epoch}")

    build_layers_iou = []
    build_layers_iou_ignore = []
    for index in range(test_length):
        predictions = []

        with open(f"./predictions/{epoch}.p", "rb") as predictions_file:
            predictions = pickle.load(predictions_file)

        prediction = predictions[index].squeeze()
        prediction = np.transpose(prediction, (1, 0, 2))

        target = targets[index].squeeze()
        target = np.transpose(target, (1, 0, 2))

        build_layer_iou = jaccard(torch.tensor(prediction), torch.tensor(target))
        build_layers_iou.append(build_layer_iou)

        build_layer_iou_ignore = jaccard_ignore(torch.tensor(prediction), torch.tensor(target))
        build_layers_iou_ignore.append(build_layer_iou_ignore)

    epoch_iou = [value for _, value in sorted(zip(test_indexes, build_layers_iou))]
    epochs_iou.append(epoch_iou)

    epoch_iou_ignore = [value for _, value in sorted(zip(test_indexes, build_layers_iou_ignore))]
    epochs_iou_ignore.append(epoch_iou_ignore)

    epochs_predictions.append(predictions)

  build_layer_iou = jaccard(torch.tensor(prediction), torch.tensor(target))
  build_layer_iou_ignore = jaccard_ignore(torch.tensor(prediction), torch.tensor(target))


In [None]:
sorted_test_indexes = sorted(test_indexes)

for index, epoch_iou in enumerate(epochs_iou):
    min, mean, max = np.min(epoch_iou), np.mean(epoch_iou), np.max(epoch_iou)
    print(f"Epoch {EPOCHS[index]} -> min: {min}, mean: {mean}, max: {max}")


    plt.figure(figsize=(15, 5))
    plt.bar(sorted_test_indexes, epoch_iou, color='red', width=1.8)

    # Add labels and title
    plt.xlabel('Build Layer Index')
    plt.ylabel('IoU')
    plt.title(f"Intersection over Union (IoU) Epoch {EPOCHS[index]}")

    # Show the plot
    plt.savefig(f"iou/{EPOCHS[index]}/test.png")
    plt.show()

for index, epoch_iou_ignore in enumerate(epochs_iou_ignore):
    min, mean, max = np.min(epoch_iou_ignore), np.mean(epoch_iou_ignore), np.max(epoch_iou_ignore)
    print(f"Epoch {EPOCHS[index]} -> min: {min}, mean: {mean}, max: {max}")


    plt.figure(figsize=(15, 5))
    plt.bar(sorted_test_indexes, epoch_iou_ignore, color='red', width=1.8)

    # Add labels and title
    plt.xlabel('Build Layer Index')
    plt.ylabel('IoU')
    plt.title(f"Intersection over Union (IoU) Epoch {EPOCHS[index]} (Ignore Background)")

    # Show the plot
    plt.savefig(f"iou/{EPOCHS[index]}/test_ignore_background.png")
    plt.show()


In [None]:
for epoch_index, epoch in enumerate(EPOCHS):
    for index in tqdm(range(test_length)):
        predictions = epochs_predictions[epoch_index]
        prediction = predictions[index].squeeze()
        prediction = np.transpose(prediction, (2, 0, 1))

        target = targets[index].squeeze()
        target = np.transpose(target, (2, 0, 1))

        fig, axes = plt.subplots(2, 2, figsize=(8,7))

        cmap = plt.get_cmap('viridis')
        cmap_zero_transparent = ListedColormap(cmap(np.arange(cmap.N)))
        cmap_zero_transparent.set_bad(alpha=0)

        plt.subplots_adjust(top=0.85)

        def animate(frame_index):
            prediction_voxel = torch.tensor(prediction[frame_index])
            target_voxel = torch.tensor(target[frame_index])

            iou_build_layer = epochs_iou[epoch_index][index]
            iou_build_layer_ignore = epochs_iou_ignore[epoch_index][index]
            iou_voxel = jaccard(prediction_voxel, target_voxel)
            iou_voxel_ignore = jaccard_ignore(prediction_voxel, target_voxel)

            fig.suptitle(f"""Index: {index}, Build Layer: {test_indexes[index]}, IOU: {iou_build_layer:.3f} {iou_build_layer_ignore:.3f}, 
Frame: {frame_index}, Frame IOU: {iou_voxel:.3f} {iou_voxel_ignore:.3f}""", fontsize=18)

            axes[0][0].clear()
            axes[0][0].set_title("Prediction", fontsize=12)
            axes[0][0].imshow(prediction[frame_index], vmin=0, vmax=2)

            axes[0][1].clear()
            axes[0][1].set_title("Ground Truth", fontsize=12)
            axes[0][1].imshow(target[frame_index], vmin=0, vmax=3)

            axes[1][0].clear()
            axes[1][0].set_title("Intersection", fontsize=12)

            scaled_prediction = prediction[frame_index] * 2
            values = target[frame_index] + scaled_prediction
            values_nan = np.where(values == 0, np.nan, values)

            img_2 = axes[1][0].imshow(values_nan, vmin=0, vmax=3, cmap=cmap_zero_transparent)

            # Add legend for values 1 and 2
            legend_elements = [
                mpatches.Patch(color=img_2.cmap(0.33), label='Ground Truth'),
                mpatches.Patch(color=img_2.cmap(0.66), label='Prediction'),
                mpatches.Patch(color=img_2.cmap(1.0), label='Intersection'),
            ]
            axes[1][0].legend(handles=legend_elements, loc="upper right")

            error = target[frame_index] - prediction[frame_index] + 1
            error_nan = np.where(error == 1, np.nan, error)

            axes[1][1].clear()
            axes[1][1].set_title("Error", fontsize=12)

            img_3 = axes[1][1].imshow(error_nan, vmin=0, vmax=2, cmap=cmap_zero_transparent)

            # Add legend for values 1 and 2
            legend_elements = [
                mpatches.Patch(color=img_3.cmap(0.0), label='Ground Truth'),
                mpatches.Patch(color=img_3.cmap(1.0), label='Prediction'),
            ]

            axes[1][1].legend(handles=legend_elements, loc="upper right")

        anim = animation.FuncAnimation(fig, animate, frames=len(prediction), interval=1000)
        anim.save(f"iou/{epoch}/{index}_layer_{test_indexes[index]}_2x2.gif", writer="imagemagick")

        plt.close()
