In [1]:
#auto reload modules
%load_ext autoreload
%autoreload 2

In [2]:
import os.path as osp
import random
from typing import Dict,List

import gym.spaces as spaces
import hydra
import numpy as np
import torch
from hydra.utils import instantiate as hydra_instantiate
from omegaconf import DictConfig
from rl_utils.envs import create_vectorized_envs
from rl_utils.logging import Logger
from tensordict.tensordict import TensorDict
from torchrl.envs.utils import step_mdp
from typing import Tuple
from imitation_learning.common.evaluator import Evaluator

In [3]:
from snake_env import SnakeEnv 

In [4]:
import yaml 
cfg = yaml.load(open("bc-irl-snake.yaml", 'r'), Loader=yaml.SafeLoader)
cfg = DictConfig(cfg)

In [5]:
cfg

{'env_settings': {}, 'obs_shape': '???', 'action_dim': '???', 'total_num_updates': '???', 'action_is_discrete': '???', 'num_steps': 50, 'num_envs': 256, 'device': 'cpu', 'only_eval': False, 'seed': 3, 'num_eval_episodes': 100, 'num_env_steps': 30000000, 'recurrent_hidden_state_size': 128, 'gamma': 0.8, 'log_interval': 10, 'eval_interval': 500, 'save_interval': 10000000000, 'load_checkpoint': None, 'load_policy': True, 'resume_training': False, 'policy': {'_target_': 'imitation_learning.policy_opt.policy.Policy', 'hidden_size': 512, 'recurrent_hidden_size': 128, 'is_recurrent': False, 'obs_shape': '${obs_shape}', 'action_dim': '${action_dim}', 'action_is_discrete': '${action_is_discrete}', 'std_init': 0, 'num_envs': '${num_envs}'}, 'policy_updater': {'_target_': 'imitation_learning.bcirl.BCIRL', '_recursive_': False, 'use_clipped_value_loss': True, 'clip_param': 0.2, 'value_loss_coef': 0.5, 'entropy_coef': 0.0001, 'max_grad_norm': 0.5, 'num_epochs': 2, 'num_mini_batch': 4, 'num_envs': '

In [6]:

def set_seed(seed: int) -> None:
    """
    Sets the seed for numpy, python random, and pytorch.
    """
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

In [7]:
from gymnasium import Env
class vectorized_env():
    def __init__(self, envs : List[Env]):
        self.envs = envs
        self.num_envs = len(self.envs)
        self.observation_space = self.envs[0].observation_space
        self.action_space = self.envs[0].action_space

    def reset(self):

        return torch.tensor([env.reset()[0].tolist() for env in self.envs],dtype=torch.float32)
    
    def step(self, action) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[Dict]]:
        steps = [env.step(action[i]) for i,env in enumerate(self.envs)]
        return_value = (torch.tensor([step[0].tolist() for step in steps],dtype=torch.float32),
                torch.tensor([step[1] for step in steps],dtype=torch.float32),
                torch.tensor([step[2] for step in steps],dtype=torch.bool),
                [step[3] for step in steps])
        return return_value
    
    def render(self, mode = "rgb_array"):
        if(self.envs[0].render_mode == "rgb_array"):
            return [env.render(mode) for env in self.envs]
        else:
            self.envs[0].render(mode)

In [8]:
set_seed(cfg.seed)

device = torch.device(cfg.device)

# Setup the environments
set_env_settings = {
    k: hydra_instantiate(v) if isinstance(v, DictConfig) else v
    for k, v in cfg.env.env_settings.items()
}
envs = vectorized_env([SnakeEnv(cfg.env.env_settings.params.config) for _ in range(cfg.num_envs)])

steps_per_update = cfg.num_steps * cfg.num_envs

num_updates = int(cfg.num_env_steps) // steps_per_update

# Set dynamic variables in the config.
cfg.obs_shape = envs.observation_space.shape
cfg.action_dim = envs.action_space.shape[0]
cfg.action_is_discrete = isinstance(cfg.action_dim, spaces.Discrete)
cfg.total_num_updates = num_updates

logger: Logger = hydra_instantiate(cfg.logger, full_cfg=cfg)
print("policy",cfg.policy)
policy = hydra_instantiate(cfg.policy)
policy = policy.to(device)
print("policy_updater",cfg.policy_updater)
updater = hydra_instantiate(cfg.policy_updater, policy=policy, device=device).to(device)


Assigning full prefix 65-3-aJmVtQ
policy {'_target_': 'imitation_learning.policy_opt.policy.Policy', 'hidden_size': 512, 'recurrent_hidden_size': 128, 'is_recurrent': False, 'obs_shape': [4], 'action_dim': 4, 'action_is_discrete': False, 'std_init': 0, 'num_envs': 256}
policy_updater {'_target_': 'imitation_learning.bcirl.BCIRL', '_recursive_': False, 'use_clipped_value_loss': True, 'clip_param': 0.2, 'value_loss_coef': 0.5, 'entropy_coef': 0.0001, 'max_grad_norm': 0.5, 'num_epochs': 2, 'num_mini_batch': 4, 'num_envs': '${num_envs}', 'num_steps': '${num_steps}', 'gae_lambda': 0.95, 'gamma': '${gamma}', 'optimizer_params': {'_target_': 'torch.optim.Adam', 'lr': 0.0003}, 'batch_size': 256, 'plot_interval': '${eval_interval}', 'norm_expert_actions': False, 'n_inner_iters': 1, 'reward_update_freq': 1, 'device': '${device}', 'total_num_updates': '${total_num_updates}', 'use_lr_decay': False, 'get_dataset_fn': {'_target_': 'imitation_learning.common.utils.get_transition_dataset', 'dataset_pa

In [9]:

start_update = 0
if cfg.load_checkpoint is not None:
    # Load a checkpoint for the policy/reward. Also potentially resume
    # training.
    ckpt = torch.load(cfg.load_checkpoint)
    updater.load_state_dict(ckpt["updater"], should_load_opt=cfg.resume_training)
    if cfg.load_policy:
        policy.load_state_dict(ckpt["policy"])
    if cfg.resume_training:
        start_update = ckpt["update_i"] + 1

eval_info = {"run_name": logger.run_name}


In [10]:
import warnings 
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")

In [11]:




# Storage for the rollouts
obs = envs.reset()
td = TensorDict({"observation": obs}, batch_size=[cfg.num_envs])

# Storage for the rollouts
storage_td = TensorDict({}, batch_size=[cfg.num_envs, cfg.num_steps], device=device)

for update_i in range(start_update, num_updates):
    is_last_update = update_i == num_updates - 1
    for step_idx in range(cfg.num_steps):

        # Collect experience.
        with torch.no_grad():
            policy.act(td)
        next_obs, reward, done, infos = envs.step(td["action"])

        td["next_observation"] = next_obs
        for env_i, info in enumerate(infos):
            if "final_obs" in info:
                td["next_observation"][env_i] = info["final_obs"]
        td["reward"] = reward.reshape(-1, 1)
        td["done"] = done
    
        storage_td[:, step_idx] = td
        # Log to CLI/wandb.
        logger.collect_env_step_info(infos)
    
    # Call method specific update function
    updater.update(policy, storage_td, logger, envs=envs)



    if cfg.log_interval != -1 and (
        update_i % cfg.log_interval == 0 or is_last_update
    ):
        logger.interval_log(update_i, steps_per_update * (update_i + 1))
        height = 2
        width = 2
        eval_env = SnakeEnv(cfg.env.env_settings.params.config)
        fig, ax = plt.subplots(nrows=height, ncols=width, sharex=True, sharey=True, gridspec_kw={'wspace': 0, 'hspace': 0})

        last_reward_map = np.zeros((eval_env.screen_width//eval_env.block_size, eval_env.screen_height//eval_env.block_size))

        for i in range(height):
            for j in range(width):
                
                reward_map = np.zeros((eval_env.screen_width//eval_env.block_size, eval_env.screen_height//eval_env.block_size))
                apple_pos = eval_env.reset()[0][:2]
                #test what you got so far by plotting a heat map of the reward using the snake only 
                for x in range(eval_env.screen_width//eval_env.block_size):
                    for y in range(eval_env.screen_height//eval_env.block_size ):
                        x_grid = x * eval_env.block_size / eval_env.screen_width
                        y_grid = y * eval_env.block_size / eval_env.screen_height
                        reward_map[x,y] = updater.reward(next_obs = torch.tensor([*apple_pos,x_grid,y_grid]+[0]*(eval_env.observation_space.shape[0]-4),dtype=torch.float32).to(device).view(1,1,-1))
                # for x in range(eval_env.screen_width//eval_env.block_size):
                #     for y in range(eval_env.screen_height//eval_env.block_size ):
                #         print(f"{reward_map[x,y]:.2f}", end=" ")
                #     print()
                    
                # Define the color map
                cmap = plt.cm.get_cmap('hot')

                # Plot the reward map without axis and numbers
                image = ax[i,j].imshow(reward_map, cmap=cmap, interpolation='nearest')
                ax[i,j].axis('off')

                # Plot the apple
                ax[i,j].scatter(
                    apple_pos[1] * eval_env.screen_height // eval_env.block_size,
                    apple_pos[0] * eval_env.screen_width // eval_env.block_size,
                    c='blue',
                    s=60
                )
                # map_diff = reward_map - last_reward_map
                # for x in range(eval_env.screen_width//eval_env.block_size):
                #     for y in range(eval_env.screen_height//eval_env.block_size ):
                #          print(f"{reward_map[x,y]:.2f}", end=" ")
                #     print()
                    
                # print("reward_maps diff " ,np.linalg.norm(reward_map-last_reward_map))
                last_reward_map = reward_map
        plt.tight_layout()
        plt.savefig(osp.join(logger.save_path, f"reward_map.{update_i}.png"))
        print(f"Saved to {osp.join(logger.save_path, f'reward_map.{update_i}.png')}")
        # for x in range(eval_env.screen_width//eval_env.block_size):
        #     for y in range(eval_env.screen_height//eval_env.block_size ):
        #         print(f"{reward_map[x,y]:.2f}", end=" ")
        #     print()

    if cfg.save_interval != -1 and (
        (update_i + 1) % cfg.save_interval == 0 or is_last_update
    ):
        save_name = osp.join(logger.save_path, f"ckpt.{update_i}.pth")
        torch.save(
            {
                "policy": policy.state_dict(),
                "updater": updater.state_dict(),
                "update_i": update_i,
            },
            save_name,
        )
        print(f"Saved to {save_name}")
        eval_info["last_ckpt"] = save_name

logger.close()
print(eval_info)


Updates 0, Steps 12800, FPS 2800
Over the last 10 episodes:
    - episode.reward: -0.15801115251422143
    - episode.score: 0.0
    - episode.distance_to_goal: 0.4444028851440428
    - inferred_episode_reward: 0.7712286368012429
    - value_loss: 0.29877109453082085
    - action_loss: -0.010995903692673892
    - dist_entropy: 5.675752639770508
    - irl_loss: 0.22453956305980682

Updates 10, Steps 140800, FPS 4365
Over the last 10 episodes:
    - episode.reward: -0.2522793998092183
    - episode.score: 0.0
    - episode.distance_to_goal: 0.377349639981784
    - inferred_episode_reward: 11.18510410785675
    - value_loss: 14.09570255279541
    - action_loss: -0.0020032030995935203
    - dist_entropy: 5.676194190979004
    - irl_loss: 0.22218220978975295

Updates 20, Steps 268800, FPS 4380
Over the last 10 episodes:
    - episode.reward: -0.2558032473069802
    - episode.score: 0.0
    - episode.distance_to_goal: 0.4473079018129836
    - inferred_episode_reward: 6.824732875823974
    - 

In [None]:
ckpt = torch.load(save_name)
updater.load_state_dict(ckpt["updater"], should_load_opt=cfg.resume_training)
policy.load_state_dict(ckpt["policy"])


<All keys matched successfully>

In [None]:
import pygame

In [None]:
cfg.env.env_settings.params.config["render_mode"] = "human"

In [None]:
envs = vectorized_env([SnakeEnv(cfg.env.env_settings.params.config) for _ in range(cfg.num_envs)])

while True:

    obs = envs.reset()
    td = TensorDict({"observation": obs}, batch_size=cfg.num_envs)
    terminated = False 
    while not terminated : 
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()

        with torch.no_grad():
            policy.act(td)
        next_obs, reward, done, infos = envs.step(td["action"])
        envs.render(mode="human")
        td["next_observation"] = next_obs
        td["reward"] = reward.reshape(-1, 1)

        td["done"] = done

        td["observation"] = next_obs
        terminated = done[0]
        # Log to CLI/wandb.
        logger.collect_env_step_info(infos)
        pygame.time.wait(100)

error: display Surface quit