In [1]:
import torch
import numpy as np
import pickle
import os
import tqdm

from imitation.algorithms.bc import BC
from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.rewards import reward_wrapper
from imitation.util import networks
from imitation.scripts.train_adversarial import save as save_trainer
from imitation.data.types import Trajectory

from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.ppo import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import DummyVecEnv


from Environment import MultidatasetEnvironment

cyber_security_datasets = [f'Datasets/Cyber_security/{i}.tsv' for i in range(1,5)]
flight_delay_datasets = [f'Datasets/Flight_delay/{i}.tsv' for i in range(1,5)]


def train_BC(expert_trajectories_path, datasets):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    with open(expert_trajectories_path, 'rb') as f:
        expert_trajectories = pickle.load(f)

    print(f"Loaded {len(expert_trajectories)} expert trajectories")

    env = MultidatasetEnvironment(datasets)
    
    policy = ActorCriticPolicy(
        env.observation_space, 
        env.action_space, 
        lr_schedule=lambda x: 1e-4, 
        net_arch=dict(pi=[100, 100, 100], vf=[100, 100, 100])
    )
    policy = policy.to(device)

    rng = np.random.default_rng()

    trainer = BC(
        observation_space=env.observation_space, 
        action_space=env.action_space,
        policy=policy,
        demonstrations=expert_trajectories,
        rng=rng,
        device=device
    )

    try:
        trainer.train(n_epochs=100)
    except RuntimeError as e:
        print(f"Training error: {str(e)}")
        # Additional debugging information
        print("\nDebugging information:")
        print(f"Policy device: {next(trainer.policy.parameters()).device}")
        for name, param in trainer.policy.named_parameters():
            print(f"{name}: {param.device}")
            
    # Save the trained policy
    save_path = 'BC_Initialization/cyber_security'
    os.makedirs(save_path, exist_ok=True)
    torch.save(trainer.policy.state_dict(), f'{save_path}/policy.pth')
    print('Initialization parameters saved to', f'{save_path}/policy.pth')


def train_GAIL():
    total_steps = 2000000
    n_gen_updates_per_round = 1
    
    venv = DummyVecEnv([lambda:MultidatasetEnvironment(dataset_paths=cyber_security_datasets, max_steps= 12) for _ in range(8)])

    # Generator
    policy_kwargs = dict(net_arch = dict(pi = [100, 100, 100], vf = [100, 100, 100]))
    learner = PPO(ActorCriticPolicy, venv, n_steps=48 , batch_size=32, policy_kwargs=policy_kwargs)

    # Descrminator
    env = MultidatasetEnvironment(cyber_security_datasets, 12)
    reward_net = BasicRewardNet(env.observation_space, env.action_space)

    # Load intialized parametrs to generator
    bc_policy = torch.load(f'BC_Initialization/cyber_security/policy.pth')
    learner.policy.load_state_dict(bc_policy)

    # Get expert_trajectories
    expert_trajectories_path = 'Expert_trajectories/train/1.pkl'
    with open(expert_trajectories_path, 'rb') as f:
        expert_trajectories = pickle.load(f)

    trainer = GAIL(
        demonstrations = expert_trajectories,
        demo_batch_size= 96*2,
        gen_replay_buffer_capacity= 768,
        n_disc_updates_per_round= 2,
        venv= venv,
        gen_algo=learner,
        reward_net=reward_net,
        allow_variable_horizon=True,
    )

    print(trainer.gen_train_timesteps)

    def get_custom_reward_function(trainer):
        venv = trainer.venv_wrapped

        def repeat_penalty(old_obs, acts, obs, dones):
            return np.array([env.repeat_penalty() for env in venv.envs])

        def custom_reward_function(old_obs, acts, obs, dones):
            r1 = trainer.reward_train.predict_processed(old_obs, acts, obs, dones)
            r2 = np.array(repeat_penalty(old_obs, acts, obs, dones))
            return r1 + r2

        return custom_reward_function

    # r = get_custom_reward_function(trainer)
    # trainer.venv_wrapped = reward_wrapper.RewardVecEnvWrapper(
    #     trainer.venv_buffering,
    #     reward_fn=r
    # )
    # trainer.venv_train = trainer.venv_wrapped
    # trainer.gen_algo.set_env(trainer.venv_train)

    n_rounds = total_steps // (n_gen_updates_per_round * trainer.gen_train_timesteps)
    assert n_rounds >= 1, (
        "No updates (need at least "
        f"{trainer.gen_train_timesteps} timesteps, have only "
        f"total_timesteps={total_steps})!"
    )

    # Training loop
    for i in tqdm.tqdm(range(0, n_rounds), desc = "round"):
        for _ in range(n_gen_updates_per_round):
            trainer.train_gen(trainer.gen_train_timesteps)
        for _ in range(trainer.n_disc_updates_per_round):
            with networks.training(trainer.reward_train):
                trainer.train_disc()

    save_trainer(trainer, f'Final_Parameters/cyber_security.GAILAgent')    


if __name__ == '__main__':
    train_BC('Expert_trajectories/train/1.pkl', cyber_security_datasets)
    # train_GAIL()



Using device: cuda
Loaded 72 expert trajectories


0batch [00:00, ?batch/s]

Training error: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)

Debugging information:
Policy device: cuda:0
mlp_extractor.policy_net.0.weight: cuda:0
mlp_extractor.policy_net.0.bias: cuda:0
mlp_extractor.policy_net.2.weight: cuda:0
mlp_extractor.policy_net.2.bias: cuda:0
mlp_extractor.policy_net.4.weight: cuda:0
mlp_extractor.policy_net.4.bias: cuda:0
mlp_extractor.value_net.0.weight: cuda:0
mlp_extractor.value_net.0.bias: cuda:0
mlp_extractor.value_net.2.weight: cuda:0
mlp_extractor.value_net.2.bias: cuda:0
mlp_extractor.value_net.4.weight: cuda:0
mlp_extractor.value_net.4.bias: cuda:0
action_net.weight: cuda:0
action_net.bias: cuda:0
value_net.weight: cuda:0
value_net.bias: cuda:0
Initialization parameters saved to BC_Initialization/cyber_security/policy.pth



