# RICO analysis

This notebook qualitatively analyzes learned models in rico dataset.

In [None]:
import itertools
import json
import logging
import sys
import os

import tensorflow as tf
from IPython.display import display, HTML

sys.path.append('../src/canvas-vae')

from canvasvae.models.vae import VAE
from canvasvae.data import DataSpec
from canvasvae.helpers.svg import SVGBuilder

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

Load datasets

In [None]:
dataspec = DataSpec('rico', '../data/rico', batch_size=2)
train_dataset = dataspec.make_dataset('train', shuffle=True)
test_dataset = dataspec.make_dataset('test', shuffle=True)

Load learned models. Edit the configuration tuples to match parameters and the job ids.

In [None]:
def load_model(
    decoder_type,
    block_type,
    num_blocks,
    latent_dim,
    job_id,
    best_or_final='final',
):
    model = VAE(
        dataspec.make_input_columns(),
        decoder_type=decoder_type,
        latent_dim=latent_dim,
        num_blocks=num_blocks,
        block_type=block_type,
    )
    model.compile(optimizer='adam')
    model.load_weights(f"tmp/canvasvae/jobs/{job_id}/checkpoints/{best_or_final}.ckpt")
    return model


# Uncomment the following with appropriate config and job id.
models = {
    f"{args[0]}_{args[1]}{args[2]}": load_model(*args) for args in [
#         ('autoregressive', 'lstm', 1, 256, '20210818223504'),
#         ('autoregressive', 'deepsvg', 1, 256, '20210818223504'),
#         ('autoregressive', 'deepsvg', 4, 256, '20210818223504'),
#         ('oneshot', 'lstm', 1, 256, '20210818223504'),
#         ('oneshot', 'deepsvg', 1, 256, '20210818223504'),
#         ('oneshot', 'deepsvg', 4, 256, '20210818223504'),
    ]
}
models

Prepare SVG document builder for visualization

In [None]:
with open('../src/canvas-vae/canvasvae/data/component_legend.json', 'r') as f:
    rico_colormap = {k: v['hex'] for k, v in json.load(f).items()}

builder = SVGBuilder(
    key='component',
    colormap=rico_colormap,
    canvas_width=144,
    canvas_height=256,
)

### Reconstruction and sampling

The following visualizes reconstruction with stochastic sampling

In [None]:
def grouper(iterable, n, fillvalue=None):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
    args = [iter(iterable)] * n
    return itertools.zip_longest(*args, fillvalue=fillvalue)


def visualize_reconstruction(model, example, dataspec, input_builders, output_builders):
    svgs = []
    for builder in input_builders:
        svgs.append(list(map(builder, dataspec.unbatch(example))))
    for i in range(3):
        prediction = model(example, sampling=(i > 0))
        for builder in output_builders:
            svgs.append(list(map(builder, dataspec.unbatch(prediction))))
    return [list(grouper(row, len(input_builders))) for row in zip(*svgs)] 


example = next(iter(test_dataset.take(1)))

svgs = {}
for key, model in models.items():
    print(key)
    svgs[key] = visualize_reconstruction(model, example, dataspec, [builder], [builder])
    for row in svgs[key]:
        display(HTML('<div>%s</div>' % ' '.join(itertools.chain.from_iterable(row))))

### Interpolation

Show interpolation between two samples.

In [None]:
def visualize_interpolation(model, example1, example2, dataspec, input_builders, output_builders, num_samples=8):
    z1 = model.encoder(example1, sampling=False)
    z2 = model.encoder(example2, sampling=False)
    cols = []
    for builder in input_builders:
        cols.append(list(map(builder, dataspec.unbatch(example1))))
    for u in np.linspace(0, 1, num_samples):
        z = z1 * (1 - u) + z2 * u
        prediction = model.decoder(z)
        for builder in output_builders:
            cols.append(list(map(builder, dataspec.unbatch(prediction))))
    for builder in input_builders:
        cols.append(list(map(builder, dataspec.unbatch(example2))))
        
    return [list(grouper(row, len(input_builders))) for row in zip(*cols)]


itr = iter(test_dataset.take(2))
example1, example2 = next(itr), next(itr)

layouts = {}
for key, model in models.items():
    print(key)
    layouts[key] = visualize_interpolation(model, example1, example2, dataspec, [builder], [builder])
    for row in layouts[key]:
        display(HTML('<div>%s</div>' % ' '.join(itertools.chain.from_iterable(row))))

### Random generation

Show randomly generated documents.

In [None]:
def generate_random(model, builders, batch_size=20):
    z = tf.random.normal(shape=(batch_size, model.decoder.input_shape[-1]))
    prediction = model.decoder(z)
    svgs = []
    for item in dataspec.unbatch(prediction):
        svgs.append(tuple(builder(item) for builder in builders))
    return svgs
    
generated = {}
for key, model in models.items():
    print(key)
    generated[key] = generate_random(model, [builder])
    display(HTML('<div>%s</div>' % ' '.join(itertools.chain.from_iterable(generated[key]))))