# Reprojection of tracks to 3D

In [None]:
import os
import pickle
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyvista as pv
from matplotlib.pyplot import cm
from tqdm import tqdm

from collab_env.alignment import reprojection
from collab_env.data.file_utils import get_project_root
from collab_env.utils import visualization as viz

In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Get info from GCloud

In [None]:
data_type = "fieldwork_processed"
session = "2024_02_06-session_0001"
splat_video = "C0043"
camera_id = "rgb_1"

# Make the session directory
session_data_dir = get_project_root() / "data" / data_type / session

# Requirements for alignment
environment_dir = session_data_dir / "environment"
aligned_frames_dir = session_data_dir / "aligned_frames" / camera_id
aligned_splat_dir = session_data_dir / "aligned_splat" / camera_id

# Files for reprojection
mesh_fn = environment_dir / splat_video / "rade-features" / "mesh" / "mesh.ply"
aligned_camera_fn = aligned_splat_dir / f"{camera_id}_mesh-aligned.pkl"
tracking_fn = aligned_frames_dir / f"{camera_id}_tracked_bboxes.csv"

Load the required files

In [None]:
# Tracks over the 2d video
df_tracks = pd.read_csv(tracking_fn)
df_tracks["track_id"] = df_tracks["track_id"].astype(int)

# Camera aligned to the mesh
with open(aligned_camera_fn, "rb") as f:
    camera_params = pickle.load(f)

### Create reprojection objects

In [None]:
# Create the camera to reproject from
camera = reprojection.Camera(
    K=camera_params["K"],
    c2w=camera_params["c2w"],
    width=camera_params["width"],
    height=camera_params["height"],
)

# Create the mesh to reproject to
mesh_environment = reprojection.MeshEnvironment(mesh_fn)

# Render the camera view of the mesh --> updates camera.image and camera.depth
image, depth = mesh_environment.render_camera(camera)

Visualize camera view of the mesh

In [None]:
# Plot the image and depth
fig, axs = plt.subplots(1, 2, figsize=(20, 10))

axs[0].imshow(image)
axs[0].set_title("Image")
axs[0].axis("off")

axs[1].imshow(depth)

### Visualize depth on mesh

In [None]:
# Use our camera and mesh environment to create an array of points on the mesh
mesh_depths = reprojection.get_depths_on_mesh(
    camera=camera, mesh=mesh_environment.mesh, smooth=True, radius=0.01
)

Use visualization tools to view the projected depths

In [None]:
# Load the mesh as a pyvista object
pv_mesh = pv.read(mesh_fn)

# Convert to RGB via colormap
depth_rgb = cm.get_cmap("viridis")(mesh_depths)
depth_rgb[np.isnan(mesh_depths)] = [0, 0, 0, 1]

# Set as an attribute of the mesh
pv_mesh.point_data["depths"] = depth_rgb

Make a nice plot of depths visualized on the mesh

In [None]:
# Copy our camera parameters and format for pyvista
# TLB --> NEED TO FIX, IDK WHY IT INTERNALLY ALTERS PV_CAMERA
pv_camera = deepcopy(camera_params)
_ = viz.format_pyvista_camera_params(pv_camera)

# Create our camera pose
poses = [pv_camera["c2w"]]

# Create camera arguments for making frustrum
camera_kwargs = viz.CAMERA_KWARGS.copy()
camera_kwargs.update(
    {
        "line_width": 5,
        "scale": 0.025,
        "opacity": 0.9,
        "color": [0.9, 0.9, 0.9],
        "show_axes": True,
    }
)

# Select attribute to visualize
depth_mesh_kwargs = {
    "scalars": "depths",
    "rgb": True,
}

# Show the splat
plotter = viz.visualize_splat(
    pv_mesh,
    poses,
    mesh_kwargs=depth_mesh_kwargs,
    viz_kwargs=viz.VIZ_KWARGS,
    camera_kwargs=camera_kwargs,
)

# Screenshot and plot as image
mesh_image = plotter.screenshot(
    window_size=(3000, 3000),
    return_img=True,
)

mesh_image = np.array(mesh_image)
plt.imshow(mesh_image)
plt.axis("off")

### Map tracks between spaces

Start by loading the tracks as a dataframe

In [None]:
# Grab the top 5 track IDs
n_agents = 20
n_min_frames = 150
mesh_bounds = 375  # Y coordinate to consider points off the mesh
std_threshold = 2  # Filter tracks outside of this many stds of the moving average

# Get track counts per ID
track_counts = df_tracks.groupby("track_id").size()
track_ids = track_counts[track_counts > n_min_frames].index[:n_agents]

# Create a subset of the tracks
df_subset_tracks = df_tracks[df_tracks["track_id"].isin(track_ids)]

Get image coordinates from the bounding box

In [None]:
# Maps the bounding box to image coordinates (u, v) formatting
uv_coords = df_subset_tracks.apply(
    lambda x: reprojection.bbox_to_coords(
        x[["x1", "y1", "x2", "y2"]].values.astype(float), method="bottom_center"
    ),
    axis=1,
)
uv_coords = np.stack(uv_coords)

# Now add to the dataframe
df_subset_tracks.loc[:, ["u", "v"]] = uv_coords

Now put those coordinates in 3D and filter

In [None]:
projected_tracks = []

for track_id, df_id in tqdm(df_subset_tracks.groupby("track_id")):
    # Remove points off the mesh
    mesh_filter = df_id["y2"] > mesh_bounds
    df_id = df_id.loc[mesh_filter]

    if df_id.empty:
        continue
    df_id = df_id.reset_index(drop=True)  # Reset the index for our mapping
    df_id = df_id.sort_values("frame")  # Sort by frame number within track

    # Filter outliers and smooth tracks
    filtered_coords = reprojection.filter_coords(df_id, std_threshold=2)
    smoothed_coords = reprojection.smooth_coords(filtered_coords)

    # Update original dataframe
    df_id.loc[df_id.index, ["u", "v"]] = smoothed_coords.values
    uv_coords = df_id.loc[:, ["u", "v"]].values

    # Apply filtering and smoothing

    # TLB need to add size in here
    world_points = camera.project_to_world(uv_coords)

    # Filter out points that are not on the mesh
    mesh_contact_filter = ~np.isnan(world_points).any(1)
    world_points = world_points[mesh_contact_filter]

    # Add the points and sizes to the dataframe
    mesh_contact_idxs = np.where(mesh_contact_filter)[0]
    df_id.loc[mesh_contact_idxs, ["x", "y", "z"]] = world_points

    # Add to list of dataframes
    projected_tracks.append(df_id)

df_projected_tracks = pd.concat(projected_tracks).reset_index(drop=True)

#### Plot over images 

Plot over the original image, mesh view, and mesh depth 2D images

In [None]:
# Grab the original image used to align the camera
sampled_frame_dir = aligned_splat_dir / "sampled-frames"
sampled_frame_fn = list(sampled_frame_dir.glob("*.png"))[0]
image = plt.imread(sampled_frame_fn)

Setup plotting parameters

In [None]:
# Plot over the original image, mesh view, and mesh depth 2D images
cmap = "rainbow"
track_ids = df_projected_tracks["track_id"].unique()
n_tracks = len(track_ids)
colormap = plt.cm.get_cmap(cmap)  # Using rainbow for bright, distinct colors
colors = [colormap(i) for i in np.linspace(0, 1, n_tracks)]
track_colors = dict(zip(track_ids, colors))

line_kwargs = {"linestyle": "-", "linewidth": 6, "label": None, "alpha": 1}

# Create track_images directory if it doesn't exist
os.makedirs("track_images", exist_ok=True)

Go through each image and plot

In [None]:
all_images = [camera.image, camera.depth, image]

for i, current_image in enumerate(all_images):
    # Plot tracks on the image
    track_image = viz.plot_tracks_on_image(
        df_tracks=df_projected_tracks,
        image=current_image,
        colors=track_colors,
        line_kwargs=line_kwargs,
    )

    # Save each image with a unique name in track_images folder, using tight layout without borders
    track_image.savefig(
        f"track_images/track-image-{i}.png", bbox_inches="tight", pad_inches=0
    )

#### Plot tracks in 3D

In [None]:
# Grab the pyvista camera parameters
pose = [pv_camera["c2w"]]

# Format camera frustrum
camera_kwargs = viz.CAMERA_KWARGS.copy()
camera_kwargs.update(
    {
        "line_width": 5,
        "scale": 0.025,
        "opacity": 0.9,
        "color": [0.9, 0.9, 0.9],
        "show_axes": True,
    }
)

# Make a plotter
plotter = viz.visualize_splat(
    mesh_fn.as_posix(),
    pose,
    mesh_kwargs=viz.MESH_KWARGS,
    viz_kwargs=viz.VIZ_KWARGS,
    camera_kwargs=camera_kwargs,
)

# Add the tracks to the plotter
plotter = viz.add_tracks_to_mesh(df_projected_tracks, plotter)

# Show the plotter
# plotter.show(window_size=(800, 800))

mesh_image = plotter.screenshot(
    window_size=(800, 800),
    return_img=True,
)

mesh_image = np.array(mesh_image)
plt.imshow(mesh_image)
plt.axis("off")