In [1]:
import os, sys
sys.path.append('..')

import io
import torch
import numpy as np
import numpy.random as npr
import hydra
import dill
from tqdm import tqdm

import plotly.io as pio
pio.renderers.default = 'jupyterlab'
from scipy.interpolate import griddata
import plotly.graph_objs as go
import matplotlib.pyplot as plt
import imageio

from irrep_actions.gym_util.multistep_wrapper import MultiStepWrapper
from irrep_actions.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
from irrep_actions.workflow.base_workflow import BaseWorkflow
from irrep_actions.utils import mcmc

pygame 2.5.2 (SDL 2.28.2, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
checkpoint = '../data/outputs/2024.02.06/23.18.33_train_so2_harmonic_implicit_policy_pusht_lowdim/checkpoints/epoch=0300-test_mean_score=1.000.ckpt'

payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
device = torch.device('cuda:2')

In [3]:
cfg = payload['config']
cls = hydra.utils.get_class(cfg._target_)

workflow = cls(cfg)
workflow: BaseWorkflow
workflow.load_payload(payload, exclude_keys=None, include_keys=None)

policy = workflow.model
policy = policy.to(device)
policy.eval()

ImplicitPolicy(
  (normalizer): LinearNormalizer(
    (params_dict): ParameterDict(
        (obs): Object of type: ParameterDict
        (action): Object of type: ParameterDict
      (obs): ParameterDict(
          (offset): Parameter containing: [torch.cuda.FloatTensor of size 38 (cuda:2)]
          (scale): Parameter containing: [torch.cuda.FloatTensor of size 38 (cuda:2)]
          (input_stats): Object of type: ParameterDict
        (input_stats): ParameterDict(
            (max): Parameter containing: [torch.cuda.FloatTensor of size 38 (cuda:2)]
            (mean): Parameter containing: [torch.cuda.FloatTensor of size 38 (cuda:2)]
            (min): Parameter containing: [torch.cuda.FloatTensor of size 38 (cuda:2)]
            (std): Parameter containing: [torch.cuda.FloatTensor of size 38 (cuda:2)]
        )
      )
      (action): ParameterDict(
          (offset): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:2)]
          (scale): Parameter containing: [torch.c

In [4]:
def eval(num_eps, start_seed=100000, num_disp = 100, num_rot=360, sample_act=False, temp=1.0, vid_out=True):
    pbar = tqdm(total=num_eps)
    max_rewards = [0] * num_eps
    for eps in range(num_eps):
        env = MultiStepWrapper(
            PushTKeypointsEnv(render_action=False, random_goal_pose=False), 
            n_obs_steps=2, 
            n_action_steps=1,
            max_episode_steps=300
        )
        env.seed(start_seed + eps)
        
        obs = env.reset()
        d = False
        B = 1
        plots = list()
        import time
        
        while not d:
            t0 = time.time()
            Do = obs.shape[-1] // 2 
            obs = torch.from_numpy(obs[:, :Do].astype(np.float32).reshape(B,2,-1,2)).to(device) # 1x2x19x2
            x_obs = (obs.reshape(1,38,2)[:,:,0] - 255.0)
            y_obs = (obs.reshape(1,38,2)[:,:,1] - 255.0) * -1.0
            new_d = torch.concatenate((x_obs.unsqueeze(-1), y_obs.unsqueeze(-1)), dim=-1).view(1, -1).view(1,2,19*2)
            nobs = policy.normalizer['obs'].normalize(new_d)
            
            action_stats = policy.get_action_stats()
            action_dist = torch.distributions.Uniform(
                low=action_stats["min"], high=action_stats["max"]
            )
            mag = torch.linspace(-1.0, 1.0, num_disp)
            mag = mag.view(1, -1).repeat(B, 1).view(B, -1, 1, 1).to(device)
            theta = torch.linspace(0, 2*np.pi, num_rot).to(device)
    
            with torch.no_grad():
                logits = policy.energy_model.get_energy_ball(nobs, mag).view(1, -1)
            action_probs = torch.softmax(logits/temp, dim=-1).view(1, num_disp, num_rot)

            if sample_act:
                flat_indexes = torch.multinomial(action_probs.flatten(start_dim=-2), num_samples=1, replacement=True)
            else:
                flat_indexes = action_probs.flatten(start_dim=-2).argmax(1)
            idx = [divmod(idx.item(), action_probs.shape[-1]) for idx in flat_indexes][0]
                
            #actions = actions[torch.arange(B).unsqueeze(-1), idxs].squeeze(1)
            actions = torch.tensor([mag[0,idx[0],0,0], theta[idx[1]]])
            print(time.time() - t0)
            if True:
                #action_probs = action_probs.view(B, num_disp, num_rot)
                max_disp = torch.max(action_probs, dim=-1)[0]
                E = action_probs[0,torch.argmax(max_disp, dim=1).item()].cpu().numpy()
          
            mag = policy.normalizer["action"].unnormalize(actions)[0]
            #theta = policy.normalizer["action"].unnormalize(actions)[:,:,1]
            theta = actions[1]
            x = mag * torch.cos(theta)
            y = mag * torch.sin(theta)
            actions = torch.concat([x.view(B,1), y.view(B,1)], dim=1).unsqueeze(1)
            
            if vid_out:
                img = env.render('human')
                
                f = plt.figure(figsize=(10,3))
                ax1 = f.add_subplot(111)
                ax2 = f.add_subplot(141, projection='polar')
                ax1.imshow(img)
                ax2.plot(np.linspace(0, 2*np.pi, E.shape[0]), E)
                #ax2.set_rmax(np.max(E))
                #ax2.set_rticks(np.round(np.linspace(np.min(E), np.max(E), 5), 3))
                ax2.set_rticks(list())
                ax2.grid(True)
                
                ax2.set_title(f"R={mag.item():.3f}", va="bottom")
                #plt.show()
                io_buf = io.BytesIO()
                f.savefig(io_buf, format='raw')
                io_buf.seek(0)
                img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
                             newshape=(int(f.bbox.bounds[3]), int(f.bbox.bounds[2]), -1))
                plots.append(img_arr)
                io_buf.close()
                plt.close() 
        
            x_act = actions[:,:,0]
            y_act = actions[:,:,1] * -1.0
            new_act = torch.concatenate((x_act, y_act), dim=-1).view(B,1,2)
            obs, r, d, _ = env.step(new_act.squeeze(0).cpu().numpy())
            max_rewards[eps] = max(r, max_rewards[eps])
        imageio.mimwrite(f'plots/lmax_3/{eps}.gif', plots)
        print(max_rewards[eps])
        pbar.update(1)
    return max_rewards

In [6]:
max_rewards = eval(num_eps=1, start_seed=100000, num_disp=100, num_rot=360, sample_act=True, temp=1.0)


  0%|                                                                                                                                                                                                                                                        | 0/1 [00:00<?, ?it/s][A

0.04672551155090332


ALSA lib confmisc.c:855:(parse_card) cannot find card '0'
ALSA lib conf.c:5178:(_snd_config_evaluate) function snd_func_card_inum returned error: No such file or directory
ALSA lib confmisc.c:422:(snd_func_concat) error evaluating strings
ALSA lib conf.c:5178:(_snd_config_evaluate) function snd_func_concat returned error: No such file or directory
ALSA lib confmisc.c:1334:(snd_func_refer) error evaluating name
ALSA lib conf.c:5178:(_snd_config_evaluate) function snd_func_refer returned error: No such file or directory
ALSA lib conf.c:5701:(snd_config_expand) Evaluate error: No such file or directory
ALSA lib pcm.c:2664:(snd_pcm_open_noupdate) Unknown PCM default


0.0032472610473632812
0.0029153823852539062
0.0028002262115478516
0.002916097640991211
0.0028378963470458984
0.003218412399291992
0.0028808116912841797
0.002755880355834961
0.0027701854705810547
0.0027704238891601562
0.0027947425842285156
0.003570556640625
0.0029685497283935547
0.0032253265380859375
0.002918243408203125
0.003191709518432617
0.011161327362060547
0.0030672550201416016
0.002821207046508789
0.0027627944946289062
0.002889871597290039


KeyboardInterrupt: 

In [None]:
np.mean(max_rewards)

In [None]:
max_rewards