In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

import gymnasium
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch
from cycler import cycler
from gymnasium.utils import seeding
from hydra import compose, initialize
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from tqdm import tqdm

import __init__
from scripts.train_rl import setup_environments
from src.data import load
from src.data.loading import ConstantRandomSampler
from src.environments.utils import antialias
from src.evaluation.utils import mm2in
from src.metrics.transforms import AffineTransform
from src.models.sae import assemble_sae
from src.utils import Bunch, deflate, get_display, gl, inflate, print_cfg

sys.modules['gym'] = gymnasium  # see [PR](https://github.com/DLR-RM/stable-baselines3/pull/780)
from stable_baselines3 import SAC

---

#### Plotting setup

In [None]:
plt.rcParams.update({
    'axes.prop_cycle': cycler('color', ["#0173B2", "#DE8F05", "#029E73", "#D55E00", "#CC78BC",
                                        "#CA9161", "#FBAFE4", "#949494", "#ECE133", "#56B4E9"]),
    'axes.titlepad': 4.0,
    'axes.xmargin': 0.025,
    'axes.ymargin': 0.025,
    'axes.titlesize': 'medium',
    'axes.labelpad': 1.0,
    'axes.spines.right': False,
    'axes.spines.top': False,
    'font.family': 'serif',
    'font.size': 8,
    'text.usetex': True,
    'text.latex.preamble': [r'\usepackage{lmodern}'],
    'grid.alpha': 0.1,
    'grid.color': '#000000',
    'legend.borderaxespad': 0.25,
    'legend.borderpad': 0.0,
    'legend.frameon': False,
    'legend.columnspacing': 1.0,
    'legend.handletextpad': 0.5,
    'legend.handlelength': 1.0,
    'lines.solid_capstyle': 'round',
    'lines.solid_joinstyle': 'round',
    'xtick.major.pad': 2.0,
    'xtick.major.size': 2.0,
    'xtick.minor.size': 0.0,
    'ytick.major.pad': 2.0,
    'ytick.major.size': 2.0,
    'ytick.minor.size': 0.0,
    'figure.constrained_layout.h_pad': 0.0,
    'figure.constrained_layout.hspace': 0.0,
    'figure.constrained_layout.use': True,
    'figure.constrained_layout.w_pad': 0.0,
    'figure.constrained_layout.wspace': 0.0
})

---

In [None]:
with initialize(version_base=None, config_path='../configs'):
    rl_cfg = compose(config_name='train_rl', overrides=[
        '+experiment=rl-feat',
        'training.observation.keypoints=True',
        'training.sae_checkpoint=logs/sae/panda_push_custom/basic+basic/2023-02-12--21-53-02--32879987/checkpoint_final.pth',
        'training.sae_name=amber-thunder-15',
        'training.sae_experiment=sae-basic',
        'wandb=off',
        'hydra=hush'
    ])

with initialize(version_base=None, config_path='../configs'):
    sae_cfg = compose(config_name='train_sae', overrides=[
        '+experiment=sae-basic',
        'wandb=off',
        'hydra=hush'
    ])

---

In [None]:
checkpoint = Bunch(**torch.load('../logs/sae/panda_push_custom/basic+basic/2023-02-12--21-53-02--32879987/checkpoint_final.pth', map_location=gl.device))

# reinstantiate model and optimizer
model = assemble_sae(sae_cfg)
model.load_state_dict(checkpoint.model_state_dict)

In [None]:
# load datasets
dataset_valid, = load(sae_cfg, valid=True)

loader_valid = DataLoader(dataset_valid, sae_cfg.training.batch_size,
                          sampler=ConstantRandomSampler(dataset_valid, sae_cfg.dataset.seed),
                          shuffle=False, drop_last=True, num_workers=8, pin_memory=False)

In [None]:
with torch.no_grad():
    model.eval()  # put model into evaluation state
    track_fps = []  # feature points for tracking error computation
    track_kps = []  # site coordinates for tracking error computation

    # loop over validation batches
    for batch, (inputs, _, sites) in enumerate(tqdm(loader_valid, leave=False)):

        # move all data to GPU
        inputs = inputs.to(gl.device)

        # encoder pass to obtain feature points
        fps = model.encoder(deflate(inputs))
        feature_points = inflate(fps, len(inputs))

        # storing fps and sites for first image of each snippet (avoiding duplicates)
        track_fps.append(feature_points[:, 0])
        track_kps.append(sites[:, 0])

    track_fps = torch.cat(track_fps).to('cpu')
    track_kps = torch.cat(track_kps).to('cpu')

In [None]:
n_sites = track_kps.shape[1]
n_fps = track_fps.shape[1]

pairwise_errors = torch.ones((n_sites, n_fps)) * np.inf
regrs = [None] * n_sites
closest_fps = [None] * n_sites

# compute error for each pair of site and feature point
for site in range(n_sites):
    for fp in range(n_fps):
        regr = AffineTransform()
        regr.fit(track_fps[:, fp], track_kps[:, site])  # fit transformation
        error = regr.mse(track_fps[:, fp], track_kps[:, site])
        if torch.all(pairwise_errors[site, :] > error):
            regrs[site] = regr
            closest_fps[site] = fp
        pairwise_errors[site, fp] = error

---

In [None]:
policy = SAC.load('../logs/rl/PandaPush-custom/2023-03-08--04-05-00--33793974/final_model.zip')

In [None]:
venv = setup_environments(rl_cfg, 1)

In [None]:
images = []
observations = []

_ = venv.reset()
_ = venv.seed(13)  # 5, 7, 11, 13, 28
action = np.array([[0, 0, 0]])
dones = np.array([False])

while not np.any(dones):
    obs, _, dones, info = venv.step(action)
    if np.any(dones):
        observations.append(info[0]['terminal_observation'])
    else:
        observations.append(obs.copy())
    images.append(venv.render())
    obs.pop('keypoints')
    action, _ = policy.predict(obs, deterministic=True)

In [None]:
fig, ax = plt.subplots()

ax.imshow(antialias(torch.tensor(images[0] / 255.0), 2), interpolation='none')
kps = (np.stack([np.squeeze(observations[i]['keypoints']) for i in range(12)], axis=1) + 1) * 128
fps_orig = np.stack([np.squeeze(observations[i]['feature_points']) for i in range(12)], axis=1)
fps = (fps_orig + 1) * 128

keypoints = [4]  # [0, 1, 4]

for kp_idx in keypoints:
    ax.plot(kps[kp_idx, :, 0], kps[kp_idx, :, 1], color='w', marker='.')
for fp_idx in [closest_fps[k] for k in keypoints]:
    ax.plot(fps[fp_idx, :, 1], fps[fp_idx, :, 0], color='C3', marker='.')
    
selected_regrs = [regrs[k] for k in keypoints]
for i, fp_idx in enumerate([closest_fps[k] for k in keypoints]):
    fps_t = (selected_regrs[i].transform(torch.tensor(fps_orig[fp_idx])) + 1) * 128
    ax.plot(fps_t[:, 1], fps_t[:, 0], color='C0', marker='.')

plt.axis('off')

fig.set_size_inches(mm2in(122 * 0.49, 122 * 0.49))
fig.savefig('../local/paper/img_basic_trajectories.pdf')

In [None]:
fig, ax = plt.subplots()

ax.imshow(antialias(torch.tensor(images[0] / 255.0), 2), interpolation='none')

keypoints = [0, 1, 4]

for kp_idx in keypoints:
    ax.plot(kps[kp_idx, 0, 0], kps[kp_idx, 0, 1], color='w', marker='o')
for fp_idx in [closest_fps[k] for k in keypoints]:
    ax.plot(fps[fp_idx, 0, 1], fps[fp_idx, 0, 0], color='C3', marker='o')

selected_regrs = [regrs[k] for k in keypoints]
for i, fp_idx in enumerate([closest_fps[k] for k in keypoints]):
    fps_t = (selected_regrs[i].transform(torch.tensor(fps_orig[fp_idx])) + 1) * 128
    ax.plot(fps_t[0, 1], fps_t[0, 0], color='C0', marker='o')

plt.axis('off')

fig.set_size_inches(mm2in(122 * 0.49, 122 * 0.49))
fig.savefig('../local/paper/img_basic_points.pdf')