In [1]:
# %% [markdown]
# # Jupyter Notebook: Interactive Heatmap via ipywidgets
#
# **Requirements**:
# - ipywidgets
# - matplotlib
# - xarray
# - numpy
# - TNT_Tensor.py with the MultiDimTensorZarr class
# - The Zarr store created by your code (e.g., "travel_times_3d.zarr")
#
# This notebook will:
# 1. Load your Zarr data from disk using `MultiDimTensorZarr`.
# 2. Convert the loaded NumPy array + dims_info into an Xarray DataArray.
# 3. Use `ipywidgets.interact` or `interactive` to create a time slider and display a Matplotlib heatmap.

# %% [markdown]
# ## 0. Imports and Setup

# %%
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider

# Import your custom class
from TNT_Tensor import MultiDimTensorZarr

# %% [markdown]
# ## 1. Load Zarr data via `MultiDimTensorZarr`

# %%
# Adjust these paths if necessary
zarr_path = "output/travel_times_3d.zarr"
group_name = "od_time_data"

# Load your custom data (3D => time × O × D)
od_tensor = MultiDimTensorZarr.load_from_zarr(zarr_path, group_name)

# Extract the numpy data and dimension info
data_3d = od_tensor.data       # shape: [time, O, D]
dims_info = od_tensor.dims_info

print("Data shape:", data_3d.shape)
print("dims_info:", dims_info)

# %% [markdown]
# ## 2. Convert to an Xarray DataArray
#
# - We assume dims_info is like:
#   ```python
#   [
#     {"name": "time", "ids": [...list of times...]},
#     {"name": "O",    "ids": [...list of origins...]},
#     {"name": "D",    "ids": [...list of destinations...]},
#   ]
#   ```
# - If your `time` labels are strings that represent real dates, we can parse them to datetimes.

# %%
dim_names = [d["name"] for d in dims_info]  # e.g. ["time", "O", "D"]

# Build coords dict
coords_dict = {}
for d in dims_info:
    label = d["name"]
    ids_list = d["ids"]
    coords_dict[label] = (label, ids_list)

da_xr = xr.DataArray(
    data_3d,
    dims=dim_names,         # (time, O, D)
    coords=coords_dict,     # {"time": (time, [...]), "O": (O, [...]), "D": (D, [...])}
    name="travel_time"
)

# Optional: if your time coords are date-like strings, parse them:
if "time" in da_xr.coords:
    try:
        da_xr["time"] = pd.to_datetime(da_xr["time"].values)
    except Exception:
        pass  # If they're not parseable, ignore

print("Xarray DataArray dims:", da_xr.dims)
print("Coordinates keys:", list(da_xr.coords))

# %% [markdown]
# ## 3. Interactive Heatmap with `ipywidgets`
#
# We'll create an integer slider for the time dimension (index 0..N-1). When it changes, we'll select that time slice and plot a heatmap of (O × D).

# %%
# Grab the actual time coordinates
time_dim = "time"
if time_dim not in da_xr.dims:
    raise KeyError(f"Dimension '{time_dim}' not found in the DataArray!")

time_values = da_xr.coords[time_dim].values
num_times = len(time_values)

print(f"We have {num_times} time steps")

# A helper function to plot a given time index
def plot_heatmap_time(time_index=0):
    """
    time_index: integer index into 'time' dimension [0..num_times-1]
    """
    # 1) Safety check
    if time_index < 0 or time_index >= num_times:
        print("Invalid time index")
        return

    # 2) Slice the data at that index
    da_slice = da_xr.isel({time_dim: time_index})  # shape: (O, D)

    # 3) Prepare the figure
    plt.figure(figsize=(8, 6))
    # We'll just plot with plt.imshow
    # If the shape is [O, D] = [78, 78], you might skip large ticks for clarity
    plt.imshow(da_slice, origin="upper", aspect="auto", cmap="viridis")

    # 4) Optional: label the axes with partial info
    # For a 78×78 matrix, labeling every coordinate would be cluttered
    plt.title(f"Travel Time\nTime = {time_values[time_index]}")
    plt.xlabel("D dimension (destination index)")
    plt.ylabel("O dimension (origin index)")

    plt.colorbar(label="Travel Time (units?)")
    plt.show()

# %% [markdown]
# ### 3.1. Create an `interact` slider

# %%
time_slider = IntSlider(
    value=0,
    min=0,
    max=num_times - 1,
    step=1,
    description="Time Index",
    continuous_update=False
)

_ = interact(plot_heatmap_time, time_index=time_slider)


Data shape: (16, 78, 78)
dims_info: [{'name': 'time', 'ids': ['2025-01-16 00:00', '2025-01-16 01:00', '2025-01-16 02:00', '2025-01-16 03:00', '2025-01-16 04:00', '2025-01-16 05:00', '2025-01-16 06:00', '2025-01-16 07:00', '2025-01-16 08:00', '2025-01-16 09:00', '2025-01-16 10:00', '2025-01-16 11:00', '2025-01-16 12:00', '2025-01-16 13:00', '2025-01-16 14:00', '2025-01-16 15:00']}, {'name': 'O', 'ids': ['APACHE BLVD_EASTBOUND_33.41467_-111.89196', 'APACHE BLVD_EASTBOUND_33.41664_-111.9399', 'APACHE BLVD_WESTBOUND_33.41487_-111.88485', 'APACHE BLVD_WESTBOUND_33.41487_-111.89073', 'APACHE BLVD_WESTBOUND_33.41487_-111.89196', 'APACHE BLVD_WESTBOUND_33.41668_-111.93975', 'AZ-101-LOOP_CLOCKWISE_33.41286_-111.89151', 'AZ-101-LOOP_CLOCKWISE_33.41746_-111.89146', 'AZ-101-LOOP_CLOCKWISE_33.42214_-111.89137', 'AZ-101-LOOP_CLOCKWISE_33.42721_-111.89122', 'AZ-101-LOOP_CLOCKWISE_33.44349_-111.89104', 'AZ-101-LOOP_COUNTERCLOCKWISE_33.40159_-111.89098', 'AZ-101-LOOP_COUNTERCLOCKWISE_33.41301_-111.8912

interactive(children=(IntSlider(value=0, continuous_update=False, description='Time Index', max=15), Output())…