# Rendering Uncertainty Test

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from dataset import get_rays
from rendering import render_uncert
from ml_helpers import test_uncert

In [None]:
device = 'cuda'
tn = 1
tf = 10
test_o, test_d, test_target_px_values = get_rays('datasets/monkey_3_15aug', mode='test')
model_full = torch.load('nerf_models/monkey_3_big.pth').to(device)
model_aug = torch.load('nerf_models/monkey_3_big_aug.pth').to(device)

### Find the entropy of all pixels in the image

In [None]:

test_img_idx = 2

img, mse, psnr = test_uncert(model_full, torch.from_numpy(test_o[test_img_idx]).to(device).float(), torch.from_numpy(test_d[test_img_idx]).to(device).float(),
                tn, tf, nb_bins=100, chunk_size=20, target=test_target_px_values[test_img_idx].reshape(400, 400, 3))

plt.imshow(img,cmap='inferno')
plt.colorbar();

In [None]:
def compute_view_entropy(uncert_img):
    
    '''Compute total entropy of a view/rendered image'''

    uncert = uncert_img.reshape(-1,1)
    view_entropy = 0.0
    for i in range(len(uncert)):
        view_entropy += uncert[i]
    return view_entropy
    
compute_view_entropy(img)

### Compare entropy of an object between two trained nerfs at specific views
Note: Mask is currently made based upon color, specifically dark pink (like in the monkey datasets). Please change to suit your needs.

In [None]:
test_img_idx = 3
import cv2

img_aug, mse, psnr = test_uncert(model_aug, torch.from_numpy(test_o[test_img_idx]).to(device).float(), torch.from_numpy(test_d[test_img_idx]).to(device).float(),
                tn, tf, nb_bins=150, chunk_size=20, target=test_target_px_values[test_img_idx].reshape(400, 400, 3))

img_full, mse, psnr = test_uncert(model_full, torch.from_numpy(test_o[test_img_idx]).to(device).float(), torch.from_numpy(test_d[test_img_idx]).to(device).float(),
                tn, tf, nb_bins=150, chunk_size=20, target=test_target_px_values[test_img_idx].reshape(400, 400, 3))

image = test_target_px_values[test_img_idx].reshape(400, 400, 3)
 # Check the image depth
if image.dtype == np.float64:
    print("Converting from float64 to uint8.")
    # Normalize the image to 0-255 and convert to uint8
    image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
elif image.dtype == np.float32:
    print("Converting from float32 to uint8.")
    # Normalize the image to 0-255 and convert to uint8
    image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
elif image.dtype == np.uint16:
    print("Converting from uint16 to uint8.")
    # Normalize the image to 0-255 and convert to uint8
    image = (image / 256).astype(np.uint8)

img_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
lower_dark_pink = np.array([130, 50, 50])
upper_dark_pink = np.array([170, 255, 255])

# Create a mask for the dark pink color
mask = cv2.inRange(img_hsv, lower_dark_pink, upper_dark_pink)
img_aug = cv2.bitwise_and(img_aug, img_aug, mask=mask)
img_full = cv2.bitwise_and(img_full, img_full, mask=mask)

plt.imshow(mask, cmap='gray')
diff = img_aug - img_full

# Create a figure object
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

im1 = ax1.imshow(test_target_px_values[test_img_idx].reshape(400, 400, 3))
ax1.set_title('Ground truth image')
im2 = ax2.imshow(diff, cmap='inferno',vmin=0, vmax=3) # was .5 to 1
ax2.set_title('Entropy difference')
fig.colorbar(im2, ax=ax2);

# Compare entropy for N images
Generates figures for comparison of entropy between trained nerf models.

In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.image import imread
import cv2

# Create the output directory if it doesn't exist
output_dir = 'output_change_figures/monkey_big_4'
os.makedirs(output_dir, exist_ok=True)

# Assuming test_img_idx's range from 0 to 9
for test_img_idx in range(10):
    img_aug, mse, psnr = test_uncert(model_aug, torch.from_numpy(test_o[test_img_idx]).to(device).float(), torch.from_numpy(test_d[test_img_idx]).to(device).float(),
                    tn, tf, nb_bins=150, chunk_size=20, target=test_target_px_values[test_img_idx].reshape(400, 400, 3))

    img_full, mse, psnr = test_uncert(model_full, torch.from_numpy(test_o[test_img_idx]).to(device).float(), torch.from_numpy(test_d[test_img_idx]).to(device).float(),
                    tn, tf, nb_bins=150, chunk_size=20, target=test_target_px_values[test_img_idx].reshape(400, 400, 3))
    

    image = test_target_px_values[test_img_idx].reshape(400, 400, 3)
    # Check the image depth
    if image.dtype == np.float64:
        print("Converting from float64 to uint8.")
        # Normalize the image to 0-255 and convert to uint8
        image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    elif image.dtype == np.float32:
        print("Converting from float32 to uint8.")
        # Normalize the image to 0-255 and convert to uint8
        image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    elif image.dtype == np.uint16:
        print("Converting from uint16 to uint8.")
        # Normalize the image to 0-255 and convert to uint8
        image = (image / 256).astype(np.uint8)

    img_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    lower_dark_pink = np.array([130, 50, 50])
    upper_dark_pink = np.array([170, 255, 255])

    # Create a mask for the dark pink color
    mask = cv2.inRange(img_hsv, lower_dark_pink, upper_dark_pink)
    img_aug = cv2.bitwise_and(img_aug, img_aug, mask=mask)
    img_full = cv2.bitwise_and(img_full, img_full, mask=mask)


    # Subtract images
    diff = cv2.subtract(img_aug, img_full)

    # Create a figure object
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

    im1 = ax1.imshow(test_target_px_values[test_img_idx].reshape(400, 400, 3))
    ax1.set_title('Ground truth image')

    im2 = ax2.imshow(diff, cmap='inferno', vmin=0, vmax=2.5) # was .5 to 1
    ax2.set_title('Entropy difference')
    fig.colorbar(im2, ax=ax2)
    
    third_plot_img = imread('figures/anomaly.png')
    im3 = ax3.imshow(third_plot_img)
    ax3.set_title('Anomaly view for comparison')
    ax3.axis('off')

    # Save the figure
    plt.savefig(os.path.join(output_dir, f'image_{test_img_idx}.png'))
    plt.close(fig)  # Close the figure to free up memory

print("Images saved successfully.")
