In [1]:
from gymnasium import spaces
import yaml
import torch
import steelix

from agilerl.modules.configs import MlpNetConfig, CnnNetConfig, MultiInputNetConfig
from agilerl.networks.q_networks import QNetwork, RainbowQNetwork
from agilerl.networks.value_functions import ValueFunction
from agilerl.networks.actors import StochasticActor, DeterministicActor

from agilerl.algorithms.dqn import DQN
from agilerl.utils.utils import create_population, make_vect_envs

  from .autonotebook import tqdm as notebook_tqdm


### QNetwork

In [None]:
from tests.helper_functions import generate_dict_or_tuple_space
from agilerl.utils.evolvable_networks import is_image_space

img_space = spaces.Box(low=0, high=255, shape=(4, 84, 84))
vec_space = spaces.Box(low=-1, high=1, shape=(4,), dtype='float32')
dict_space = spaces.Dict({'img': img_space, 'vec': vec_space})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

img_config = CnnNetConfig(
    channel_size=[16],
    kernel_size=[4],
    stride_size=[1],
)
vec_config = MlpNetConfig(
    hidden_size=[64],
)
multi_input_config = MultiInputNetConfig(
    channel_size=[8],
    kernel_size=[2],
    stride_size=[1],
    hidden_size=[32],
    vector_space_mlp=False
)

actor = StochasticActor(
    observation_space=env.single_observation_space,
    action_space=spaces.Discrete(18),
    encoder_config=multi_input_config,
    # support=torch.linspace(-10, 10, 51),
    latent_dim=64,
    device=device
)

In [None]:
actor

In [19]:
actor.filter_mutation_methods('kernel')

In [None]:
actor.mutation_methods

In [3]:
pop = DQN.population(4, observation_space=img_space, action_space=spaces.Discrete(4))

In [3]:
from agilerl.modules.cnn import EvolvableCNN
from agilerl.hpo.mutation import Mutations

In [4]:
from accelerate import Accelerator

with open('configs/training/multi_input.yaml') as f:
    config = yaml.safe_load(f)

vector_actions = spaces.Box(low=-1, high=1, shape=(4,), dtype='float32')
discrete_actions = spaces.Discrete(4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# accelerator = Accelerator()
INIT_HP = config["INIT_HP"]
MUTATION_PARAMS = config["MUTATION_PARAMS"]
INIT_HP['AGENT_IDS'] = [f'agent_{i}' for i in range(4)]
n_agents = 4
agent_pop = create_population(
    algo=INIT_HP["ALGO"],
    observation_space=env.single_observation_space,
    action_space=env.single_action_space, 
    net_config=config["NET_CONFIG"],
    INIT_HP=INIT_HP,
    population_size=INIT_HP["POP_SIZE"],
    num_envs=INIT_HP["NUM_ENVS"],
    device=device,
    # accelerator=accelerator
)

In [None]:
ind = agent_pop[0]
ind.inspect_attributes(input_args_only=True)

In [None]:
ind = agent_pop[0]
ind.optimizer

In [6]:
mutations = Mutations(
    0,
    1,
    0.5,
    0,
    0,
    0,
    0.5,
    agent_ids=INIT_HP['AGENT_IDS'],
    device=device
)
new_population = [agent.clone(wrap=False) for agent in agent_pop]
mutated_population = [
    mutations.architecture_mutate(agent) for agent in new_population
]

Applied mutation: encoder.feature_net.BatchFurnaceD.remove_channel
Applied mutation: None
Applied mutation: None
Applied mutation: None


In [7]:
mutated_population[0].actor.mutation_methods

['head_net.add_layer',
 'head_net.remove_layer',
 'remove_latent_node',
 'add_latent_node',
 'encoder.remove_latent_node',
 'encoder.add_latent_node',
 'encoder.feature_net.BatchFurnaceA.add_channel',
 'encoder.feature_net.BatchFurnaceA.change_kernel',
 'encoder.feature_net.BatchFurnaceA.remove_channel',
 'encoder.feature_net.BatchFurnaceB.add_channel',
 'encoder.feature_net.BatchFurnaceB.change_kernel',
 'encoder.feature_net.BatchFurnaceB.remove_channel',
 'encoder.feature_net.BatchFurnaceC.add_channel',
 'encoder.feature_net.BatchFurnaceC.change_kernel',
 'encoder.feature_net.BatchFurnaceC.remove_channel',
 'encoder.feature_net.BatchFurnaceD.add_channel',
 'encoder.feature_net.BatchFurnaceD.change_kernel',
 'encoder.feature_net.BatchFurnaceD.remove_channel',
 'head_net.remove_node',
 'head_net.add_node']

In [9]:
ind = agent_pop[0]
print(ind.lr)
print(ind.learn_step)
print(ind.batch_size)

0.001
2048
128
