In [1]:
import gymnasium as gym
import numpy as np
import torch as th

from rllte.env.utils import Gymnasium2Torch
from rllte.xplore.reward import ICM

**Create a fake Atari environment with image observations**

In [2]:
class FakeAtari(gym.Env):
    def __init__(self):
        self.action_space = gym.spaces.Discrete(7)
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(4, 84, 84))
        self.count = 0

    def reset(self):
        self.count = 0
        return self.observation_space.sample(), {}

    def step(self, action):
        self.count += 1
        if self.count > 100 and np.random.rand() < 0.1:
            term = trunc = True
        else:
            term = trunc = False
        return self.observation_space.sample(), 0, term, trunc, {}

**Synchronous Mode, the `.update()` will be automatically invoked in the `.compute()` function, usually for on-policy RL algorithms.**

In [3]:
# set the parameters
device = 'mps' if th.backends.mps.is_available() else 'cuda' if th.cuda.is_available() else 'cpu'
n_steps = 128
n_envs = 8
# create the vectorized environments
envs = gym.vector.AsyncVectorEnv([FakeAtari for _ in range(n_envs)])
# wrap the environments to convert the observations to torch tensors
envs = Gymnasium2Torch(envs, device)
# create the intrinsic reward module
irs = ICM(envs, device)
# reset the environments and get the initial observations
obs, infos = envs.reset()
# create a dictionary to store the samples
samples = {'observations':[], 
           'actions':[], 
           'rewards':[],
           'terminateds':[],
           'truncateds':[],
           'next_observations':[]}
# sampling loop
for _ in range(n_steps):
    # sample random actions
    actions = th.stack([th.as_tensor(envs.action_space.sample()) for _ in range(n_envs)])
    # environment step
    next_obs, rewards, terminateds, truncateds, infos = envs.step(actions)
    # watch the interactions and get necessary information for the intrinsic reward computation
    irs.watch(observations=obs, 
              actions=actions, 
              rewards=rewards,
              terminateds=terminateds,
              truncateds=truncateds,
              next_observations=next_obs)
    # store the samples
    samples['observations'].append(obs)
    samples['actions'].append(actions)
    samples['rewards'].append(rewards)
    samples['terminateds'].append(terminateds)
    samples['truncateds'].append(truncateds)
    samples['next_observations'].append(next_obs)
    obs = next_obs
# compute the intrinsic rewards
samples = {k: th.stack(v) for k, v in samples.items()}
intrinsic_rewards = irs.compute(samples=samples)
print(intrinsic_rewards)
print(intrinsic_rewards.shape)

Box(0.0, 1.0, (4, 84, 84), float32)
Discrete(7)
tensor([[5.3476, 7.1356, 6.8465,  ..., 7.4179, 7.0747, 5.7406],
        [6.2589, 7.1989, 5.0291,  ..., 6.7728, 7.8519, 6.2785],
        [6.5028, 7.3675, 5.8179,  ..., 6.5310, 5.7410, 6.5957],
        ...,
        [7.3416, 6.4600, 5.5094,  ..., 7.5757, 8.3019, 6.7766],
        [7.3124, 6.6850, 6.6613,  ..., 6.3896, 7.5636, 7.0359],
        [8.4999, 6.5634, 7.4811,  ..., 7.7395, 7.5860, 7.3720]],
       device='mps:0')
torch.Size([128, 8])


**Asynchronous Mode, the `.update()` must be invoked separately, usually for off-policy RL algorithms.**

In [5]:
# set the parameters
device = 'mps' if th.backends.mps.is_available() else 'cuda' if th.cuda.is_available() else 'cpu'
n_steps = 128
n_envs = 8
# create the vectorized environments
envs = gym.vector.AsyncVectorEnv([FakeAtari for _ in range(n_envs)])
# wrap the environments to convert the observations to torch tensors
envs = Gymnasium2Torch(envs, device)
# create the intrinsic reward module
irs = ICM(envs, device)
# reset the environments and get the initial observations
obs, infos = envs.reset()
# create a dictionary to store the samples
samples = {'observations':[], 
           'actions':[], 
           'rewards':[],
           'terminateds':[],
           'truncateds':[],
           'next_observations':[]}
# sampling loop
for _ in range(n_steps):
    # sample random actions
    actions = th.stack([th.as_tensor(envs.action_space.sample()) for _ in range(n_envs)])
    # environment step
    next_obs, rewards, terminateds, truncateds, infos = envs.step(actions)
    # watch the interactions and get necessary information for the intrinsic reward computation
    irs.watch(observations=obs, 
              actions=actions, 
              rewards=rewards,
              terminateds=terminateds,
              truncateds=truncateds,
              next_observations=next_obs)
    # compute the intrinsic rewards at each step
    intrinsic_rewards = irs.compute(samples={'observations':obs.unsqueeze(0), 
                                            'actions':actions.unsqueeze(0), 
                                            'rewards':rewards.unsqueeze(0),
                                            'terminateds':terminateds.unsqueeze(0),
                                            'truncateds':truncateds.unsqueeze(0),
                                            'next_observations':next_obs.unsqueeze(0)}, 
                                            sync=False)
    print(intrinsic_rewards, intrinsic_rewards.shape)
    # store the samples
    samples['observations'].append(obs)
    samples['actions'].append(actions)
    samples['rewards'].append(rewards)
    samples['terminateds'].append(terminateds)
    samples['truncateds'].append(truncateds)
    samples['next_observations'].append(next_obs)
    obs = next_obs
# update the intrinsic reward module
samples = {k: th.stack(v) for k, v in samples.items()}
irs.update(samples=samples)

tensor([[1.8394, 1.8430, 2.5142, 2.0302, 2.1765, 1.6593, 1.8448, 1.6650]],
       device='mps:0') torch.Size([1, 8])
tensor([[0.0185, 0.0136, 0.0185, 0.0149, 0.0150, 0.0161, 0.0186, 0.0180]],
       device='mps:0') torch.Size([1, 8])
tensor([[0.0159, 0.0131, 0.0197, 0.0171, 0.0144, 0.0159, 0.0196, 0.0190]],
       device='mps:0') torch.Size([1, 8])
tensor([[0.0156, 0.0186, 0.0207, 0.0172, 0.0214, 0.0157, 0.0186, 0.0208]],
       device='mps:0') torch.Size([1, 8])
tensor([[0.0224, 0.0202, 0.0201, 0.0224, 0.0201, 0.0202, 0.0224, 0.0187]],
       device='mps:0') torch.Size([1, 8])
tensor([[0.0240, 0.0248, 0.0165, 0.0183, 0.0242, 0.0183, 0.0182, 0.0249]],
       device='mps:0') torch.Size([1, 8])
tensor([[0.0231, 0.0264, 0.0264, 0.0265, 0.0265, 0.0175, 0.0257, 0.0265]],
       device='mps:0') torch.Size([1, 8])
tensor([[0.0243, 0.0187, 0.0280, 0.0187, 0.0206, 0.0225, 0.0187, 0.0206]],
       device='mps:0') torch.Size([1, 8])
tensor([[0.0295, 0.0197, 0.0294, 0.0238, 0.0197, 0.0285, 0.0295,