## Analysis of tangling across layers (panel E)

In [None]:
%matplotlib inline
import numpy as np
from legacy.functions_notebook import get_layers, measure_tangling
from definitions import ROOT_DIR
import os
import pickle
import umap
import matplotlib as mpl
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from envs.environment_factory import EnvironmentFactory
import matplotlib
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.vec_env import VecNormalize
from legacy.functions_notebook import make_parallel_envs
from myosuite.envs.myo.myochallenge.baoding_v1 import Task


In [None]:
df1 = pd.read_hdf(os.path.join(ROOT_DIR, "data", "rollouts", "final_model_500_episodes_activations_info_small_variations_ccw", "data.hdf"))
df2 = pd.read_hdf(os.path.join(ROOT_DIR, "data", "rollouts", "final_model_500_episodes_activations_info_small_variations_cw", "data.hdf"))

df = pd.concat((df1, df2)).reset_index()

cw_path = os.path.join(ROOT_DIR, "data", "rollouts", "step_12_500_episodes_activations_info_cw", "data.hdf")
rollouts_cw = pd.read_hdf(cw_path)
ccw_path = os.path.join(ROOT_DIR, "data", "rollouts", "step_12_500_episodes_activations_info_ccw", "data.hdf")
rollouts_ccw = pd.read_hdf(ccw_path)
early_baoding_df = pd.concat((rollouts_cw, rollouts_ccw)).reset_index()
early_baoding_df.keys()


In [None]:
def average_by_timestep(vec, timesteps):
    out_vec = []
    for ts in sorted(np.unique(timesteps)):
        out_vec.append(np.mean(vec[timesteps == ts], axis=0))
    return np.vstack(out_vec)


In [None]:
n_comp = 25
layer_names = ["observation", "lstm_state_1", "lstm_out", "layer_1_out", "layer_2_out", "action"]
readable_layer_names = ["Observation", "LSTM state", "LSTM out", "Layer 1 out", "Layer 2 out", "Action"]

pca = PCA(n_components=n_comp)

for layer in layer_names:
    print("Layer: ", layer)
    data = np.array(df[layer].to_list())
    embeddings = pca.fit_transform(data)
    df[layer + "_pc"] = list(embeddings)

In [None]:
n_comp = 3
first_step = 0
df_trunc = df[df.step > first_step]
for layer in layer_names:
    data = np.array(df_trunc[layer + "_pc"].to_list())
    out_cw = data[df_trunc.task == "cw"]
    out_ccw = data[df_trunc.task == "ccw"]
    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_subplot(projection="3d")

    cmap = matplotlib.colormaps["Reds"]
    color_list = [cmap(i) for i in np.linspace(0.5, 1, len(df_trunc.step[df_trunc.task == "cw"].unique()))]    
    colors = [color_list[idx] for idx in (df_trunc.step[df_trunc.task == "cw"] - first_step - 1)]
    mean_traj = average_by_timestep(out_cw[:, :3], df_trunc.step[df_trunc.task == "cw"])
    ax.scatter(mean_traj[:, 0], mean_traj[:, 1], mean_traj[:, 2], c=color_list, label="Clockwise")

    cmap = matplotlib.colormaps["Blues"]
    color_list = [cmap(i) for i in np.linspace(0.5, 1, len(df_trunc.step[df_trunc.task == "ccw"].unique()))]    
    colors = [color_list[idx] for idx in (df_trunc.step[df_trunc.task == "ccw"] - first_step - 1)]
    mean_traj = average_by_timestep(out_ccw[:, :3], df_trunc.step[df_trunc.task == "ccw"])
    ax.scatter(mean_traj[:, 0], mean_traj[:, 1], mean_traj[:, 2], c=color_list, label="Counter-clockwise")
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.xaxis._axinfo["grid"]['color'] =  (1,1,1,0)
    ax.yaxis._axinfo["grid"]['color'] =  (1,1,1,0)
    ax.zaxis._axinfo["grid"]['color'] =  (1,1,1,0)
    ax.set_xlabel("\nPC 1")
    ax.set_ylabel("\nPC 2")
    ax.set_zlabel("\nPC 3")
    ax.set_box_aspect(aspect=None, zoom=0.75)

    # fig = plt.figure()
    # ax = plt.axes(projection='3d')
    # ax.scatter(data[:, 0], data[:, 1], data[:, 2], alpha=0.2, s=1, c=df["step"])
    # plt.title(layer)
    # plt.savefig(os.path.join(ROOT_DIR, "data", "figures", "panel_5", f"pca_{layer}.png"), format="png", dpi=600, bbox_inches="tight")
    plt.show()

In [None]:
baoding_actions = np.vstack(df.action)
early_baoding_actions = np.vstack(early_baoding_df.action)

actions = np.vstack((baoding_actions, early_baoding_actions))
n_comp_pca = 25
n_comp_umap = 3
random_state = 42

pca = PCA(n_components=n_comp_pca)
reducer = umap.UMAP(n_components=n_comp_umap, random_state=random_state)

pca_actions = pca.fit_transform(actions)
umap_actions = reducer.fit_transform(actions)

fig = plt.figure(figsize=(4, 4))
ax = fig.add_subplot(projection="3d")

ax.scatter(umap_actions[:len(baoding_actions), 0], umap_actions[:len(baoding_actions), 1], umap_actions[:len(baoding_actions), 2], c="red", label="Baoding")
# cmap = matplotlib.colormaps["Reds"]
# color_list = [cmap(i) for i in np.linspace(0.3, 1, len(df_trunc.step[df_trunc.task == "cw"].unique()))]    
# colors = [color_list[idx] for idx in (df_trunc.step[df_trunc.task == "cw"] - first_step - 1)]
# ax.scatter(out_cw[:, 0], out_cw[:, 1], out_cw[:, 2], alpha=0.1, s=1, c=colors)
# mean_traj = average_by_timestep(out_cw[:, :3], df_trunc.step[df_trunc.task == "cw"])

# cmap = matplotlib.colormaps["Blues"]
# color_list = [cmap(i) for i in np.linspace(0.3, 1, len(df_trunc.step[df_trunc.task == "ccw"].unique()))]    
# colors = [color_list[idx] for idx in (df_trunc.step[df_trunc.task == "ccw"] - first_step - 1)]
# ax.scatter(out_ccw[:, 0], out_ccw[:, 1], out_ccw[:, 2], alpha=0.1, s=1, c=colors)
# mean_traj = average_by_timestep(out_ccw[:, :3], df_trunc.step[df_trunc.task == "ccw"])

In [None]:
# Graph per layer with UMAP
n_comp = 3
first_step = 13
random_state = 0
reducer = umap.UMAP(n_components=n_comp, random_state=random_state)
df_trunc = df[df.step > first_step]
for layer in layer_names:
    print("Layer", layer)
    data = np.array(df_trunc[layer + "_pc"].to_list())
    data = reducer.fit_transform(data)
    out_cw = data[df_trunc.task == "cw"]
    out_ccw = data[df_trunc.task == "ccw"]
    out_dir = os.path.join(ROOT_DIR, "data", "embeddings")
    
    with open(os.path.join(out_dir, f"new_umap_embeddings_ccw_{layer}.pkl"), "wb") as file:
        pickle.dump(out_ccw, file)
    with open(os.path.join(out_dir, f"new_umap_embeddings_cw_{layer}.pkl"), "wb") as file:
        pickle.dump(out_cw, file)
        
    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_subplot(projection="3d")

    cmap = matplotlib.colormaps["Reds"]
    color_list = [cmap(i) for i in np.linspace(0.3, 1, len(df_trunc.step[df_trunc.task == "cw"].unique()))]    
    colors = [color_list[idx] for idx in (df_trunc.step[df_trunc.task == "cw"] - first_step - 1)]
    ax.scatter(out_cw[:, 0], out_cw[:, 1], out_cw[:, 2], alpha=0.1, s=1, c=colors)
    mean_traj = average_by_timestep(out_cw[:, :3], df_trunc.step[df_trunc.task == "cw"])
    
    cmap = matplotlib.colormaps["Blues"]
    color_list = [cmap(i) for i in np.linspace(0.3, 1, len(df_trunc.step[df_trunc.task == "ccw"].unique()))]    
    colors = [color_list[idx] for idx in (df_trunc.step[df_trunc.task == "ccw"] - first_step - 1)]
    ax.scatter(out_ccw[:, 0], out_ccw[:, 1], out_ccw[:, 2], alpha=0.1, s=1, c=colors)
    mean_traj = average_by_timestep(out_ccw[:, :3], df_trunc.step[df_trunc.task == "ccw"])

    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.xaxis._axinfo["grid"]['color'] =  (1,1,1,0)
    ax.yaxis._axinfo["grid"]['color'] =  (1,1,1,0)
    ax.zaxis._axinfo["grid"]['color'] =  (1,1,1,0)

    ax.set_xlabel("\nUMAP 1", fontsize=12)
    ax.set_ylabel("\nUMAP 2", fontsize=12)
    ax.set_zlabel("\nUMAP 3", fontsize=12)
    ax.set_box_aspect(aspect=None, zoom=0.75)
    ax.ticklabel_format(style="sci", scilimits=(-2, 2))
    ax.locator_params(axis='both', nbins=4)
    ax.tick_params(axis='both', which='major', labelsize=12)
    ax.tick_params(axis='both', which='minor', labelsize=10)

    # plt.savefig(os.path.join(ROOT_DIR, "data", "figures", "panel_5", f"umap_{layer}_1000_episodes_from_step_{first_step}_seed_{random_state}.png"), format="png", dpi=600, bbox_inches="tight")
    plt.show()

In [None]:
def measure_tangling(data):
    derivative = np.gradient(data,axis=0) * 40  # sample frequency
    epsilon = 1e-10
    Q_all = []
    for t in range(derivative.shape[0]):
        Q = (np.linalg.norm(derivative[t] - derivative, axis=1)**2) / (epsilon + np.linalg.norm(data[t] - data, axis=1)**2)
        Q = np.max(Q)
        Q_all.append(Q)
    
    return np.mean(Q_all)

In [None]:
num_components = 8

for layer in layer_names:
    tangling_list = []
    episode_pc_list = df.groupby(["episode", "task"])[layer + "_pc"].agg(lambda x: np.vstack(x)[:, :num_components]).tolist()
    for episode_pc in episode_pc_list:
        tangling_list.append(measure_tangling(episode_pc))
    print(layer, np.mean(tangling_list))

In [None]:
# Scatter plot tangling memory vs observation and memory vs action
scatter_layer_names = ["observation", "lstm_state_1", "action"]
tangling_dict = {}
for layer in scatter_layer_names:
    tangling_list = []
    episode_pc_list = df.groupby(["episode", "task"])[layer + "_pc"].agg(lambda x: np.vstack(x)[:, :num_components]).tolist()
    for episode_pc in episode_pc_list:
        tangling_list.append(measure_tangling(episode_pc))
    tangling_dict[layer] = tangling_list

# memory vs observation
lim = (0, 8500)
fig, ax = plt.subplots(figsize=(3.6, 3.6))
ax.scatter(tangling_dict["observation"], tangling_dict["lstm_state_1"], s=0.5, color="dodgerblue")
ax.plot(lim, lim, color="red")
ax.set_xlabel("Observation space (86D)", fontsize=13)
ax.set_ylabel("Memory layer\nrepresentation (256D)", fontsize=13)
ax.set_xlim(lim)
ax.set_ylim(lim)
ax.ticklabel_format(scilimits=(3, 3))
out_name = "scatter_tangling_memory_vs_observation.png"
fig.savefig(os.path.join(ROOT_DIR, "data", "figures", "panel_2", out_name), format="png", dpi=800, bbox_inches="tight")
fig.show()

# memory vs action
lim = (0, 12500)
fig, ax = plt.subplots(figsize=(3.6, 3.6))
ax.scatter(tangling_dict["action"], tangling_dict["lstm_state_1"], s=0.5, color="dodgerblue")
ax.plot(lim, lim, color="red")
ax.set_xlabel("Action space (39D)", fontsize=13)
ax.set_ylabel("Memory layer\nrepresentation (256D)", fontsize=13)
ax.set_xlim(lim)
ax.set_ylim(lim)
ax.ticklabel_format(scilimits=(3, 3))
out_name = "scatter_tangling_memory_vs_action.png"
fig.savefig(os.path.join(ROOT_DIR, "data", "figures", "panel_2", out_name), format="png", dpi=800, bbox_inches="tight")
fig.show()
        

In [None]:
for map_name in ["Reds", "Blues"]:
    cmap = matplotlib.colormaps[map_name]
    color_list = [cmap(i) for i in np.linspace(0.5, 1, 200)]    
    c_mat = np.array(color_list)
    c_mat = np.dstack([c_mat for _ in range(30)]).transpose(2, 0, 1)
    plt.imshow(c_mat)
    plt.axis("off")
    plt.savefig(os.path.join(ROOT_DIR, "data", "figures", "panel_5", f"cmap_{map_name}.png"), format="png", dpi=600, bbox_inches="tight")
    plt.show()