First, do basic set up of environment, trajectory generator, expert etc.

In [1]:
# Are we doing real training or not?
real_training = False

In [6]:
import torch as th
import gym
from gym.wrappers import TimeLimit
import numpy as np

from seals.util import AutoResetWrapper

from stable_baselines3 import PPO
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.ppo import CnnPolicy

from imitation.algorithms import preference_comparisons
from imitation.policies.base import NormalizeFeaturesExtractor
from imitation.rewards.reward_nets import CnnRewardNet

device = th.device("cuda" if th.cuda.is_available() else "cpu")

rng = np.random.default_rng()

# Here we ensure that our environment has constant-length episodes by resetting
# it when done, and running until all timesteps have elapsed.
# For real training, you will want a much longer time limit than 100
def constant_length_asteroids(num_steps):
    atari_env = gym.make('AsteroidsNoFrameskip-v4')
    preprocessed_env = AtariWrapper(atari_env)
    endless_env = AutoResetWrapper(preprocessed_env)
    return TimeLimit(endless_env, max_episode_steps=num_steps)

if real_training:
    venv = make_vec_env(constant_length_asteroids, env_kwargs={"num_steps": 100}, n_envs=8)
else:
    # For real training, you will want a vectorized environment with 8 environments in parallel.
    venv = make_vec_env(constant_length_asteroids, env_kwargs={"num_steps": 100})
venv = VecFrameStack(venv, n_stack=4)

reward_net = CnnRewardNet(
    venv.observation_space,
    venv.action_space,
).to(device)

# Note that for trajectory encoding we use TotalFragmenter so as to get all possible fragments.
fragmenter = preference_comparisons.TotalFragmenter()
gatherer = preference_comparisons.SyntheticGatherer(rng=rng)
preference_model = preference_comparisons.PreferenceModel(reward_net)
reward_trainer = preference_comparisons.BasicRewardTrainer(
    preference_model=preference_model,
    loss=preference_comparisons.CrossEntropyRewardLoss(),
    epochs=3,
    rng=rng
)

agent = PPO(
    policy=CnnPolicy,
    env=venv,
    seed=0,
    n_steps=16,  # To train on atari well, set this to 128
    batch_size=16,  # To train on atari well, set this to 256
    ent_coef=0.01,
    learning_rate=0.00025,
    n_epochs=4,
)

trajectory_generator = preference_comparisons.AgentTrainer(
    algorithm=agent,
    reward_fn=reward_net,
    venv=venv,
    exploration_frac=0.0,
    rng=rng
)

# pref_comparisons = preference_comparisons.PreferenceComparisons(
#     trajectory_generator,
#     reward_net,
#     num_iterations=2,
#     fragmenter=fragmenter,
#     preference_gatherer=gatherer,
#     reward_trainer=reward_trainer,
#     fragment_length=10,
#     transition_oversampling=1,
#     initial_comparison_frac=0.1,
#     allow_variable_horizon=False,
#     initial_epoch_multiplier=1,
# )

AttributeError: 'NoneType' object has no attribute 'cvtColor'

Generate a set number of trajectories, then from these we then generate all fragments and use preference_comparison to create a total order.

In [None]:
# Generation of trajectories and fragmentation done inside the preference comparison function.
# and creation of total order
num_steps = 100
if real_training:
    num_steps = 1000
trajectories = trajectory_generator.sample(num_steps)
fragments = fragmenter(trajectories)

# Can't just use 'sorted' since that requires key-generation.
# Could easily do bubble or insertion, but O(n^2). Merge better, but requires a little more thought
# on how to divide (would require enough comparisons that in worst case it is also O(n^2)).

# For the moment will just use bubble since it is easy to implement. Then improve complexity later.
def bubble_sorted(array):
    for i in range(len(array)):
        already_sorted = True
        for j in range(len(array)-i-1):
            # Feed in the two fragments and compare
            if gatherer([array[j], array[j+1]])[0] == 1:
                array[j], array[j + 1] = array[j + 1], array[j]
                already_sorted = False
        if already_sorted:
            break
    return array

fragments_total_order = bubble_sorted(fragments)