In [14]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
import numpy as np

rng = np.random.default_rng()

In [16]:
from mtt.simulator import Simulator
from mtt.data import OnlineDataset

init_simulator = lambda: Simulator(
    width=1000,
    p_initial=4,
    p_birth=2,
    p_survival=0.95,
    p_clutter=1e-4,
    p_detection=0.95,
    sigma_motion=5.0,
    sigma_initial_state=(50.0, 50.0, 2.0),
    n_sensors=1,
    noise_range=10.0,
    noise_bearing=0.1,
    dt=0.1,
)
dataset = OnlineDataset(n_steps=100, sigma_position=0.01, length=20, img_size=128, init_simulator=init_simulator)
dataset = list(dataset)

In [17]:
positions = [info[-1]["target_positions"] for *_, info in dataset]
n_targets = [len(pos) for pos in positions]
print(f"# of targets: mean = {np.mean(n_targets):0.2f}, std = {np.std(n_targets):0.2f}")
print(f"position std: {np.std(np.concatenate(positions), axis=0)}")

# of targets: mean = 19.45, std = 4.02
position std: [717.36619456 434.89755688]


In [18]:
import guild.ipy as guild
import matplotlib.pyplot as plt
import numpy as np
from glob import glob
import os
from mtt.models import Conv2dCoder
from torchinfo import summary

runs = guild.runs()
runs = runs[runs.started > "2022-04-04"]
matching = runs.run.apply(lambda run: run.value.short_id == "6c0cf3b6")
dir = runs[matching].iloc[0].run.value.dir
checkpoint_file = glob(os.path.join(dir, "**/epoch*.ckpt"))[0]
model = Conv2dCoder.load_from_checkpoint(checkpoint_file)
summary(model, (1,) + model.input_shape)

Layer (type:depth-idx)                   Output Shape              Param #
Conv2dCoder                              --                        --
├─Sequential: 1-1                        [1, 128, 16, 16]          --
│    └─Conv2d: 2-1                       [1, 128, 64, 64]          207,488
│    └─ReLU: 2-2                         [1, 128, 64, 64]          --
│    └─Conv2d: 2-3                       [1, 128, 32, 32]          1,327,232
│    └─ReLU: 2-4                         [1, 128, 32, 32]          --
│    └─Conv2d: 2-5                       [1, 128, 16, 16]          1,327,232
│    └─ReLU: 2-6                         [1, 128, 16, 16]          --
├─Sequential: 1-2                        [1, 128, 16, 16]          --
│    └─Conv2d: 2-7                       [1, 1024, 16, 16]         132,096
│    └─ReLU: 2-8                         [1, 1024, 16, 16]         --
│    └─Conv2d: 2-9                       [1, 1024, 16, 16]         1,049,600
│    └─ReLU: 2-10                        [1, 1024, 16,

In [42]:
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import torch

n_past = 5
width = 1000
extent = (-width / 2, width / 2, -width / 2, width / 2)

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
fig.set_facecolor('white')

def animate(i):
    sensor_img, position_img, info = dataset[i]
    with torch.no_grad():
        output_img = model(sensor_img[None,...].cuda())[0, -n_past].cpu()
    sensor_img = sensor_img[-n_past].numpy()
    position_img = position_img[-n_past].numpy()
    info = info[-n_past]

    target_positions = info["target_positions"]
    measurements = np.concatenate(info["measurements"])
    clutter = np.concatenate(info["clutter"])

    for i in range(len(ax)):
        ax[i].clear()
        ax[i].plot(*target_positions.T, "r.", markersize=5)
        ax[i].plot(*measurements.T, "bx", markersize=5)
        ax[i].plot(*clutter.T, "g^", markersize=5)
    
    ax[0].set_title(f"Sensor Measurements")
    ax[1].set_title(f"Target Positions (Ground Truth)")
    ax[2].set_title(f"CNN Output")

    ax[0].imshow(sensor_img, extent=extent, origin="lower", cmap="gray_r")
    ax[1].imshow(position_img, extent=extent, origin="lower", cmap="gray_r")
    ax[2].imshow(output_img, extent=extent, origin="lower", cmap="gray_r")
    ax[1].legend(["Target Position", "Sensor Measurement", "Clutter"], loc='upper center', bbox_to_anchor=(0.5, -0.1), ncol=3, fancybox=True)

anim = FuncAnimation(fig, animate, frames=len(dataset), interval=100)
# HTML(anim.to_jshtml())
anim.save("./animation.mp4", fps=10, dpi=150, bitrate=-1)
plt.close()