In [None]:
# import sys
# import gymnasium as gym
# sys.modules["gym"] = gym
import gym
gym.__version__

In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback

In [None]:
import numpy as np
import torch as th
from torch import nn
import matplotlib.pyplot as plt

In [None]:
import gnwrapper

In [None]:
env_id = "CarRacing-v0"
NUM_CPU = 4  # Number0of processes to use

In [None]:
from utils import CarRacingGroundTruthObsWrapper
def wrapper(env):
    env = CarRacingGroundTruthObsWrapper(env) 
    env = gnwrapper.Animation(env)
    return env

In [None]:
expert = PPO.load("./policy/ppo_CarRacing_expert-1kk.zip", print_system_info=True)

In [None]:
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
import dataclasses
from stable_baselines3.common.vec_env import VecTransposeImage

class TransformaObservacions(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        
        image_width = self.observation_space.shape[0]
        image_height = self.observation_space.shape[1]
        n_chanel = self.observation_space.shape[2]
        
        self.new_obs_shape = (n_chanel, image_width, image_height)
        
        self.observation_space = gym.spaces.Box(shape=self.new_obs_shape, low=0, high=255)

    def observation(self, observ):
        return np.reshape(observ, self.new_obs_shape, order='F')

def make_env(env_id):
    def _init():
        env = gym.make(env_id)
        env = gym.wrappers.gray_scale_observation.GrayScaleObservation(env, keep_dim=True)
        env = TransformaObservacions(env)
        env = RolloutInfoWrapper(env) # Wrapper to save origin obs
        env = wrapper(env) # Wrapper Obs
        return env
    
    return _init

# env = gym.wrappers.gray_scale_observation.GrayScaleObservation(env, keep_dim=True)
# env = RolloutInfoWrapper(env) # Wrapper to save origin obs
# env = wrapper(env) # Wrapper Obs
# env = DummyVecEnv([lambda: env]) # Vectorized env

env = DummyVecEnv([make_env(env_id)]*NUM_CPU)

In [None]:
NUM_EPISODES = 100
rng = np.random.default_rng()
rollouts = rollout.rollout(
    expert,
    env,
    rollout.make_sample_until(min_timesteps=None, min_episodes=NUM_EPISODES),
    rng = rng,
    unwrap = True,
)

transitions = rollout.flatten_trajectories(rollouts)

In [None]:
print(
    f"""The `rollout` function generated a list of {len(rollouts)} {type(rollouts[0])}.
After flattening, this list is turned into a {type(transitions)} object containing {len(transitions)} transitions.
The transitions object contains arrays for: {', '.join(transitions.__dict__.keys())}."
"""
)

In [None]:
from imitation.algorithms import bc
from stable_baselines3.common.policies import ActorCriticCnnPolicy

env = gym.make(env_id)
env = gym.wrappers.gray_scale_observation.GrayScaleObservation(env, keep_dim=True)
env = gnwrapper.Animation(env)
env = DummyVecEnv([lambda: env])
env = VecTransposeImage(env)

bc_trainer = bc.BC(
    observation_space=env.observation_space,
    action_space=env.action_space,
    demonstrations=transitions,
    rng=rng,
    batch_size=256,
    policy=ActorCriticCnnPolicy(
                         observation_space=env.observation_space,
                         action_space=env.action_space,
                         lr_schedule=lambda _: th.finfo(th.float32).max,
                         # net_arch=dict(pi=[256], vf=[256]),
                         net_arch=[256, 256],
                         activation_fn = nn.LeakyReLU,
                         ortho_init=False,
                         )
)

In [None]:
bc_trainer.train(n_epochs=10)

In [None]:
obs = env.reset()
while True:
    action, _states = bc_trainer.policy.predict(obs.copy())
    obs, rewards, dones, info = env.step(action)
    env.render()
    if dones:
        break
    
env.close()

In [None]:
from stable_baselines3.common.evaluation import evaluate_policy

In [None]:
reward, _ = evaluate_policy(bc_trainer.policy, env, 5)
print(f"BC reward: {reward}")