# Mapping tutorial

In [177]:
# Global setup code

import genstudio.plot as Plot
import itertools
import jax
import jax.numpy as jnp
import genjax
from urllib.request import urlopen
from genjax import SelectionBuilder as S
from genjax import ChoiceMapBuilder as C
from genjax import pretty
from genjax.typing import IntArray, FloatArray, PRNGKey
from penzai import pz
from typing import Any, Iterable

import os

pretty()

html = Plot.Hiccup
Plot.configure({"display_as": "html", "dev": False})

# Ensure a location for image generation.
os.makedirs("imgs", exist_ok=True)

## Basic data structures

In [178]:
@pz.pytree_dataclass
class Blocks(genjax.PythonicPytree):
    pos: IntArray
    present: FloatArray

grid = r"""
.....................
.....................
.XXX..........X......
.X.X..........X......
.X............XXX....
.X..............X....
.X..............X....
.XX............XXX...
..X..............X...
..X..............X...
..X..............X...
..X..............X...
..XXXXX......XXXXX...
......X......X.......
......X....XXX.......
......X....X.........
......X....X.........
......XXXXXX.........
.....................
.....................
.....................
"""

GRID_RADIUS = 10

def all_grid_points(radius):
    return jnp.stack(jnp.meshgrid(jnp.arange(-radius, radius + 1), jnp.arange(-radius, radius + 1)), axis=-1).reshape(-1, 2)

ALL_GRID_POINTS = all_grid_points(GRID_RADIUS)

# extract the points from the grid string:
def extract_points_from_grid_string(grid_string):
    lines = grid_string.strip().split("\n")
    height = len(lines)
    width = len(lines[0])
    present = jnp.zeros((height, width))
    for i, line in enumerate(lines):
        for j, char in enumerate(line):
            if char == "X":
                present = present.at[i, j].set(1)
    return present.flatten()

true_presence = extract_points_from_grid_string(grid)

true_blocks = Blocks(ALL_GRID_POINTS, true_presence)

## Plotting

In [179]:
OFFSETS = jnp.array([
    [-0.5, -0.5],
    [0.5, -0.5],
    [0.5, 0.5],
    [-0.5, 0.5],
])

OFFSETS_CLOSED = jnp.concat([OFFSETS, OFFSETS[:1]], axis=0)

def plot_blocks(blocks: Blocks):
    block_vertices = blocks.pos[:, None, :] + OFFSETS_CLOSED[None, :, :]
    return [Plot.line(vertices, fill="black", fillOpacity=present) for vertices, present in zip(block_vertices, blocks.present) if present > 0]

def line_segments_of_block(block: IntArray):
    vertices = block[None, :] + OFFSETS
    return jnp.stack([vertices, jnp.roll(vertices, 1, axis=0)], axis=1)

line_segments = jax.vmap(line_segments_of_block, in_axes=0)

def world_plot_spec(blocks):
    lines = plot_blocks(blocks)
    return lines + Plot.ellipse([[0, 0]], r=0.2, fill="red")

def plot_world(blocks):
    return wrap_plot(
        world_plot_spec(blocks)
    )

def wrap_plot(p):
    return Plot.new(
        p,
        {"width": 500, "aspectRatio": 1},
        Plot.domain([-10, 10], [-10, 10])
    )

In [180]:
plot_world(true_blocks)

## Prior

In [181]:
@genjax.gen
def block_prior_single(position: IntArray):
    present = genjax.bernoulli(0.0) @ "present"
    return Blocks(position, present)

blocks_prior = block_prior_single.vmap(in_axes=0)

In [182]:
key = jax.random.key(0)
sample_prior_jitted = jax.jit(blocks_prior.simulate)
tr = sample_prior_jitted(key, (ALL_GRID_POINTS,))
blocks = tr.get_retval()
log2_weight = tr.get_score() / jnp.log(2.0)
print(f"Log2(weight): {log2_weight}")
plot_world(blocks)

Log2(weight): -441.000244140625


In [183]:
num_samples = 50
key, sub_key = jax.random.split(key)
repeated_grid_positions = jax.lax.broadcast(grid_positions, (num_samples,))
traces = jax.vmap(sample_prior_jitted, in_axes=0)(jax.random.split(sub_key, num_samples), (repeated_grid_positions,))
blocks = traces.get_retval()
average_blocks = Blocks(jnp.mean(blocks.pos, axis=0), jnp.mean(blocks.present, axis=0)) # TODO: why can't I apply this directly to Blocks? I thought Pytrees allowed this.
plot_world(average_blocks)

## Exact sensor model

In [184]:
def solve_lines(p, u, q, v, PARALLEL_TOL=1.0e-10):
    """
    Solves for the intersection of two lines defined by points and direction vectors.

    Args:
    - p, u: Point and direction vector defining the first line.
    - q, v: Point and direction vector defining the second line.
    - PARALLEL_TOL: Tolerance for determining if lines are parallel.

    Returns:
    - s, t: Parameters for the line equations at the intersection point.
            Returns [-inf, -inf] if lines are parallel.
    """
    det = u[0] * v[1] - u[1] * v[0]
    return jnp.where(
        jnp.abs(det) < PARALLEL_TOL,
        jnp.array([-jnp.inf, -jnp.inf]),
        jnp.array(
            [
                (v[0] * (p[1] - q[1]) - v[1] * (p[0] - q[0])) / det,
                (u[1] * (q[0] - p[0]) - u[0] * (q[1] - p[1])) / det,
            ]
        ),
    )


def distance(dir, seg):
    """
    Computes the distance from the origin to a segment, in a given direction.

    Args:
    - p: The Pose object.
    - seg: The Segment object.

    Returns:
    - float: The distance to the segment. Returns infinity if no valid intersection is found.
    """
    pos = jnp.array([0, 0])
    a = solve_lines(pos, dir, seg[0], seg[1] - seg[0])
    return jnp.where(
        (a[0] >= 0.0) & (a[1] >= 0.0) & (a[1] <= 1.0),
        a[0],
        jnp.inf,
    )

def unit_dir(angle):
    return jnp.array([jnp.cos(angle), jnp.sin(angle)])


In [185]:
ANGLES = jnp.arange(0, 1, 0.04) * 2 * jnp.pi
MAX_DISTANCE = GRID_RADIUS * 2

def distance_to_block(dir, block):
    segs = line_segments_of_block(block.pos)
    dists = jax.vmap(lambda seg: distance(dir, seg))(segs)
    return jax.lax.cond(
        block.present > 0,
        lambda: jnp.min(dists, axis=0),
        lambda: jnp.inf)

def distance_to_blocks(dir, blocks):
    return jnp.minimum(jnp.min(jax.vmap(lambda block: distance_to_block(dir, block))(blocks), axis=0), MAX_DISTANCE)

def sensor_distances(blocks):
    return jax.vmap(lambda angle: distance_to_blocks(unit_dir(angle), blocks))(ANGLES)

def sensor_plot_spec(readings):
    unit_vecs = jax.vmap(unit_dir, in_axes=0)(ANGLES)
    ray_endpoints = unit_vecs * readings[:, None]
    return [
        Plot.line([[0, 0], [x, y]], stroke=Plot.constantly("sensor rays"))
        for x, y in ray_endpoints
    ] + [
        Plot.ellipse([endpoint], r=0.1, fill=Plot.constantly("sensor readings"))
        for endpoint in ray_endpoints
    ]


true_readings = sensor_distances(true_blocks)

def world_and_sensors_plot_spec(blocks, readings):
    return world_plot_spec(blocks) + sensor_plot_spec(readings)

wrap_plot(world_and_sensors_plot_spec(true_blocks, true_readings))


## Noisy sensor model

In [186]:
NOISE = 0.5

@genjax.gen
def sensor_model_single(blocks, angle):
    return genjax.normal(distance_to_blocks(unit_dir(angle), blocks), NOISE) @ "readings"

sensor_model = sensor_model_single.vmap(in_axes=(None, 0))

In [187]:
key = jax.random.key(0)
trace = sensor_model.simulate(key, (true_blocks, ANGLES,))
readings = trace.get_retval()
wrap_plot(world_and_sensors_plot_spec(true_blocks, readings))

## Full model

In [188]:
@genjax.gen
def full_model():
    blocks = blocks_prior(ALL_GRID_POINTS) @ "blocks"
    readings = sensor_model(blocks, ANGLES) @ ()
    return (blocks, readings)


In [189]:
key = jax.random.key(1)
trace = full_model.simulate(key, ())
blocks, readings = trace.get_retval()
wrap_plot(world_and_sensors_plot_spec(blocks, readings))
trace.get_choices()

In [190]:
key = jax.random.key(1)
trace = sensor_model.simulate(key, (true_blocks, ANGLES))
observed_readings = trace.get_retval()
wrap_plot(sensor_plot_spec(observed_readings))

## Maximum a posteriori estimation

In [201]:
def map_estimation(key, observed_readings, N = 100):
    """Very naive MAP estimation. Just try N random samples and take the one with the highest weight."""
    model_importance = jax.jit(full_model.importance)
    keys = jax.random.split(key, N)
    constraints = C["readings"].set(observed_readings)
    traces, log_weights = jax.vmap(lambda key: model_importance(key, constraints, ()))(keys)
    log_weights = log_weights - jax.scipy.special.logsumexp(log_weights)
    print("Normalized log weights > -10: ", [lw.item() for lw in log_weights if lw > -10])
    best_index = jnp.argmax(log_weights)
    best_trace = jax.tree_util.tree_map(lambda trace: trace[best_index], traces)
    blocks, readings = best_trace.get_retval()
    return blocks

key, subkey = jax.random.split(key)
inferred_blocks = map_estimation(subkey, observed_readings, N=100000)
wrap_plot(world_and_sensors_plot_spec(inferred_blocks, observed_readings))

Log weights > -10:  [0.0]


This is clearly bad. We need better inference methods.