In [None]:
# This code is used to compute the MSE
# between predicted shape and actual shape
# Steps involves:
# 1. Fold the original image horizontally at the center
# (Mirror the top part on to the bottom and add,
# then give pixel value 1 to all values other than 0)
# 2. Fold the predicted image similarly 
# (Mirror the top part on to the bottom and add,
# then normalize the image)

In [None]:
# Import necessary libraries

import matplotlib.pyplot as plt
import numpy as np

In [None]:
test_shape_dir  = '../data/data_npy/shape_npy/shape_filled8.npy'
test_shape = np.load(test_shape_dir)
# Normalize the image, convert to opacity map
## Test Set
test_shape = test_shape/np.amax(test_shape)
test_shape_where_0 = np.where(test_shape == 0)
test_shape_where_1 = np.where(test_shape == 1)
test_shape[test_shape_where_0] = 1  # 1 represent the shape (1 opacity)
test_shape[test_shape_where_1] = 0  # 0 represent background (0 opacity)

In [None]:
# Verification
# Plot - Test LCs
num = 3
fig,ax=plt.subplots(num,3, figsize=(4,3), gridspec_kw={ 'width_ratios': [1,1,1],
        'wspace': 0.2,'hspace': 0.4})
plt.rcParams['figure.dpi'] = 400

ax[0][1].set_title('Shape',size=10)

# advance = 60

i = 0
for i in np.arange(0,num):
    k = np.random.randint(0, len(test_shape)-1)
    ax[i][0].tick_params(left = False, right = False , labelleft = False ,labelbottom = False, bottom = False)
    img = ax[i][0].imshow(test_shape[k],cmap='inferno')
    plt.colorbar(img)
    
    k = np.random.randint(0, len(test_shape)-1)
    ax[i][1].tick_params(left = False, right = False , labelleft = False ,labelbottom = False, bottom = False)
    img = ax[i][1].imshow(test_shape[k],cmap='inferno')
    plt.colorbar(img)
    
    k = np.random.randint(0, len(test_shape)-1)
    ax[i][2].tick_params(left = False, right = False , labelleft = False ,labelbottom = False, bottom = False)
    img = ax[i][2].imshow(test_shape[k],cmap='inferno')
    plt.colorbar(img)

    i = i + 1

In [None]:
def fold_original_image(org_image):
    """
    Fold an input image in half along its vertical axis.

    This function takes an input image, 'org_image,' and folds it vertically along its
    central axis. The folding process involves splitting the input image into two halves,
    flipping the bottom half vertically, and then merging the flipped bottom half with
    the top half to create a folded image.

    Parameters:
    ----------
    org_image : numpy.ndarray
        The input image to be folded. It should be a 2D NumPy array representing
        grayscale pixel values, where higher values typically represent brighter areas.

    Returns:
    -------
    folded_image : numpy.ndarray
        The folded image resulting from the vertical folding of 'org_image.' It is a
        2D NumPy array of the same shape as the input image, with the bottom half flipped
        and merged with the top half. The pixel values may be thresholded to 1 (white) or 0 (black)
        based on whether they are greater than 0.

    Example:
    --------
    import numpy as np
    org_image = np.random.random((100, 200))  # Replace with your input image
    folded_result = fold_original_image(org_image)
    plt.imshow(folded_result, cmap='gray')
    plt.title('Folded Image')
    plt.show()
    """
    image = np.zeros((int(org_image.shape[0]/2),int(org_image.shape[1])))
    image_top = test_shape[0,0:int(org_image.shape[0]/2)]
    image_bottom = test_shape[0,int(org_image.shape[0]/2):]
    image_bottom_flip = np.flipud(image_bottom)
    folded_image = image_bottom_flip+image_top
    # folded_image_normalized = (folded_image - np.amin(folded_image))/(np.amax(folded_image) - np.amin(folded_image))
    folded_image[folded_image>0] = 1
    return(folded_image)

# Code to test the above function
fold_test_shape = fold_original_image(test_shape[0])
print("fold_test_shape.shape = ",fold_test_shape.shape)

fig, ax = plt.subplots(1, 2, figsize=(12, 4))

im = ax[0].imshow(test_shape[0], cmap='inferno')
ax[0].set_title('Original image')
plt.colorbar(im)

im = ax[1].imshow(fold_test_shape, cmap='inferno')
ax[1].set_title('Normalized output')
plt.colorbar(im)

plt.show()