# RICO analysis
This notebook qualitatively analyzes learned models in rico dataset.

In [None]:
%load_ext autoreload
%autoreload 2

##### Editable parameters

In [None]:
ckpt_dir = "../results/rico/ours-exp-ft/checkpoints"
dataset_name = "rico"
db_root = "../data/rico"
batch_size = 4

##### Initialization

In [None]:
import copy
import itertools
import logging
import random
import sys

import numpy as np
import tensorflow as tf
from IPython.display import display, HTML
%matplotlib inline

sys.path.append("../src/mfp")

from mfp.models.mfp import MFP, merge_inputs_and_prediction
from mfp.models.architecture.mask import get_seq_mask
from mfp.models.masking import get_initial_masks
from mfp.data import DataSpec
from mfp.helpers import svg_rico as svg
from util import grouper, load_model

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# fix seed for debug
tf.random.set_seed(0)

##### Load datasets

In [None]:
dataspec = DataSpec(dataset_name, db_root, batch_size)
test_dataset = dataspec.make_dataset("test", shuffle=False)
iterator = iter(test_dataset.take(1))
example = next(iterator)

##### Load pre-trained models

In [None]:
input_columns = dataspec.make_input_columns()
models = {"main": load_model(ckpt_dir, input_columns=input_columns)}

##### Define some helpers for ELEM-filling task

In [None]:
builder0 = svg.SVGBuilder(
    max_width=128,
    max_height=192,
    key="type",
    preprocessor=dataspec.preprocessor,
)

# demo for ELEM prediction (randomly mask a single element)
def visualize_reconstruction(models, example, dataspec, input_builders, output_builders):
    seq_mask = get_seq_mask(example["length"])
    mfp_masks = get_initial_masks(input_columns, seq_mask)
    example_copy = copy.deepcopy(example)

    n_elem = tf.cast(tf.reduce_sum(tf.cast(seq_mask, tf.float32), axis=1), tf.int32).numpy()
    target_indices = [random.randint(0, n - 1) for n in n_elem]
    indices = []
    B, S = example_copy["left"].shape[:2]
    for i in range(B):
        indices.append([j for j in range(S) if j != target_indices[i]])
    indices = tf.convert_to_tensor(np.array(indices))
    for key in example_copy.keys():
        if example_copy[key].shape[1] > 1:
            example_copy[key] = tf.gather(example_copy[key], indices, batch_dims=1)
    example_copy["length"] -= 1

    svgs = []
    for builder in input_builders:
        svgs.append(list(map(builder, dataspec.unbatch(example_copy))))

    for key in mfp_masks.keys():
        if not input_columns[key]["is_sequence"]:
            continue
        dummy = mfp_masks[key].numpy()
        for i in range(len(target_indices)):
            dummy[i, target_indices[i]] = True  # hide single element for each sample
        mfp_masks[key] = tf.convert_to_tensor(dummy)

    for model in models:
        pred = model(example, training=False, demo_args={"masks": mfp_masks})
        pred = merge_inputs_and_prediction(example, input_columns, mfp_masks, pred)

        for builder in output_builders:
            svgs.append(list(map(builder, dataspec.unbatch(pred))))

    for builder in input_builders:
        svgs.append(list(map(builder, dataspec.unbatch(example))))

    return [list(grouper(row, len(input_builders))) for row in zip(*svgs)]


##### Visualization of results
From left to right: input (one element missing), prediction, ground truth

In [None]:
svgs = visualize_reconstruction(models.values(), example, dataspec, [builder0], [builder0])
for i, row in enumerate(svgs):
    print(i)
    display(HTML("<div>%s</div>" % " ".join(itertools.chain.from_iterable(row))))