In [1]:
import gymnasium as gym
import numpy as np
import torch as th
from rllte.env.utils import Gymnasium2Torch
from rllte.xplore.reward import ICM

In [2]:
class FakeAtari(gym.Env):
    def __init__(self):
        self.action_space = gym.spaces.Discrete(7)
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(4, 84, 84))
        self.count = 0

    def reset(self):
        self.count = 0
        return self.observation_space.sample(), {}

    def step(self, action):
        self.count += 1
        if self.count > 100 and np.random.rand() < 0.1:
            term = trunc = True
        else:
            term = trunc = False
        return self.observation_space.sample(), 0, term, trunc, {}

In [3]:
device = 'mps' if th.cuda.is_available() else 'cpu'
n_envs = 8
n_steps = 128

envs = gym.vector.AsyncVectorEnv([FakeAtari for _ in range(n_envs)])
envs = Gymnasium2Torch(envs, device)
irs = ICM(envs, device, obs_norm_type="rms")
samples = {"observations": [], "actions": [], "rewards": [], 
           "terminateds": [], "truncateds": [], "next_observations": []}

obs, infos = envs.reset()
for _ in range(n_steps):
    actions = th.stack([th.as_tensor(envs.action_space.sample()) for _ in range(n_envs)])
    next_obs, rwds, terms, truncs, infos = envs.step(actions)
    irs.watch(observations=obs,
              actions=actions,
              rewards=rwds,
              terminateds=terms,
              truncateds=truncs,
              next_observations=next_obs
              )

    samples['observations'].append(obs)
    samples['actions'].append(actions)
    samples['rewards'].append(rwds)
    samples['terminateds'].append(terms)
    samples['truncateds'].append(truncs)
    samples['next_observations'].append(next_obs)
    
    obs = next_obs

samples = {k: th.stack(v) for k, v in samples.items()}
intrinsic_rewards = irs.compute(samples)
print(intrinsic_rewards.size(), intrinsic_rewards)

torch.Size([128, 8]) tensor([[ 8.5521,  9.2151,  6.7879,  ...,  8.1058, 10.1724,  6.8122],
        [ 9.0602,  8.9318,  8.3807,  ...,  9.0816,  8.7549,  7.4512],
        [ 8.1448,  9.9906,  8.9156,  ...,  7.0198,  9.0013,  8.5745],
        ...,
        [ 8.3094,  8.5377,  8.4231,  ...,  8.9208,  7.4855,  6.7978],
        [11.7878,  7.6400,  9.3886,  ...,  7.9695,  8.0754,  7.0947],
        [ 8.9067,  9.5402,  9.5420,  ...,  7.2864,  8.6413,  7.3011]])
