# Data Visualization Tool

We provide this tool for visualizing samples in our polar image / surface normal dataset (see more on dataset [here](https://github.com/alexrgilbert/deepsfp/blob/master/data/README.md)). It can used for examining data before and after transformations, generated crop indices (see more on data preparation [here](https://github.com/alexrgilbert/deepsfp/blob/master/README.md)), as well as test set reconstructions and test results. 

You can use the same configuration override/command line options process for specifying the dataset to use (see *'Set Config'* cell below). 

Outputs are [optionally] saved to `$SfP_ROOT/<$cfg.output_dir>/<$cfg.(train|test).dataloader.dataset.name>/<$EXPERIMENT_ID>`, where experiment IDs are of the following forms:

- Data Visualization:` <$CONFIG_FILENAME>_<$TIME_STRING>_(mat|pth)-data-viz`
- Crop Visualization: `<$CONFIG_FILENAME>_<$TIME_STRING>_<crop_h>-<crop_w>-<thresh>-crop-idcs-viz`
- Test Reconstruction Visualization + Test Results: `<$CONFIG_FILENAME>_<$TIME_STRING>_pred-viz`

Along with the saved visualizations (if specified), will be the following meta-data files (for reproducing visualization runs):

- Frozen Config (`<$EXPERIMENT_ID>.yaml`): YAML configuration override file with the exact configuration used for this conversion run.
- Frozen Train/Test Split (`(train|test)_set.csv`): Copy of the training or test set data list (specified by `<$data_cfg.root>/<$data_cfg.name>/<$data_cfg.data_list>>` where `data_cfg = <$cfg.(train|test).dataloader.dataset>`).
- Log file (`<$EXPERIMENT_ID>.log`): Includes a printout of the current commit hash for this repo, the experiment config, and logging messages from the process.

## Imports

In [1]:
import os
import torch
import scipy.io as sio
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import pandas as pd
from tqdm import tqdm
from yacs.config import CfgNode as CN
from __init__ import config as default_config, update_config, setup_experiment

## Utils

In [2]:
def load_mat(filename):
    sample = sio.loadmat(filename)
    proc = lambda arr : torch.from_numpy(np.atleast_3d(arr).transpose((2,0,1)))
    sample = {k: proc(v) for k,v in sample.items() if k[:2] != '__'}
    # Key name changes
    sample['est'] = sample['normals_prior']
    sample['image'] = sample['images']
    sample['label'] = sample['normals_gt']
    return sample


def load_pth(filename):
    return torch.load(filename)


def normalize_normals(normals):
    normalsn = normals - torch.min(normals)
    normalsm = normalsn / (torch.max(normalsn) + 1e-8)
    return normalsm

## Set Config

In [3]:
CONFIG_PATH = '../experiments/deepsfp.yaml' # Default: None
OPTS = []  # Default: []
OPTS += ['enable_tblogging', 'False']  # Disable TB-logging for visualiation runs
config = update_config(default_config.clone(), config_filepath=CONFIG_PATH, cli_options=OPTS, )

## Visualize Dataset

In [4]:
def _viz_polar_images(images, axes=None):
    # Plot polar images
    if axes is None:
        _, ax_arr = plt.subplots(2,2)
        axes = [ax_arr[i//2][i%2] for i in range(4)]
    angles = [0,45,90,135]
    for i in range(len(images)):
        img,ax = images[i], axes[i]
        ax.imshow(img)
        ax.set_title(f'Polar Image @ {angles[i]}°')


def _viz_mask(mask, ax=None):
    # Plot Binary Mask
    if ax is None:
        _, ax = plt.subplots(1,1)
    ax.imshow(mask[0])
    ax.set_title(f'Binary Foreground Mask')


def _viz_normals_gt(normals, ax=None):
    # Plot normal GT
    if ax is None:
        _, ax = plt.subplots(1,1)
    ax.imshow(normalize_normals(normals).permute(1,2,0))
    ax.set_title(f'Normals GT')


def _viz_normals_prior(normals, axes=None):
    # Plot normal priors
    if axes is None:
        _, axes = plt.subplots(1,3)
    solution_type = ['Diffuse', 'Specular 1', 'Specular 2']
    for i in range(len(normals)//3):
        normal,ax = normals[(3*i):(3*i)+3], axes[i]
        ax.imshow(normalize_normals(normal).permute(1,2,0))
        ax.set_title(f'Normals Prior:\n{solution_type[i]} Solution')


def visualize_data(sample, sample_name=None, figsize=(12,12), disp=False, dest_dir=None, logger=None):
    # Setup
    fig = plt.figure(constrained_layout=True, figsize=figsize)
    fig.tight_layout()
    if sample_name is not None:
        fig.suptitle(f'{sample_name}', y=1.01)
    gs = GridSpec(3, 3, figure=fig)
    
    # Polar images
    polar_image_axes = [fig.add_subplot(gs[i//2,i%2]) for i in range(4)]
    _viz_polar_images(sample['image'], axes=polar_image_axes)
    
    # Binary Mask
    binary_mask_ax = fig.add_subplot(gs[0,2])
    _viz_mask(sample['mask'], ax=binary_mask_ax)
    
    # Normals GT
    normals_gt_ax = fig.add_subplot(gs[1,2])
    _viz_normals_gt(sample['label'], ax=normals_gt_ax)
    
    # Normals Priors
    normals_prior_axes = [fig.add_subplot(gs[2,i]) for i in range(3)]
    _viz_normals_prior(sample['est'], axes=normals_prior_axes)
    
    # Display / Saving / Cleanup
    if disp:
        plt.show()
    if dest_dir is not None:
        dest = os.path.join(dest_dir,f"{sample_name.replace('/','-')}.png")
        if logger is not None:
            logger.info(f'\tSaving data viz to {os.path.realpath(dest)}')
        fig.savefig(dest, facecolor='w')
    plt.close()

In [None]:
FIGSIZE=(12,12)
DISP=True
SAVE=True

PHASE = 'test'  # 'train' or 'test'
PTH_FILES = False  # Whether to visualize transformed data (i.e. converted from .mat to .pth, 
                   # see https://github.com/alexrgilbert/deepsfp/blob/master/README.md)

load = load_pth if PTH_FILES else load_mat

datacfg = config.get(PHASE).dataloader.dataset
dataroot, dataset, datafile = datacfg.root, datacfg.name, datacfg.data_list
datadir = os.path.join('..', dataroot, dataset)
obj_dir = os.path.join(datadir, 'objects')

logger, exp_dir, _, _, _ = setup_experiment(config, CONFIG_PATH, root_dir='..', 
        quiet=True, phase=PHASE, name=f'{PHASE}-set-{"pth" if PTH_FILES else "mat"}-data-viz',
        meta=[{'config_override_filepath': CONFIG_PATH, 'cli_config_overrides': OPTS}, config])
print(f'See {exp_dir} for logs{" and plots" if SAVE else ""}.')  # Logging will be silent because, Jupyter

logger.info(f'Visualizing the "{dataset}" {PHASE} dataset located at {os.path.realpath(dataroot)}.')
datalist = pd.read_csv(os.path.join(datadir,datafile), header=None, squeeze=True)
ext = '.pth' if PTH_FILES else '.mat'
datalist = datalist.apply(lambda f: f'{os.path.splitext(f)[0]}{ext}')
pbar = tqdm(datalist)
for i, filename in enumerate(pbar):
    pbar.set_description(filename)
    logger.info(f'[{i+1}/{len(datalist)}] Visualizing {filename}...')
    filepath = os.path.join(obj_dir, filename)
    if not os.path.exists(filepath):
        logger.warning(f'{filepath} does not exist! Skipping!')
        continue
    sample = load_pth(filepath) if PTH_FILES else load_mat(filepath)
    visualize_data(sample, sample_name=filename, figsize=FIGSIZE, disp=DISP, dest_dir=exp_dir, logger=logger)

## Visualize Random Crops

Read more on random crop idices [here](https://github.com/alexrgilbert/deepsfp/blob/master/README.md). Must be generated before this block can be run.

In [5]:
def viz_crop_masks(sample, sample_name, crop_idcs, crop_h_w, n=5, figsize=(12,12), disp=False, dest_dir=None, logger=None):
    sel = np.random.randint(0,len(crop_idcs),n**2)
    top_left = crop_idcs[sel,:].data
    gt = normalize_normals(sample['label'])
    fig,axes = plt.subplots(n,n, figsize=figsize)
    fig.suptitle(f'[{sample_name}] Random {crop_h_w[0]}x{crop_h_w[1]} Crops of GT Normals', y=1.0)
    fig.tight_layout()
    for i in range(n**2):
        ax = axes[i//n][i%n]
        top,left = top_left[i]
        crop = gt[:,top: top + crop_h_w[0], left: left + crop_h_w[1]]
        ax.imshow(crop.permute(1,2,0))
        ax.set_title(f'({top},{left})->({top+crop_h_w[0]},{left+crop_h_w[1]})')
    if disp:
        plt.show()
    if dest_dir is not None:
        dest = os.path.join(dest_dir,f"{sample_name.replace('/','-')}.png")
        if logger is not None:
            logger.info(f'\tSaving crop mask viz to {os.path.realpath(dest)}')
        fig.savefig(dest, facecolor='w')
    plt.close()

In [None]:
FIGSIZE=(12,12)
DISP=True
SAVE=True

PHASE = 'test' # 'train' or 'test'
N = 5  # Sqrt of total number of random crops to sample and visualize


datacfg = config.get(PHASE).dataloader.dataset
dataroot, dataset, datafile = datacfg.root, datacfg.name, datacfg.data_list
datadir = os.path.join('..', dataroot, dataset)
objdir = os.path.join(datadir, 'objects')
assert 'RandomCrop' in datacfg.transforms, f'No RandomCrop configured for {PHASE} phase!'
cropcfg = datacfg.transforms.RandomCrop
crop_h,crop_w = cropcfg.crop_size
thresh = cropcfg.foreground_ratio_threshold
idcsdir = os.path.join(datadir, 'crop_indices',f'{crop_h}_{crop_w}_{thresh}')

logger, exp_dir, _, _, _ = setup_experiment(config, CONFIG_PATH, root_dir='..', quiet=True, phase=PHASE,
                         name=f'{PHASE}-set-{crop_h}-{crop_w}-{str(thresh).replace(".","_")}-crop-idcs-viz',
                        meta=[{'config_override_filepath': CONFIG_PATH, 'cli_config_overrides': OPTS}, config])

print(f'See {exp_dir} for logs{" and plots" if SAVE else ""}.')
logger.info(f'Visualizing crop indices for crops of dimensions {crop_h}x{crop_w} with foreground '
            f'ratio > {thresh} on the "{dataset}" {PHASE} dataset located at '
            f'{os.path.realpath(dataroot)}')

datalist = pd.read_csv(os.path.join(datadir,datafile),header=None, squeeze=True)
datalist = datalist.apply(lambda f: f'{os.path.splitext(f)[0]}.pth')
pbar = tqdm(datalist)
for i,filename in enumerate(pbar):
    pbar.set_description(filename)
    logger.info(f'[{i+1}/{len(datalist)}] Visualizing crop masks for {filename}...')
    objpath = os.path.join(objdir, filename)
    idcspath = os.path.join(idcsdir,filename)
    if not os.path.exists(filepath):
            logger.warning(f'Object @ {filepath} does not exist! Skipping!')
            continue
    if not os.path.exists(idcspath):
            logger.warning(f'Crop indices @ {filepath} do not exist! Skipping!')
            continue
    sample = torch.load(objpath)
    idcs = torch.load(idcspath)
    viz_crop_masks(sample, filename, idcs, (crop_h,crop_w), n=N, 
                    figsize=FIGSIZE, disp=DISP, dest_dir=exp_dir, logger=logger)

## Test Reconstructions

Visualizing test set reconstructions requires providing the path to a test experiment directory (i.e. `$SfP_ROOT/<$cfg.output_dir>/<$cfg.test.dataloader.dataset.name>/<$CONFIG_FILENAME>_<$TIME_STRING>_test`). The config used for the corresponding test run will be used automatically (the *Set Config* cell above will be ignored) to plot the test set reconstructions (from `reconstructions.pth`) alongside their ground truth normal's.

TODO: Currently expects reconstructions represent surface normals. For alternative format, must replace _ae_map.

#### Path to Test Experiment Directory

In [3]:
EXPERIMENT_DIRECTORY = ''

In [32]:
COLORS =  ['r', 'g', 'b', 'c', 'm', 'y', 'k',]

def _ae_map(sample, label):
    target, mask = label['label'], label['mask']
    dot_product = (sample * target).sum(0)
    output_norm = torch.norm(sample, dim=0)
    target_norm = torch.norm(target, dim=0)
    dot_product = (dot_product / (output_norm * target_norm + 1e-8)).clamp(-1, 1)

    error_map = torch.acos(dot_product) # [-pi, pi]
    angular_map = error_map * 180.0 / np.pi
    angular_map = angular_map * mask[0].float()
    return angular_map.squeeze()

def visualize_reconstruction(sample, label, object_name, figsize=(16,4), disp=False, dest_dir=None, logger=None):
    # Setup
    fig,axes = plt.subplots(1,4, figsize=figsize)
    fig.suptitle(f'{object_name} | MAE = {sample["error"]:.3f}')
    fig.tight_layout()

    # Black & White (Polar) Image
    img_ax = axes[0]
    img_ax.imshow(label['image'][0], cmap='gray')
    img_ax.set_title(f'Polar Image @ 0°')
    
    # Reconstruction
    sample_ax = axes[1]
    pred = normalize_normals(sample['reconstruction'])*label['mask'].float()
    sample_ax.imshow(pred.permute(1,2,0))
    sample_ax.set_title(f'Reconstruction')

    # Angular Error Map
    err_ax = axes[2]
    err_map = _ae_map(sample['reconstruction'],label)
    err_img = err_ax.imshow(err_map.abs(), cmap='Reds', vmin=0, vmax=180)
    err_divider = make_axes_locatable(err_ax)
    err_cax = err_divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(err_img, cax=err_cax)
    err_ax.set_title(f'Angular Error Map (°)')

    # Ground truth
    label_ax = axes[3]
    gt = normalize_normals(label['label'])*label['mask'].float()
    label_ax.imshow(gt.permute(1,2,0))
    label_ax.set_title(f'Ground Truth')

    # Display & Save
    if disp:
        plt.show()
    if dest_dir is not None:
        dest = os.path.join(dest_dir,f"{object_name}.png")
        if logger is not None:
            logger.info(f'\tSaving reconstruction viz to {os.path.realpath(dest)}')
        fig.savefig(dest, facecolor='w')
    plt.close()

def plot_results(results, exp_name, figsize=(12,16), disp=False, dest_dir=None, logger=None):
    # Setup
    fig = plt.figure(constrained_layout=True, figsize=figsize)
    fig.tight_layout()
    gs = GridSpec(4, 3, figure=fig)
    df = results.sort_values(['lighting','object','orientation'],axis=0,ascending=False)
    fig.suptitle(f'Experiment {exp_name}: Results\nOverall Avg MAE (°) = {df.error.mean():.3f}', y=.93)
    
    # MAE by Sample
    mbs_ax = fig.add_subplot(gs[:,:1])
    s, l, v = lambda ser: ser.sort_index(ascending=False), df.lighting, df.lighting.value_counts()
    bars = mbs_ax.barh(range(len(df)), df.error, tick_label=df.iloc[:,1:3].fillna('').agg('_'.join,axis=1), 
                        label=s(l).to_list(), color=np.repeat(COLORS[:l.nunique()], s(v).to_list()))
    mbs_ax.legend(tuple(bars[b] for b in s(v).cumsum().values - 1)[::-1],tuple(s(v).index)[::-1])
    for b in bars:
        w,y,h = b.get_width(), b.get_y(), b.get_height()
        mbs_ax.annotate(f'{w:.2f}',xy=(w, y+h/2), xytext=(3, 0),
                    textcoords="offset points", ha='left', va='center')
    mbs_ax.minorticks_on()
    mbs_ax.tick_params(axis='y', which='minor', left=False)
    mbs_ax.grid(axis='x',which='both')
    mbs_ax.set_title(f'MAE (°) by Sample')

    # MAE by Lighting
    mbl_ax = fig.add_subplot(gs[:2,1:])
    mbl = df.groupby('lighting')['error'].mean()
    bars = mbl_ax.bar(range(len(mbl)), mbl.values, tick_label=mbl.index, color=COLORS[:len(mbl)])#, label=mbl.index.to_list(),)
    # mbl_ax.legend(bars,mbl.index.to_list(),loc='lower left')
    for b in bars:
        h,x,w = b.get_height(), b.get_x(), b.get_width()
        mbl_ax.annotate(f'{h:.2f}',xy=(x+w/2, h), xytext=(0, 3),
                    textcoords="offset points", ha='center', va='bottom')
    mbl_ax.minorticks_on()
    mbl_ax.tick_params(axis='x', which='minor', bottom=False)
    mbl_ax.grid(axis='y',which='both')
    mbl_ax.set_title(f'Average MAE (°) by Lighting Condition')

    # MAE by Object
    mbo_ax = fig.add_subplot(gs[2:,1:])
    mbo = df.groupby(['object','orientation'])['error'].mean()
    labels = list(map(lambda t: '_'.join(t), mbo.index.to_list()))
    mbo_ax.set_xticklabels(mbo_ax.get_xticks(), rotation = 90)
    bars = mbo_ax.bar(range(len(mbo)), mbo.values, tick_label=labels, color=COLORS[:len(mbo)])
    for b in bars:
        h,x,w = b.get_height(), b.get_x(), b.get_width()
        mbo_ax.annotate(f'{h:.2f}',xy=(x+w/2, h), xytext=(0, 3),
                    textcoords="offset points", ha='center', va='bottom')
    mbo_ax.minorticks_on()
    mbo_ax.tick_params(axis='x', which='minor', bottom=False)
    mbo_ax.grid(axis='y',which='both')
    mbo_ax.set_title(f'Average MAE (°) by Object')

    # Display & Save
    if disp:
        plt.show()
    if dest_dir is not None:
        dest = os.path.join(dest_dir,f"results_plots.png")
        if logger is not None:
            logger.info(f'\tSaving results plots to {os.path.realpath(dest)}')
        fig.savefig(dest, facecolor='w')
    plt.close()

In [None]:
RECONSTRUCTIONS_FIGSIZE=(16,4)
RESULTS_FIGSIZE=(12,16)
DISP=True
SAVE=True
RECONSTRUCTIONS = True  # Whether to visualize reconstructions
RESULTS = True  # Whether to plot results

exp_dir = os.path.realpath(EXPERIMENT_DIRECTORY)
exp_id = os.path.basename(exp_dir)
config_path = os.path.join(exp_dir,f'{exp_id}.yaml')
reconstructions_path = os.path.join(exp_dir, 'reconstructions.pth')
results_path = os.path.join(exp_dir, 'results.csv')
testset_path = os.path.join(exp_dir, 'test_set.csv')

# Load config used for testing
with open(config_path, 'r') as config_file:
    test_config = CN.load_cfg(config_file)
# Disable TB-Logging for visualization run and force set testset list
update_config(test_config, cli_options=['enable_tblogging', 'False', 
                                        'test.dataloader.dataset.data_list', testset_path])

logger, exp_dir, _, _, _ = setup_experiment(test_config, config_path, root_dir='..', 
                                quiet=True, phase='test', name=f'pred-viz', meta=[test_config])
dest_dir = exp_dir if SAVE else None
print(f'See {exp_dir} for logs{" and plots" if SAVE else ""}.')

if RECONSTRUCTIONS:
    logger.info(f'Visualizing test set reconstructions from experiment {exp_id}')
    reconstructions = torch.load(reconstructions_path)
    datacfg = test_config.test.dataloader.dataset
    dataroot, dataset = datacfg.root, datacfg.name
    datadir = os.path.join('..', dataroot, dataset)
    objdir = os.path.join(datadir, 'objects')
    matlist = pd.read_csv(testset_path,header=None, squeeze=True)
    pthlist = matlist.apply(lambda f: f'{os.path.splitext(f)[0]}.pth')
    pbar = tqdm(pthlist)
    for i, filename in enumerate(pbar):
        pbar.set_description(filename)
        logger.info(f'[{i+1}/{len(pthlist)}] Visualizing {filename}...')
        filepath = os.path.join(objdir, filename)
        if not os.path.exists(filepath):
            logger.warning(f'No ground truth for {filename}! {filepath} does not exist! Skipping!')
            continue
        objname = os.path.splitext(filename)[0].replace('/','_')
        if objname not in reconstructions:
            logger.warning(f'No reconstruction for {objname}! Skipping!')
            continue
        sample = reconstructions[objname]
        label = load_pth(filepath)
        visualize_reconstruction(sample, label, object_name=objname, figsize=RECONSTRUCTIONS_FIGSIZE, 
                                 disp=DISP, dest_dir=dest_dir, logger=logger)
if RESULTS:
    logger.info(f'Visualizing test set results from experiment {exp_id}')
    results = pd.read_csv(results_path, index_col=0)
    plot_results(results, exp_id, figsize=RESULTS_FIGSIZE, disp=DISP, dest_dir=dest_dir, logger=logger)