### Scratch ideas

- Three timescales for plots
  - Overall Experiment plots (cumulatively calculated)
  - Light-cycle plots (every 12 hours)
  - Block plots (every block)

- All subjects
- Per subject
- By patch number
- By patch rate

- Env plots
  - Block stats
    - How long were the blocks?
  - Patch stats
    - How many times and for what % of total time was each patch L, M, H?

- Features to plot
  - Weights
  - Position / Ethogram
  - Pellets / feeding bouts
  - Wheel distance spun
  - Patch preference
___

### Plots

- Block Page
  - Header: "Arena: block_start_time - start_end_time"
  - Plot 1: Patch stats: Patch rate for each patch as a bar, sampled patch threshold values as overlaid scatter
  - Plot 2: Weight: Weight over time
  - Plot 3: Position: X-Y line position over time with running cumulative total distance traveled
    - Option 1: 3d plot with time in z-axis, cumulative distance in colorbar
    - Option 2: 2d plot with time / cumulative distance in colorbar / linewidth
  - Plot 4: Ethogram: Contiguous blocks of time spent in given regions. Downsampled to 1s windows?
    - One row for each: nest, corridor, field, patch1, patch2, patch3; text overlay of total distance traveled within each continguous block?
    - Color-coded for when in the same location as 1+ others (one color per subject)
    - Text: percentage of time together overall, and by region
  - Plot 5: Total 
  - Plot 5: Pellets: Pellets per patch - cumulative line plots, x-axis time
  - Plot 6: Wheel distance: Wheel distance per patch - cumulative line plots, x-axis time
  - Plot 7: Running cumulative normalized patch preference (5s bins?): by wheel distance, pellets, and time spent 
  - Plot 8: Running windowed normalized patch preference: by wheel distance (2m?), pellets (5?), time spent (10 mins?)

  - Notes
    - By "plot" I mean "axis". How these different axes are arranged into figures is TBD.
    - Per subject plots: All except 1
    - All subject plots: 6::
    - For each patch plot, info should include patch number and patch rate
    - Each in the set of (each subject, each patch rate, each patch number) should have its own unique color
    - Wheel distance is the preferred of the three metrics (wheel distance, pellets, time spent) for patch preference

- Light-cycle Page
  - Header: "Arena: light-cycle_start_time - light-cycle_end_time"
  - Plot 1: Block viz calendar: Horizontal bars with each block's duration, x-axis in datetime 
    - Clicking on block takes you to that block page?
  - Plot 2: Block viz histogram: Histogram of block durations
  - Plot 3: Patch rate counts: split bar plot, 3 groups (patches) of 3 groups (rate counts)
  - Plot 4: Weight (same as 'Plot 2' in 'Block Page')
  - Plot 5: Position (same as 'Plot 3' in 'Block Page')
  - Plot 6: Ethogram (same as 'Plot 4' in 'Block Page')
  - Plot 7: Pellets per patch by patch number (similar to 'Plot 5' in 'Block Page')
  - Plot 8: Pellets per patch by patch rate (similar to 'Plot 5' in 'Block Page')
  - Plot 9: Wheel distance per patch by patch number (similar to 'Plot 6' in 'Block Page')
  - Plot 10: Wheel distance per patch by patch rate (similar to 'Plot 6' in 'Block Page')
  - Plot 11: Running cumulative normalized patch preference by patch number (similar to 'Plot 7' in 'Block Page')
  - Plot 12: Running cumulative normalized patch preference by patch rate (similar to 'Plot 7' in 'Block Page')
  - Plot 13: Running windowed normalized patch preference by patch number (similar to 'Plot 8' in 'Block Page')
  - Plot 14: Running windowed normalized patch preference by patch rate (similar to 'Plot 8' in 'Block Page')

    - Notes
      - For plots 7::, vertical bars at time of each block switch. Text in-between each vertical bar indicating patch rates for that block.

- Overall Experiment Page
  - Text info: Exp Name, epoch root paths, total recording time?
  - Plot 1: Arena light-cycle calendar:
    - Clicking on item takes you to that Light-cycle page?
  - Rest:
    - Same as Light-cycle Page?
    - Can index / x-axis by block-number instead of time in some cases?


In [None]:
"""Imports."""

from datetime import date
from pathlib import Path

from dotmap import DotMap
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go

import aeon
from aeon.io import reader
from aeon.io.device import Device, register
from aeon.analysis.utils import visits, distancetravelled
from aeon.schema import core, foraging, social

In [None]:
"""Build social schema."""

core.metadata  # binder function: "device-name passed": returns a `reader.Metadata` Reader object
metadata = Device("Metadata", core.metadata)

# ---

# Environment (will be a nested, multiple binder function Device object)
# ---

# BlockState
# binder function: "device-name passed"; `pattern` will be set by `Device` object name: "Environment"
block_state_b = lambda pattern: {
    "BlockState": reader.Csv(
        f"{pattern}_BlockState*", ["pellet_ct", "pellet_ct_thresh", "due_time"]
    )
}

# EnvironmentState
core.environment_state  # binder function: "device-name passed"

# Combine EnvironmentState and BlockState
env_block_state_b = lambda pattern: register(
    pattern, core.environment_state, block_state_b
)

# LightEvents
cols = ["channel", "value"]
light_events_r = reader.Csv("Environment_LightEvents*", cols)
light_events_b = lambda pattern: {
    "LightEvents": light_events_r
}  # binder function: "empty pattern"

# MessageLog
core.message_log  # binder function: "device-name passed"

# SubjectState
cols = ["id", "weight", "type"]
subject_state_r = reader.Csv("Environment_SubjectState*", cols)
subject_state_b = lambda pattern: {
    "SubjectState": subject_state_r
}  # binder function: "empty pattern"

# SubjectVisits
cols = ["id", "type", "region"]
subject_visits_r = reader.Csv("Environment_SubjectVisits*", cols)
subject_visits_b = lambda pattern: {
    "SubjectVisits": subject_visits_r
}  # binder function: "empty pattern"

# SubjectWeight
cols = ["weight", "confidence", "subject_id", "int_id"]
subject_weight_r = reader.Csv("Environment_SubjectWeight*", cols)
subject_weight_b = lambda pattern: {
    "SubjectWeight": subject_weight_r
}  # binder function: "empty pattern"

# Nested binder fn Device object.
environment = Device(
    "Environment", env_block_state_b, light_events_b, core.message_log  # device name
)

# Separate Device object for subject-specific streams.
subject = Device("Subject", subject_state_b, subject_visits_b, subject_weight_b)

# ---

# Camera
# ---

camera_top_b = lambda pattern: {"CameraTop": reader.Video("CameraTop*")}
camera_top_pos_b = lambda pattern: {
    "CameraTopPos": social.Pose("CameraTop_test-node1*")
}

cam_names = ["North", "South", "East", "West", "Patch1", "Patch2", "Patch3", "Nest"]
cam_names = ["Camera" + name for name in cam_names]
camera_b = [
    lambda pattern, name=name: {name: reader.Video(name + "*")} for name in cam_names
]

camera = Device("Camera", camera_top_b, camera_top_pos_b, *camera_b)

# ---

# Nest
# ---

weight_raw_b = lambda pattern: {
    "WeightRaw": reader.Harp("Nest_200*", ["weight(g)", "stability"])
}
weight_filtered_b = lambda pattern: {
    "WeightFiltered": reader.Harp("Nest_202*", ["weight(g)", "stability"])
}
weight_baselined_b = lambda pattern: {
    "WeightBaselined": reader.Harp("Nest_203*", ["weight(g)", "stability"])
}

nest = Device("Nest", weight_raw_b, weight_filtered_b, weight_baselined_b)

# ---

# Patch
# ---

patches = ["1", "2", "3"]
patch_streams = ["32", "35", "90", "201", "202", "203", "State"]
patch_names = [
    "Patch" + name + "_" + stream for name in patches for stream in patch_streams
]
patch_b = []
for stream in patch_names:
    if "32" in stream:
        fn = lambda pattern, stream=stream: {
            stream: reader.BitmaskEvent(stream + "*", value=34, tag="beambreak")
        }
    elif "35" in stream:
        fn = lambda pattern, stream=stream: {
            stream: reader.BitmaskEvent(stream + "*", value=1, tag="delivery")
        }
    elif "90" in stream:
        fn = lambda pattern, stream=stream: {stream: reader.Encoder(stream + "*")}
    elif "201" in stream:
        fn = lambda pattern, stream=stream: {
            stream: reader.Harp(stream + "*", ["manual_delivery"])
        }
    elif "202" in stream:
        fn = lambda pattern, stream=stream: {
            stream: reader.Harp(stream + "*", ["missed_pellet"])
        }
    elif "203" in stream:
        fn = lambda pattern, stream=stream: {
            stream: reader.Harp(stream + "*", ["retried_delivery"])
        }
    elif "State" in stream:
        fn = lambda pattern, stream=stream: {
            stream: reader.Csv(stream + "*", ["threshold", "offset", "rate"])
        }
    patch_b.append(fn)

patch = Device("Patch", *patch_b)
# ---

# Rfid
# ---

rfid_names = [
    "EventsGate",
    "EventsNest1",
    "EventsNest2",
    "EventsPatch1",
    "EventsPatch2",
    "EventsPatch3",
]
rfid_names = ["Rfid" + name for name in rfid_names]
rfid_b = [
    lambda pattern, name=name: {name: reader.Harp(name + "*", ["rfid"])}
    for name in rfid_names
]

rfid = Device("Rfid", *rfid_b)

social01 = DotMap(
    [   
        metadata, 
        environment,
        subject,
        camera,
        nest,
        patch, 
        rfid
    ]
)

In [None]:
"""Set root, times, and get metadata."""

root = Path("/ceph/aeon/aeon/data/raw/AEON3/social0.1")

# Block start and end time
start_time = pd.Timestamp("2023-12-05 14:10:13")
end_time = pd.Timestamp("2023-12-05 15:10:27")

metadata = aeon.load(root, social01.Metadata, start_time, end_time)

In [None]:
"""Get weight data."""

subjects_weights = aeon.load(root, social01.Subject.SubjectWeight, start_time, end_time)
subjects = subjects_weights.subject_id.unique()

In [None]:
"""Plot weight data."""

subjects_weights = aeon.load(root, social01.Subject.SubjectWeight, start_time, end_time)
subjects = subjects_weights.subject_id.unique()
fig = go.Figure()
for subject in subjects:
    subject_weights = subjects_weights[subjects_weights.subject_id == subject]
    fig.add_trace(
        go.Scatter(
            x=subject_weights.index,
            y=subject_weights.weight,
            mode="lines",
            name=subject,
        )
    )
fig.show()

In [None]:
"""Get position data."""

subjects_positions = aeon.load(root, social01.Camera.CameraTopPos, start_time, end_time)

# Replace class integers with class strings (this should not be this hard...)
model_path = metadata.metadata.values[0].PoseTrackingTop.ModelPath
model_path_prefix = metadata.metadata.values[
    0
].PoseTrackingTop.PathPrefix = "\\ceph\\aeon\\aeon\\data\\processed"
model_dir = (
    Path(model_path_prefix.replace("\\", "/")).joinpath(Path(model_path.replace("\\", "/")))
)
model_dir = Path(str(model_dir).replace("_", "-"))
subjects_positions = social.class_int2str(subjects_positions, model_dir)

# Set confidence threshold to return position values and downsample by 5x
subjects_positions = subjects_positions[subjects_positions.class_likelihood > 0.9].iloc[::5]

In [None]:
"""Plot position over time."""

fig = go.Figure()
for subject in subjects:
    subject_parts = subjects_positions[subjects_positions["class"] == subject]
    subject_pos = subject_parts[subject_parts["part"] == "centroid"]
    fig.add_trace(
        go.Scatter3d(
            x=subject_pos.x,
            y=subject_pos.y,
            z=subject_pos.index,
            mode="lines",
            name=subject,
        )
    )
fig.show()

In [None]:
"""Get wheel data."""

p1 = -distancetravelled(aeon.load(root, social01.Patch.Patch1_90, start_time, end_time).angle).iloc[::2]
p2 = -distancetravelled(aeon.load(root, social01.Patch.Patch1_90, start_time, end_time).angle).iloc[::2]
p3 = -distancetravelled(aeon.load(root, social01.Patch.Patch1_90, start_time, end_time).angle).iloc[::2]

In [None]:
"""Plot wheel data over time (per subject?)."""