<a href="https://colab.research.google.com/github/Cauch-BS/cscg-hippo/blob/main/notebooks/Data_Visualization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title Mount Drive
# @markdown 1. Accept the requested permissions from Google Drive
from google.colab import drive
drive.mount('/content/drive')

# @markdown  2. Download the UMAP embedding data
!mkdir ./CSCG_example_NABI
!mkdir ./CSCG_example_NABI/first_run
!gdown -O ./CSCG_example_NABI/first_run/embedding_42.npy 1b-2_JbeuvsQcasH3YB0d7S65s-y_F_fR
!gdown -O ./CSCG_example_NABI/first_run/day_data.npy 18T1Pp5DKQMqHdQOoiUd7cPSHjTozvNJv
!gdown -O ./CSCG_example_NABI/first_run/selected_pos.tar.gz 1x_BP265r9gZU2mAEhvp7ePYpaU8wYdNQ

# @markdown 3. Load the UMAP data with `numpy.load()`
import numpy as np

base = "/content/CSCG_example_NABI/first_run"

umap_embedding = np.load(f"{base}/embedding_42.npy")
day_ind_array = np.load(f"{base}/day_data.npy")

# @markdown 4. Load the dataframe with pandas.

import tarfile
from pathlib import Path
import pandas as pd

# Extract tar.gz
parquet_root = Path(base) / "selected_pos_parquet"
parquet_root.mkdir(exist_ok=True)

with tarfile.open(f"{base}/selected_pos.tar.gz", "r:gz") as tar:
    tar.extractall(parquet_root, filter = "data")

# Load parquet files (mirrors original list semantics)
parquet_dir = parquet_root / "vr2p_extracted_selected_pos_parquet"

selected_pos_big = [
    pd.read_parquet(p)
    for p in sorted(parquet_dir.glob("selected_pos_*.parquet"))
]

Mounted at /content/drive
Downloading...
From: https://drive.google.com/uc?id=1b-2_JbeuvsQcasH3YB0d7S65s-y_F_fR
To: /content/CSCG_example_NABI/first_run/embedding_42.npy
100% 5.65M/5.65M [00:00<00:00, 174MB/s]
Downloading...
From: https://drive.google.com/uc?id=18T1Pp5DKQMqHdQOoiUd7cPSHjTozvNJv
To: /content/CSCG_example_NABI/first_run/day_data.npy
100% 3.77M/3.77M [00:00<00:00, 82.4MB/s]
Downloading...
From: https://drive.google.com/uc?id=1x_BP265r9gZU2mAEhvp7ePYpaU8wYdNQ
To: /content/CSCG_example_NABI/first_run/selected_pos.tar.gz
100% 5.17M/5.17M [00:00<00:00, 25.6MB/s]


In [None]:
from google.colab import output
output.enable_custom_widget_manager()

In [None]:
import plotly.graph_objects as go
import plotly.io as pio
from matplotlib import cm, colors
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

day_num = 8

# Set the initial viewing parameters
initial_camera_position = {
    'up': {'x': 0, 'y': 0, 'z': 1},
    'center': {'x': 0, 'y': 0, 'z': 0},
    'eye': {'x': -1.5, 'y': 1.5, 'z': 1.5}
}

# Function to update the camera position and text box values
def update_view(b):
    new_camera_position = dict(
        up=dict(x=up_x.value, y=up_y.value, z=up_z.value),
        center=dict(x=center_x.value, y=center_y.value, z=center_z.value),
        eye=dict(x=eye_x.value, y=eye_y.value, z=eye_z.value)
    )
    fig.update_layout(scene_camera=new_camera_position)

def get_current_values(b):
    up_x.value = fig.layout.scene.camera.up['x']
    up_y.value = fig.layout.scene.camera.up['y']
    up_z.value = fig.layout.scene.camera.up['z']
    center_x.value = fig.layout.scene.camera.center['x']
    center_y.value = fig.layout.scene.camera.center['y']
    center_z.value = fig.layout.scene.camera.center['z']
    eye_x.value = fig.layout.scene.camera.eye['x']
    eye_y.value = fig.layout.scene.camera.eye['y']
    eye_z.value = fig.layout.scene.camera.eye['z']

def reset_view(b):
    fig.update_layout(scene_camera=initial_camera_position)

    # Update the text box values
    up_x.value = initial_camera_position['up']['x']
    up_y.value = initial_camera_position['up']['y']
    up_z.value = initial_camera_position['up']['z']
    center_x.value = initial_camera_position['center']['x']
    center_y.value = initial_camera_position['center']['y']
    center_z.value = initial_camera_position['center']['z']
    eye_x.value = initial_camera_position['eye']['x']
    eye_y.value = initial_camera_position['eye']['y']
    eye_z.value = initial_camera_position['eye']['z']

def save_pdf(b):
    # Save the figure as a high-resolution PDF
    fig.write_image(f"session_{day_num}.pdf", engine="kaleido", width=600, height=600, scale=10)


def save_png(b):
    # Save the figure as a high-resolution PNG
    fig.write_image(f"session_{day_num}.png", engine="kaleido", width=600, height=600, scale=10)

# Create the save PNG button
save_png_button = widgets.Button(description='Save PNG')
save_png_button.on_click(save_png)


# Function to save the current coordinates to a numpy file
def save_coordinates(b):
    current_coordinates = np.array([
        [up_x.value, up_y.value, up_z.value],
        [center_x.value, center_y.value, center_z.value],
        [eye_x.value, eye_y.value, eye_z.value]
    ])
    np.save('coordinates.npy', current_coordinates)

# Create the save coordinates button
save_coordinates_button = widgets.Button(description='Save Coordinates')
save_coordinates_button.on_click(save_coordinates)


# Create the text boxes for viewing parameters
up_x = widgets.FloatText(value=initial_camera_position['up']['x'], description='up_x')
up_y = widgets.FloatText(value=initial_camera_position['up']['y'], description='up_y')
up_z = widgets.FloatText(value=initial_camera_position['up']['z'], description='up_z')
center_x = widgets.FloatText(value=initial_camera_position['center']['x'], description='center_x')
center_y = widgets.FloatText(value=initial_camera_position['center']['y'], description='center_y')
center_z = widgets.FloatText(value=initial_camera_position['center']['z'], description='center_z')
eye_x = widgets.FloatText(value=initial_camera_position['eye']['x'], description='eye_x')
eye_y = widgets.FloatText(value=initial_camera_position['eye']['y'], description='eye_y')
eye_z = widgets.FloatText(value=initial_camera_position['eye']['z'], description='eye_z')

# Create the update button
update_button = widgets.Button(description='Update View')
update_button.on_click(update_view)

# Create the get values button
get_values_button = widgets.Button(description='Get Current Values')
get_values_button.on_click(get_current_values)

# Create the reset button
reset_button = widgets.Button(description='Reset View')
reset_button.on_click(reset_view)


# Create the save button
save_button = widgets.Button(description='Save PDF')
save_button.on_click(save_pdf)

view_params_container = widgets.VBox(
    [up_x, up_y, up_z, center_x, center_y, center_z, eye_x, eye_y, eye_z, update_button, get_values_button, reset_button, save_button, save_png_button]
)


# Define the plotting code
selected_position = selected_pos_big[day_num]

# markers = [
#     {'name': 'Track', 'color': '#808080', 'position': 0.9},
#     {'name': 'Indicator-Near', 'color': '#FFD700', 'position': 0.85},
#     {'name': 'R1-Near', 'color': '#FF7F00', 'position': 0.8},
#     {'name': 'R2-Near', 'color': '#FF5500', 'position': 0.75},
#     {'name': 'Indicator-Far', 'color': '#74a9cf', 'position': 0.7},
#     {'name': 'R1-Far', 'color': '#2b8cbe', 'position': 0.65},
#     {'name': 'R2-Far', 'color': '#045a8d', 'position': 0.6},
#     {'name': 'Teleportation', 'color': '#000000', 'position': 0.55},
# ]

markers = [
    {'name': 'Track', 'color': '#808080', 'position': 0.9},
    {'name': 'Indicator-Near', 'color': '#FBB4B9', 'position': 0.85},  # Light Magenta
    {'name': 'R1-Near', 'color': '#F768A1', 'position': 0.8},  # Medium Magenta
    {'name': 'R2-Near', 'color': '#C51B8A', 'position': 0.75},  # Base Magenta
    {'name': 'Indicator-Far', 'color': '#A8D8A7', 'position': 0.7},  # Light Green
    {'name': 'R1-Far', 'color': '#41AE76', 'position': 0.65},  # Medium Green
    {'name': 'R2-Far', 'color': '#006D2C', 'position': 0.6},  # Dark Green
    {'name': 'Teleportation', 'color': '#000000', 'position': 0.55},
]


for marker in markers:
    selected_position.loc[selected_position['position_marker'] == marker['name'], 'area-color'] = marker['color']

from matplotlib import cm, colors

norm = colors.Normalize(vmin=0, vmax=230)

for reward_id in [1, 2]:
    if reward_id == 1:
        cmap = cm.get_cmap('Blues').copy()
    else:
        cmap = cm.get_cmap('YlOrBr').copy()
    ind = selected_position['reward_id'] == reward_id
    selected_position.loc[ind, 'position-color'] = list(map(colors.rgb2hex, cmap(norm(selected_position.loc[ind, 'position']))))

ind_A = selected_position.set == 'Cue Set A'
ind_else = selected_position.set != 'Cue Set A'

selected_position.loc[ind_A, 'set-color'] = '#000000'
selected_position.loc[ind_else, 'set-color'] = '#808080'

norm = colors.Normalize(vmin=0, vmax=100)
for reward_id in [1, 2]:
    if reward_id == 1:
        cmap = cm.get_cmap('Blues').copy()
    else:
        cmap = cm.get_cmap('YlOrBr').copy()
    ind = selected_position['reward_id'] == reward_id
    trial_number_list = norm(selected_position.loc[ind, 'trial_number']).astype(float)

    selected_position.loc[ind, 'trial-color'] = list(map(colors.rgb2hex, cmap(trial_number_list)))

embedding = umap_embedding[np.squeeze(day_ind_array == day_num), :]

customdata = [[row['position_marker'], row['trial_number'], row['reward_id'], row['position'], row['set']] for index, row in selected_position.iterrows()]
scatter_data = go.Scatter3d(
    x=embedding[:, 0], y=embedding[:, 1], z=embedding[:, 2],
    mode='markers',
    marker=dict(
        size=1.6,
        color=selected_position['area-color'],
        opacity=0.8
    ),
    customdata=customdata,
    hovertemplate="<br>".join([
        "Trial Type: %{customdata[2]}",
        "Position: %{customdata[3]}",
        "Area: %{customdata[0]}",
        "Trial Number: %{customdata[1]}",
        "Set: %{customdata[4]}",
    ])
)

fig = go.FigureWidget(data=[scatter_data])

template = 'simple_white'

#template = 'plotly_dark'

# Format axis appearance
fig.update_layout(
    margin=dict(t=40),
    template=template,
    scene=dict(
        xaxis_showspikes=False,
        yaxis_showspikes=False,
        zaxis_showspikes=False,
        xaxis_title="UMAP 1",
        yaxis_title="UMAP 2",
        zaxis_title="UMAP 3"
    )
)
fig.update_layout(width=1200, height=650)

fig.layout._compound_props['template']['layout']['scene']['xaxis']['showline'] = False
fig.layout._compound_props['template']['layout']['scene']['yaxis']['showline'] = False
fig.layout._compound_props['template']['layout']['scene']['zaxis']['showline'] = False
fig.layout._compound_props['template']['layout']['scene']['xaxis']['ticks'] = ''
fig.layout._compound_props['template']['layout']['scene']['yaxis']['ticks'] = ''
fig.layout._compound_props['template']['layout']['scene']['zaxis']['ticks'] = ''

fig.update_layout(
    scene=dict(
        xaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title=''),
        yaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title=''),
        zaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title='')
    )
)

fig.update_layout(font_color="white")



# Add region marker annotations.
# for marker in markers:
#     fig.add_annotation(x=0.1, y=marker['position'], text=marker['name'], font_color=marker['color'], showarrow=False)

# Select trial UI
trial_use_selected = widgets.Checkbox(value=False, description='Highlight selected', indent=False, layout=widgets.Layout(width='150px'))
trial_selector = widgets.IntSlider(
    value=selected_position.trial_number.min(),
    min=selected_position.trial_number.min(),
    max=selected_position.trial_number.max(),
    step=1,
    description='Trial Number',
    indent=False
)
trial_container = widgets.HBox([trial_use_selected, trial_selector])

# Select color scheme
color_options = ['Trial Type - Areas', 'Trial Type - Position', 'Trial Type - Trial Number', 'Cue Sets']
color_scheme_selector = widgets.Dropdown(
    options=color_options,
    value=color_options[0],
    description='Color:',
    layout=widgets.Layout(width='250px'),
    indent=False
)
color_scheme_container = widgets.HBox([color_scheme_selector])

# Function to update the text box values
def update_text_box_values(change, _):
    up_x.value = fig.layout.scene.camera.up['x']
    up_y.value = fig.layout.scene.camera.up['y']
    up_z.value = fig.layout.scene.camera.up['z']
    center_x.value = fig.layout.scene.camera.center['x']
    center_y.value = fig.layout.scene.camera.center['y']
    center_z.value = fig.layout.scene.camera.center['z']
    eye_x.value = fig.layout.scene.camera.eye['x']
    eye_y.value = fig.layout.scene.camera.eye['y']
    eye_z.value = fig.layout.scene.camera.eye['z']


# Observe the camera position changes
fig.layout.scene.camera.on_change(update_text_box_values, 'up')
fig.layout.scene.camera.on_change(update_text_box_values, 'center')
fig.layout.scene.camera.on_change(update_text_box_values, 'eye')

# Display full UI
ui = widgets.VBox([color_scheme_container, trial_container, view_params_container, save_coordinates_button, fig])
display(ui)




The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.


The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.


The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.


The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.


Message serialization failed with:
Out of range float values are not JSON compliant: nan
Supporting this message is deprecated in jupyter-client 7, please make sure your message is JSON-compliant



VBox(children=(HBox(children=(Dropdown(description='Color:', layout=Layout(width='250px'), options=('Trial Typ…