In [4]:
%matplotlib widget

from matplotlib.cm import ScalarMappable
from matplotlib.colors import BoundaryNorm, ListedColormap
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display

def create_3d_hit_plot(hits, color_col):
    # Define a colormap
    unique_t_values = np.sort(hits[color_col].unique())
    N_t = len(unique_t_values)

    # Create a mapping from t values to colors
    cmap = plt.get_cmap("gist_rainbow", len(unique_t_values))
    t_to_color = {t: cmap(((i * (N_t // 2 - 1)) % N_t + 1) / N_t) for i, t in enumerate(unique_t_values)}

    # Create a ListedColormap and BoundaryNorm for the color bar
    colors = [t_to_color[t] for t in unique_t_values]
    listed_cmap = ListedColormap(colors)
    norm = BoundaryNorm(boundaries=np.arange(N_t + 1) - 0.5, ncolors=N_t)

    # Calculate consistent limits for x, y, and z
    x_min, x_max = hits["x"].min(), hits["x"].max()
    y_min, y_max = hits["y"].min(), hits["y"].max()
    z_min, z_max = hits["z"].min(), hits["z"].max()

    # Function to create the plot
    def create_plot(filtered_hits):
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection="3d")
        
        colors = filtered_hits[color_col].map(t_to_color)
        sc = ax.scatter(
            filtered_hits["z"], filtered_hits["x"], filtered_hits["y"], c=colors, alpha=0.5
        )

        # Add color bar
        sm = ScalarMappable(cmap=listed_cmap, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax)
        cbar.set_label(color_col)
        cbar.set_ticks(np.arange(N_t))
        cbar.set_ticklabels(unique_t_values)

        # Set labels
        ax.set_xlabel("Z")
        ax.set_ylabel("X")
        ax.set_zlabel("Y")
        ax.set_title("3D Scatter Plot of Hits")

        # Set consistent limits
        ax.set_xlim(z_min, z_max)
        ax.set_ylim(x_min, x_max)
        ax.set_zlim(y_min, y_max)

        plt.show()


    # Function to update the plot based on the selected t range
    def update_plot(t_range):
        min_t, max_t = t_range[0] - 0.0001, t_range[1] + 0.0001
        filtered_hits = hits[(hits[color_col] >= min_t) & (hits[color_col] <= max_t)]
        create_plot(filtered_hits)


    # Create a range slider widget for t selection
    t_slider = widgets.FloatRangeSlider(
        value=[hits[color_col].min(), hits[color_col].max()],
        min=hits[color_col].min(),
        max=hits[color_col].max(),
        step=1,
        description="t range:",
    )
    output = widgets.interactive_output(update_plot, {"t_range": t_slider})
    display(t_slider, output)

## Squeezed Sinusoid Strategy

Th cell below creates points with teh squeezed sinusoid strategy. Try using different values for `step_size_multiplier`. Note the following:
1. Values below 2 are not a good choice because the generated values will not cover the full circle
2. Even values are not a good choice because they will repeat the points after every rotation

In [9]:
import os
import sys
import pandas as pd
import torch
from torch import Tensor

# Add the parent directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from src.TimeStep.ForAdjusting.PlacementStrategy import SqueezedSinusoidStrategy
from src.TimeStep.ForAdjusting.VolumeLayer import VLTimeStep

strategy = SqueezedSinusoidStrategy(step_size_multiplier=2)
time_step = VLTimeStep(placement_strategy=strategy)

hits_df = pd.DataFrame(columns=["x", "y", "z", "t"])
num_t = time_step.get_num_time_steps()
sizes = torch.randint(low=1000, high=2000, size=(num_t,), dtype=torch.int32)

for t in range(num_t):
    hits_t = time_step.place_hits(t, sizes[t:t+1])
    hits_df_new = pd.DataFrame(hits_t.cpu().numpy(), columns=["x", "y", "z"])
    hits_df_new["t"] = t
    hits_df = pd.concat([hits_df, hits_df_new], ignore_index=True)
    

  hits_df = pd.concat([hits_df, hits_df_new], ignore_index=True)


In [10]:
create_3d_hit_plot(hits_df, color_col = "t")

FloatRangeSlider(value=(0.0, 7.0), description='t range:', max=7.0, step=1.0)

Output()