In [10]:
import os
import glob
import json

import numpy as np
import pandas as pd

# Plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

# For PCA + clustering
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

# ============================================================
# 1) POOL GEOMETRY  (roughly matches your MWM script)
# ============================================================

POOL_CX, POOL_CY = 333, 231
POOL_A,  POOL_B  = 453 / 2, 451 / 2  # semi-axes in pixels

RINGS = [
    (333, 231, 453 / 2, 451 / 2),  # Ring 1 (outermost)
    (333, 231, 392 / 2, 392 / 2),  # Ring 2
    (334, 228, 326 / 2, 325 / 2),  # Ring 3
    (336, 229, 268 / 2, 265 / 2),  # Ring 4
    (336, 229, 206 / 2, 204 / 2),  # Ring 5 (innermost)
]

DATA_FOLDER = "test 2"   # <--- put your SLK files here
SLK_PATTERN = os.path.join(DATA_FOLDER, "*.slk")

# ============================================================
# 2) LOAD STRATEGY CLASSIFICATIONS FROM CSV
# ============================================================

def load_strategy_classifications(csv_file="strategy_per_mouse_by_phase_day_trial.csv"):
    """
    Load strategy classifications from CSV file.
    Returns a dictionary with keys: (mouse_id, day, trial) -> strategy
    """
    strategies = {}
    try:
        df = pd.read_csv(csv_file)
        # Melt the dataframe to have mouse IDs as values instead of columns
        melted_df = df.melt(id_vars=['phase', 'day', 'trial'], 
                           var_name='mouse_id', 
                           value_name='strategy')
        
        # Convert mouse_id to integer for consistency
        melted_df['mouse_id'] = melted_df['mouse_id'].astype(int)
        
        # Create dictionary with composite key
        for _, row in melted_df.iterrows():
            key = (row['mouse_id'], row['day'], row['trial'])
            strategies[key] = row['strategy']
            
        print(f"✅ Loaded {len(strategies)} strategy classifications from {csv_file}")
        return strategies
    except Exception as e:
        print(f"❌ Error loading strategy classifications: {e}")
        return {}


# ============================================================
# 3) SLK LOADER (AnyMaze → x,y)
# ============================================================

def load_slk(filepath: str) -> pd.DataFrame:
    """
    Minimal AnyMaze .slk loader.
    Returns a DataFrame with columns: x, y
    """
    frames = {}
    with open(filepath, "r") as f:
        for line in f:
            if not line.startswith("C;"):
                continue
            parts = line.strip().split(";")
            col = row = None
            val = None
            for p in parts[1:]:
                if p.startswith("X"):
                    col = int(p[1:]) if p[1:].isdigit() else None
                elif p.startswith("Y"):
                    row = int(p[1:]) if p[1:].isdigit() else None
                elif p.startswith("K"):
                    v = p[1:]
                    if v.startswith('"') and v.endswith('"'):
                        v = v[1:-1]
                    try:
                        val = float(v)
                    except ValueError:
                        val = None
            if col in (2, 3) and row is not None and val is not None:
                frames.setdefault(row, {})[col] = val

    x_vals, y_vals = [], []
    for row in sorted(frames.keys()):
        frame = frames[row]
        if 2 in frame and 3 in frame:
            x_vals.append(frame[2])
            y_vals.append(frame[3])

    return pd.DataFrame({"x": x_vals, "y": y_vals})


# ============================================================
# 4) POOL SHAPES FOR PLOTLY (circular)
# ============================================================

def build_pool_shapes():
    """
    Return a list of Plotly 'shapes' that draw the pool
    as concentric circles + midlines in the axes x2/y2.
    """
    shapes = []

    # Concentric rings → approximate each ellipse by a circle
    for (cx, cy, a, b) in RINGS:
        r = (a + b) / 2.0  # average radius so it looks circular
        shapes.append(
            dict(
                type="circle",
                xref="x2",
                yref="y2",
                x0=cx - r,
                x1=cx + r,
                y0=cy - r,
                y1=cy + r,
                line=dict(width=1, color="rgba(0,0,0,0.3)"),
            )
        )

    # Vertical midline
    shapes.append(
        dict(
            type="line",
            xref="x2",
            yref="y2",
            x0=POOL_CX,
            x1=POOL_CX,
            y0=POOL_CY - POOL_B - 20,
            y1=POOL_CY + POOL_B + 20,
            line=dict(width=1, color="rgba(0,0,0,0.3)", dash="dot"),
        )
    )

    # Horizontal midline
    shapes.append(
        dict(
            type="line",
            xref="x2",
            yref="y2",
            x0=POOL_CX - POOL_A - 20,
            x1=POOL_CX + POOL_A + 20,
            y0=POOL_CY,
            y1=POOL_CY,
            line=dict(width=1, color="rgba(0,0,0,0.3)", dash="dot"),
        )
    )

    return shapes


# ============================================================
# 5) FEATURE EXTRACTION FOR CLUSTERING
#    (simple: resample each trajectory to N points and flatten)
# ============================================================

def resample_trajectory(df: pd.DataFrame, n_points: int = 100) -> np.ndarray:
    """
    Resample a 2D trajectory (x,y) to n_points using simple interpolation.
    Returns flattened array: [x0, y0, x1, y1, ..., xN-1, yN-1]
    """
    if df.empty:
        return np.zeros(2 * n_points, dtype=float)

    x = df["x"].values
    y = df["y"].values
    t = np.linspace(0, 1, len(df))
    t_new = np.linspace(0, 1, n_points)

    x_new = np.interp(t_new, t, x)
    y_new = np.interp(t_new, t, y)

    return np.vstack([x_new, y_new]).T.reshape(-1)


# ============================================================
# 6) INTERACTIVE PLOTLY VIEWER
# ============================================================

def create_interactive_plot_with_trajectory(
    X_2d,
    file_info_list,
    trajectories_df,
    clusters,
    strategies_dict,
    best_score=None,
    output_html="interactive_trajectories.html",
):
    """
    Interactive PCA + trajectory viewer.

    Left: PCA scatter (one dot per trial).
    Right: pool + trajectory of the clicked dot.
    """

    X_2d = np.asarray(X_2d)
    clusters = np.asarray(clusters)
    xs = X_2d[:, 0]
    ys = X_2d[:, 1]

    # numeric cluster values for colormap
    if clusters.dtype.kind in {"U", "S", "O"}:
        unique_labels = np.unique(clusters)
        label_to_int = {lab: i for i, lab in enumerate(unique_labels)}
        cluster_vals = np.array([label_to_int[lab] for lab in clusters], dtype=float)
    else:
        cluster_vals = clusters.astype(float)

    # hover text + customdata
    hover_text = []
    customdata = []
    for i, info in enumerate(file_info_list):
        mouse_id = info.get('mouse_id', 'NA')
        day = info.get('day', 'NA')
        trial = info.get('trial', 'NA')
        
        # Get strategy classification
        strategy = "Not classified"
        if mouse_id != 'NA' and day != 'NA' and trial != 'NA':
            strategy_key = (int(mouse_id), int(day), int(trial))
            strategy = strategies_dict.get(strategy_key, "Not classified")
        
        hover_text.append(
            f"File: {os.path.basename(info.get('file', 'NA'))}<br>"
            f"Mouse: {mouse_id}<br>"
            f"Day: {day}<br>"
            f"Trial: {trial}<br>"
            f"Type: {'reversal' if info.get('reversal') else 'training'}<br>"
            f"Cluster: {clusters[i]}<br>"
            f"Strategy: {strategy}"
        )
        customdata.append(i)

    # pack trajectories as JSON
    traj_list = []
    for df in trajectories_df:
        traj_list.append(
            dict(
                x=df["x"].tolist(),
                y=df["y"].tolist(),
            )
        )
    traj_json = json.dumps(traj_list)

    # figure with 2 subplots
    fig = make_subplots(
        rows=1,
        cols=2,
        column_widths=[0.55, 0.45],
        subplot_titles=("Cluster map (PCA space)", "Trajectory"),
    )

    # LEFT: PCA scatter
    fig.add_trace(
        go.Scatter(
            x=xs,
            y=ys,
            mode="markers",
            marker=dict(
                size=8,
                opacity=0.7,
                color=cluster_vals,
                colorscale="Viridis",
                colorbar=dict(title="Cluster"),
            ),
            text=hover_text,
            customdata=customdata,
            hovertemplate="%{text}<extra></extra>",
        ),
        row=1,
        col=1,
    )

    # RIGHT: empty trajectory trace
    fig.add_trace(
        go.Scatter(
            x=[],
            y=[],
            mode="lines",
            line=dict(width=2, color="red"),
        ),
        row=1,
        col=2,
    )

    # axes for PCA
    fig.update_xaxes(title="PC1", row=1, col=1)
    fig.update_yaxes(title="PC2", row=1, col=1)

    # axes for pool + trajectory
    fig.update_xaxes(
        title="x (px)",
        range=[POOL_CX - POOL_A - 20, POOL_CX + POOL_A + 20],
        row=1,
        col=2,
    )
    fig.update_yaxes(
        title="y (px)",
        range=[POOL_CY + POOL_B + 20, POOL_CY - POOL_B - 20],
        autorange=False,
        row=1,
        col=2,
        scaleanchor="x2",
        scaleratio=1,
    )

    # pool shapes
    shapes = build_pool_shapes()

    # layout / title
    title = f"Trajectory clustering — {len(np.unique(clusters))} clusters"
    if best_score is not None:
        title += f" (silhouette = {best_score:.3f})"

    fig.update_layout(
        shapes=shapes,
        showlegend=False,
        dragmode="pan",
        hovermode="closest",
        title=title,
    )

    fig.add_annotation(
        text="Click a point to load the trajectory",
        xref="paper",
        yref="paper",
        x=0.78,
        y=1.06,
        showarrow=False,
        font=dict(size=12),
    )

    # JS: click scatter → update trajectory trace (index 1)
    fig_js = f"""
    <script>
    var traj = {traj_json};

    document.addEventListener('DOMContentLoaded', function() {{
        var plot = document.querySelector('.plotly-graph-div');
        if (!plot) {{
            console.warn("Plotly graph div not found for click handler.");
            return;
        }}

        plot.on('plotly_click', function(ev) {{
            var pt = ev.points[0];
            var idx = pt.customdata;
            var t = traj[idx];

            Plotly.restyle(plot, {{
                x: [t.x],
                y: [t.y]
            }}, [1]);  // trace 1 = right subplot (trajectory)
        }});
    }});
    </script>
    """

    # export HTML
    html_core = pio.to_html(fig, full_html=False, include_plotlyjs="cdn")
    html_str = html_core + fig_js

    with open(output_html, "w", encoding="utf-8") as f:
        f.write(html_str)

    print("✅ Saved interactive HTML to:", os.path.abspath(output_html))
    print("➡ Open this file in your browser and click a dot to see its trajectory.")

    fig.show()
    return fig


# ============================================================
# 7) MAIN PIPELINE: load SLKs, build features, cluster, plot
# ============================================================

def main():
    # ---- Load strategy classifications ----
    strategies_dict = load_strategy_classifications()
    
    # ---- load SLK files ----
    files = sorted(glob.glob(SLK_PATTERN))
    if not files:
        print("No SLK files found in", SLK_PATTERN)
        return

    print(f"Found {len(files)} SLK files.")

    trajectories_df = []
    file_info_list = []

    for path in files:
        df = load_slk(path)
        trajectories_df.append(df)

        base = os.path.basename(path)
        name_lower = base.lower()

        # very rough parsing (adapt if needed)
        try:
            mouse_id = int(base[:3])
        except Exception:
            mouse_id = None

        try:
            j_idx = name_lower.index("j")
            day = int(base[j_idx + 1])
        except Exception:
            day = None

        is_reversal = "reversal" in name_lower

        # crude trial guess
        trial = None
        try:
            if is_reversal:
                rev_index = name_lower.rfind("reversal")
                trial_str = base[rev_index + len("reversal") : base.rfind(".")]
                if trial_str:
                    trial = int(trial_str)
            else:
                trial = int(base[j_idx + 2])
        except Exception:
            trial = None

        file_info_list.append(
            dict(
                file=path,
                mouse_id=mouse_id,
                day=day,
                trial=trial,
                reversal=is_reversal,
            )
        )

    # ---- build feature matrix ----
    feat_list = []
    for df in trajectories_df:
        feat = resample_trajectory(df, n_points=100)  # 200-D per trial
        feat_list.append(feat)
    X = np.vstack(feat_list)

    # ---- PCA to 2D ----
    pca = PCA(n_components=2)
    X_2d = pca.fit_transform(X)

    # ---- clustering ----
    k = 3  # choose number of clusters you like
    kmeans = KMeans(n_clusters=k, n_init=10, random_state=0)
    labels = kmeans.fit_predict(X_2d)

    # silhouette (if k > 1)
    best_score = None
    if k > 1 and len(X_2d) > k:
        best_score = silhouette_score(X_2d, labels)

    # ---- interactive plot ----
    create_interactive_plot_with_trajectory(
        X_2d=X_2d,
        file_info_list=file_info_list,
        trajectories_df=trajectories_df,
        clusters=labels,
        strategies_dict=strategies_dict,
        best_score=best_score,
        output_html="interactive_trajectories.html",
    )


if __name__ == "__main__":
    main()

✅ Loaded 80 strategy classifications from strategy_per_mouse_by_phase_day_trial.csv
Found 160 SLK files.
✅ Saved interactive HTML to: /datasets/_deepnote_work/interactive_trajectories.html
➡ Open this file in your browser and click a dot to see its trajectory.


<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=02ffb3a7-f1f8-4631-b7f7-44afdef896f1' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>