In [1]:
%load_ext autoreload
%autoreload 2
%env WANDB_NOTEBOOK_NAME analysis.ipynb
%env WANDB_SILENT true
%matplotlib agg
# ipympl

from collections import defaultdict
from itertools import product
import re
import os
import tempfile

import matplotlib as mpl
import matplotlib.collections as mpl_col
import matplotlib.gridspec as mpl_grid
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import mpl_toolkits.mplot3d as mp3d
import numpy as np
import seaborn as sns
import sklearn.metrics
import torch
from tqdm import tqdm
import wandb

import data
import inept

# Set params
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
BASE_FOLDER = os.path.abspath('')
DATA_FOLDER = os.path.join(BASE_FOLDER, '../data')
PLOT_FOLDER = os.path.join(BASE_FOLDER, '../plots')

# Style
sns.set_context('paper', font_scale=1.25)
sns.set_style('white')
sns.set_palette('husl')

# MPL params
mpl.rcParams['animation.embed_limit'] = 100

# Disable gradients
torch.set_grad_enabled(False);

env: WANDB_NOTEBOOK_NAME=analysis.ipynb
env: WANDB_SILENT=true


- HIGH PRIORITY
  - Add more accuracy metrics
  - Perturbation analysis with inverse transform
  - Add 2D functionality
  - Add optional UMAP

- LOW PRIORITY
  - Switch to `mayavi` instead of mpl to have true 3d and proper layering

### Load All Classes

In [2]:
# Parameters
run_id = (
    'k52g4dx3',  # Random 100x
    '2dt27jy2',  # No random 20k
)[1]
stage_override = None  # Manually override policy stage selection
num_nodes = int(1e2)
seed_override = None  # 43

# Load run
api = wandb.Api()
run = api.run(f'oafish/INEPT/{run_id}')
config = defaultdict(lambda: {})
for k, v in run.config.items():
    dict_name, key = k.split('/')
    config[dict_name][key] = v
config = dict(config)

# Reproducibility
seed = seed_override if seed_override is not None else config['note']['seed']
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed(seed)
np.random.seed(seed)

# Load data
modalities, types, features = data.load_data(config['data']['dataset'], DATA_FOLDER)
ppc = inept.utilities.Preprocessing(
    **inept.utilities.overwrite_dict(config['data'], {'standardize': True, 'num_nodes': num_nodes}),
    device=DEVICE)
modalities = ppc.fit_transform(modalities)
modalities, types = ppc.subsample(modalities, types)
modalities = ppc.cast(modalities)

# Load env
env = inept.environments.trajectory(*modalities, **config['env'], device=DEVICE)

# Get latest policy
latest_mdl = [-1, None]  # Pkl
latest_wgt = [-1, None]  # State dict
for file in run.files():
    # Find mdl files
    matches = re.findall(f'^(?:models|trained_models)/policy_(\w+).(mdl|wgt)$', file.name)
    if len(matches) > 0: stage = int(matches[0][0]); ftype = matches[0][1]
    else: continue

    # Record
    latest_known_stage = latest_mdl[0] if ftype == 'mdl' else latest_wgt[0]
    if (stage_override is None and stage > latest_known_stage) or (stage_override is not None and stage == stage_override):
        if ftype == 'mdl': latest_mdl = [stage, file]
        elif ftype == 'wgt': latest_wgt = [stage, file]
print(f'Policy found at stage {latest_mdl[0]}')

# Load file
load_type = 'wgt'
if load_type == 'mdl':
    with tempfile.TemporaryDirectory() as tmpdir:
        latest_mdl[1].download(tmpdir, replace=True)
        policy = torch.load(os.path.join(tmpdir, latest_mdl[1].name))
elif load_type == 'wgt':
    # Mainly used in the case of old argument names, also more secure
    with tempfile.TemporaryDirectory() as tmpdir:
        latest_wgt[1].download(tmpdir, replace=True)
        # config_to_use = config['policy']
        config_to_use = inept.utilities.overwrite_dict(config['policy'], {'positional_dim': 6, 'modal_dims': [76]})
        policy = inept.models.PPO(**config_to_use)
        incompatible_keys = policy.load_state_dict(torch.load(os.path.join(tmpdir, latest_wgt[1].name), weights_only=True))
policy = policy.to(DEVICE).eval()
policy.actor.set_action_std(1e-7)

Policy found at stage 14


### Generate Runs

In [3]:
# Initialize memories
memories = {}

##### Integration

In [4]:
def get_present(timestep, present, labels, *args, deployment, state=None):
    # Copy status
    present = present.clone()
    if state is not None: state = state.clone()

    # Iterate over each label
    for label, delay, rate, origin in zip(*deployment.values()):
        # If delay has been reached
        if timestep >= delay:
            # Look at each node
            for i in range(len(present)):
                # If label matches and not already present
                if labels[i] == label and not present[i]:
                    # Roll for appearance
                    if np.random.rand() < rate:
                        # Mark as present and set origin
                        if origin is not None:
                            state[i] = state[np.random.choice(np.argwhere((labels==origin)*present.cpu().numpy()).flatten())]
                        present[i] = True

    # Return nicely
    ret = (present,)
    if state is not None: ret += (state,)
    return inept.utilities.clean_return(ret)

In [5]:
# Variable deployment times
deployment = {
    'labels': ['Lwm', 'L6', 'L5', 'L4', 'L2/3'],
    'delay': 50*np.arange(5),
    'rates': [1, .015, .015, .015, .015],
    'origins': [None, 'Lwm', 'L6', 'L5', 'L4'],
}

# Initialize
env.reset(); memories['discovery'] = defaultdict(lambda: [])
if deployment is not None:
    present = torch.zeros(num_nodes, dtype=bool, device=DEVICE)
    present, state = get_present(0, present, types[0], deployment=deployment, state=env.get_state())
    env.set_state(state)
else:
    present = torch.ones(num_nodes, dtype=bool)
memories['discovery']['present'].append(present)
memories['discovery']['states'].append(env.get_state())
memories['discovery']['rewards'].append(torch.zeros(num_nodes, device=DEVICE))

# Simulate
for timestep in tqdm(range(config['train']['max_ep_timesteps'])):
    # Step
    state = env.get_state(include_modalities=True)
    actions = torch.zeros((num_nodes, env.dim), device=DEVICE)
    actions[present] = policy.act_macro(
        state[present],
        keys=torch.arange(num_nodes, device=DEVICE)[present],
        max_batch=config['train']['max_batch'],
        max_nodes=config['train']['max_nodes'],
    )
    rewards = torch.zeros(num_nodes)
    # TODO: Currently, rewards factor in non-present nodes
    rewards, _, _ = env.step(actions, return_itemized_rewards=True)
    new_state = env.get_state()
    new_state[~present] = state[~present, :2*env.dim]  # Don't move un-spawned nodes
    env.set_state(new_state)

    # Record
    if deployment is not None:
        present, state = get_present(timestep+1, present, types[0], deployment=deployment, state=env.get_state())
        env.set_state(state)
    memories['discovery']['present'].append(present)
    memories['discovery']['states'].append(env.get_state())
    memories['discovery']['rewards'].append(rewards)

# Stack
memories['discovery']['present'] = torch.stack(memories['discovery']['present'])
memories['discovery']['states'] = torch.stack(memories['discovery']['states'])
memories['discovery']['rewards'] = torch.stack(memories['discovery']['rewards'])
memories['discovery'] = dict(memories['discovery'])

100%|██████████| 1000/1000 [00:07<00:00, 137.40it/s]


### Plot Memories

##### Integration

In [6]:
# Prepare data
present = memories['discovery']['present'].cpu()
states = memories['discovery']['states'].cpu()
rewards = memories['discovery']['rewards'].cpu()

# Parameters
interval = 1e3*env.delta/3  # Time between frames (3x speedup)
min_max_vel = 1e-2  # Stop at first frame all vels are below target
frame_override = None  # Manually enter number of frames to draw
num_lines = 25  # Number of attraction and repulsion lines
rotations_per_second = .1  # Camera azimuthal rotations per second

# Create plot
figsize = (17, 10)
fig = plt.figure(figsize=figsize)
# grid = mpl_grid.GridSpec(1, 2, width_ratios=(2, 1))
# ax1 = fig.add_subplot(grid[0], projection='3d')
# ax2 = fig.add_subplot(grid[1])
# fig.tight_layout(pad=2)
ax1 = fig.add_axes([1 /figsize[0], 1 /figsize[1], 8 /figsize[0], 8 /figsize[1]], projection='3d')
ax2 = fig.add_axes([12 /figsize[0], 1 /figsize[1], 4 /figsize[0], 8 /figsize[1]])

# Initialize nodes
get_node_data = lambda frame: states[frame, :, :3]
nodes = [
    ax1.plot(
        # *get_node_data(0)[types[0]==l].T,
        [], [],
        label=l,
        linestyle='',
        marker='o',
        ms=6,
        zorder=2.3,
    )[0]
    for l in np.unique(types[0])
]

# Initialize velocity arrows
arrow_length_scale = 1
get_arrow_xyz_uvw = lambda frame: (states[frame, :, :3], states[frame, :, env.dim:env.dim+3])
arrows = ax1.quiver(
    [], [], [],
    [], [], [],
    arrow_length_ratio=0,
    length=arrow_length_scale,
    lw=2,
    color='gray',
    alpha=.4,
    zorder=2.2,
)

# Initialize modal lines
# relative_connection_strength = [np.array([(1-dist[j, k].item()/dist.max().item())**2 for j, k in product(*[range(s) for s in dist.shape]) if j < k]) for dist in env.dist]
get_distance_discrepancy = lambda frame: [np.array([((states[frame, j, :3] - states[frame, k, :3]).square().sum().sqrt() - dist[j, k].cpu()).item() for j, k in product(*[range(s) for s in dist.shape]) if j < k]) for dist in env.dist]
get_modal_lines_segments = lambda frame, dist: np.array(states[frame, [[j, k] for j, k in product(*[range(s) for s in dist.shape]) if j < k], :3])
clip_dd_to_alpha = lambda dd: np.clip(np.abs(dd), 0, 2) / 2
# Randomly select lines to show
line_indices = [[j, k] for j, k in product(*[range(s) for s in env.dist[0].shape]) if j < k]
total_lines = int((env.dist[0].shape[0]**2 - env.dist[0].shape[0]) / 2)  # Only considers first modality
line_selection = [
    np.random.choice(total_lines, num_lines, replace=False) if num_lines is not None else list(range(total_lines)) for dist in env.dist
]
modal_lines = [
    mp3d.art3d.Line3DCollection(
        get_modal_lines_segments(0, dist)[line_selection[i]],
        label=f'Modality {i}',
        lw=2,
        zorder=2.1,
    )
    for i, dist in enumerate(env.dist)
]
for ml in modal_lines: ax1.add_collection(ml)

# Silhouette scoring
get_silhouette_samples = lambda frame: sklearn.metrics.silhouette_samples(states[frame, :, :3].cpu(), types[0])
bars = [ax2.bar(l, 0) for l in np.unique(types[0])]
ax2.axhline(y=0, color='black')

# Limits
# TODO: Double-check that the `present` indexing works
ax1.set(
    xlim=(states[present][:, 0].min(), states[present][:, 0].max()),
    ylim=(states[present][:, 1].min(), states[present][:, 1].max()),
    zlim=(states[present][:, 2].min(), states[present][:, 2].max()),
)
ax2.set(ylim=(-1, 1))

# Legends
l1 = ax1.legend(handles=nodes, loc='upper left')
ax1.add_artist(l1)
l2 = ax1.legend(handles=[
    ax1.plot([], [], color='red', label='Repulsion')[0],
    ax1.plot([], [], color='blue', label='Attraction')[0],
], loc='upper right')
ax1.add_artist(l2)
ax2.spines[['right', 'top', 'bottom', 'left']].set_visible(False)

# Styling
ax1.set(xlabel='x', ylabel='y', zlabel='z')
get_angle = lambda frame: (30, (360*rotations_per_second)*(frame*interval/1000)-60, 0)
ax1.view_init(*get_angle(0))

# Update function
def update(frame, nodes, arrows, modal_lines, bars):
    # Adjust nodes
    for i, l in enumerate(np.unique(types[0])):
        present_labels = present[frame] * torch.tensor(types[0]==l)
        data = get_node_data(frame)[present_labels].T
        nodes[i].set_data(*data[:2])
        nodes[i].set_3d_properties(data[2])

    # Adjust arrows
    xyz_xyz = [[xyz, xyz+arrow_length_scale*uvw] for i, (xyz, uvw) in enumerate(zip(*get_arrow_xyz_uvw(frame))) if present[frame, i]]
    arrows.set_segments(xyz_xyz)

    # Adjust lines
    for i, (dist, ml) in enumerate(zip(env.dist, modal_lines)):
        ml.set_segments(get_modal_lines_segments(frame, dist)[line_selection[i]])
        distance_discrepancy = get_distance_discrepancy(frame)[i][line_selection[i]]
        color = np.array([(0., 0., 1.) if dd > 0 else (1., 0., 0.) for dd in distance_discrepancy])
        alpha = np.expand_dims(clip_dd_to_alpha(distance_discrepancy), -1)
        for j, line_index in enumerate(line_selection[i]):
            if not present[frame, line_indices[line_index]].all(): alpha[j] = 0.
        ml.set_color(np.concatenate((color, alpha), axis=-1))

    # Barplots
    for bar, l in zip(bars, np.unique(types[0])):
        bar[0].set_height(get_silhouette_samples(frame)[types[0]==l].mean())

    # Styling
    ax1.set_title(f'{frame: 4} : {rewards[frame].mean():5.2f}')
    ax2.set_title(f'Silhouette Coefficient : {get_silhouette_samples(frame).mean():5.2f}')   
    ax1.view_init(*get_angle(frame))

    # CLI
    print(f'{frame} / {frames-1}', end='\r')
    if frame == frames-1: print()

    return nodes, arrows, modal_lines

# Compile animation
frames = states[..., env.dim:env.dim+3].square().sum(dim=2).sqrt().max(dim=1).values < min_max_vel
frames = np.array([(frames[i] or frames[i+1]) if i != len(frames)-1 else frames[i] for i in range(len(frames))])  # Disregard interrupted sections of low movement
frames = np.argwhere(frames)
frames = frames[0, 0].item()+1 if len(frames) > 0 else states.shape[0]
frames = frames if frame_override is None else frame_override
ani = animation.FuncAnimation(
    fig=fig,
    func=update,
    fargs=(nodes, arrows, modal_lines, bars),
    frames=frames,
    interval=interval,
)

# Display animation as it renders
# plt.show()

# Display complete animation
# from IPython.display import HTML
# HTML(ani.to_jshtml())

# Save animation
file_type = 'mp4'
if file_type == 'mp4': writer = animation.FFMpegWriter(fps=int(1e3/interval), extra_args=['-vcodec', 'libx264'], bitrate=8e3)  # Faster
elif file_type == 'gif': writer = animation.FFMpegWriter(fps=int(1e3/interval))  # Slower
ani.save(os.path.join(PLOT_FOLDER, f'{config["data"]["dataset"]}_discovery.{file_type}'), writer=writer, dpi=300)

560 / 560
0 / 560