In [1]:
%matplotlib widget

from matplotlib.cm import ScalarMappable
from matplotlib.colors import BoundaryNorm, ListedColormap
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display

def create_3d_hit_plot(hits, color_col):
    # Define a colormap
    unique_t_values = np.sort(hits[color_col].unique())
    N_t = len(unique_t_values)

    # Create a mapping from t values to colors
    cmap = plt.get_cmap("gist_rainbow", len(unique_t_values))
    t_to_color = {t: cmap(((i * (N_t // 2 - 1)) % N_t + 1) / N_t) for i, t in enumerate(unique_t_values)}

    # Create a ListedColormap and BoundaryNorm for the color bar
    colors = [t_to_color[t] for t in unique_t_values]
    listed_cmap = ListedColormap(colors)
    norm = BoundaryNorm(boundaries=np.arange(N_t + 1) - 0.5, ncolors=N_t)

    # Calculate consistent limits for x, y, and z
    x_min, x_max = hits["x"].min(), hits["x"].max()
    y_min, y_max = hits["y"].min(), hits["y"].max()
    z_min, z_max = hits["z"].min(), hits["z"].max()

    # Function to create the plot
    def create_plot(filtered_hits):
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection="3d")
        
        colors = filtered_hits[color_col].map(t_to_color)
        sc = ax.scatter(
            filtered_hits["z"], filtered_hits["x"], filtered_hits["y"], c=colors, alpha=0.5
        )

        # Add color bar
        sm = ScalarMappable(cmap=listed_cmap, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax)
        cbar.set_label(color_col)
        cbar.set_ticks(np.arange(N_t))
        cbar.set_ticklabels(unique_t_values)

        # Set labels
        ax.set_xlabel("Z")
        ax.set_ylabel("X")
        ax.set_zlabel("Y")
        ax.set_title("3D Scatter Plot of Hits")

        # Set consistent limits
        ax.set_xlim(z_min, z_max)
        ax.set_ylim(x_min, x_max)
        ax.set_zlim(y_min, y_max)

        plt.show()


    # Function to update the plot based on the selected t range
    def update_plot(t_range):
        min_t, max_t = t_range[0] - 0.0001, t_range[1] + 0.0001
        filtered_hits = hits[(hits[color_col] >= min_t) & (hits[color_col] <= max_t)]
        create_plot(filtered_hits)


    # Create a range slider widget for t selection
    t_slider = widgets.FloatRangeSlider(
        value=[hits[color_col].min(), hits[color_col].max()],
        min=hits[color_col].min(),
        max=hits[color_col].max(),
        step=1,
        description="t range:",
    )
    output = widgets.interactive_output(update_plot, {"t_range": t_slider})
    display(t_slider, output)

## Sinusoid Strategy

Th cell below creates points with the sinusoid strategy. Try using different values for `step_size_multiplier`. Note the following:
1. Values below 2 are not a good choice because the generated values will not cover the full circle
2. Even values are not a good choice because they will repeat the points after every rotation

In [2]:
import sys
import os

import pandas as pd
import torch
from torch import Tensor
from torch.nn import functional as F

# Add the parent directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from src.Util.Globals import DATA_DIR
from src.Util import CoordinateSystemEnum
from src.Data.CollisionEventLoader import CollisionEventLoader
from src.TimeStep.ForAdjusting.PlacementStrategy import SinusoidStrategy
from src.TimeStep.ForAdjusting.VolumeLayer import VLTimeStep

normalize_hits = False
strategy = SinusoidStrategy(step_size_multiplier=13.87193)
time_step = VLTimeStep(placement_strategy=strategy, normalize_hits=normalize_hits)

root_dir = os.path.dirname(os.path.abspath(os.getcwd()))
coordinate_system = CoordinateSystemEnum.CARTESIAN
data_loader = CollisionEventLoader(
    os.path.join(root_dir, DATA_DIR, "train_sample"),
    time_step,
    batch_size=1,
    coordinate_system=coordinate_system,
    normalize_hits=normalize_hits,
    device="cpu",
)

hits_df = pd.DataFrame(columns=["x", "y", "z", "t"])
gt_hits_df = pd.DataFrame(columns=["x", "y", "z", "t"])

num_t = time_step.get_num_time_steps()
gt_hits_list, gt_batch_index, _ = next(iter(data_loader))

for t in range(num_t):
    gt_hits = gt_hits_list[t]
    gt_index = gt_batch_index[t]

    part_ids = time_step.assign_to_shell_parts(gt_hits, t, coordinate_system)
    num_parts = time_step.get_num_shell_parts(t)
    batch_size = data_loader.batch_size

    batch_part_ids = gt_index * num_parts + part_ids
    _, gt_size = torch.unique(batch_part_ids, return_counts=True)

    gt_size = F.pad(gt_size, (0, num_parts * batch_size - gt_size.size(0)), value=0)
    gt_size = gt_size.view(batch_size, num_parts)

    hits_t = time_step.place_hits(
        t, gt_size, coordinate_system=CoordinateSystemEnum.CARTESIAN
    )

    hits_df_new = pd.DataFrame(hits_t.cpu().numpy(), columns=["x", "y", "z"])
    hits_df_new["t"] = t
    hits_df = pd.concat([hits_df, hits_df_new], ignore_index=True)

    gt_hits_df_new = pd.DataFrame(gt_hits.cpu().numpy(), columns=["x", "y", "z"])
    gt_hits_df_new["t"] = t
    gt_hits_df = pd.concat([gt_hits_df, gt_hits_df_new], ignore_index=True)

  hits_df = pd.concat([hits_df, hits_df_new], ignore_index=True)
  gt_hits_df = pd.concat([gt_hits_df, gt_hits_df_new], ignore_index=True)


In [3]:
s_hits_df = hits_df.sample(frac=0.1)
s_gt_hits_df = gt_hits_df.sample(frac=0.1)

s_hits_df["t"] = s_hits_df["t"] * 2
s_gt_hits_df["t"] = s_gt_hits_df["t"] * 2 + 1

combined_df = pd.concat([s_hits_df, s_gt_hits_df], ignore_index=True)

create_3d_hit_plot(combined_df, color_col = "t")

FloatRangeSlider(value=(0.0, 15.0), description='t range:', max=15.0, step=1.0)

Output()