## Build HTML file to visualize results

In [32]:
# General imports
import torch
import numpy as np
import os, sys
import json
from tqdm import tqdm
import pandas as pd
from airium import Airium
import re

In [2]:
# Local imports
sys.path.insert(0, 'src')
from utils import read_json, read_lists, ensure_dir
from utils.df_utils import load_and_preprocess_csv, get_sorted_idxs
from utils.html_utils import save_visualizations_separately, build_html
from utils.visualizations import bar_graph
from parse_config import ConfigParser
from data_loader import data_loaders
import model.model as module_arch

In [20]:
# Variables
results_timestamp = '0125_114341'
target_class = 'cat'
n_select = '50'
paths_timestamp = '0125_112421'
data_type = 'softmax'
sort_columns = ['Post Accuracy']

In [39]:
# ID Regex
id_regex = '/+[a-z0-9_]*\-[a-z0-9_]*\-[a-z0-9_]*/.*/'
def get_image_id(path):
    return re.search(id_regex, path).group()[1:-1]

# Constant paths
class_list_path = os.path.join('metadata', 'cinic-10', 'class_names.txt')
# Results paths
results_dir = os.path.join('saved', 'edit', 'trials', 'CINIC10_ImageNet-VGG_16', results_timestamp)
csv_path = os.path.join(results_dir, 'results_table.csv')
trial_paths_path = os.path.join(results_dir, 'trial_paths.txt')

paths_dir = os.path.join('paths', 'edits', 'semantics', 
                         '{}_{}'.format(target_class, n_select), 
                         paths_timestamp)
value_image_paths_path = os.path.join(paths_dir, 'value_images_{}.txt'.format(data_type))

# HTML file directories
html_save_dir = os.path.join('html', '{}_{}'.format(target_class, n_select))
html_assets_dir = os.path.join(html_save_dir, 'assets')
ensure_dir(html_assets_dir)

In [22]:
# Load class list
class_list = read_lists(class_list_path)
# Load CSV and paths
df = load_and_preprocess_csv(
    csv_path,
    drop_duplicates=['ID'])

value_image_paths = read_lists(value_image_paths_path)
trial_paths = read_lists(trial_paths_path)

## Sanity Checks

In [23]:
# Sanity check same number of rows
n_rows = len(df)
assert len(value_image_paths) == n_rows, "{} rows in paths; {} rows in data frame".format(len(value_image_paths), n_rows)
assert len(trial_paths) == n_rows, "{} rows in paths; {} rows in data frame".format(len(trial_paths), n_rows)

# Sanity check that each row corresponds to one another
for image_id, value_image_path, trial_path in zip(df['ID'], value_image_paths, trial_paths):
    image_id = image_id.split('/')
    for id_part in image_id:
        assert id_part in value_image_path
        assert id_part in trial_path

# Check columns in sort_columns are in dataframe
for column in sort_columns:
    assert column in df.columns

In [37]:
# Get sorted idxs based on sort columns
sorted_df, sorted_idxs = get_sorted_idxs(
    df=df,
    columns=sort_columns,
    increasing=False)

# Sort image paths and trial paths accordingly
sorted_value_image_paths = [value_image_paths[idx] for idx in sorted_idxs]
sorted_trial_paths = [trial_paths[idx] for idx in sorted_idxs]
sorted_IDs = [re.search(id_regex, path).group()[1:-1] for path in sorted_value_image_paths]

# Sanity check
for id_, trial_path in zip(sorted_IDs, sorted_trial_paths):
    assert id_ in trial_path


In [25]:
# Save metrics for each data point as string
sorted_df = sorted_df.round(3)
metrics = ['Recall', 'Precision', 'F1']
groups = ['Mean', 'Target', 'Orig Pred']
metric_strings = []
for idx in range(n_rows):
    metric_string = ''
    key = 'Post Accuracy'
    metric_string += "Accuracy: {}".format(sorted_df[key].iloc[idx])

    for group in groups:
        metric_string += "\n{}".format(group)
        
        for metric in metrics:
            key = 'Post {} {}'.format(group, metric)
            metric_string += "\n\t{:<15} {}".format(metric, sorted_df[key].iloc[idx])

    metric_strings.append(metric_string)
assert len(metric_strings) == n_rows

### Get paths to all things we want to visualize: 

1) cumulative masking graphic
2) cumulative masking graph
3) class distribution pre/post edit

In [26]:
# Copy files from segmentation process
file_names = [
    '{}_cumulative_modifying.png'.format(data_type),
    'target_{}_v_n_images.png'.format(data_type)]
input_dirs = [os.path.dirname(path) for path in sorted_value_image_paths]
html_asset_save_dirs, save_id_paths = save_visualizations_separately(
    input_dirs=input_dirs,
    file_names=file_names,
    output_dir=html_assets_dir,
    overwrite=False)



In [27]:
# Create class distribution bar graphs per row
columns = ['Pre Class Dist', 'Post Class Dist']
bar_graph_save_paths = []
for idx, (trial_dir, html_asset_save_dir) in enumerate(zip(sorted_trial_paths, html_asset_save_dirs)):
    image_id = os.path.join(os.path.basename(os.path.dirname(html_asset_save_dir)),
                            os.path.basename(html_asset_save_dir))
    assert image_id in trial_dir
    
    data = []
    for column in columns:
        data.append(sorted_df.iloc[idx][column])
    data = np.stack(data, axis=0)
    
    bar_graph_save_path = os.path.join(html_asset_save_dir, 'class_distribution_bar_graph.png')
    bar_graph_save_paths.append(bar_graph_save_path)
    
    bar_graph(
        data=data,
        labels=class_list,
        groups=columns,
        title='Class Distribution for {}'.format(image_id),
        xlabel_rotation=30,
        ylabel='Counts',
        save_path=bar_graph_save_path,
        show_plot=False)

In [51]:
# Combine the paths 
image_ids, save_visualization_paths = save_id_paths

marker_idx = 0
id_idx = 0
asset_paths = []
last_id = image_ids[0]
while marker_idx < len(save_visualization_paths) and id_idx < len(bar_graph_save_paths):
    cur_id = get_image_id(save_visualization_paths[marker_idx])
    if cur_id == last_id:
        asset_paths.append(save_visualization_paths[marker_idx])
        marker_idx += 1
    else:
        asset_paths.append(bar_graph_save_paths[id_idx])
        id_idx += 1
        last_id = cur_id


## Create HTML file

In [62]:
def build_html(file_paths,
               asset_ids,
               html_save_path,
               texts=None,
               id_regex='/+[a-z0-9_]*\-[a-z0-9_]*\-[a-z0-9_]*/.*/'):
    '''
    Given paths to assets to embed, build HTML page

    Arg(s):
        file_paths : list[str]
            paths to each asset (sorted to group assets together)
        html_save_path : str
            where the html file will be saved to
        id_regex : str
            Regular expression to extract ID

    Returns:
        html_string : str
            html as a string
    '''

    # Create Airium object
    air = Airium()

    air('<!DOCTYPE html>')
    with air.html(lang="pl"):
        # Set HTML header
        with air.head():
            air.meta(charset="utf-8")
            air.title(_t="Cumulative Image Visualization")

        # Set HTML body
        text_idx = 0
        with air.body():
            prev_id = ""
            for path in file_paths:
                # asset_id = os.path.join(
                #     os.path.basename(os.path.dirname(path)),
                #     os.path.basename(path))
                asset_id = re.search(id_regex, path).group()
                # Remove the start and trailing backslashes
                asset_id = asset_id[1:-1]
                # Create new header
                if asset_id != prev_id:
                    with air.h3():
                        air(asset_id)
                    if texts is not None:
                        with air.p():
                            air(texts[text_idx])
                        text_idx += 1
                    prev_id = asset_id

                # Embed asset as image
                relative_asset_path = os.path.relpath(path, os.path.dirname(html_save_path))
                air.img(src=relative_asset_path, height=350)
                air.p("\n\n")
    # Turn Airium object to html string
    html_string = str(air)
    return html_string

In [63]:
html_save_path = os.path.join(html_save_dir, 'visualization.html')
html_string = build_html(
    asset_paths,
    asset_ids=image_ids,
    texts=metric_strings,
    html_save_path=html_save_path)           

print(html_string)
with open(html_save_path, 'wb') as f:
    f.write(bytes(html_string, encoding='utf-8'))
print("Saved HTML file to {}".format(html_save_path))

<!DOCTYPE html>
<html lang="pl">
  <head>
    <meta charset="utf-8" />
    <title>Cumulative Image Visualization</title>
  </head>
  <body>
    <h3>
      cat-train-n02121808_1820/felzenszwalb_masked
    </h3>
    <p>
      Accuracy: 0.69
Mean
	Recall          0.69
	Precision       0.695
	F1              0.689
Target
	Recall          0.593
	Precision       0.492
	F1              0.538
Orig Pred
	Recall          0.83
	Precision       0.785
	F1              0.807
    </p>
    <img src="assets/cat-train-n02121808_1820/felzenszwalb_masked/softmax_cumulative_modifying.png" height="350" />
    <p 

></p>
    <img src="assets/cat-train-n02121808_1820/felzenszwalb_masked/target_softmax_v_n_images.png" height="350" />
    <p 

></p>
    <img src="assets/cat-train-n02121808_1820/felzenszwalb_masked/class_distribution_bar_graph.png" height="350" />
    <p 

></p>
    <h3>
      cat-train-n02125494_5963/felzenszwalb_gaussian
    </h3>
    <p>
      Accuracy: 0.688
Mean
	Recall          0.688
	Prec

In [30]:
import re

string = 'html/cat_50/assets/cat-train-n01322898_3222/felzenszwalb_gaussian/class_distribution_bar_graph.png'
found = re.search('/+[a-z0-9_]*\-[a-z0-9_]*\-[a-z0-9_]*/.*/', string).group()
print(found)


/cat-train-n01322898_3222/felzenszwalb_gaussian/
