In [None]:
import os, sys
sys.path.append('..')

import io
import torch
import numpy as np
import numpy.random as npr
import hydra
import dill

import plotly.io as pio
pio.renderers.default = 'jupyterlab'
from scipy.interpolate import griddata
import plotly.graph_objs as go
import matplotlib.pyplot as plt
import imageio

from irrep_actions.gym_util.multistep_wrapper import MultiStepWrapper
from irrep_actions.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
from irrep_actions.workflow.base_workflow import BaseWorkflow
from irrep_actions.utils import mcmc

In [None]:
# CH SO2 Lmax 5
# checkpoint = '../data/outputs/2024.01.30/16.43.31_train_so2_harmonic_implicit_policy_pusht_lowdim/checkpoints/latest.ckpt'

# CH SO2 Lmax 3
checkpoint = '../data/outputs/2024.01.30/06.16.01_train_so2_harmonic_implicit_policy_pusht_lowdim/checkpoints/epoch=0400-test_mean_score=0.975.ckpt'

# CH skip
# checkpoint = '../data/outputs/2024.01.30/19.08.21_train_so2_harmonic_implicit_policy_pusht_lowdim/checkpoints/latest.ckpt'

payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
device = torch.device('cuda:2')

In [None]:
cfg = payload['config']
cls = hydra.utils.get_class(cfg._target_)

workflow = cls(cfg)
workflow: BaseWorkflow
workflow.load_payload(payload, exclude_keys=None, include_keys=None)

policy = workflow.model
policy = policy.to(device)
policy.eval()

In [None]:
env = MultiStepWrapper(PushTKeypointsEnv(render_action=False), n_obs_steps=2, n_action_steps=1, max_episode_steps=300)
env.seed(100000)

In [None]:
obs = env.reset()
d = False
B = 1
plots = list()

In [None]:
while not d:
    Do = obs.shape[-1] // 2 
    obs = torch.from_numpy(obs[:, :Do].astype(np.float32).reshape(B,2,-1,2)).to(device) # 1x2x19x2
    x_obs = (obs.reshape(1,38,2)[:,:,0] - 255.0)
    y_obs = (obs.reshape(1,38,2)[:,:,1] - 255.0) * -1.0
    new_d = torch.concatenate((x_obs.unsqueeze(-1), y_obs.unsqueeze(-1)), dim=-1).view(1, -1).view(1,2,19*2)
    nobs = policy.normalizer['obs'].normalize(new_d)
    
    action_stats = policy.get_action_stats()
    action_dist = torch.distributions.Uniform(
        low=action_stats["min"], high=action_stats["max"]
    )
    actions = action_dist.sample((1, policy.pred_n_samples, 1)).to(
        dtype=nobs.dtype
    )
    num_disp = 500
    num_rot = 360
    mag = torch.linspace(-1, 1.0, num_disp)
    mag = mag.view(1, -1, 1).repeat(B, 1, num_rot).view(B, -1, 1, 1).to(device)
    theta = torch.linspace(0, 2*np.pi, num_rot)
    theta = theta.view(1, 1, -1).repeat(B, num_disp, 1).view(-1, 1).to(device)
    actions = torch.concatenate((mag, theta.view(B, -1, 1, 1)), dim=-1)
    
    with torch.no_grad():
        logits = policy.energy_model(nobs, mag, theta)
    action_probs = torch.softmax(logits, dim=-1)
    
    idxs = torch.argmax(action_probs, dim=-1).unsqueeze(-1)
    actions = actions[torch.arange(B).unsqueeze(-1), idxs].squeeze(1)
    action_probs = action_probs.view(B, num_disp, num_rot)
    logits = logits.view(B, num_disp, num_rot)
    max_disp = torch.max(action_probs, dim=-1)[0]
    rr = torch.argmax(max_disp, dim=1).item()
    tt = torch.argmax(action_probs[0,rr]).item()
    
    rad = np.linspace(-1, 1, num_disp)
    azm = np.linspace(0, 2 * np.pi, num_rot)
    r, th = np.meshgrid(rad, azm)
    
    img = env.render('human')
    
    f = plt.figure(figsize=(20,6))
    ax1 = f.add_subplot(111)
    ax2 = f.add_subplot(141, projection='polar')
    ax1.imshow(img)
    ax2.pcolormesh(action_probs.squeeze().cpu().numpy(), shading='nearest')
    #ax2.scatter(np.linspace(0, 2*np.pi, num_rot)[tt], rr, color='red', s=5, marker='*')
    ax2.set_rticks([])
    ax2.grid(False)
    plt.show()
    io_buf = io.BytesIO()
    f.savefig(io_buf, format='raw')
    io_buf.seek(0)
    img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
                         newshape=(int(f.bbox.bounds[3]), int(f.bbox.bounds[2]), -1))
    plots.append(img_arr)
    io_buf.close()
    plt.close()
    
    mag = policy.normalizer["action"].unnormalize(actions)[:,:,0]
    #theta = policy.normalizer["action"].unnormalize(actions)[:,:,1]
    theta = actions[:,:,1]
    x = mag * torch.cos(theta)
    y = mag * torch.sin(theta)
    actions = torch.concat([x.view(B,1), y.view(B,1)], dim=1).unsqueeze(1)
    
    x_act = actions[:,:,0]
    y_act = actions[:,:,1] * -1.0
    new_act = torch.concatenate((x_act, y_act), dim=-1).view(B,1,2)
    obs, r, d, _ = env.step(new_act.squeeze(0).cpu().numpy())
imageio.mimwrite(f'plots/test.gif', plots)