In [1]:
import hydra
import torch
from lightning.fabric import Fabric
from lightning.pytorch.utilities.seed import isolate_rng
from sheeprl.algos.dreamer_v3.utils import init_weights, uniform_init_weights
from sheeprl.models.models import MLP
from sheeprl.utils.utils import dotdict
from omegaconf import OmegaConf
import pathlib
from sheeprl.utils.env import make_env
import gymnasium as gym
import numpy as np
from torch import nn
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
import os
from sheeprl.utils.distribution import (
    MSEDistribution
)

# path of your checkpoint
ckpt_path = pathlib.Path(r"C:\Users\user\Documents\MasterThesis\sheep RL\sheeprl-main\logs\runs\p2e_dv3_exploration\PongNoFrameskip-v4\2024-07-11_03-11-12_p2e_dv3_exploration_PongNoFrameskip-v4_42\version_0\checkpoint\ckpt_80000_0.ckpt")

seed = 12
fabric = Fabric(accelerator="cuda", devices=1)
fabric.launch()
#state = fabric.load(ckpt_path)
cfg = dotdict(OmegaConf.to_container(OmegaConf.load(ckpt_path.parent.parent / "config.yaml"), resolve=True))

You are using a CUDA device ('NVIDIA GeForce RTX 4060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


In [2]:
torch.set_float32_matmul_precision('medium')
# Environment setup
vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv
envs = vectorized_env(
    [
        make_env(
            cfg,
            cfg.seed + 0 * cfg.env.num_envs + i,
            0 * cfg.env.num_envs,
            "./imagination",
            "imagination",
            vector_env_idx=i,
        )
        for i in range(cfg.env.num_envs)
    ]
)
action_space = envs.single_action_space
observation_space = envs.single_observation_space

is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
actions_dim = tuple(
    action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r

  logger.warn(


In [3]:
# The number of environments is set to 1
cfg.env.num_envs = 4

ens_list = []
cfg_ensembles = cfg.algo.ensembles
cfg_ensembles.n = 4
ensembles_ln_cls = hydra.utils.get_class(cfg_ensembles.layer_norm.cls)
with isolate_rng():
    for i in range(cfg_ensembles.n):
        fabric.seed_everything(cfg.seed + i)
        ens_list.append(
            MLP(
                input_dims=int(
                    sum(actions_dim)
                    + cfg.algo.world_model.recurrent_model.recurrent_state_size
                    + cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size
                ),
                output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size,
                hidden_sizes=[cfg_ensembles.dense_units] * cfg_ensembles.mlp_layers,
                activation=hydra.utils.get_class(cfg_ensembles.dense_act),
                flatten_dim=None,
                layer_args={"bias": ensembles_ln_cls == nn.Identity},
                norm_layer=ensembles_ln_cls,
                norm_args={
                    **cfg_ensembles.layer_norm.kw,
                    "normalized_shape": cfg_ensembles.dense_units,
                },
            ).apply(init_weights)
        )

Seed set to 42
Seed set to 43
Seed set to 44
Seed set to 45


In [4]:
ensembles = nn.ModuleList(ens_list)
for i in range(len(ensembles)):
    ensembles[i] = fabric.setup_module(ensembles[i])

In [6]:
if cfg.checkpoint.resume_from:
    state = fabric.load(cfg.checkpoint.resume_from)

world_size = fabric.world_size
buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 4
rb = EnvIndependentReplayBuffer(
    buffer_size,
    n_envs=cfg.env.num_envs,
    memmap=cfg.buffer.memmap,
    memmap_dir=os.path.join(".", "memmap_buffer", f"rank_{fabric.global_rank}"),
    buffer_cls=SequentialReplayBuffer,
)

if cfg.checkpoint.resume_from and cfg.buffer.checkpoint:
    if isinstance(state["rb"], list) and world_size == len(state["rb"]):
        rb = state["rb"][fabric.global_rank]
    elif isinstance(state["rb"], EnvIndependentReplayBuffer):
        rb = state["rb"]
    else:
        raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated")

In [7]:
import copy
import pathlib

import gymnasium as gym
import numpy as np
import torch
import torchvision
from lightning.fabric import Fabric
from omegaconf import OmegaConf
from PIL import Image
import torch.nn as nn

from sheeprl.algos.dreamer_v3.agent import build_agent
#from sheeprl.algos.p2e_dv3.agent import build_agent
from sheeprl.data.buffers import SequentialReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.utils import dotdict

# path of your checkpoint
ckpt_path = pathlib.Path(r"C:\Users\user\Documents\MasterThesis\sheep RL\sheeprl-main\logs\runs\dreamer_v3\PongNoFrameskip-v4\2024-06-17_00-51-22_dreamer_v3_PongNoFrameskip-v4_42\version_0\checkpoint\ckpt_120000_0.ckpt")
seed = 12
# fabric_model = Fabric(accelerator="cuda", devices=1)
# fabric_model.launch()
model_state = fabric.load(ckpt_path)
model_cfg = dotdict(OmegaConf.to_container(OmegaConf.load(ckpt_path.parent.parent / "config.yaml"), resolve=True))

# The number of environments is set to 1
# model_cfg.env.num_envs = 1
# torch.set_float32_matmul_precision('medium')
# envs = gym.vector.SyncVectorEnv(
#     [
#         make_env(
#             model_cfg,
#             model_cfg.seed + 0 * model_cfg.env.num_envs + i,
#             0 * model_cfg.env.num_envs,
#             "./imagination",
#             "imagination",
#             vector_env_idx=i,
#         )
#         for i in range(model_cfg.env.num_envs)
#     ]
# )
# action_space = envs.single_action_space
# observation_space = envs.single_observation_space

# obs_keys = model_cfg.algo.cnn_keys.encoder + model_cfg.algo.mlp_keys.encoder
# is_continuous = isinstance(action_space, gym.spaces.Box)
# is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
# actions_dim = tuple(
#     action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
# )
# is_continuous = isinstance(action_space, gym.spaces.Box)
# is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
# actions_dim = tuple(
#     action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
# )
world_model, actor, critic, critic_target, player = build_agent(
    fabric,
    actions_dim,
    is_continuous,
    cfg,
    observation_space,
    model_state["world_model"],
    model_state["actor"],
    model_state["critic"],
    model_state["target_critic"]
)

In [9]:
from tqdm import tqdm

In [10]:
policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1

for iter_num in tqdm(range(39999, total_iters)):

    local_data = rb.sample_tensors(
        cfg.algo.per_rank_batch_size,
        sequence_length=cfg.algo.per_rank_sequence_length,
        n_samples=64,
        dtype=None,
        device=fabric.device,
        from_numpy=cfg.buffer.from_numpy,
    )
    batch = {k: v[i].float() for k, v in local_data.items()}
    data = batch

    batch_size = cfg.algo.per_rank_batch_size
    sequence_length = cfg.algo.per_rank_sequence_length
    recurrent_state_size = cfg.algo.world_model.recurrent_model.recurrent_state_size
    stochastic_size = cfg.algo.world_model.stochastic_size
    discrete_size = cfg.algo.world_model.discrete_size
    device = fabric.device
    data = {k: data[k] for k in data.keys()}
    batch_obs = {k: data[k] / 255.0 - 0.5 for k in cfg.algo.cnn_keys.encoder}
    batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder})
    data["is_first"][0, :] = torch.ones_like(data["is_first"][0, :])

    # Given how the environment interaction works, we remove the last actions
    # and add the first one as the zero action
    batch_actions = torch.cat((torch.zeros_like(data["actions"][:1]), data["actions"][:-1]), dim=0)

    # Dynamic Learning
    stoch_state_size = stochastic_size * discrete_size
    recurrent_state = torch.zeros(1, batch_size, recurrent_state_size, device=device)
    posterior = torch.zeros(1, batch_size, stochastic_size, discrete_size, device=device)
    recurrent_states = torch.empty(sequence_length, batch_size, recurrent_state_size, device=device)
    priors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device)
    posteriors = torch.empty(sequence_length, batch_size, stochastic_size, discrete_size, device=device)
    posteriors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device)

    # embedded observations from the environment
    embedded_obs = world_model.encoder(batch_obs)

    for i in range(0, sequence_length):
        recurrent_state, posterior, _, posterior_logits, prior_logits = world_model.rssm.dynamic(
            posterior, recurrent_state, batch_actions[i : i + 1], embedded_obs[i : i + 1], data["is_first"][i : i + 1]
        )
        recurrent_states[i] = recurrent_state
        priors_logits[i] = prior_logits
        posteriors[i] = posterior
        posteriors_logits[i] = posterior_logits
    latent_states = torch.cat((posteriors.view(*posteriors.shape[:-2], -1), recurrent_states), -1)

    ensemble_optimizer = hydra.utils.instantiate(
        cfg.algo.critic.optimizer, params=ensembles.parameters(), _convert_="all"
    )

    loss = 0.0
    ensemble_optimizer.zero_grad(set_to_none=True)
    for ens in ensembles:
        out = ens(
            torch.cat(
                (
                    posteriors.view(*posteriors.shape[:-2], -1).detach(),
                    recurrent_states.detach(),
                    data["actions"].detach(),
                ),
                -1,
            )
        )[:-1]
        next_state_embedding_dist = MSEDistribution(out, 1)
        loss -= next_state_embedding_dist.log_prob(posteriors.view(sequence_length, batch_size, -1).detach()[1:]).mean()
    loss.backward()

    #print(loss.item())
    if iter_num % 1000 == 0:
        print(loss.item())

    ensemble_grad = None
    if cfg.algo.ensembles.clip_gradients is not None and cfg.algo.ensembles.clip_gradients > 0:
        ensemble_grad = fabric.clip_gradients(
            module=ens,
            optimizer=ensemble_optimizer,
            max_norm=cfg.algo.ensembles.clip_gradients,
            error_if_nonfinite=False,
        )
    ensemble_optimizer.step()

    if (iter_num+1) % 10000 == 0:
        state = {
            "ensembles": ensembles.state_dict(),
            "ensemble_optimizer": ensemble_optimizer.state_dict(),
            "iter_num": iter_num * world_size,
            "batch_size": cfg.algo.per_rank_batch_size * world_size,
        }
        ckpt_path = fr"C:\Users\user\Documents\MasterThesis\sheep RL\sheeprl-main\notebooks\checkpoints_ensembles\ckpt_{iter_num}_{fabric.global_rank}.ckpt"
        fabric.save(ckpt_path, state)

  0%|          | 2/1210001 [00:09<1463:43:54,  4.35s/it]

83.65117645263672


  0%|          | 1002/1210001 [22:18<438:01:37,  1.30s/it]

84.34780883789062


  0%|          | 2002/1210001 [44:20<463:56:13,  1.38s/it]

83.34278106689453


  0%|          | 3002/1210001 [1:06:19<436:39:54,  1.30s/it]

83.37515258789062


  0%|          | 4002/1210001 [1:28:20<437:00:49,  1.30s/it]

83.92355346679688


  0%|          | 5002/1210001 [1:50:12<434:38:21,  1.30s/it]

83.25057983398438


  0%|          | 6002/1210001 [2:12:18<444:04:17,  1.33s/it]

83.67758178710938


  1%|          | 7002/1210001 [2:34:29<438:11:50,  1.31s/it]

83.06826782226562


  1%|          | 8002/1210001 [2:56:34<468:33:40,  1.40s/it]

84.1716537475586


  1%|          | 9002/1210001 [3:19:09<434:41:02,  1.30s/it]

82.94760131835938


  1%|          | 10002/1210001 [3:42:12<491:57:53,  1.48s/it]

82.81026458740234


  1%|          | 10130/1210001 [3:45:06<444:23:25,  1.33s/it]


KeyboardInterrupt: 

In [15]:
import os

os.environ["MUJOCO_GL"] = "egl"
import copy
import pathlib

import gymnasium as gym
import numpy as np
import torch
import torchvision
from lightning.fabric import Fabric
from omegaconf import OmegaConf
from PIL import Image
import torch.nn as nn

from sheeprl.algos.dreamer_v3.agent import build_agent
#from sheeprl.algos.p2e_dv3.agent import build_agent
from sheeprl.data.buffers import SequentialReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.utils import dotdict
## Agent and Environment initialization

# path of your checkpoint
ckpt_path = pathlib.Path(r"C:\Users\user\Documents\MasterThesis\sheep RL\sheeprl-main\logs\runs\dreamer_v3\PongNoFrameskip-v4\2024-06-17_00-51-22_dreamer_v3_PongNoFrameskip-v4_42\version_0\checkpoint\ckpt_120000_0.ckpt")
seed = 12
fabric = Fabric(accelerator="cuda", devices=1)
fabric.launch()
state = fabric.load(ckpt_path)
cfg = dotdict(OmegaConf.to_container(OmegaConf.load(ckpt_path.parent.parent / "config.yaml"), resolve=True))

#fabric.seed_everything(cfg.seed)
# The number of environments is set to 1
cfg.env.num_envs = 1
torch.set_float32_matmul_precision('medium')
envs = gym.vector.SyncVectorEnv(
    [
        make_env(
            cfg,
            cfg.seed + 0 * cfg.env.num_envs + i,
            0 * cfg.env.num_envs,
            "./imagination",
            "imagination",
            vector_env_idx=i,
        )
        for i in range(cfg.env.num_envs)
    ]
)
action_space = envs.single_action_space
observation_space = envs.single_observation_space

obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder
# is_continuous = isinstance(action_space, gym.spaces.Box)
# is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
# actions_dim = tuple(
#     action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
# )
# (
#     world_model,
#     ensembles,
#     actor,
#     critic,
#     target_critic,
#     actor_exploration,
#     critics_exploration,
#     player,
# ) = build_agent(
#     fabric,
#     actions_dim,
#     is_continuous,
#     cfg,
#     observation_space,
#     state["world_model"] if cfg.checkpoint.resume_from else None,
#     state["ensembles"] if cfg.checkpoint.resume_from else None,
#     state["actor_task"] if cfg.checkpoint.resume_from else None,
#     state["critic_task"] if cfg.checkpoint.resume_from else None,
#     state["target_critic_task"] if cfg.checkpoint.resume_from else None,
#     state["actor_exploration"] if cfg.checkpoint.resume_from else None,
#     state["critics_exploration"] if cfg.checkpoint.resume_from else None,
# )
# is_continuous = isinstance(action_space, gym.spaces.Box)
# is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
# actions_dim = tuple(
#     action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
# )
# (
#     world_model_f,
#     ensembles_f,
#     actor_f,
#     critic_f,
#     target_critic_f,
#     actor_exploration_f,
#     critics_exploration_f,
#     player_f,
# ) = build_agent(
#     fabric,
#     actions_dim,
#     is_continuous,
#     cfg,
#     observation_space,
#     state["world_model"] if cfg.checkpoint.resume_from else None,
#     state["ensembles"] if cfg.checkpoint.resume_from else None,
#     state["actor_task"] if cfg.checkpoint.resume_from else None,
#     state["critic_task"] if cfg.checkpoint.resume_from else None,
#     state["target_critic_task"] if cfg.checkpoint.resume_from else None,
#     state["actor_exploration"] if cfg.checkpoint.resume_from else None,
#     state["critics_exploration"] if cfg.checkpoint.resume_from else None,
# )
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
actions_dim = tuple(
    action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
world_model, actor, critic, critic_target, player = build_agent(
    fabric,
    actions_dim,
    is_continuous,
    cfg,
    observation_space,
    state["world_model"],
    state["actor"],
    state["critic"],
    state["target_critic"]
)
## Buffer initialization


initial_steps = 600  # set according to your environment.
imagination_steps = 50  # number of imagination steps, must be lower than or equal to the `initial_steps`.
clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r
rb_initial = SequentialReplayBuffer(initial_steps, cfg.env.num_envs)
rb_play = SequentialReplayBuffer(imagination_steps, cfg.env.num_envs)
rb_imagination = SequentialReplayBuffer(imagination_steps, cfg.env.num_envs)
step_data = {}
player.init_states()
obs = envs.reset(seed=cfg.seed)[0]
for k in obs_keys:
    step_data[k] = obs[k][np.newaxis]
step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["is_first"] = np.ones_like(step_data["dones"])
## Environment interaction

def add_salt_and_pepper_noise(image, salt_prob, pepper_prob):
    noisy_image = image.copy()
    total_pixels = image.size
    num_salt = np.ceil(salt_prob * total_pixels)
    num_pepper = np.ceil(pepper_prob * total_pixels)

    # Adding salt noise (white pixels)
    coords = [np.random.randint(0, i, int(num_salt)) for i in image.shape]
    noisy_image[tuple(coords)] = 255

    # Adding pepper noise (black pixels)
    coords = [np.random.randint(0, i, int(num_pepper)) for i in image.shape]
    noisy_image[tuple(coords)] = 0

    return noisy_image

def add_gaussian_noise(image, mean=0, std=0.3):
    gauss = np.random.normal(mean, std, image.shape, )
    noisy_image = image + gauss
    #noisy_image = np.clip(noisy_image, 0, 1)  # Ensure values are within [0, 1]
    return noisy_image

salt_prob = 0.3  # 2% of the pixels will be salt noise
pepper_prob = 0.3  # 2% of the pixels will be pepper noise
info_list = []

avg_ld = 0
# play for `initial_steps` steps
for i in range(initial_steps):
    with torch.no_grad():
        preprocessed_obs = {}

        obs["rgb"] = add_salt_and_pepper_noise(obs["rgb"], salt_prob, pepper_prob)
        #obs["rgb"] = np.ones((1,3,64,64)) * 255
        for k, v in obs.items():
            preprocessed_obs[k] = torch.as_tensor(v[np.newaxis], dtype=torch.float32, device=fabric.device)
            if k in cfg.algo.cnn_keys.encoder:
                preprocessed_obs[k] = preprocessed_obs[k] / 255.0 - 0.5
        mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
        if len(mask) == 0:
            mask = None
        real_actions = actions = player.get_actions(preprocessed_obs, mask=mask)
        actions = torch.cat(actions, -1).cpu().numpy()

        temp_states = torch.cat((player.stochastic_state, player.recurrent_state), -1).to("cuda")
        test_input = torch.cat((temp_states, real_actions[0]), -1)

        next_state_predictions = []

        #ensembles_input[0][0][0] = 10000.0
        for ens in ensembles:
            next_state_predictions.append(
                ens(
                    test_input
                )
            )
        
        ld = torch.stack(next_state_predictions).var(0).mean(-1, keepdim=True).to("cpu").detach().item()
        print(ld)
        avg_ld += ld
        if is_continuous:
            real_actions = torch.stack(real_actions, dim=-1).cpu().numpy()
        else:
            real_actions = torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()

    step_data["stochastic_state"] = player.stochastic_state.detach().cpu().numpy()
    step_data["recurrent_state"] = player.recurrent_state.detach().cpu().numpy()
    step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1))
    rb_initial.add(step_data, validate_args=cfg.buffer.validate_args)

    next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape))
    #next_obs["rgb"] = add_salt_and_pepper_noise(next_obs["rgb"], salt_prob, pepper_prob)
    
    rewards = np.array(rewards).reshape((1, cfg.env.num_envs, -1))
    dones = np.logical_or(dones, truncated).astype(np.uint8).reshape((1, cfg.env.num_envs, -1))

    step_data["is_first"] = np.zeros_like(step_data["dones"])
    if "restart_on_exception" in infos:
        for i, agent_roe in enumerate(infos["restart_on_exception"]):
            if agent_roe and not dones[i]:
                last_inserted_idx = (rb_initial.buffer[i]._pos - 1) % rb_initial.buffer[i].buffer_size
                rb_initial.buffer[i]["dones"][last_inserted_idx] = np.ones_like(
                    rb_initial.buffer[i]["dones"][last_inserted_idx]
                )
                rb_initial.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like(
                    rb_initial.buffer[i]["is_first"][last_inserted_idx]
                )
                step_data["is_first"][i] = np.ones_like(step_data["is_first"][i])

    real_next_obs = copy.deepcopy(next_obs)
    if "final_observation" in infos:
        for idx, final_obs in enumerate(infos["final_observation"]):
            if final_obs is not None:
                for k, v in final_obs.items():
                    real_next_obs[k][idx] = v

    for k in obs_keys:
        step_data[k] = next_obs[k][np.newaxis]

    obs = next_obs

    rewards = rewards.reshape((1, cfg.env.num_envs, -1))
    step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1))
    step_data["rewards"] = clip_rewards_fn(rewards)
    step_data["rewards"] = clip_rewards_fn(rewards)
    dones_idxes = dones.nonzero()[0].tolist()
    reset_envs = len(dones_idxes)
    if reset_envs > 0:
        reset_data = {}
        for k in obs_keys:
            reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis]
        reset_data["dones"] = np.ones((1, reset_envs, 1))
        reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim)))
        reset_data["rewards"] = step_data["rewards"][:, dones_idxes]
        reset_data["is_first"] = np.zeros_like(reset_data["dones"])
        rb_initial.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args)

        # Reset already inserted step data
        step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"])
        step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes])
        step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes])
        player.init_states(dones_idxes)

    ## Save the recurrent and stochastic latent states for the imagination phase
    if i == initial_steps - imagination_steps:
        stochastic_state = player.stochastic_state.clone()
        recurrent_state = player.recurrent_state.clone()
## Imagination and Reconstruction

# deciede if you want to take the actions from the buffer
# (i.e., the actions actually played by the agent)
# or imagine them and compare with the actions actually played by the agent
imagine_actions = True
# imagination / reconstruction obs process
imagined_latent_states = torch.cat((stochastic_state, recurrent_state), -1)
step_data = {}
reconstruced_step_data = {}

imagined_trajectories = [imagined_latent_states]
imagined_actions = []
with torch.no_grad():
    for i in range(imagination_steps):
        if imagine_actions:
            # imagined actions
            actions = actor(imagined_latent_states.detach())[0][0]
            imagined_actions.append(actions)
        else:
            # actions actually played by the agent
            actions = torch.tensor(
                rb_initial["actions"][-imagination_steps + i],
                device=fabric.device,
                dtype=torch.float32,
            )[None]
            imagined_actions.append(actions)

        # imagination step
        stochastic_state, recurrent_state = world_model.rssm.imagination(stochastic_state, recurrent_state, actions)
        stochastic_state = stochastic_state.view(1, 1, -1)
        # update current state
        imagined_latent_states = torch.cat((stochastic_state, recurrent_state), -1)
        imagined_trajectories.append(imagined_latent_states)

        rec_obs = world_model.observation_model(imagined_latent_states)
        step_data["rgb"] = rec_obs["rgb"].unsqueeze(0).detach().cpu().numpy() + 0.5
        step_data["actions"] = actions.unsqueeze(0).detach().cpu().numpy()
        rb_imagination.add(step_data)

        # reconstruct the observations from the latent states used when interacting with the environment
        played_latent_states = torch.cat(
            (
                torch.tensor(rb_initial["stochastic_state"][-imagination_steps + i], device=fabric.device),
                torch.tensor(rb_initial["recurrent_state"][-imagination_steps + i], device=fabric.device),
            ),
            -1,
        )
        rec_obs_played = world_model.observation_model(played_latent_states)
        # The decoder has been trained to reconstruct the observations from the latent states in the range [-0.5, 0.5]
        # NOTE: Check how the observations are handled in older versions of SheepRL (before 0.5.5)
        # if you need to add 0.5 or not (in latest versions it is done automatically by the decoder in its forward method).
        reconstruced_step_data["rgb"] = rec_obs_played["rgb"].unsqueeze(0).detach().cpu().numpy() + 0.5
        rb_play.add(reconstruced_step_data)
ensembles_input = torch.cat((imagined_trajectories[1], imagined_actions[1]), -1)

Seed set to 42


0.0006449965294450521
0.00018769799498841166
0.0002079825644614175
0.0003028228529728949
0.0003159301704727113
0.0002625550841912627
0.00023666485503781587
0.00021871173521503806
0.0002210103120887652
0.00022103193623479456
0.00021067471243441105
0.00022386792988982052
0.00020755603327415884
0.00028745405143126845
0.00022042490309104323
0.00021592702250927687
0.00042325197136960924
0.00035299500450491905
0.00033605776843614876
0.00017703816411085427
0.00020869733998551965
0.00019944040104746819
0.00017044864944182336
0.00015068605716805905
0.00017090157780330628
0.00014473972260020673
0.0001942969101946801
0.0003904083860106766
0.0002764298114925623
0.00022933899890631437
0.00016008209786377847
0.0001863927027443424
0.00023655610857531428
0.00034613063326105475
0.000315347861032933
0.0003102956688962877
0.000256091560004279
0.0003233199822716415
0.00027932034572586417
0.00023034510377328843
0.0001903433003462851
0.00023273623082786798
0.00019829437951557338
0.0002154249232262373
0.0002

In [14]:
avg_ld/initial_steps

0.00016893557195241254

In [16]:
avg_ld/initial_steps

0.0002545898898097221

In [4]:
from sheeprl.algos.p2e_dv3.agent import build_agent
state_ens = fabric.load(r"C:\Users\user\Documents\MasterThesis\sheep RL\sheeprl-main\notebooks\checkpoints_ensembles\ckpt_39999_0.ckpt")

In [5]:
# ens_list = []
# cfg_ensembles = cfg.algo.ensembles
# ensembles_ln_cls = hydra.utils.get_class(cfg_ensembles.layer_norm.cls)
# with isolate_rng():
#     for i in range(4):
#         fabric.seed_everything(cfg.seed + i)
#         ens_list.append(
#             MLP(
#                 input_dims=int(
#                     sum(actions_dim)
#                     + cfg.algo.world_model.recurrent_model.recurrent_state_size
#                     + cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size
#                 ),
#                 output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size,
#                 hidden_sizes=[cfg_ensembles.dense_units] * cfg_ensembles.mlp_layers,
#                 activation=hydra.utils.get_class(cfg_ensembles.dense_act),
#                 flatten_dim=None,
#                 layer_args={"bias": ensembles_ln_cls == nn.Identity},
#                 norm_layer=ensembles_ln_cls,
#                 norm_args={
#                     **cfg_ensembles.layer_norm.kw,
#                     "normalized_shape": cfg_ensembles.dense_units,
#                 },
#             ).apply(init_weights)
#         )
ensembles = nn.ModuleList(ens_list)
if state_ens["ensembles"]:
    ensembles.load_state_dict(state_ens["ensembles"])
    print("hey")
for i in range(len(ensembles)):
    ensembles[i] = fabric.setup_module(ensembles[i])

hey


In [64]:
state_ens.keys()

dict_keys(['ensembles', 'ensemble_optimizer', 'iter_num', 'batch_size'])

In [134]:
is_continuous

False

In [128]:
import cv2
image = obs["rgb"]

image = image.squeeze().transpose(1, 2, 0)

salt_prob = 0.3  # 2% of the pixels will be salt noise
pepper_prob = 0.05
image = add_salt_and_pepper_noise(image, salt_prob, pepper_prob)
# Since the image might be in the range [0, 1], convert it to [0, 255]
#image = (image * 255).astype(np.uint8)

# Display the image using OpenCV
cv2.imshow('Image', image)
cv2.waitKey(0)
cv2.destroyAllWindows()

In [129]:
new_obs = {"rgb": image.reshape(1,3,64,64)}

In [130]:
preprocessed_obs = {}
for k, v in new_obs.items():
    preprocessed_obs[k] = torch.as_tensor(v[np.newaxis], dtype=torch.float32, device=fabric.device)
    if k in cfg.algo.cnn_keys.encoder:
        preprocessed_obs[k] = preprocessed_obs[k] / 255.0 - 0.5
mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
if len(mask) == 0:
    mask = None
real_actions = actions = player.get_actions(preprocessed_obs, mask=mask)

In [131]:
no_noise_rec_state = player.recurrent_state

In [132]:
player.stochastic_state
player.recurrent_state

imagined_latent_states = torch.cat((player.stochastic_state, player.recurrent_state), -1).to("cuda")
test_input = torch.cat((imagined_latent_states, real_actions[0]), -1)

In [133]:
with torch.no_grad():

# Latent disagreement evaluation

    salt_prob = 0.55  # 2% of the pixels will be salt noise
    pepper_prob = 0.55  # 2% of the pixels will be pepper noise
    # ensembles_input = torch.cat((imagined_trajectories[45], imagined_actions[45]), -1)
    #ensembles_input = test_input
    # ensembles_input = add_gaussian_noise(
    #     ensembles_input.cpu(),
    #     mean=0,
    #     std=0.000000000005
    #     ).to("cuda")


    #ensembles_input = torch.Tensor(add_salt_and_pepper_noise(ensembles_input.cpu().numpy(), salt_prob, pepper_prob)).to("cuda")
   # ensembles_input = ensembles_input.type(torch.float32)
    next_state_predictions = []

    #ensembles_input[0][0][0] = 10000.0
    for ens in ensembles:
        next_state_predictions.append(
            ens(
                test_input
            )
        )
    
    print(torch.stack(next_state_predictions).var(0).mean(-1, keepdim=True))

tensor([[[0.0004]]], device='cuda:0')


In [61]:
0.0002

0.0291

In [164]:
next_state_predictions

[tensor([[[ 0.0121, -0.0126, -0.0078,  ..., -0.0403,  0.2224,  0.0054]]],
        device='cuda:0'),
 tensor([[[ 0.0007, -0.0040,  0.0119,  ..., -0.0925,  0.1733,  0.0109]]],
        device='cuda:0'),
 tensor([[[-0.0012, -0.0083, -0.0115,  ..., -0.0630,  0.2246, -0.0140]]],
        device='cuda:0'),
 tensor([[[ 0.0042,  0.0216,  0.0047,  ..., -0.0007,  0.2231,  0.0054]]],
        device='cuda:0')]

In [80]:
ensembles_input[0][0][1000:]

tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.9204e-01,
        -2.6577e-02,  1.1485e-01, -2.9715e-01,  2.8504e-01,  9.7372e-01,
        -1.2104e-01,  9.2868e-02,  3.4227e-01, -1.8017e-01,  3.1242e-01,
        -9.2735e-01,  1.4306e-01,  2.5604e-01,  4.7433e-02, -8.8480e-01,
        -1.4090e-02, -6.4023e-02, -7.4501e-03, -1.1800e-01, -3.4504e-01,
         4.4829e-01, -2.4832e-03, -2.7185e-01, -1.2636e-01, -1.9267e-01,
         9.1845e-01, -4.2993e-02, -9.7375e-01,  8.3756e-02, -7.4939e-02,
        -2.0845e-01,  7.6471e-01, -1.2103e-01,  9.0938e-02,  1.8896e-01,
        -8.5878e-01, -1.4089e-02, -2.2358e-01,  2.5581e-01,  2.2538e-01,
        -2.4724e-02, -4.7102e-02,  7.9861e-04,  4.6