In [1]:
from gymnasium import spaces
import yaml
import torch

from agilerl.modules.configs import MlpNetConfig, CnnNetConfig, MultiInputNetConfig
from agilerl.networks.q_networks import QNetwork, RainbowQNetwork
from agilerl.networks.value_functions import ValueFunction
from agilerl.networks.actors import StochasticActor, DeterministicActor

from agilerl.algorithms.dqn import DQN
from agilerl.utils.utils import create_population

  from .autonotebook import tqdm as notebook_tqdm


### QNetwork

In [2]:
from tests.helper_functions import generate_dict_or_tuple_space

img_space = spaces.Box(low=0, high=255, shape=(4, 84, 84))
vec_space = spaces.Box(low=-1, high=1, shape=(4,), dtype='float32')
dict_space = spaces.Dict({'img': img_space, 'vec': vec_space})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

img_config = CnnNetConfig(
    channel_size=[16],
    kernel_size=[4],
    stride_size=[1],
)
vec_config = MlpNetConfig(
    hidden_size=[64],
)
multi_input_config = MultiInputNetConfig(
    channel_size=[8, 8, 8],
    kernel_size=[2, 2, 2],
    stride_size=[2, 2, 2],
    hidden_size=[32, 32, 32],
    vector_space_mlp=False
)

actor = RainbowQNetwork(
    observation_space=img_space,
    action_space=spaces.Discrete(4),
    support=torch.linspace(-10, 10, 51),
    encoder_config=img_config,
    latent_dim=64,
    device=device
)

In [3]:
def check_equal_params_ind(before_ind, mutated_ind):
    before_dict = dict(before_ind.named_parameters())
    after_dict = mutated_ind.named_parameters()
    _not_eq = []
    for key, param in after_dict:
        if key in before_dict:
            old_param = before_dict[key]
            old_size = old_param.data.size()
            new_size = param.data.size()
            if old_size == new_size:
                # If the sizes are the same, just copy the parameter
                param.data = old_param.data
            elif "norm" not in key:
                # Create a slicing index to handle tensors with varying sizes
                slice_index = tuple(slice(0, min(o, n)) for o, n in zip(old_size[:2], new_size[:2]))
                # assert (
                #     torch.all(torch.eq(param.data[slice_index], old_param.data[slice_index]))), \
                #     f"Parameter {key} not equal after mutation {mutated_ind.last_mutation_attr}:\n{param.data[slice_index]}\n{old_param.data[slice_index]}"
                if not torch.all(torch.eq(param.data[slice_index], old_param.data[slice_index])):
                    _not_eq.append(key)
    
    print(_not_eq)


In [4]:
from agilerl.modules.bert import EvolvableBERT

mod = EvolvableBERT([12], [12], device=device)
new_mod = mod.clone()

In [None]:
new_mod.add_node()

In [None]:
mod.encoder.state_dict()['bert_encoder_layer_0.linear1.weight']

In [None]:
new_mod.encoder.state_dict()['bert_encoder_layer_0.linear1.weight']

In [3]:
from agilerl.modules.cnn import EvolvableCNN
from agilerl.hpo.mutation import Mutations

In [13]:
from accelerate import Accelerator

with open('configs/training/ppo.yaml') as f:
    config = yaml.safe_load(f)

vector_actions = spaces.Box(low=-1, high=1, shape=(4,), dtype='float32')
discrete_actions = spaces.Discrete(4)

# accelerator = Accelerator()
INIT_HP = config["INIT_HP"]
INIT_HP['AGENT_IDS'] = [f'agent_{i}' for i in range(4)]
n_agents = 4
agent_pop = create_population(
    algo=INIT_HP["ALGO"],
    observation_space=img_space,
    action_space=discrete_actions,
    net_config={'encoder_config': img_config},
    INIT_HP=INIT_HP,
    population_size=INIT_HP["POP_SIZE"],
    num_envs=INIT_HP["NUM_ENVS"],
    device=device,
    # accelerator=accelerator
)

In [16]:
agent_pop[0].optimizer.networks

[StochasticActor(
   (encoder): EvolvableCNN(
     (model): Sequential(
       (encoder_conv_layer_1): Conv2d(4, 16, kernel_size=(4, 4), stride=(1, 1))
       (encoder_layer_norm_1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (encoder_activation_1): ReLU()
       (encoder_flatten): Flatten(start_dim=1, end_dim=-1)
       (encoder_linear_output): Linear(in_features=104976, out_features=32, bias=True)
       (encoder_output_activation): ReLU()
     )
   )
   (head_net): EvolvableDistribution(
     (_wrapped): EvolvableMLP(
       (model): Sequential(
         (actor_linear_layer_1): Linear(in_features=32, out_features=16, bias=True)
         (actor_layer_norm_1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
         (actor_activation_1): ReLU()
         (actor_linear_layer_output): Linear(in_features=16, out_features=4, bias=True)
         (actor_activation_output): Softmax(dim=-1)
       )
     )
   )
 ),
 ValueFunction(
   (encoder): 

In [6]:
mutations = Mutations(
    'PPO',
    0,
    1,
    0.5,
    0,
    0,
    0,
    ["batch_size", "lr", "learn_step"],
    0.5,
    device=device,)

new_population = [agent.clone(wrap=True) for agent in agent_pop]
mutated_population = mutations.mutation(new_population, True)
# print([ind.mut for ind in mutated_population])

head_net.add_node
RainbowQNetwork(
  (encoder): EvolvableCNN(
    (model): Sequential(
      (encoder_conv_layer_1): Conv2d(4, 16, kernel_size=(4, 4), stride=(1, 1))
      (encoder_layer_norm_1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (encoder_activation_1): ReLU()
      (encoder_flatten): Flatten(start_dim=1, end_dim=-1)
      (encoder_linear_output): Linear(in_features=104976, out_features=32, bias=True)
      (encoder_output_activation): ReLU()
    )
  )
  (head_net): EvolvableMLP(
    (model): Sequential(
      (value_linear_layer_1): NoisyLinear(in_features=32, out_features=80)
      (value_layer_norm_1): LayerNorm((80,), eps=1e-05, elementwise_affine=True)
      (value_activation_1): ReLU()
      (value_linear_layer_output): NoisyLinear(in_features=80, out_features=51)
      (value_activation_output): ReLU()
    )
  )
  (advantage_net): EvolvableMLP(
    (model): Sequential(
      (advantage_linear_layer_1): NoisyLinear(in_features=3

In [55]:
mutated_population[0].actor.state_dict()

OrderedDict([('encoder.model.encoder_conv_layer_1.weight',
              tensor([[[[ 1.2011e-01, -2.3543e-01, -8.6037e-03,  3.5908e-02],
                        [-1.8133e-02, -9.5431e-02,  1.5800e-01,  1.9248e-01],
                        [ 1.8469e-01,  2.1109e-01, -2.7355e-03, -8.7567e-02],
                        [-3.3392e-01, -2.5585e-01,  1.8480e-01, -1.3048e-01]],
              
                       [[-1.0107e-01, -7.1645e-02,  1.5379e-02, -1.8511e-03],
                        [-1.1292e-01,  5.4702e-02, -2.5859e-01, -5.7284e-02],
                        [-2.7937e-01,  5.6650e-02, -1.2157e-01, -3.0416e-01],
                        [ 1.5155e-01, -2.0830e-01,  3.1269e-01, -2.5233e-02]],
              
                       [[-4.8939e-02,  8.9096e-02,  5.2152e-02,  1.3511e-01],
                        [-1.1432e-05, -1.7608e-01, -2.5988e-01, -5.6473e-02],
                        [-3.7252e-01,  1.8827e-01,  1.0439e-01,  3.1282e-01],
                        [-1.8274e-01,  2.8279e-01, 

In [30]:
ind = mutated_population[0]

In [51]:
before_dict = agent_pop[0].actor.state_dict()
before_dict

OrderedDict([('encoder.model.encoder_conv_layer_1.weight',
              tensor([[[[ 1.2011e-01, -2.3543e-01, -8.6037e-03,  3.5908e-02],
                        [-1.8133e-02, -9.5431e-02,  1.5800e-01,  1.9248e-01],
                        [ 1.8469e-01,  2.1109e-01, -2.7355e-03, -8.7567e-02],
                        [-3.3392e-01, -2.5585e-01,  1.8480e-01, -1.3048e-01]],
              
                       [[-1.0107e-01, -7.1645e-02,  1.5379e-02, -1.8511e-03],
                        [-1.1292e-01,  5.4702e-02, -2.5859e-01, -5.7284e-02],
                        [-2.7937e-01,  5.6650e-02, -1.2157e-01, -3.0416e-01],
                        [ 1.5155e-01, -2.0830e-01,  3.1269e-01, -2.5233e-02]],
              
                       [[-4.8939e-02,  8.9096e-02,  5.2152e-02,  1.3511e-01],
                        [-1.1432e-05, -1.7608e-01, -2.5988e-01, -5.6473e-02],
                        [-3.7252e-01,  1.8827e-01,  1.0439e-01,  3.1282e-01],
                        [-1.8274e-01,  2.8279e-01, 

In [56]:
ind = mutated_population[0]
after_dict = ind.actor.state_dict()
after_dict

OrderedDict([('encoder.model.encoder_conv_layer_1.weight',
              tensor([[[[ 1.2011e-01, -2.3543e-01, -8.6037e-03,  3.5908e-02],
                        [-1.8133e-02, -9.5431e-02,  1.5800e-01,  1.9248e-01],
                        [ 1.8469e-01,  2.1109e-01, -2.7355e-03, -8.7567e-02],
                        [-3.3392e-01, -2.5585e-01,  1.8480e-01, -1.3048e-01]],
              
                       [[-1.0107e-01, -7.1645e-02,  1.5379e-02, -1.8511e-03],
                        [-1.1292e-01,  5.4702e-02, -2.5859e-01, -5.7284e-02],
                        [-2.7937e-01,  5.6650e-02, -1.2157e-01, -3.0416e-01],
                        [ 1.5155e-01, -2.0830e-01,  3.1269e-01, -2.5233e-02]],
              
                       [[-4.8939e-02,  8.9096e-02,  5.2152e-02,  1.3511e-01],
                        [-1.1432e-05, -1.7608e-01, -2.5988e-01, -5.6473e-02],
                        [-3.7252e-01,  1.8827e-01,  1.0439e-01,  3.1282e-01],
                        [-1.8274e-01,  2.8279e-01, 

In [61]:
def assert_equal_state_dict(before_pop, mutated_pop):
    for before_ind, mutated in zip(before_pop, mutated_pop):
        before_modules = before_ind.evolvable_attributes(networks_only=True).values()
        mutated_modules = mutated.evolvable_attributes(networks_only=True).values()
        for before_mod, mutated_mod in zip(before_modules, mutated_modules):
            before_dict = before_mod.state_dict()
            after_dict = mutated_mod.state_dict()
            for key, param in after_dict.items():
                if key in before_dict:
                    old_param = before_dict[key]
                    old_size = old_param.data.size()
                    new_size = param.data.size()
                    if old_size == new_size:
                        # If the sizes are the same, just copy the parameter
                        param.data = old_param.data
                    elif "norm" not in key:
                        # Create a slicing index to handle tensors with varying sizes
                        slice_index = tuple(slice(0, min(o, n)) for o, n in zip(old_size[:2], new_size[:2]))
                        assert torch.all(torch.eq(param.data[slice_index], old_param.data[slice_index]))


In [18]:
print(ind.critic.last_mutation_attr)
ind.critic

encoder.add_channel


ValueFunction(
  (encoder): EvolvableCNN(
    (model): Sequential(
      (encoder_conv_layer_1): Conv2d(4, 32, kernel_size=(4, 4), stride=(1, 1))
      (encoder_layer_norm_1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (encoder_activation_1): ReLU()
      (encoder_flatten): Flatten(start_dim=1, end_dim=-1)
      (encoder_linear_output): Linear(in_features=209952, out_features=32, bias=True)
      (encoder_output_activation): ReLU()
    )
  )
  (head_net): EvolvableMLP(
    (model): Sequential(
      (value_linear_layer_1): Linear(in_features=32, out_features=16, bias=True)
      (value_layer_norm_1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (value_activation_1): ReLU()
      (value_linear_layer_output): Linear(in_features=16, out_features=1, bias=True)
      (value_activation_output): Identity()
    )
  )
)