In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import os
import pickle
import time
from tqdm.notebook import tqdm

import torch
torch.set_default_tensor_type(torch.DoubleTensor)

from spatial_scene_grammars.nodes import *
from spatial_scene_grammars.rules import *
from spatial_scene_grammars.scene_grammar import *
from spatial_scene_grammars.visualization import *
from spatial_scene_grammars_examples.planar_clusters_gaussians.grammar import *
from spatial_scene_grammars.parsing import *
from spatial_scene_grammars.sampling import *

import meshcat
import meshcat.geometry as meshcat_geom

In [2]:
if 'vis' not in globals():
    vis = meshcat.Visualizer()

base_url = "http://127.0.0.1"
meshcat_url = base_url + ":" + vis.url().split(":")[-1]
print("Meshcat url: ", meshcat_url)
from IPython.display import HTML
HTML("""
    <div style="height: 400px; width: 100%; overflow-x: auto; overflow-y: hidden; resize: both">
    <iframe src="{url}" style="width: 100%; height: 100%; border: none"></iframe>
</div>
""".format(url=meshcat_url))

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7002/static/
Meshcat url:  http://127.0.0.1:7002/static/


In [3]:
# Sample a dataset of scenes from the default grammar params.
# Draw a random sample from the grammar and visualize it.
# (Cache output.)
torch.random.manual_seed(2)
N_samples = 1
RESAMPLE = True
scenes_file = "sampled_scenes_%d.dat" % N_samples

ground_truth_grammar = SpatialSceneGrammar(
    root_node_type = Desk,
    root_node_tf = torch.eye(4)
)

if not os.path.exists(scenes_file) or RESAMPLE:
    samples = []
    for k in range(N_samples):
        tree = ground_truth_grammar.sample_tree(detach=True)
        observed_nodes = tree.get_observed_nodes()
        samples.append((tree, observed_nodes))

    with open(scenes_file, "wb") as f:
        pickle.dump(samples, f)

with open(scenes_file, "rb") as f:
    samples = pickle.load(f)
print("Loaded %d scenes." % len(samples))
observed_node_sets = [x[1] for x in samples]

draw_scene_tree_contents_meshcat(samples[0][0], zmq_url=vis.window.zmq_url, prefix="sample")



Loaded 1 scenes.


In [4]:
# Initialize a grammar with wide parameter guesses.
grammar = SpatialSceneGrammar(
    root_node_type = Desk,
    root_node_tf = torch.eye(4),
    sample_params_from_prior=False
)
# Force parameter guesses for rules as wide as possible.
# TODO: Make this a grammar method.
for node_type in grammar.all_types:
    for xyz_param_dict, rot_param_dict in grammar.rule_params_by_node_type[node_type.__name__]:
        if "width" in xyz_param_dict.keys():
            xyz_param_dict["width"].set(torch.ones_like(xyz_param_dict["width"]()) * 5.)
        

def do_vis(tree):
    draw_scene_tree_structure_meshcat(tree, zmq_url=vis.window.zmq_url, prefix="sampled_in_progress")
    
def get_posterior_tree_samples_from_observation(grammar, observed_nodes, num_mcmc_steps=15, subsample_rate=3, verbose=0):
    draw_scene_tree_contents_meshcat(
        SceneTree.make_from_observed_nodes(observed_nodes), zmq_url=vis.window.zmq_url, prefix="observed"
    )
    
    # Use a MIP to get MAP structure.
    mip_results = infer_mle_tree_with_mip(
        grammar, observed_nodes, verbose=verbose, max_scene_extent_in_any_dir=10.
    )
    mip_optimized_tree = get_optimized_tree_from_mip_results(mip_results)
    if not mip_optimized_tree:
        return None
    
    draw_scene_tree_structure_meshcat(mip_optimized_tree, zmq_url=vis.window.zmq_url, prefix="mip_refined")
                                      
    # Use NLP to refine that to a MAP estimate.
    refinement_results = optimize_scene_tree_with_nlp(mip_optimized_tree, verbose=verbose)
    refined_tree = refinement_results.refined_tree
    
    # And sample trees around that MAP estimate with the
    # same structure.
    sampled_trees = do_fixed_structure_mcmc(
        grammar, refined_tree, num_samples=num_mcmc_steps, verbose=verbose,
        perturb_in_config_space=True, translation_variance=1.0, rotation_variance=1.0,
        do_hit_and_run_postprocess=False, vis_callback=do_vis
    )
    
    # Finally, subsample the sampled trees as requested and return
    # the sampled set.
    return sampled_trees[::subsample_rate]

def collect_posterior_sample_sets(grammar, observed_node_sets):
    posterior_sample_sets = []
    for observed_nodes in tqdm(observed_node_sets, desc='Collecting posterior samples'):
        posterior_samples = get_posterior_tree_samples_from_observation(
            grammar, observed_nodes, verbose=1, subsample_rate=2, num_mcmc_steps=20)
        if posterior_samples is not None:
            posterior_sample_sets.append(posterior_samples)
    return posterior_sample_sets
posterior_sample_sets = collect_posterior_sample_sets(grammar, observed_node_sets)

HBox(children=(FloatProgress(value=0.0, description='Collecting posterior samples', max=1.0, style=ProgressSty…

Starting setup.
Activation vars allocated.
Continuous variables allocated.
Setup time:  1.1809122562408447
Num vars:  7058
Num constraints:  24656
Optimization success?:  True
Logfile: 

Gurobi 9.0.2 (linux64) logging started Tue Sep  7 02:22:02 2021

Gurobi Optimizer version 9.0.2 build v9.0.2rc0 (linux64)
Optimize a model with 18046 rows, 7058 columns and 123162 nonzeros
Model fingerprint: 0xf1d0a99e
Model has 381 quadratic objective terms
Variable types: 6762 continuous, 296 integer (296 binary)
Coefficient statistics:
  Matrix range     [2e-01, 1e+01]
  Objective range  [1e-01, 1e+01]
  QObjective range [4e+00, 6e+04]
  Bounds range     [1e+00, 1e+00]
  RHS range        [1e+00, 1e+01]
Presolve removed 17176 rows and 6676 columns
Presolve time: 0.22s
Presolved: 870 rows, 382 columns, 2702 nonzeros
Presolved model has 360 quadratic objective terms
Variable types: 288 continuous, 94 integer (94 binary)

Root relaxation: objective -2.745839e+00, 1184 iterations, 0.01 seconds

    Nodes

New score -93768532.584353, old score 102.306179, alpha 0.000000
0: Accept rate 0.000000
New score -99640040.195265, old score 102.306179, alpha 0.000000
1: Accept rate 0.000000
New score -90667334.059732, old score 102.306179, alpha 0.000000
2: Accept rate 0.000000
New score -42787456.560281, old score 102.306179, alpha 0.000000
3: Accept rate 0.000000
New score -101645064.295172, old score 102.306179, alpha 0.000000
4: Accept rate 0.000000
New score -186245453.024509, old score 102.306179, alpha 0.000000
5: Accept rate 0.000000
New score -125991512.941829, old score 102.306179, alpha 0.000000
6: Accept rate 0.000000
New score -172362707.677133, old score 102.306179, alpha 0.000000
7: Accept rate 0.000000
New score -66173565.774476, old score 102.306179, alpha 0.000000
8: Accept rate 0.000000
New score -130450691.598468, old score 102.306179, alpha 0.000000
9: Accept rate 0.000000
New score -74706908.137102, old score 102.306179, alpha 0.000000
10: Accept rate 0.000000
New score -1925

In [5]:
for k, tree in enumerate(posterior_sample_sets[-1]):
    draw_scene_tree_structure_meshcat(tree, zmq_url=vis.window.zmq_url, prefix="guesses/%d" % k)