In [None]:
# imports

import os
import cv2 as cv
import numpy as np

import matplotlib.pyplot as plt

from psfyp import PSSolver, PSIntegrator, scale_image, corrupt_images, angular_error, multiplot

In [None]:
np.random.seed(28)

In [None]:
# PARAMETERS

IMG_PATH = r'data\frog\Objects\\'
        
LDIRS_FILE = r'data\frog\light_directions.txt'
    
RESULTS_PATH = r'results\\'

if not os.path.exists(RESULTS_PATH):
    os.makedirs(RESULTS_PATH)

ALGORITHM = 'sbl'

NUM_PROCESSES = -1
REGION = 'whole'

SPECULAR_WEIGHT = 0.3
NOISE_WEIGHT = 0.1

SBL_LAMBDA = 1.0e-3
SBL_SIGMA = 1.0e6
SBL_MAX_ITERS = 100
SBL_USE_PAPER_ALGORITHM = False

SMOOTH = True

SIGMA_F = 2.0
SIGMA_G = 0.1
WINDOW_SIZE = 2 * int(SIGMA_F)


In [None]:
# function to write normals to an image
def write_normals(normals, path):
    normals_write = np.zeros_like(normals)

    # need to shuffle axes around
    normals_write[:,:,0] = normals[:,:,2].copy()
    normals_write[:,:,1] = normals[:,:,1].copy()
    normals_write[:,:,2] = normals[:,:,0].copy()

    # scale and write image
    cv.imwrite(path, 255 * normals_write)

In [None]:
# read images
images = []
i = 0
for img in os.listdir(IMG_PATH):

    new_img = cv.imread(f"{IMG_PATH}{img}")

    # convert images to greyscale
    new_img = cv.cvtColor(new_img, cv.COLOR_RGB2GRAY)

    # scale image to range [0, 1]
    new_img = scale_image(new_img)

    images.append(new_img)

# if the images are from the "frog" dataset, crop them to appropriate regions
if REGION == 'head':
    images = [image[10:310,170:470] for image in images] # frog head
elif REGION == 'tummy':
    images = [image[280:480, 250:450] for image in images] # frog tummy

IMAGE_SHAPE = images[0].shape[:2]

# flatten images into m x n observation matrix
O = np.matrix([image.flatten() for image in images]).T

NUM_IMAGES = len(images)

In [None]:
# read ldirs
with open(LDIRS_FILE) as f:
    txt = f.readlines()

    ldirs = [
        [np.float64(comp) for comp in line.split(' ') if len(comp) > 0]
        for line in txt
    ]

    ldirs = np.matrix([
        [x, y, z] for x, y, z in zip(ldirs[0], ldirs[1], ldirs[2])
    ])

    f.close()

L = ldirs.T

try:
    assert(len(images) == ldirs.shape[0])
except:
    raise Exception('Number of images and lighting directions does not match!')

In [None]:
# create solver
solver = PSSolver(O, L)

# get the mask
mask = solver.mask

print('Reconstructing normals...')

if ALGORITHM == 'basic':
    normals = solver.ps_basic(out_shape=IMAGE_SHAPE)
else: # sbl
    normals, errors = solver.ps_sbl(
        max_iters=SBL_MAX_ITERS,
        lambda_=SBL_LAMBDA,
        sigma=SBL_SIGMA,
        num_processes=NUM_PROCESSES,
        out_shape=IMAGE_SHAPE,
        use_paper_algorithm=SBL_USE_PAPER_ALGORITHM
    )

    # write errors

    if not os.path.exists(fr'{RESULTS_PATH}\sbl-error-variances\\'):
        os.makedirs(fr'{RESULTS_PATH}\sbl-error-variances\\')

    for idx in range(NUM_IMAGES):

        errors_show = errors[:,:,idx].copy()
        errors_show = np.absolute(errors_show)

        plt.imshow(errors_show)
        plt.axis('off')
        plt.set_cmap('plasma')
        plt.title('Abs. SBL Error Variances', size=20)
        cb = plt.colorbar(
            plt.pcolormesh(errors_show, vmax=errors_show.max(), vmin=errors_show.min())
            , fraction=0.046, pad=0.04)
        cb.ax.tick_params(labelsize=12)
        cb.set_label('Abs. Variance', rotation=270, labelpad=18, size=12)

        plt.savefig(fr'{RESULTS_PATH}\sbl-error-variances\image-{idx:2}.png', dpi=plt.gcf().dpi, bbox_inches='tight')
        plt.close()

# save normals
write_normals(normals, fr"{RESULTS_PATH}normals.png")

plt.imshow(normals)
plt.show()
plt.close()

In [None]:
# integrate normals
print('Integrating surface...')

integrator = PSIntegrator(normals, IMAGE_SHAPE, mask)

integrator.to_obj(f'{RESULTS_PATH}surface.obj')

integrator.display()

In [None]:
# corrupt images
print('Corrupting images...')

images_corrupted, aux_corrupted = corrupt_images(images, normals, ldirs, specular_weight=SPECULAR_WEIGHT, noise_weight=NOISE_WEIGHT)


# flatten corrupted images
O_corrupted = np.matrix([image.flatten() for image in images_corrupted]).T

In [None]:
# show the breakdown of a corrupted image

idx = 0

multiplot([
    images_corrupted[idx],
    aux_corrupted[idx]['mask_specular'],
    aux_corrupted[idx]['mask_gaussian'],
    images[idx]
],[
    'Corrupted Image',
    'Specular Layer',
    'Noise Layer',
    'Original Image'
])

In [None]:
if not os.path.exists(fr'{RESULTS_PATH}\corrupted-images\\'):
    os.makedirs(fr'{RESULTS_PATH}\corrupted-images\\')

# save the corrupted images
for i in range(NUM_IMAGES):
    cv.imwrite(fr'{RESULTS_PATH}\corrupted-images\image-corrupted-{i:2}.png', 255 * images_corrupted[i])

In [None]:
# create the corrupted solver
solver_corrupted = PSSolver(O_corrupted, L)

solver_corrupted.set_mask(use_mask=mask)

print('Reconstructing normals (corrupted)...')

if ALGORITHM == 'basic':
    normals_corrupted = solver_corrupted.ps_basic(out_shape=IMAGE_SHAPE)
else: # sbl
    normals_corrupted, errors_corrupted = solver_corrupted.ps_sbl(
        max_iters=SBL_MAX_ITERS,
        lambda_=SBL_LAMBDA,
        sigma=SBL_SIGMA,
        num_processes=NUM_PROCESSES,
        out_shape=IMAGE_SHAPE,
        use_paper_algorithm=SBL_USE_PAPER_ALGORITHM
    )
    
    # write errors
    for idx in range(NUM_IMAGES):

        errors_show = errors_corrupted[:,:,idx].copy()
        errors_show = np.absolute(errors_show)

        plt.imshow(errors_show)
        plt.set_cmap('plasma')
        plt.axis('off')
        plt.title('Abs. SBL Error Variances', size=20)
        cb = plt.colorbar(
            plt.pcolormesh(errors_show, vmax=errors_show.max(), vmin=errors_show.min())
            , fraction=0.046, pad=0.04)
        cb.ax.tick_params(labelsize=12)
        cb.set_label('Abs. Variance', rotation=270, labelpad=18, size=12)

        plt.savefig(fr'{RESULTS_PATH}\sbl-error-variances\image-corrupted-{idx:2}.png', dpi=plt.gcf().dpi, bbox_inches='tight')
        plt.close()

# save normals
write_normals(normals_corrupted, fr"{RESULTS_PATH}normals-corrupted.png")

plt.imshow(normals_corrupted)
plt.show()
plt.close()

In [None]:
# calculate mean angular error

if not os.path.exists(fr'{RESULTS_PATH}\angular-differences\\'):
    os.makedirs(fr'{RESULTS_PATH}\angular-differences\\')


ae = angular_error(normals, normals_corrupted, solver.mask)

plt.set_cmap('viridis')
plt.imshow(ae)
plt.title(f'Mean Angular Difference: {ae.mean():.4f}$^\circ$')
plt.axis('off')
cb = plt.colorbar(
    plt.pcolormesh(ae, vmax=ae.max(), vmin=ae.min())
    , fraction=0.046, pad=0.04)
cb.ax.tick_params(labelsize=12)
cb.set_label('Difference ($^\circ$)', rotation=270, labelpad=18, size=12)

plt.savefig(fr'{RESULTS_PATH}\angular-differences\baseline-corrupted.png', dpi=plt.gcf().dpi, bbox_inches='tight')

plt.show()
plt.close()

In [None]:
print('Integrating surface (corrupted)...')

integrator_corrupted = PSIntegrator(normals_corrupted, IMAGE_SHAPE, mask)

integrator_corrupted.to_obj(f'{RESULTS_PATH}surface-corrupted.obj')

integrator_corrupted.display()

In [None]:
print('Smoothing corrupted normals...')

normals_smoothed = integrator_corrupted.smooth_normals(
    sigma_f=SIGMA_F,
    sigma_g=SIGMA_G,
    window_size=WINDOW_SIZE,
    num_processes=NUM_PROCESSES
)

# save normals
write_normals(normals_smoothed, fr"{RESULTS_PATH}normals-smoothed.png")

plt.imshow(normals_smoothed)
plt.show()
plt.close()

In [None]:
# angular difference
ae = angular_error(normals, normals_smoothed, solver.mask)


plt.set_cmap('viridis')
plt.imshow(ae)
plt.title(f'Mean Angular Difference: {ae.mean():.4f}$^\circ$')
plt.axis('off')
cb = plt.colorbar(
    plt.pcolormesh(ae, vmax=ae.max(), vmin=ae.min())
    , fraction=0.046, pad=0.04)
cb.ax.tick_params(labelsize=12)
cb.set_label('Difference ($^\circ$)', rotation=270, labelpad=18, size=12)

plt.savefig(fr'{RESULTS_PATH}\angular-differences\baseline-smoothed.png', dpi=plt.gcf().dpi, bbox_inches='tight')

plt.show()
plt.close()

In [None]:
print('Integrating surface (smoothed)...')

integrator_smoothed = PSIntegrator(normals_smoothed, IMAGE_SHAPE, mask)

integrator_smoothed.to_obj(f'{RESULTS_PATH}surface-smoothed.obj')

integrator_smoothed.display()