In [1]:
import torch
import numpy as np
from pathlib import Path
from cartesian_polar.exp_autoencoder.agent import Autoencoder
import random
import matplotlib.pyplot as plt
import tqdm
from environment import DuplicatedCoordsEnv
import pandas as pd
import itertools
from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import LabelEncoder
from matplotlib.lines import Line2D
import matplotlib.cm as cm
from sklearn.decomposition import PCA
from scipy.spatial.distance import cdist
from sklearn.manifold import TSNE
import seaborn as sns
from scipy.cluster.hierarchy import dendrogram, linkage
from sklearn.preprocessing import normalize, StandardScaler
import plotly.express as px

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
save_path = Path("save")
save_path.exists()

True

In [4]:
data_dir_ew = save_path / "7-30-A-10-EW-2"
data_dir_lr = save_path / "8-1-10-LR"
print(data_dir_ew.exists())
print(data_dir_lr.exists())

True
True


In [5]:
data_path_ew = data_dir_ew / "data.tar"
data_path_lr = data_dir_lr / "data.tar"
print(data_path_ew.exists())
print(data_path_lr.exists())

True
True


In [6]:
data_dict_ew = torch.load(data_path_ew, weights_only=False, map_location=DEVICE)
data_dict_lr = torch.load(data_path_lr, weights_only=False, map_location=DEVICE)
data_dict_ew.keys()
data_dict_lr.keys()

dict_keys(['rewards', 'steps', 'episodes', 'all_states', 'all_actions', 'losses', 'p', 'epsilons', 'weights_val_stats', 'biases_val_stats', 'weights_grad_stats', 'biases_grad_stats', 'net', 'env', 'weights', 'biases'])

In [7]:
# ASSUMING EQUAL ARCHITECTURES
parameters = data_dict_ew['p']
n_observations = parameters.n_observations
n_actions = parameters.n_actions
n_units = parameters.n_hidden_units
bottleneck = parameters.bottleneck

In [8]:
env = data_dict_ew['env'] # assuming equal environments

### Helper Functions

In [9]:
"""
Function for converting angle degree to cardinal direction
"""
def degrees_to_cardinal(degree):
    # Normalize the degree to [0, 360)
    degree = degree % 360

    # Define the mapping
    directions = {
        0: 'N',
        90: 'E',
        180: 'S',
        270: 'W'
    }

    # Find the closest cardinal angle
    closest = min(directions.keys(), key=lambda x: abs(x - degree))
    return directions[closest]

In [10]:
def convert_state_sample_to_orig_sample(state_sample):
    orig_state_samples = []
    for state in state_sample:

        if state[0].item() == 1:
            odor_cue = 'No Odor'
        elif state[1].item() == 1:
            odor_cue = 'Odor A'
        else:
            odor_cue = 'Odor B'

        coords_orig = env.conv_north_cartesian2orig(state[3:7]).tolist()
        coords_orig.insert(0, odor_cue)
        
        head_deg = coords_orig[3]
        head_cardinal = degrees_to_cardinal(head_deg)
        coords_orig[3] = head_cardinal
        
        coords_orig.append(state[3].item()) # CNP x
        coords_orig.append(state[4].item()) # CNP y
        coords_orig.append(state[7].item()) # CSP x
        coords_orig.append(state[8].item()) # CSP y

        orig_state_samples.append(coords_orig)
    return orig_state_samples

In [11]:
def get_layer_activations(net, input_states, layer_index):
    activations_list = []

    def hook(module, input, output):
        activations_list.append(output.detach().cpu().numpy().squeeze())

    # Register hook
    handle = list(net.mlp.children())[layer_index].register_forward_hook(hook)

    # Forward pass through each input state
    with torch.no_grad():
        for state in input_states:
            net(state)

    # Clean up hook
    handle.remove()

    # Stack into numpy array
    activations_array = np.stack(activations_list)  # Shape: (500, 512)

    return activations_array


In [12]:
def compute_centroids_by_category(activations, labels):
    centroids = {}
    for label in np.unique(labels):
        mask = labels == label
        centroids[label] = activations[mask].mean(axis=0)
    return centroids

## States

In [241]:
# Discrete Points
all_x = [-2, -1, 0, 1, 2]
all_y = [-2, -1, 0, 1, 2]
all_head = [0, 90, 180, 270]
all_odor = [torch.tensor(0), torch.tensor(1), torch.tensor(2)]

all_possible_states = list(itertools.product(all_odor, all_x, all_y, all_head))

# Continuous Points
lin = np.linspace(-2, 2, 50)
xx, yy = np.meshgrid(lin, lin)
# flatten for easy iteration
# x_flat = xx.flatten()
# y_flat = yy.flatten()

cont = np.linspace(-2, 2, 50)


# state_dicts = [
#     {'cue': odor, 'x': x, 'y': y, 'direction': hd}
#     for (odor, x, y, hd) in all_possible_states
# ]

state_dicts_x = [
    {'cue': torch.tensor(2), 'x': x, 'y': y, 'direction': 180}
    for y in [-2,-1,0,1,2]
    for x in cont
]
state_dicts_y = [
    {'cue': torch.tensor(2), 'x': x, 'y': y, 'direction': 180}
    for x in [-2,-1,0,1,2]
    for y in cont
]
state_dicts = state_dicts_x + state_dicts_y
print(len(state_dicts))

500


In [242]:
all_possible_tensors = []

for state_dict in state_dicts:
    state_tensor = DuplicatedCoordsEnv.conv_dict_to_flat_duplicated_coords(env, state_dict).float()
    all_possible_tensors.append(state_tensor)

In [243]:
# think closely about the coordinates; why does y = 2 and y = -2 have this strange more precision

print(len(all_possible_tensors))

500


## Activations

In [244]:
cat = data_dir_ew

# if cat == data_dir_lr:
#     cat_label = 'L/R'
# else:
#     cat_label = 'E/W'

model = Autoencoder(n_observations, n_actions, bottleneck, n_units)

model_path = cat / f'trained-agent-state-0.pt'
model.load_state_dict(torch.load(model_path, weights_only=False, map_location=torch.device('cpu')))
model.eval()

Autoencoder(
  (mlp): Sequential(
    (0): Linear(in_features=21, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=10, bias=True)
    (3): ReLU()
    (4): Linear(in_features=10, out_features=256, bias=True)
    (5): ReLU()
    (6): Linear(in_features=256, out_features=3, bias=True)
  )
)

In [245]:
state_samples = all_possible_tensors

In [246]:
q_values = get_layer_activations(model, state_samples, 6)
chosen_actions = np.argmax(q_values, axis=1)

action_labels = ['Move Forward', 'Turn Left', 'Turn Right']
chosen_action_names = [action_labels[i] for i in chosen_actions]

In [247]:
orig_samples = convert_state_sample_to_orig_sample(state_sample=state_samples)
print(orig_samples)

upper_triangle_coords = [(-1,2), (0,2), (1,2), (2,2), (0,1), (1,1), (2,1), (1,0), (2,0), (2,-1)]
metadata = {
    "odor": [],
    "x": [],
    "y": [],
    "grid_pos": [],
    "head_dir": [],
    "action": [],
    "triangle": [],
    "CNP x": [],
    "CNP y": [],
    "CSP x": [],
    "CSP y": []
}

for state in orig_samples:
    metadata["odor"].append(state[0])
    metadata["x"].append(state[1])
    metadata["y"].append(state[2])
    metadata["grid_pos"].append((state[1], state[2]))
    metadata["head_dir"].append(state[3])
    if (state[1], state[2]) in upper_triangle_coords:
        metadata["triangle"].append('U')
    else:
        metadata["triangle"].append('L')
    metadata["CNP x"].append(state[4])
    metadata["CNP y"].append(state[5])
    metadata["CSP x"].append(state[6])
    metadata["CSP y"].append(state[7])
metadata["action"] = chosen_action_names

metadata_df = pd.DataFrame(metadata)
print(metadata_df.to_string())

[['Odor B', -2.0, -2.0, 'S', 4.0, 4.0, 0.0, 0.0], ['Odor B', -1.9183673858642578, -2.0, 'S', 3.918367385864258, 4.0, 0.08163265138864517, 0.0], ['Odor B', -1.8367347717285156, -2.0, 'S', 3.8367347717285156, 4.0, 0.16326530277729034, 0.0], ['Odor B', -1.7551021575927734, -2.0, 'S', 3.7551021575927734, 4.0, 0.2448979616165161, 0.0], ['Odor B', -1.6734693050384521, -2.0, 'S', 3.673469305038452, 4.0, 0.3265306055545807, 0.0], ['Odor B', -1.59183669090271, -2.0, 'S', 3.59183669090271, 4.0, 0.40816327929496765, 0.0], ['Odor B', -1.5102040767669678, -2.0, 'S', 3.5102040767669678, 4.0, 0.4897959232330322, 0.0], ['Odor B', -1.4285714626312256, -2.0, 'S', 3.4285714626312256, 4.0, 0.5714285969734192, 0.0], ['Odor B', -1.3469388484954834, -2.0, 'S', 3.3469388484954834, 4.0, 0.6530612111091614, 0.0], ['Odor B', -1.2653062343597412, -2.0, 'S', 3.265306234359741, 4.0, 0.7346938848495483, 0.0], ['Odor B', -1.18367338180542, -2.0, 'S', 3.18367338180542, 4.0, 0.8163265585899353, 0.0], ['Odor B', -1.1020

In [248]:
activations = get_layer_activations(model, state_samples, 3)

In [249]:
activations.shape

(500, 10)

In [250]:
pca = PCA(n_components=5)
pca_result = pca.fit_transform(activations)

In [251]:
df = pd.DataFrame({
    "PC1": pca_result[:, 0],
    "PC2": pca_result[:, 1],
    "PC3": pca_result[:, 2],
    "odor": metadata["odor"],
    "x": metadata["x"],
    "y": metadata["y"],
    "head": metadata["head_dir"],
    "action": metadata["action"],
    "triangle": metadata["triangle"]
})

In [252]:
print(df.to_string())

          PC1       PC2       PC3    odor         x         y head        action triangle
0    0.358289  0.195464 -0.003512  Odor B -2.000000 -2.000000    S     Turn Left        L
1    0.357692  0.191044 -0.009769  Odor B -1.918367 -2.000000    S     Turn Left        L
2    0.356967  0.186135 -0.015989  Odor B -1.836735 -2.000000    S     Turn Left        L
3    0.356354  0.181249 -0.022146  Odor B -1.755102 -2.000000    S     Turn Left        L
4    0.355584  0.176320 -0.028493  Odor B -1.673469 -2.000000    S     Turn Left        L
5    0.354235  0.173296 -0.034282  Odor B -1.591837 -2.000000    S     Turn Left        L
6    0.353503  0.169510 -0.039209  Odor B -1.510204 -2.000000    S     Turn Left        L
7    0.353064  0.165414 -0.043718  Odor B -1.428571 -2.000000    S     Turn Left        L
8    0.353175  0.161620 -0.048095  Odor B -1.346939 -2.000000    S     Turn Left        L
9    0.354195  0.158105 -0.051820  Odor B -1.265306 -2.000000    S     Turn Left        L
10   0.355

In [253]:
import plotly.graph_objects as go
import plotly.express as px

fig = go.Figure()

discrete = [-2, -1, 0, 1, 2]

# pick a palette with at least 5 distinct colors
palette = px.colors.qualitative.Set1  
color_map = {val: palette[i % len(palette)] for i, val in enumerate(discrete)}

# lines for discrete y
for y_val in discrete:
    sub = df[df["y"] == y_val].sort_values("x")

    fig.add_trace(go.Scatter3d(
        x=sub["PC1"],
        y=sub["PC2"],
        z=sub["PC3"],
        mode="lines",
        line=dict(
            color=color_map[y_val],  # assign one color per y line
            width=6
        ),
        #name=f"y={y_val}",
        showlegend=False,
        hoverinfo="text",
        text=[f"x: {x}, y: {y}, action: {a}"
              for x, y, a in zip(sub["x"], sub["y"], sub["action"])]
    ))

    # add a text label at the last point of the line
    fig.add_trace(go.Scatter3d(
        x=[sub["PC1"].iloc[-1]],
        y=[sub["PC2"].iloc[-1]],
        z=[sub["PC3"].iloc[-1]],
        mode="text",
        text=[f"y={y_val}"],
        textposition="top center",
        showlegend=False
    ))

# lines for discrete x
for x_val in discrete:
    sub = df[df["x"] == x_val].sort_values("y")

    fig.add_trace(go.Scatter3d(
        x=sub["PC1"],
        y=sub["PC2"],
        z=sub["PC3"],
        mode="lines",
        line=dict(
            color=color_map[x_val],  # assign one color per x line
            width=6
        ),
        name=f"x={x_val}",
        hoverinfo="text",
        text=[f"x: {x}, y: {y}, action: {a}"
              for x, y, a in zip(sub["x"], sub["y"], sub["action"])]
    ))

    # add a text label at the last point of the line
    fig.add_trace(go.Scatter3d(
        x=[sub["PC1"].iloc[0]],
        y=[sub["PC2"].iloc[0]],
        z=[sub["PC3"].iloc[0]],
        mode="text",
        text=[f"x={x_val}"],
        textposition="top center",
        showlegend=False
    ))

fig.update_layout(scene=dict(
    xaxis_title="PC1",
    yaxis_title="PC2",
    zaxis_title="PC3"
))

fig.show()

In [254]:
import plotly.graph_objects as go

fig = go.Figure()

discrete = [-2,-1,0,1,2]

# loop over unique y-values
for y_val in discrete:
    sub = df[df["y"] == y_val].sort_values("x")  # sort by x so line flows correctly

    fig.add_trace(go.Scatter3d(
        x=sub["PC1"],
        y=sub["PC2"],
        z=sub["PC3"],
        mode="lines",
        line=dict(
            color=sub["x"],         # gradient along x
            colorscale="Viridis",
            width=6
        ),
        showlegend=False,
        text=[f"x: {x}, y: {y}, action: {a}"
              for x, y, a in zip(sub["x"], sub["y"], sub["action"])],
        hoverinfo="text"
    ))

for x_val in discrete:
    sub = df[df["x"] == x_val].sort_values("y")  # sort by x so line flows correctly

    fig.add_trace(go.Scatter3d(
        x=sub["PC1"],
        y=sub["PC2"],
        z=sub["PC3"],
        mode="lines",
        line=dict(
            color=sub["y"],         # gradient along x
            colorscale="Viridis",
            width=6
        ),
        showlegend=False,
        text=[f"x: {x}, y: {y}, action: {a}"
              for x, y, a in zip(sub["x"], sub["y"], sub["action"])],
        hoverinfo="text"
    ))

# add dummy scatter to create colorbar for x
fig.add_trace(go.Scatter3d(
    x=[None], y=[None], z=[None],  # invisible points
    mode="markers",
    marker=dict(
        colorscale="Viridis",
        cmin=df["x"].min(),
        cmax=df["x"].max(),
        colorbar=dict(title="x-value"),
        showscale=True
    ),
    showlegend=False
))

fig.update_layout(scene=dict(
    xaxis_title="PC1",
    yaxis_title="PC2",
    zaxis_title="PC3"
))

fig.show()

In [255]:
import plotly.graph_objects as go

fig = go.Figure()

unique_actions = sorted(df["action"].unique())
colors_map = {action: px.colors.qualitative.Set1[i % 9] for i, action in enumerate(unique_actions)}

discrete = [-2,-1,0,1,2]

# loop over unique y-values
for y_val in discrete:
    sub = df[df["y"] == y_val].sort_values("x")  # sort by x so line flows correctly
  
    # iterate over consecutive points
    for k in range(len(sub) - 1):
        x_pair = sub["PC1"].iloc[k:k+2]
        y_pair = sub["PC2"].iloc[k:k+2]
        z_pair = sub["PC3"].iloc[k:k+2]
        action_pair = sub["action"].iloc[k]  # color by first point of segment

        fig.add_trace(go.Scatter3d(
            x=x_pair,
            y=y_pair,
            z=z_pair,
            mode="lines",
            line=dict(color=colors_map[action_pair], width=6),
            showlegend=False,  # we’ll handle legend separately
            text=[f"x: {x}, y: {y}, action: {a}"
              for x, y, a in zip(sub["x"].iloc[k:k+2], sub["y"].iloc[k:k+2], sub["action"].iloc[k:k+2])]
        ))

for x_val in discrete:
    sub = df[df["x"] == x_val].sort_values("y")  # sort by x so line flows correctly

    # iterate over consecutive points
    for k in range(len(sub) - 1):
        x_pair = sub["PC1"].iloc[k:k+2]
        y_pair = sub["PC2"].iloc[k:k+2]
        z_pair = sub["PC3"].iloc[k:k+2]
        action_pair = sub["action"].iloc[k]  # color by first point of segment

        fig.add_trace(go.Scatter3d(
            x=x_pair,
            y=y_pair,
            z=z_pair,
            mode="lines",
            line=dict(color=colors_map[action_pair], width=6),
            showlegend=False,  # we’ll handle legend separately
            text=[f"x: {x}, y: {y}, action: {a}"
              for x, y, a in zip(sub["x"].iloc[k:k+2], sub["y"].iloc[k:k+2], sub["action"].iloc[k:k+2])]
        ))

# optional: create dummy traces for legend
for action, color in colors_map.items():
    fig.add_trace(go.Scatter3d(
        x=[None],
        y=[None],
        z=[None],
        mode="lines",
        line=dict(color=color, width=6),
        name=action
    ))

fig.update_layout(scene=dict(
    xaxis_title="PC1",
    yaxis_title="PC2",
    zaxis_title="PC3"
))

fig.show()

In [None]:
# need to find some way to verify these actions with the rendering -- compute the actions, go along the grid paths
# POLICY MAPS