In [1]:
%load_ext autoreload
%autoreload 2
%env WANDB_NOTEBOOK_NAME analysis.ipynb
%env WANDB_SILENT true
%matplotlib agg
# ipympl

from collections import defaultdict
from itertools import product
import os
import tempfile

import matplotlib as mpl
import matplotlib.collections as mpl_col
import matplotlib.gridspec as mpl_grid
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import mpl_toolkits.mplot3d as mp3d
import numpy as np
import seaborn as sns
import sklearn.metrics
import torch
from tqdm import tqdm
import wandb

import data
import inept

# Set params
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
BASE_FOLDER = os.path.abspath('')
DATA_FOLDER = os.path.join(BASE_FOLDER, '../data')
PLOT_FOLDER = os.path.join(BASE_FOLDER, '../plots')

# Style
sns.set_context('paper', font_scale=1.25)
sns.set_style('white')
sns.set_palette('husl')

# MPL params
mpl.rcParams['animation.embed_limit'] = 100

# Disable gradients
torch.set_grad_enabled(False);

env: WANDB_NOTEBOOK_NAME=analysis.ipynb
env: WANDB_SILENT=true


### Load All Classes

In [2]:
# Parameters
run_id = '2dt27jy2'
num_nodes = int(1e2)
seed_override = None

# Load run
api = wandb.Api()
run = api.run(f'oafish/INEPT/{run_id}')
config = defaultdict(lambda: {})
for k, v in run.config.items():
    dict_name, key = k.split('/')
    config[dict_name][key] = v
config = dict(config)

# Reproducibility
seed = seed_override if seed_override is not None else config['note']['seed']
torch.manual_seed(seed)
if DEVICE == 'cuda': torch.cuda.manual_seed(seed)
np.random.seed(seed)

# Load data
modalities, types, features = data.load_data(config['data']['dataset'], DATA_FOLDER)
modalities, types = inept.utilities.modify_data(
    modalities, types,
    **inept.utilities.overwrite_dict(config['data'], {'standardize': True, 'num_nodes': num_nodes}),
    device=DEVICE,
)

# Load env
env = inept.environments.trajectory(*modalities, **config['env'], device=DEVICE)

# Load policy
fname = f'trained_models/policy_{run.summary["stage"]:02}.mdl'
# fname = f'models/policy_{run.summary["stage"]:02}.mdl'
with tempfile.TemporaryDirectory() as tmpdir:
    run.file(fname).download(tmpdir, replace=True)
    policy = torch.load(os.path.join(tmpdir, fname)).to(DEVICE).eval()
policy.actor.set_action_std(1e-7)

### Generate Runs

In [3]:
# Initialize memories
memories = {}

##### Integration

In [4]:
# Initialize
env.reset(); memories['integration'] = defaultdict(lambda: [])
memories['integration']['states'].append(env.get_state())
memories['integration']['rewards'].append(torch.zeros(num_nodes, device=DEVICE))

# Simulate
for _ in tqdm(range(config['train']['max_ep_timesteps'])):
    # Step
    state = env.get_state(include_modalities=True)
    actions = policy.act_macro(
        state,
        keys=list(range(config['data']['num_nodes'])),
        max_batch=config['train']['max_batch'],
        max_nodes=config['train']['max_nodes'],
    )
    rewards, finished, itemized_rewards = env.step(actions, return_itemized_rewards=True)

    # Record
    memories['integration']['states'].append(env.get_state())
    memories['integration']['rewards'].append(rewards)

# Stack
memories['integration']['states'] = torch.stack(memories['integration']['states'])
memories['integration']['rewards'] = torch.stack(memories['integration']['rewards'])
memories['integration'] = dict(memories['integration'])

100%|██████████| 1000/1000 [00:06<00:00, 158.02it/s]


### Plot Memories

##### Integration

In [5]:
# Prepare data
states = memories['integration']['states'].cpu()
rewards = memories['integration']['rewards'].cpu()

# Create plot
figsize = (17, 10)
fig = plt.figure(figsize=figsize)
# grid = mpl_grid.GridSpec(1, 2, width_ratios=(2, 1))
# ax1 = fig.add_subplot(grid[0], projection='3d')
# ax2 = fig.add_subplot(grid[1])
# fig.tight_layout(pad=2)
ax1 = fig.add_axes([1 /figsize[0], 1 /figsize[1], 8 /figsize[0], 8 /figsize[1]], projection='3d')
ax2 = fig.add_axes([12 /figsize[0], 1 /figsize[1], 4 /figsize[0], 8 /figsize[1]])

# Initialize nodes
get_node_data = lambda frame: states[frame, :, :3]
nodes = [
    ax1.plot(
        *get_node_data(0)[types[0]==l].T,
        label=l,
        linestyle='',
        marker='o',
        ms=6,
        zorder=3,
    )[0]
    for l in np.unique(types[0])
]

# Initialize velocity arrows
arrow_length_scale = 1
get_arrow_xyz_uvw = lambda frame: (states[frame, :, :3], states[frame, :, env.dim:env.dim+3])
arrows = ax1.quiver(
    *get_arrow_xyz_uvw(0)[0].T,
    *get_arrow_xyz_uvw(0)[1].T,
    arrow_length_ratio=0,
    length=arrow_length_scale,
    lw=2,
    color='gray',
    alpha=.4,
    zorder=2,
)

# Initialize modal lines
total_lines = int((env.dist[0].shape[0]**2 - env.dist[0].shape[0]) / 2)  # Only considers first modality
num_lines = 25  # total_lines
# modal_connection_strength = [np.array([(1-dist[j, k].item()/dist.max().item())**2 for j, k in product(*[range(s) for s in dist.shape]) if j < k]) for dist in env.dist]
get_distance_discrepancy = lambda frame: [[((states[frame, j, :3] - states[frame, k, :3]).square().sum().sqrt() - dist[j, k].cpu()).item() for j, k in product(*[range(s) for s in dist.shape]) if j < k] for dist in env.dist]
get_modal_lines_segments = lambda frame, dist: np.array(states[frame, [[j, k] for j, k in product(*[range(s) for s in dist.shape]) if j < k], :3])
clip_dd_to_alpha = lambda dd: np.clip(np.abs(dd), 0, 2) / 2
line_selection = [
    # Randomly select lines to show
    np.random.choice(total_lines, num_lines, replace=False) for dist in env.dist
]
modal_lines = [
    mp3d.art3d.Line3DCollection(
        get_modal_lines_segments(0, dist)[line_selection[i]],
        label=f'Modality {i}',
        lw=2,
        zorder=1,
    )
    for i, dist in enumerate(env.dist)
]
for ml in modal_lines: ax1.add_collection(ml)

# Silhouette scoring
get_silhouette_samples = lambda frame: sklearn.metrics.silhouette_samples(states[frame, :, :3].cpu(), types[0])
bars = [
    ax2.bar(
        l,
        get_silhouette_samples(0)[types[0]==l].mean(),
    )
    for l in np.unique(types[0])
]
ax2.axhline(y=0, color='black')

# Limits
ax1.set(
    xlim=(states[:, :, 0].min(), states[:, :, 0].max()),
    ylim=(states[:, :, 1].min(), states[:, :, 1].max()),
    zlim=(states[:, :, 2].min(), states[:, :, 2].max()),
)
ax2.set(ylim=(-1, 1))

# Legends
l1 = ax1.legend(handles=nodes, loc='upper left')
ax1.add_artist(l1)
l2 = ax1.legend(handles=[
    ax1.plot([], [], color='red', label='Repulsion')[0],
    ax1.plot([], [], color='blue', label='Attraction')[0],
], loc='upper right')
ax1.add_artist(l2)
ax2.spines[['right', 'top', 'bottom', 'left']].set_visible(False)

# Update function
def update(frame, nodes, arrows, modal_lines, bars):
    # Adjust nodes
    for i, l in enumerate(np.unique(types[0])):
        data = get_node_data(frame)[types[0]==l].T
        nodes[i].set_data(*data[:2])
        nodes[i].set_3d_properties(data[2])

    # Adjust arrows
    xyz_xyz = [[xyz, xyz+arrow_length_scale*uvw] for xyz, uvw in zip(*get_arrow_xyz_uvw(frame))]
    arrows.set_segments(xyz_xyz)

    # Adjust lines
    modal_lines[0].set_segments([[0, 0, 0], [1, 1, 1]])
    for i, (dist, ml) in enumerate(zip(env.dist, modal_lines)):
        ml.set_segments(get_modal_lines_segments(frame, dist)[line_selection[i]])
        distance_discrepancy = get_distance_discrepancy(frame)[i]
        color = np.array([(0., 0., 1.) if dd > 0 else (1., 0., 0.) for dd in distance_discrepancy])
        alpha = np.expand_dims(clip_dd_to_alpha(distance_discrepancy), -1)
        ml.set_color(np.concatenate((color, alpha), axis=-1))

    # Barplots
    for bar, l in zip(bars, np.unique(types[0])):
        bar[0].set_height(get_silhouette_samples(frame)[types[0]==l].mean())

    # Styling
    ax1.set_title(f'{frame: 4} : {rewards[frame].mean():5.2f}')
    ax2.set_title(f'Silhouette Coefficient : {get_silhouette_samples(frame).mean():5.2f}')

    # CLI
    print(f'{frame} / {frames-1}', end='\r')
    if frame == frames-1: print()

    return nodes, arrows, modal_lines

# Compile animation
interval = 1e3*env.delta/3  # 3x speedup
min_max_vel = 1e-2  # Stop at first frame all vels are below target
frames = torch.argwhere(states[..., env.dim:env.dim+3].square().sum(dim=2).sqrt().max(dim=1).values < min_max_vel)
frames = frames[0, 0].item()+1 if len(frames) > 0 else states.shape[0]
ani = animation.FuncAnimation(
    fig=fig,
    func=update,
    fargs=(nodes, arrows, modal_lines, bars),
    frames=frames,
    interval=interval,
)

# Display animation as it renders
# plt.show()

# Display complete animation
# from IPython.display import HTML
# HTML(ani.to_jshtml())

# Save animation
file_type = 'mp4'
if file_type == 'mp4': writer = animation.FFMpegWriter(fps=int(1e3/interval), extra_args=['-vcodec', 'libx264'], bitrate=8e3)  # Faster
elif file_type == 'gif': writer = animation.FFMpegWriter(fps=int(1e3/interval))  # Slower
ani.save(os.path.join(PLOT_FOLDER, f'{config["data"]["dataset"]}_integration.{file_type}'), writer=writer, dpi=300)

204 / 204
0 / 204