# License
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at:

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Instructions

This Notebook allows to reproduce the experiments reported in the publication titled:

"[*A Multi-Agent Framework for the Asynchronous and Collaborative Extension of Multitask ML Systems*](https://arxiv.org/abs/2209.14745)" (2022)

---
To start an experiment:
---

1. Set `EXPERIMENT_NAME` in the configurations below to a name of choice. This will be used also as name of the folder storing the state of the µNet system generated by the experiment (system state folder).

1. By default the system state is written in a temporary folder stored in the memory of the current Virtual Machine (VM). This temporary folder is deleted every time the VM is stopped or restarted.
It is possible to save the system state folder on your Google Drive by activating the
`SAVE_SYSTEM_STATE_ON_GOOGLE_DRIVE` option. In this case you will be prompted for access approval and the system state folder will be saved in a folder named `"munet_experiments"` under your Google Drive root folder. It is also possible to store the state into a Google Drive folder shared with multiple users by creating a link to the shared folder into your Google Drive and then setting `EXPERIMENTS_ROOT_DIR` (below) to the path of the linked shared folder. 

1. `AGENT` can be set to: a) `VitT3` to use a ViT-Tiny root model capped to 3 layers and apply the experiment configuration used for the experiments on the *Multitask Character Classifiaction Benchmark*.
b) `VitB` to use a ViT Base root model and apply the experiment configuration used for the experiments on the *Visual Domain Decathlong Benchmark*.
c) `Vit` to use a ViT-Large root model and apply the large-scale continual-learning experiment configuration.

1. Set `TASK_NAME` to the task id string.
For example a task id from the [Tensorflow Datasets image classification catalog](https://www.tensorflow.org/datasets/catalog/beans).
Refer to `TFDS_IMAGE_CLASSIFCATON_DATASETS` and `VTAB_TASKS` (below) for lists of tasks ids that have been tested with the current code. Note that some tasks require manual download, refer to the corresponding catalogue page for instructions. **WARNING**: The system state needs to be populated with at least one rooot model before running an agent training on any task. To generate the root model set `TASK_NAME` to either `"root_model/checkpoint"` or `"root_model/random_init"` for respectively loading a pretrained root model or generating a randomly initialized one.

1. To start the experiment, select "Connect to a hosted runtime" from the dropdown menu on the top right, and then select "Run all" from the "Runtime" menu.

---
During the experiment execution:
---

1. The logging output is printed after the last cell of this Colab.

1. The system state folder is populated with a subfolder for each agent.
The name of each agent folder is prefixed with the agent name and sufixed with the task name.
Each agent directory is populated with incremental state subfolders  containing the sharded state of the architectures and parameters generated by the agent.

1. Agents can be started asyncronously and run in parallel in varying quantities.
It is possible to resume an interrupted agent training by restarting the execution with the same configuration.
It is possible to further continue a complented training by increasing the `config.num_cycles_max` set in the configurations below.

1. To achieve a multi-agent execution, multiple Colabs need to be run in parallel, each set to the same configuration but different `TASK_NAME`.

1. To achieve heterogeneous hardware execution, parallel Colab Notebooks can be connected to a runtime of different types.
It is possible to switch between CPU, GPU and TPU by selecting `Change runtime type` in the `Resources` tab in this Colab Notebook.

In [None]:
# @title Agent parameters
EXPERIMENT_NAME = "munet_test"  # @param { type: "string", isTemplate: true }
SAVE_SYSTEM_STATE_ON_GOOGLE_DRIVE = False  # @param { type: "boolean", isTemplate: true }
AGENT = "VitT3" # @param ["VitT3", "VitB", "Vit"] { type: "string", isTemplate: true }
# Set TASK_NAME to "root_model/checkpoint" or "root_model/random_init" to initalize the population.
TASK_NAME = "root_model/checkpoint"  # @param { type: "string", isTemplate: true }
assert TASK_NAME

# Set to False to disable autotune.
AUTO_TUNE = True
# Allows to override the default scale factor = 1 (no cost penalties)
SCALE_FACTOR = None
if SCALE_FACTOR is not None:
  assert 0 < SCALE_FACTOR <= 1
# Print debug statements.
VERBOSE = False
# Skip intermediate state save if last state was written within this time range.
SKIP_INTERMEDIATE_STATE_SECS = 10 * 60  # 10 minutes.

In [None]:
if SAVE_SYSTEM_STATE_ON_GOOGLE_DRIVE:
  from google.colab import drive
  drive.mount('/content/gdrive')
  EXPERIMENTS_ROOT_DIR = "/content/gdrive/My Drive/munet_experiments/"
  print("Saving system state in Google Drive.")
else:
  EXPERIMENTS_ROOT_DIR = "/tmp/"
  print("WARNING: Saving system state in VM, changes will be lost after reboot!!")

In [None]:
!pip install --upgrade -q pip jax jaxlib
!pip install --upgrade -q git+https://github.com/google/flax.git

In [None]:
!pip install -q ml_collections

In [None]:
!pip install -q tensorflow_addons

In [None]:
![ -d task_adaptation ] || git clone --depth=1 https://github.com/google-research/task_adaptation
![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer
import sys
if './task_adaptation' not in sys.path:
  sys.path.append('./task_adaptation')
if './vision_transformer' not in sys.path:
  sys.path.append('./vision_transformer')

In [None]:
import jax.tools.colab_tpu
try:
  jax.tools.colab_tpu.setup_tpu()
except:
  pass  # Not a Tpu

In [None]:
import copy
import datetime
import gc
import inspect
import jax
import jax.numpy as jnp
import json
import math
import matplotlib
import numpy as np
import os
import optax
import pandas as pd
import random
import re
import time
from collections import defaultdict
from functools import partial
from matplotlib import pyplot as plt
from threading import Thread, Lock
from typing import Optional

In [None]:
import flax
import flax.linen as nn
from flax.training import checkpoints as flax_checkpoints

In [None]:
import tensorflow as tf
import tensorflow.io.gfile as gfile
import tensorflow_datasets as tfds
tf.compat.v1.enable_eager_execution()

In [None]:
from ml_collections import ConfigDict, FrozenConfigDict
from vision_transformer.vit_jax import input_pipeline
from vision_transformer.vit_jax import checkpoint
from vision_transformer.vit_jax.configs import models as models_config  # Model configurations.
from vision_transformer.vit_jax import models_vit as models # Actual model code.

In [None]:
import task_adaptation.registry as task_adapt_registry
import task_adaptation.data.caltech
import task_adaptation.data.cifar
import task_adaptation.data.dtd
import task_adaptation.data.oxford_flowers102
import task_adaptation.data.oxford_iiit_pet
import task_adaptation.data.sun397
import task_adaptation.data.svhn
import task_adaptation.data.patch_camelyon
import task_adaptation.data.eurosat
import task_adaptation.data.resisc45
import task_adaptation.data.diabetic_retinopathy
import task_adaptation.data.clevr
import task_adaptation.data.dmlab
import task_adaptation.data.dsprites
import task_adaptation.data.kitti
import task_adaptation.data.smallnorb

In [None]:
# Ref Tfds catalog: https://www.tensorflow.org/datasets/catalog/beans
TFDS_IMAGE_CLASSIFCATON_DATASETS = set([
    "beans",
    "binary_alpha_digits",
    "caltech_birds2010",
    "caltech_birds2011",
    "cars196",
    "cassava",
    "cats_vs_dogs",
    "cifar10",
    "cifar100",
    "citrus_leaves",
    "cmaterdb/bangla",
    "cmaterdb/devanagari",
    "cmaterdb/telugu",
    "colorectal_histology",
    "controlled_noisy_web_labels/mini_imagenet_red",
    "controlled_noisy_web_labels/mini_imagenet_blue",
    "curated_breast_imaging_ddsm/patches",
    "cycle_gan/apple2orange",
    "cycle_gan/summer2winter_yosemite",
    "cycle_gan/horse2zebra",
    "cycle_gan/monet2photo",
    "cycle_gan/cezanne2photo",
    "cycle_gan/ukiyoe2photo",
    "cycle_gan/vangogh2photo",
    "cycle_gan/maps",
    "cycle_gan/cityscapes",
    "cycle_gan/facades",
    "cycle_gan/iphone2dslr_flower",
    "deep_weeds",
    "domainnet/real",
    "domainnet/painting",
    "domainnet/clipart",
    "domainnet/quickdraw",
    "domainnet/infograph",
    "domainnet/sketch",
    "emnist/balanced",
    "emnist/byclass",
    "emnist/bymerge",
    "emnist/digits",
    "emnist/letters",
    "emnist/mnist",
    "fashion_mnist",
    "food101",
    "horses_or_humans",
    "i_naturalist2017",
    "i_naturalist2018",
    "imagenet2012",
    "imagenet_a",
    "imagenet_lt",
    "imagenet_r",
    "imagenet_sketch",
    "imagenette",
    "imagewang",
    "kmnist",
    "malaria",
    "mnist",
    "omniglot",
    "pet_finder",
    "places365_small",
    "plant_village",
    "plantae_k",
    "quickdraw_bitmap",
    "rock_paper_scissors",
    "siscore/rotation",
    "siscore/size",
    "siscore/location",
    "stanford_dogs",
    "stanford_online_products",
    "stl10",
    "tf_flowers",
    "uc_merced",
    "visual_domain_decathlon/aircraft",
    "visual_domain_decathlon/cifar100",
    "visual_domain_decathlon/daimlerpedcls",
    "visual_domain_decathlon/dtd",
    "visual_domain_decathlon/gtsrb",
    "visual_domain_decathlon/imagenet12",
    "visual_domain_decathlon/omniglot",
    "visual_domain_decathlon/svhn",
    "visual_domain_decathlon/ucf101",
    "visual_domain_decathlon/vgg-flowers",
    ])

In [None]:
# Append suffix "/1k" to get the 1k version of each task.
VTAB_TASKS = [
              "caltech101",
              # cifar100/10 were already added with slightly different val split
              # but same test set. So here is added only the 1k versions.
              "cifar100/1k",
              "cifar10/1k",
              "dtd",
              "oxford_flowers102",
              "oxford_iiit_pet",
              "sun397",
              "svhn_cropped",
                ###
              "patch_camelyon",
              "eurosat",
              "resisc45",
              "diabetic_retinopathy_detection/btgraham-300",
                ###
              "clevr/count_cylinders",  # Not in results table.
              "clevr/count_all",  # Clevr-Count
              "clevr/closest_object_distance",  # Clevr-Dist
              "dmlab",
              "dsprites/label_x_position",  # dSpr-Loc
              "dsprites/label_orientation",  # dSpr-Ori
              "kitti/closest_object_distance",  # Not in results table.
              "kitti/count_vehicles",  # Not in results table.
              "kitti/closest_vehicle_distance",  # Kitti-dist
              "smallnorb/label_category",  # Not in results table.
              "smallnorb/label_lighting",  # Not in results table.
              "smallnorb/label_azimuth",  # Azim
              "smallnorb/label_elevation",  # Elev
              ]

for tn in VTAB_TASKS:
  assert tn not in TFDS_IMAGE_CLASSIFCATON_DATASETS, tn

In [None]:
TFDS_BUILDERS_CACHE = {}

def get_tfds_builder(tfds_name):
  global TFDS_BUILDERS_CACHE
  if tfds_name not in TFDS_BUILDERS_CACHE:
    TFDS_BUILDERS_CACHE[tfds_name] = tfds.builder(tfds_name)
    TFDS_BUILDERS_CACHE[tfds_name].download_and_prepare()
  return TFDS_BUILDERS_CACHE[tfds_name]

# ViT

In [None]:
def ids_str2ints(ids_str):
  return [int(v) for v in str(ids_str).split("_")] if ids_str else []
def ids_ints2str(ids_ints):
  return "_".join([str(v) for v in sorted(ids_ints)])

In [None]:
AddPositionEmbs = models.AddPositionEmbs
Encoder1DBlock = models.Encoder1DBlock
VisionTransformer = models.VisionTransformer

class ResidualAdapter(nn.Module):
  adapter_dim: int

  @nn.compact
  def __call__(self, x):
    hidden_dim = x.shape[-1]
    y = nn.LayerNorm()(x)
    y = nn.Dense(self.adapter_dim)(y)
    y = nn.gelu(y)
    # Zero Initialization so that added adapter does not change the representation.
    y = nn.Dense(hidden_dim, kernel_init=jax.nn.initializers.zeros)(y)
    return x + y  # Residual.

# Modified from vision_transformer/vit_jax/models Encoder to add residual adapters.
class Encoder(nn.Module):
  num_layers: int
  mlp_dim: int
  num_heads: int
  adapter_layers: str  # <MOD
  adapter_dim: int  # MOD>
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1

  @nn.compact
  def __call__(self, inputs, *, train):
    assert inputs.ndim == 3  # (batch, len, emb)

    x = AddPositionEmbs(
        posemb_init=nn.initializers.normal(stddev=0.02),  # from BERT.
        name="posembed_input")(
            inputs)

    # Input Encoder
    adapter_layers_ids = ids_str2ints(self.adapter_layers)  # <MOD>
    for lyr in range(self.num_layers):
      if lyr in adapter_layers_ids:  # <MOD
        x = ResidualAdapter(
            adapter_dim=self.adapter_dim,
            name=f"residual_adapter_{lyr}"
            )(x)  # MOD>
      x = Encoder1DBlock(
          mlp_dim=self.mlp_dim,
          dropout_rate=self.dropout_rate,
          attention_dropout_rate=self.attention_dropout_rate,
          name=f"encoderblock_{lyr}",
          num_heads=self.num_heads)(x,
                                    deterministic=True)  # Disable dropout.
    encoded = nn.LayerNorm(name="encoder_norm")(x)
    return encoded

In [None]:
def get_vit_filename(query):
  df = checkpoint.get_augreg_df()
  res = df.query(query).filename.unique()
  assert len(res) == 1
  return res[0]

In [None]:
VIT_CONFIG_CACHE = {}

def get_vit_config(query):
  if query not in VIT_CONFIG_CACHE:
    filename = get_vit_filename(query)
    config = models_config.AUGREG_CONFIGS[filename.split("-")[0]].copy_and_resolve_references()
    # Overwrite with custom Encoder.
    config.unlock()
    config.encoder = Encoder
    config.transformer.adapter_layers = ""
    config.transformer.adapter_dim = -1
    # Disable dropout.
    config.transformer.dropout_rate = 0.0
    config.transformer.attention_dropout_rate = 0.0
    config.lock()
    VIT_CONFIG_CACHE[query] = config
  return VIT_CONFIG_CACHE[query].copy_and_resolve_references()

def get_max_num_layers(query):
  config = get_vit_config(query)
  return config.transformer.num_layers

# Agents

## Agents configs

In [None]:
DATASET_HPARAMS_KEYS_PRERFIX = "ds_"
OPTIMIZER_HPARAMS_KEYS_PRERFIX = "opt_"

In [None]:
def get_config_ti3_chars():
  config = ConfigDict()
  config.num_train_examples_between_validations_max = 51200  # 100 batches.
  config.num_validations_per_path_training = 4
  config.num_validation_examples_max = 5120  # 10 batches.
  config.batch_size = 512
  config.num_cycles_max = 5
  config.num_samples_per_cycle = 4*8
  # Force finetune last layer norm that technically is part of the head.
  config.force_mutations = ["clone:encoder_norm"]
  config.policy_class = "PPDecay"
  config.policy_kwargs = {}
  config.scorer_class = "ScorerDecay"
  config.scorer_kwargs = dict(
      scale_factor=1.0,
      num_params=1_484_162,  # Params with Ti/16 3 layers.
      flops=15_000_000,  # Flops with Ti/16 3 layers and image size 32.
      )

  config.vit_checkpoint_query = 'name=="Ti/16" and ds=="i21k" and aug=="light1" and wd==0.1 and sd==0.0'

  config.models_default_hparams = {
      "_mu_": 0.1,
      # Default num_classes has no effect since it is always overwritten or used
      # for rand init models whose head is always replaced.
      "num_classes": 1,
      "mutate_adapters": True,
      # Set to ids_ints2str(range(max_num_layers)) to activate all adapters.
      "adapter_layers": "",
      "num_layers": 3,
      "adapter_dim": 32,
      "opt_lr": 0.01,
      "opt_lr_schedule": "cosine",
      "opt_lr_warmup_ratio": 0.1,
      "opt_momentum": 0.9,
      "opt_nesterov": False,
      "ds_image_size": 32,
      "ds_area_range_min": 1.0,
      "ds_aspect_ratio_range_min": 1.0,
      "ds_flip_left_right": False,
      "ds_brightness_delta": 0.0,
      "ds_contrast_delta": 0.0,
      "ds_saturation_delta": 0.0,
      "ds_hue_delta": 0.0,
      "ds_quality_delta": 0.0,
  }
  config.models_mutation_ranges = {}
  return config

In [None]:
def get_config_base_deca():
  config = ConfigDict()
  config.num_train_examples_between_validations_max = 51200
  config.num_validations_per_path_training = 4
  config.num_validation_examples_max = 5120
  config.batch_size = 128
  config.num_cycles_max = 10  # Set to 30 for convergence.
  config.num_samples_per_cycle = 4*8
  config.force_mutations = ["clone:encoder_norm"]
  config.policy_class = "PPDecay"
  config.policy_kwargs = {}
  config.scorer_class = "ScorerDecay"
  config.scorer_kwargs = dict(
      scale_factor=1.0,
      num_params=85_652_738,
      flops=100_000_000_000,
      )
  config.vit_checkpoint_query = 'name=="B/16" and ds=="i21k" and aug=="medium1" and wd==0.1 and sd==0'

  max_num_layers = get_max_num_layers(config.vit_checkpoint_query)
  config.models_default_hparams = {
      "_mu_": 0.1,
      "num_classes": 1,
      "mutate_adapters": True,
      "adapter_layers": "",
      "num_layers": max_num_layers,
      "adapter_dim": 32,
      "opt_lr": 0.01,
      "opt_lr_schedule": "cosine",
      "opt_lr_warmup_ratio": 0.1,
      "opt_momentum": 0.9,
      "opt_nesterov": False,
      "ds_image_size": 80,
      "ds_area_range_min": 1.0,
      "ds_aspect_ratio_range_min": 1.0,
      "ds_flip_left_right": False,
      "ds_brightness_delta": 0.0,
      "ds_contrast_delta": 0.0,
      "ds_saturation_delta": 0.0,
      "ds_hue_delta": 0.0,
      "ds_quality_delta": 0.0,
  }
  config.models_mutation_ranges = {}
  return config

In [None]:
# Configuration for quick tests.
def get_config_ti0_cmaterdb():
  config = ConfigDict()
  config.num_train_examples_between_validations_max = 51200  # 100 batches.
  config.num_validations_per_path_training = 2
  config.num_validation_examples_max = 5120  # 10 batches.
  config.batch_size = 512
  config.num_cycles_max = 8
  config.num_samples_per_cycle = 2*8
  config.max_task_population_size = 7
  config.force_mutations = ["clone:encoder_norm"]
  config.policy_class = "PPDecay"
  config.policy_kwargs = {}
  config.scorer_class = "ScorerDecay"
  config.scorer_kwargs = dict(
      scale_factor=0.99,
      num_params=1_000_000,
      flops=1_000_000,
      )
  # The query is used to get the model configs even if the checkpoint is not loaded.
  config.vit_checkpoint_query = 'name=="Ti/16" and ds=="i21k" and aug=="light1" and wd==0.1 and sd==0.0'

  config.models_default_hparams = {
      "_mu_": 0.2,
      "num_classes": 1,
      "mutate_adapters": True,
      "adapter_layers": "",
      "num_layers": 0,
      "adapter_dim": 32,
      "opt_lr": 0.01,
      "opt_lr_schedule": "cosine",
      "opt_lr_warmup_ratio": 0.1,
      "opt_momentum": 0.9,
      "opt_nesterov": False,
      "ds_image_size": 32,
      "ds_area_range_min": 1.0,
      "ds_aspect_ratio_range_min": 1.0,
      "ds_flip_left_right": False,
      "ds_brightness_delta": 0.0,
      "ds_contrast_delta": 0.0,
      "ds_saturation_delta": 0.0,
      "ds_hue_delta": 0.0,
      "ds_quality_delta": 0.0,
  }
  config.models_mutation_ranges = {}
  return config

In [None]:
def get_config_large():
  config = ConfigDict()
  config.num_train_examples_between_validations_max = 100_000
  config.num_validations_per_path_training = 8
  config.num_validation_examples_max = 10_000
  config.batch_size = 16
  config.num_cycles_max = 24
  config.num_samples_per_cycle = 16
  config.max_task_population_size = 16
  config.force_mutations = ["clone:encoder_norm"]
  config.policy_class = "PPDecay"
  config.policy_kwargs = {}
  config.scorer_class = "ScorerDecay"
  config.scorer_kwargs = dict(
      scale_factor=0.99,
      num_params=2_200_000_000,
      flops=3_800_000_000_000,
      )
  config.vit_checkpoint_query = 'name=="L/16" and ds=="i21k" and aug=="medium2" and wd==0.03 and sd==0.1'

  max_num_layers = get_max_num_layers(config.vit_checkpoint_query)
  config.models_default_hparams = {
      "_mu_": 0.2,
      "num_classes": 1,
      "num_layers": max_num_layers,
      "opt_lr": 0.01,
      "opt_lr_schedule": "cosine",
      "opt_lr_warmup_ratio": 0.05,
      "opt_momentum": 0.9,
      "opt_nesterov": False,
      "ds_image_size": 384,
      "ds_area_range_min": 1.0,
      "ds_aspect_ratio_range_min": 1.0,
      "ds_flip_left_right": False,
      "ds_brightness_delta": 0.0,
      "ds_contrast_delta": 0.0,
      "ds_saturation_delta": 0.0,
      "ds_hue_delta": 0.0,
      "ds_quality_delta": 0.0,
  }
  config.models_mutation_ranges = {}
  return config

In [None]:
def config_add_auto_tune(config, agent_class):
  """Extend config with auto tune parameters."""
  config.models_mutation_ranges["_mu_"] = [
      0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18, 0.20, 0.22, 0.24,
      0.26, 0.28, 0.30
  ]
  config.models_mutation_ranges["num_layers"] = list(
      range(config.models_default_hparams["num_layers"] + 1))
  config.models_mutation_ranges["opt_lr"] = [
      0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2,
      0.5
  ]
  config.models_mutation_ranges["opt_lr_warmup_ratio"] = [
      0.0, 0.01, 0.02, 0.05, 0.1, 0.2, 0.3
  ]
  config.models_mutation_ranges["opt_momentum"] = [
      0.5, 0.6, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.98, 0.99
  ]
  config.models_mutation_ranges["opt_nesterov"] = [True, False]
  config.models_mutation_ranges["ds_image_size"] = [224, 384]
  config.models_mutation_ranges["ds_area_range_min"] = [
      0.05, 0.5, 0.95, 1.0
  ]
  config.models_mutation_ranges["ds_aspect_ratio_range_min"] = [
      0.5, 0.75, 1.0
  ]
  config.models_mutation_ranges["ds_flip_left_right"] = [True, False]
  config.models_mutation_ranges["ds_brightness_delta"] = [
      0.0, 0.01, 0.02, 0.05, 0.1, 0.2
  ]
  config.models_mutation_ranges["ds_contrast_delta"] = [
      0.0, 0.01, 0.02, 0.05, 0.1, 0.2
  ]
  config.models_mutation_ranges["ds_saturation_delta"] = [
      0.0, 0.01, 0.02, 0.05, 0.1, 0.2
  ]
  config.models_mutation_ranges["ds_hue_delta"] = [
      0.0, 0.01, 0.02, 0.05, 0.1, 0.2
  ]
  config.models_mutation_ranges["ds_quality_delta"] = [
      0.0, 0.01, 0.02, 0.05, 0.1, 0.2
  ]

  # agent_class specific overrides.
  if agent_class == "VitT0":
    config.models_mutation_ranges["num_layers"] = list(range(4))
    config.models_mutation_ranges["ds_image_size"] = [32, 48]
  elif agent_class in ["VitT3", "VitB"]:
    config.models_mutation_ranges["num_layers"] = list(
        range(config.models_default_hparams["num_layers"] + 1))
    config.models_mutation_ranges["adapter_dim"] = [8, 16, 32, 64, 128]
    config.models_mutation_ranges["ds_image_size"] = [
        16 * i for i in (range(1, 1 + int(112 / 16)))
    ]
  else:
    assert agent_class == "Vit", (
        f"Undefined mutation ranges for benchmark: {agent_class}")
  return config

def config_validate(config):
  for khp in config.models_default_hparams:
    if khp in config.models_mutation_ranges:
      assert config.models_default_hparams[khp] \
          in config.models_mutation_ranges[khp]
  for khp in config.models_mutation_ranges:
    assert khp in config.models_default_hparams

## Agents utils

In [None]:
def format_agent_id(agent_id: str):
  assert "~" not in agent_id, f"Invalid agent id: {agent_id}"
  return agent_id.replace("/", "~")

In [None]:
def run_cycles(agent):
  config = agent.config
  devices = jax.local_devices()
  print("DEVICE COUNT:", len(devices))
  task_name = config.task_name
  num_cycles = config.num_cycles_max

  for _ in range(num_cycles):
    agent.load_state()
    if agent.cycle_id >= num_cycles:
      break
    print("\n\n====")
    print(f"CYCLE: [{agent.cycle_id+1}/{num_cycles}]")
    task = Path.cached_tasks(task_name=task_name)
    agent.pop.start_cycle()
    agent_cycle(
        task, devices, agent.pop, agent.generation_id, agent.cycle_id, config)
    agent.pop.end_cycle()
    agent.cycle_id += 1
    agent.generation_id = 0

    write_threads = save_state(
        agent.pop, agent.cycle_id, agent.generation_id, config)
    # Display stats.
    avg_time_per_sample = (
        agent.pop.paths_df["metrics.end_time"].mean() \
            - agent.pop.paths_df["metrics.start_time_loop"].mean()
        ) / len(devices)
    print(f"Avg time per path: {avg_time_per_sample:.2f} s")
    # Wait for last state write to complete.
    for t in write_threads:
      t.join()
    if agent.cycle_id >= num_cycles:
      break

def run_root_model(agent):
  agent.load_state()
  pop = agent.pop
  config = agent.config
  cycle_id = agent.cycle_id
  generation_id = agent.generation_id
  write_threads = save_state(pop, cycle_id, generation_id, config)
  for t in write_threads:
    t.join()

## Agent classes

In [None]:
class Agent():
  @property
  def class_name(self):
    return self.__class__.__name__

  @property
  def id(self):
    return self.config.agent_id

  def run(self):
    assert False, "Not implementd"

  def complete_config(self, task_name, experiment_dir):
    self.config.task_name = task_name
    self.config.experiment_dir = experiment_dir
    self.config.agent_id = format_agent_id(f"{self.class_name}/{task_name}")
    self.config.agent_dir = os.path.join(experiment_dir, self.id)


class Vit(Agent):
  """ViT large / all tasks"""

  def __init__(self, task_name, experiment_dir, auto_tune, scale_factor):
    self.config = self.init_config_fn()
    self.complete_config(task_name, experiment_dir, auto_tune, scale_factor)

  def load_state(self):
    task_name = self.config.task_name
    self.pop = Population(self.config)
    self.cycle_id = 0
    self.generation_id = 0
    # Root models.
    if task_name.startswith("root_model/"):
      hparams = self.config.models_default_hparams.as_configdict()
      if task_name == "root_model/random_init":
        path_params = Path.cached_init_params(
          query=self.config.vit_checkpoint_query,
          **hparams)
      else:
        assert task_name == "root_model/checkpoint", task_name
        path_params = get_vit_checkpoint_mapped(
            hparams["ds_image_size"],
            self.config.vit_checkpoint_query)
      path = Path(
          hparams,
          params2comps(
              path_params, train_locks=[self.id], agent_id=self.id),
          parent=None,
          agent_id=self.id,
          task_name=task_name)
      self.pop.paths[self.id].append(path)
      return

    # Load latest agent state.
    def validate_df(df):
      assert len(df["agent_id"].unique()) == 1, len(df["agent_id"].unique())
      assert df["agent_id"].unique()[0] == self.id, df["agent_id"].unique()[0]
    agent_checkpoint = latest_checkpoint(
        os.path.join(self.config.agent_dir, "state_*_*/"))
    if agent_checkpoint:
      matched = re.findall(r"checkpoint_([0-9]+)_([0-9]+)$", agent_checkpoint)
      assert len(matched) == 1
      self.cycle_id = int(matched[0][0])
      self.generation_id = int(matched[0][1])
      state_dir = os.path.dirname(agent_checkpoint)
      self.pop.paths_df = df_read_from_csv(state_dir, "paths")
      self.pop.comps_df = df_read_from_csv(state_dir, "components")
      validate_df(self.pop.paths_df)
      validate_df(self.pop.comps_df)
      # Set globals.
      Path.paths = []
      Path.counter = 1 + int(self.pop.paths_df.id.max())
      Component.counter = 1 + int(self.pop.comps_df.id.max())
      # Get id of the last componet saved in a non intermediate checkpoint.
      non_intermediated_checkpoint = latest_checkpoint(
          os.path.join(self.config.agent_dir, "state_*_0/"))
      if non_intermediated_checkpoint:
        ni_paths_df = df_read_from_csv(
            os.path.dirname(non_intermediated_checkpoint), "paths")
        validate_df(ni_paths_df)
        Path.last_saved = int(ni_paths_df.id.max())
        ni_comps_df = df_read_from_csv(
            os.path.dirname(non_intermediated_checkpoint), "components")
        validate_df(ni_comps_df)
        Component.last_saved = int(ni_comps_df.id.max())
      print("CONTINUING FROM STATE", self.cycle_id, self.generation_id)

    # Load all available paths.
    all_agents_dirs = os.path.join(self.config.experiment_dir, "*")
    state_dir = os.path.dirname(agent_checkpoint) if agent_checkpoint else None
    load_paths(self.pop, state_dir, all_agents_dirs)

    assert self.pop.paths, "Empty population, run an agent creating a " \
        "root model to initialize the population."
    df_leaderboard(pop_to_df(self.pop))

  @property
  def init_config_fn(self):
    return get_config_large

  def run(self):
    if self.config.task_name.startswith("root_model/"):
      run_root_model(self)
      return
    run_cycles(self)

  def complete_config(self, task_name, experiment_dir, auto_tune, scale_factor):
    super().complete_config(task_name, experiment_dir)
    if auto_tune:
      self.config = config_add_auto_tune(self.config, self.class_name)
    if scale_factor:
      self.config.scorer_kwargs["scale_factor"] = scale_factor
    self.config = FrozenConfigDict(self.config)
    config_validate(self.config)


class VitT0(Vit):
  """ViT tiny 0 layers / cmaterdb benchmark"""
  @property
  def init_config_fn(self):
    return get_config_ti0_cmaterdb


class VitT3(Vit):
  """ViT tiny 3 layers / characters benchmark"""
  @property
  def init_config_fn(self):
    return get_config_ti3_chars


class VitB(Vit):
  """ViT base / decathlon benchmark"""
  @property
  def init_config_fn(self):
    return get_config_base_deca

# ViT Model

In [None]:
def get_sample_images(image_size:int, batch_size:int):
  return np.zeros((batch_size, image_size, image_size, 3))

def get_sample_labels(batch_size:int):
  return np.zeros(batch_size, dtype=np.int32)

def get_sample_batch(image_size:int, batch_size:int):
  return {"image": get_sample_images(image_size, batch_size),
          "label": get_sample_labels(batch_size),}

In [None]:
def get_vit_checkpoint(image_size, query):
  filename = get_vit_filename(query)

  config = get_vit_config(query)

  model = VisionTransformer(**config, num_classes=2)  # num_classes unsed.
  init_params = copy.deepcopy(jax.device_get(
      model.init(jax.random.PRNGKey(random.randrange(1e10)),
                 get_sample_images(image_size=image_size,
                                   batch_size=1),
                 train=False  # Disable dropout.
                 )["params"]))

  params = checkpoint.load_pretrained(
    pretrained_path=f"gs://vit_models/augreg/{filename}.npz",
    init_params=init_params,
    model_config=config)

  return params

def get_vit_checkpoint_mapped(image_size, query):
  params = get_vit_checkpoint(image_size, query)
  params = params_model_to_comps(params)
  return params

def get_reshaped_posembed_component(
    agent_id: str, ds_image_size: int, query: str):
  params = get_vit_checkpoint_mapped(ds_image_size, query)["posembed_input"]
  return Component(name="posembed_input",
                   agent_id=agent_id,
                   params=params,
                   train_locks=[])

In [None]:
# Parameter mapping.
TRANSFORMER_KEYS = set(
    ["encoder_norm", "posembed_input" ] + \
    [f"encoderblock_{k}" for k in range(30)])

def params_model_to_comps(params):
  global TRANSFORMER_KEYS
  TRANSFORMER_KEYS.update(params["Transformer"].keys())
  new_params = {}
  for k in params.keys():
    if k == "Transformer":
      t_params = params[k]
      for t_k in t_params.keys():
        new_params[t_k] = t_params[t_k]
    else:
      new_params[k] = params[k]
  return flax.core.freeze(new_params)

def params_comps_to_model(params):
  params = params.unfreeze()

  params["Transformer"] = {}
  keys = list(params.keys())
  assert len(TRANSFORMER_KEYS) != 0
  for k in keys:
    if k in TRANSFORMER_KEYS:
      params["Transformer"][k] = params.pop(k)
  return flax.core.freeze(params)

In [None]:
def get_vit_model(
    num_classes: int,
    num_layers: int,
    query: str,
    adapter_layers: str = "",
    adapter_dim: int = -1,
    ):
  config = get_vit_config(query)
  config["transformer"]["num_layers"] = num_layers
  config["transformer"]["adapter_layers"] = adapter_layers
  config["transformer"]["adapter_dim"] = adapter_dim
  config = FrozenConfigDict(config)
  model = VisionTransformer(**config, num_classes=num_classes)
  return model

def get_vit_model_and_params(
    num_classes, num_layers, ds_image_size, query,
    adapter_layers, adapter_dim):
  model = get_vit_model(
      num_classes, num_layers, query, adapter_layers, adapter_dim)
  init_params = copy.deepcopy(jax.device_get(
      model.init(
          jax.random.PRNGKey(random.randrange(1e10)),
          get_sample_images(image_size=ds_image_size, batch_size=1),
          train=False  # Disable dropout.
          )["params"]))
  return model, init_params

def get_vit_params_mapped(
    num_classes: int,
    num_layers: int,
    ds_image_size: int,
    query: str,
    adapter_layers: str = "",
    adapter_dim: int = -1):
  model, init_params = get_vit_model_and_params(
      num_classes, num_layers, ds_image_size, query,
      adapter_layers, adapter_dim)
  init_params = params_model_to_comps(init_params)
  return init_params

In [None]:
def format_params(a, b):
  params = a.copy(b)
  assert len(params) == len(a) + len(b)  # Dicts keys should not overlap.
  params = params_comps_to_model(params)
  return params

In [None]:
def get_optimizer(
    opt_lr: float,
    opt_lr_schedule: str,
    opt_lr_warmup_ratio: float,
    opt_momentum: float,
    opt_nesterov: bool,
    num_train_batches_between_validations: int,
    num_validations_per_path_training: int,
    ):
  min_lr = opt_lr / 1000.0
  if opt_lr_schedule == "constant":
    # Divide by 2 so that average lr is the same as other types.
    learning_rate = 0.5 * opt_lr
  elif opt_lr_schedule == "linear":
    train_steps = int(num_train_batches_between_validations
                      * num_validations_per_path_training)
    warmup_steps = int(opt_lr_warmup_ratio * train_steps)
    schedules = [
        optax.linear_schedule(
            init_value=min_lr,
            end_value=opt_lr,
            transition_steps=warmup_steps),
        optax.linear_schedule(
            init_value=opt_lr,
            end_value=min_lr,
            transition_steps=train_steps-warmup_steps)]
    learning_rate = optax.join_schedules(schedules, [warmup_steps])
  elif opt_lr_schedule == "cosine":
    train_steps = int(num_train_batches_between_validations
                      * num_validations_per_path_training)
    learning_rate = optax.warmup_cosine_decay_schedule(
        init_value=min_lr,
        peak_value=opt_lr,
        warmup_steps=int(opt_lr_warmup_ratio * train_steps),
        decay_steps=train_steps)
  elif opt_lr_schedule == "restarts":
    train_steps = num_train_batches_between_validations
    repeats = num_validations_per_path_training
    kwargs = dict(
        init_value=min_lr,
        peak_value=opt_lr,
        warmup_steps=int(opt_lr_warmup_ratio * train_steps),
        decay_steps=train_steps,
    )
    kwargs = [kwargs] * repeats
    learning_rate = optax.sgdr_schedule(kwargs)
  else:
    assert False, f"Invalid lr schedule: {opt_lr_schedule}"

  return optax.chain(
      optax.clip_by_global_norm(1.0),
      optax.sgd(
          learning_rate=learning_rate,
          momentum=opt_momentum,
          nesterov=opt_nesterov,
          accumulator_dtype=jnp.bfloat16))

In [None]:
def get_default_splits(tfds_name):
  info = get_tfds_builder(tfds_name).info
  splits = list(info.splits.keys())
  assert "train" in splits, splits
  splits.remove("train")
  used_percent = 0
  slice_percent = 5
  pp = {}
  for k in ["test", "validation"]:
    if k in splits:
      pp[k] = k
      splits.remove(k)
    else:
      pp[k] = f"train[{used_percent}%:{used_percent+slice_percent}%]"
      used_percent += slice_percent
  pp["train"] = f"train[{used_percent}%:]"
  return pp

def get_dataset_and_splits(tfds_name: str):
  vtab_class = None
  if tfds_name in ["imagenet_v2", "cifar10_1"]:
    assert False,  f"{tfds_name} used as validation set for other tasks."

  if tfds_name == "imagenet2012":
    dataset = {
        "train":"imagenet2012", "validation":"imagenet_v2", "test":"imagenet2012"}
    splits = {
        "train":"train", "validation":"test", "test":"validation"}
  elif tfds_name == "cifar100":
    dataset = tfds_name
    splits = {
        "train":"train[:98%]", "validation":"train[98%:]", "test":"test"}
  elif tfds_name == "cifar10":
    dataset = {
        "train":"cifar10", "validation":"cifar10_1", "test":"cifar10"}
    splits = {
        "train":"train", "validation":"test", "test":"test"}
  elif (tfds_name.startswith("visual_domain_decathlon/") or
        tfds_name in ["i_naturalist2017", "i_naturalist2018", "places365_small"]):
    dataset = tfds_name
    # Test has no labels, split validation in half.
    splits =  {
        "train":"train", "validation":"validation[:50%]", "test":"validation[50%:]"}
  elif tfds_name.startswith("cmaterdb/"):
    dataset = tfds_name
    # Increase size of validation set due to small dataset size.
    splits =  {
        "train":"train[20%:]", "validation":"train[:20%]", "test":"test"}
  elif tfds_name == "omniglot":
    # Test has no labels, and missing validation, use additional splits.
    dataset = tfds_name
    splits = {"train":"train", "validation":"small1", "test":"small2"}
  elif tfds_name.startswith("controlled_noisy_web_labels/"):
    dataset = tfds_name
    splits =  {
        "train":"train_00",
        "validation":"validation[:50%]",
        "test":"validation[50%:]"}
  elif tfds_name.startswith("cycle_gan/"):
    dataset = tfds_name
    splits =  {
        "train":"trainA[10%:]+trainB[10%:]",
        "validation":"trainA[:10%]+trainB[:10%]",
        "test":"testA+testB"}
  elif tfds_name in ["imagenet_a", "imagenet_r", "imagenet_sketch",
                     "siscore/rotation", "siscore/size", "siscore/location",]:
    # Only test split.
    dataset = tfds_name
    splits =  {
        "train":"test[10%:]",
        "validation":"test[5%:10%]",
        "test":"test[:5%]"}
  elif tfds_name in ["pet_finder"]:
    # Explicitly use only train split. E.g. test has no labels.
    dataset = tfds_name
    splits =  {
        "train":"train[10%:]",
        "validation":"train[5%:10%]",
        "test":"train[:5%]"}
  elif tfds_name == "quickdraw_bitmap":
    dataset = tfds_name
    # Cap size of test and validation set.
    splits =  {
        "train":"train[20000:]", "validation":"train[10000:20000]", "test":"train[:10000]"}
  elif tfds_name == "stanford_online_products":
    dataset = tfds_name
    # Use the first 10k test samples as validation since test has 60k.
    splits =  {
        "train":"train", "validation":"test[:10000]", "test":"test[10000:]"}
  elif tfds_name in VTAB_TASKS or (
      tfds_name.endswith("/1k") and tfds_name.replace("/1k", "") in VTAB_TASKS):
    is_vtab_1k = tfds_name.endswith("/1k")
    tfds_name = tfds_name.replace("/1k", "")
    registry_name = {
        "diabetic_retinopathy_detection/btgraham-300": "diabetic_retinopathy",
        "svhn_cropped": "svhn",
        "cifar100": "cifar",
        "cifar10": "cifar",
    }.get(tfds_name, tfds_name.split("/")[0])
    args = {
        "clevr/count_all": ("count_all",),
        "clevr/count_cylinders": ("count_cylinders",),
        "clevr/closest_object_distance": ("closest_object_distance",),
        "dsprites/label_x_position": ("label_x_position",),
        "dsprites/label_orientation": ("label_orientation",),
        "kitti/closest_object_distance": ("closest_object_distance",),
        "kitti/count_vehicles": ("count_vehicles",),
        "kitti/closest_vehicle_distance": ("closest_vehicle_distance",),
        "smallnorb/label_category": ("label_category",),
        "smallnorb/label_lighting": ("label_lighting",),
        "smallnorb/label_azimuth": ("label_azimuth",),
        "smallnorb/label_elevation": ("label_elevation",),
        "cifar100": (100,),
        "cifar10": (10,),
    }.get(tfds_name, ())
    vtab_class = task_adapt_registry.Registry.lookup(
        f"data.{registry_name}")(*args)
    vtab_splits = vtab_class._tfds_splits
    dataset = {
        "caltech101": "caltech101:3.*.*",
        "dtd": "dtd:3.*.*",
        "oxford_flowers102": "oxford_flowers102:2.*.*",
        "oxford_iiit_pet": "oxford_iiit_pet:3.*.*",
        "sun397": "sun397/tfds:4.*.*",
        "svhn": "svhn_cropped:3.*.*",
        "patch_camelyon": "patch_camelyon:2.*.*",
        "eurosat": "eurosat/rgb:2.*.*",
        "resisc45": "resisc45:3.*.*",
        "diabetic_retinopathy": "diabetic_retinopathy_detection/btgraham-300:3.*.*",
        "clevr": "clevr:3.*.*",
        "dmlab": "dmlab:2.0.1",
        "dsprites": "dsprites:2.*.*",
        "kitti": "kitti:3.2.0",
        "smallnorb": "smallnorb:2.*.*",
        "cifar" : "cifar100:3.*.*" if tfds_name == "cifar100" else "cifar10:3.*.*",
    }[registry_name]
    if is_vtab_1k:
      splits =  {
          "train": str(vtab_splits["train800"]),
          "validation": str(vtab_splits["val200"]),
          "test": str(vtab_splits["test"]),
          }
    else:
      splits =  {
          "train": str(vtab_splits["train"]),
          "validation": str(vtab_splits["val"]),
          "test": str(vtab_splits["test"]),
          }
  else:
    dataset = tfds_name
    splits = get_default_splits(tfds_name)
  return dataset, splits, vtab_class


class Task():
  def __init__(self, name, config):
    self.config = config

    self.dataset, self.splits, self.vtab_class = get_dataset_and_splits(name)
    self.name = name
    if self.vtab_class:
      self.num_classes = self.vtab_class.get_num_classes()
    else:
      self.num_classes = self.get_builder(
          "train").info.features[self.get_label_key()].num_classes
    num_train_examples = self.get_builder(
        "train").info.splits[self.splits["train"]].num_examples
    self.train_batch_size = config.batch_size
    self.num_train_batches_between_validations = math.ceil(
        min(num_train_examples,
            config.num_train_examples_between_validations_max)
        / self.train_batch_size)
    self.cache_train = num_train_examples < min(100_000, (
        config.num_validations_per_path_training
        * self.num_train_batches_between_validations
        * self.train_batch_size))

    num_validation_examples_tot = self.get_builder(
        "validation").info.splits[self.splits["validation"]].num_examples
    if config.num_validation_examples_max <= num_validation_examples_tot:
      self.validation_batch_size = config.batch_size
      self.num_validation_batches = math.floor(
          config.num_validation_examples_max / self.validation_batch_size)
    else:
      # Adjust batch_size and num_batches to cover the smaller validation sets.
      self.num_validation_batches = math.ceil(
          num_validation_examples_tot / config.batch_size)
      self.validation_batch_size = math.floor(
          num_validation_examples_tot / self.num_validation_batches)
      assert num_validation_examples_tot >= (
          self.num_validation_batches*self.validation_batch_size)
    self.num_validation_examples = (
        self.num_validation_batches * self.validation_batch_size)

    print(f"Task: {self.name}")
    print(f"  Train batches between validations: {self.num_train_batches_between_validations}")
    print(f"  Validation batches: {self.num_validation_batches}")
    print(f"  Validation batch size: {self.validation_batch_size}")
    print(f"  Dataset {{\n{self.dataset}}}")
    print(f"  Splits {{\n{self.splits}}}")

  def get_label_key(self):
    return {
        "stanford_online_products": "super_class_id",
        }.get(self.name, "label")

  def get_builder(self, mode):
    if type(self.dataset) == str:
      return get_tfds_builder(self.dataset)
    return get_tfds_builder(self.dataset[mode])

  def __str__(self):
    return f"Task_{self.name}"

  def get_ds(self, mode, hparams):
    data = self.get_builder(mode).as_dataset(
        split=self.splits[mode],
        shuffle_files=mode=='train')

    def _pp(data):
      im = data["image"]
      tf.debugging.assert_type(im, tf.uint8)

      if mode == "train":
        if hparams.get("ds_quality_delta", 0.0) > 0.0:
          im = tf.image.random_jpeg_quality(
              im,
              min_jpeg_quality=int(100 * (1 - hparams["ds_quality_delta"])),
              max_jpeg_quality=100)

      # Must have 3 channels.
      if im.shape[-1] == 1:
        im = tf.squeeze(tf.stack([im] * 3, -1), axis=-2)
      assert im.shape[-1] == 3

      if mode == "train":
        if hparams.get("ds_area_range_min", 1.0) < 1.0:
          channels = im.shape[-1]
          begin, size, _ = tf.image.sample_distorted_bounding_box(
              tf.shape(im),
              tf.zeros([0, 0, 4], tf.float32),
              aspect_ratio_range=[hparams["ds_aspect_ratio_range_min"],
                                  1.0/hparams["ds_aspect_ratio_range_min"]],
              area_range=[hparams["ds_area_range_min"], 1.0],
              # Overlap with bounding box, the bounding box should anyway
              # default defaults to whole image in this case.
              min_object_covered=0,
              use_image_if_no_bounding_boxes=True)
          im = tf.slice(im, begin, size)
          # Restore the depth-dimension lost by the above operation.
          im.set_shape([None, None, channels])
        if hparams.get("ds_flip_left_right", False):
          if tf.random.uniform(shape=[]) > 0.5:
            im = tf.image.flip_left_right(im)
        if hparams.get("ds_brightness_delta", 0.0) > 0.0:
          im = tf.image.random_brightness(
              im, max_delta=hparams["ds_brightness_delta"])
        if hparams.get("ds_contrast_delta", 0.0) > 0.0:
          im = tf.image.random_contrast(
              im, lower=1 - hparams["ds_contrast_delta"],
              upper=1 + hparams["ds_contrast_delta"])
        if hparams.get("ds_saturation_delta", 0.0) > 0.0:
          im = tf.image.random_saturation(
              im, lower=1 - hparams["ds_saturation_delta"],
              upper=1 + hparams["ds_saturation_delta"])
        if hparams.get("ds_hue_delta", 0.0) > 0.0:
          im = tf.image.random_hue(im, max_delta=hparams["ds_hue_delta"])

      tf.debugging.assert_type(im, tf.uint8)
      im = tf.image.resize(im, [hparams["ds_image_size"],
                                hparams["ds_image_size"]])
      tf.debugging.assert_type(im, tf.float32)

      # Values in range [-1 , 1].
      im = im / 127.5 - 1
      im = tf.clip_by_value(im, -1, 1)

      return {"image": im, "label": data[self.get_label_key()]}

    if mode == "validation":
      data = data.take(self.num_validation_examples)
    if mode == "validation" or (mode == "train" and self.cache_train):
      data = data.cache()
    if mode != "test":
      data = data.repeat()
    if self.vtab_class and self.vtab_class._base_preprocess_fn:
      data = data.map(self.vtab_class._base_preprocess_fn, tf.data.AUTOTUNE)
    data = data.map(_pp, tf.data.AUTOTUNE)
    if mode == "train":
      batch_size = self.train_batch_size
    else:
      batch_size = self.validation_batch_size
    data = data.batch(batch_size)
    if mode == "train":
      data = data.shuffle(10)
    return tfds.as_numpy(data.prefetch(tf.data.AUTOTUNE))

def get_task_factory_fn(config):
  def get_task(task_name: str):
    return Task(name=task_name, config=config)
  return get_task

In [None]:
def get_num_params(params):
  return sum(jax.tree_util.tree_flatten(
      jax.tree_util.tree_map(lambda p: np.prod(p.shape), params)
      )[0])

In [None]:
def params2comps(params, train_locks, agent_id, name=None):
  """Convert frozend dict of params to a list of components."""
  components = []
  for k in params:
    if name is None or name == k:
      c = Component(
          name=k, agent_id=agent_id,
          params=params[k], train_locks=train_locks)
      components.append(c)
  return components

def params2comp_names(params):
  return list(params.keys())

In [None]:
def fingerprint_params(params):
  return np.sum(np.array(jax.tree_util.tree_leaves(
      jax.tree_util.tree_map(jnp.sum, params))))

class Component():
  counter = 0
  # Components of retained paths with id <= last_saved are saved in checkpoint.
  last_saved = -1

  def reset_globals():
    Component.counter = 0
    Component.last_saved = -1

  def __init__(
      self, name: str, agent_id: str, params, train_locks, opt_state=None):
    self.name = name
    self.agent_id = agent_id
    self.params = jax.device_get(params)
    self.opt_state = jax.device_get(opt_state)
    self.num_params = None
    self.train_locks = set(train_locks)
    self.id = Component.counter
    Component.counter += 1

  def __str__(self):
    rtn = f"Component: {self.id}\n  Name: {self.name}"
    rtn += f"\n  Train locks: {self.train_locks}"
    rtn += f"\n  Fingerprint: {self.fingerprint()}"
    rtn += f"\n  Num params: {self.get_num_params()}"
    return rtn

  def get_num_params(self):
    if self.num_params is None:
      self.num_params = get_num_params(self.params)
    return self.num_params

  def fingerprint(self):
    return fingerprint_params(self.params)

  def is_trainable(self):
    return len(self.train_locks) == 0

  def clone(self, agent_id):
    return Component(name=self.name,
                     agent_id=agent_id,
                     params=copy.deepcopy(jax.device_get(self.params)),
                     train_locks=set(),
                     opt_state=copy.deepcopy(self.opt_state))

In [None]:
class ObjectCache():
  def __init__(self, factory_fn, max_size=None):
    self.factory_fn = factory_fn
    self.factory_fn_signature = inspect.signature(factory_fn)
    self.cache = {}
    self.max_size = max_size

  def __call__(self, *args, **kwargs):
    assert not args, "No positional arguments allowed."
    kw_params = {}
    fn_name = self.factory_fn.__name__
    fn_params = inspect.signature(self.factory_fn).parameters
    for k_param, v_param in fn_params.items():
      if k_param in kwargs:
        kw_params[k_param] = kwargs[k_param]
      elif v_param.default != v_param.empty:
        # Fallback to declared defalut value.
        kw_params[k_param] = fn_params[k_param].default
      else:
        assert False, (
            f"Missing value for argument {k_param} for function {fn_name}")

      if v_param.annotation != v_param.empty:
        # Apply annotated type.
        assert isinstance(type(v_param.annotation), type)
        kw_params[k_param] = v_param.annotation(kw_params[k_param])

    key = json.dumps(kw_params, sort_keys=True)
    if key not in self.cache:
      if self.max_size and self.max_size <= len(self.cache):
        rm_key = random.choice(list(self.cache.keys()))
        print(f"Removed from cache: {fn_name}({rm_key})  [cache size {len(self.cache)}]")
        rm_obj = self.cache.pop(rm_key)
        del rm_obj
      self.cache[key] = self.factory_fn(**kw_params)
      print(f"Added to cache: {fn_name}({key})  [cache size {len(self.cache)}]")
    return self.cache[key]

In [None]:
def incremental_mutation(value, values_list:list):
  assert value in values_list, f"{value} not in {values_list}"
  idx = values_list.index(value)
  idx += 1 if np.random.uniform() < 0.5 else -1
  idx = max(0, min(len(values_list)-1, idx))
  return values_list[idx]

In [None]:
def compute_flops_hlo(flax_module, *a, **kw):
  """Compute FLOPs in flax_module."""
  # Compute flops on cpu for cross platform consistency.
  analysis = jax.jit(flax_module, backend='cpu').lower(*a, **kw).cost_analysis()
  return analysis["flops"]

In [None]:
class Path():
  def reset_globals(config):
    Path.config = config
    Path.counter = 0
    Path.last_saved = -1
    Path.paths = []
    Path.scorer = globals()[config.scorer_class](**config.scorer_kwargs)
    # Cache output of functions calls with same args.
    Path.cached_tasks = ObjectCache(get_task_factory_fn(config))
    Path.cached_posembed_components = ObjectCache(get_reshaped_posembed_component)
    Path.cached_optimizers = ObjectCache(get_optimizer)
    Path.cached_models = ObjectCache(get_vit_model)
    Path.cached_init_params = ObjectCache(get_vit_params_mapped, max_size=10)

  def __init__(self, hparams, components, parent, agent_id, task_name):
    self.components = components
    self.id = Path.counter
    Path.counter += 1
    self.agent_id = agent_id
    self.task_name = task_name
    self.parent = parent
    self.hparams = hparams
    self.metrics = {
        "generation": 0 if parent is None else parent.metrics["generation"] + 1,
    }
    Path.paths.append(self)

  def __str__(self):
    rtn = f"Path: {self.id}"
    rtn += f"\n  Components: {[c.id for c in self.components]}"
    if self.parent:
      rtn += f"\n  Parent: {self.parent.id}"
    rtn += f"\n  Agent: {self.agent_id}"
    rtn += f"\n  Task: {self.task_name}"
    for k,v in self.hparams.items():
      rtn += f"\n    {k}: {v}"
    for k,v in self.metrics.items():
      rtn += f"\n    {k}: {v}"
    rtn += f"\n  Score: {self.score()}"
    return rtn

  @property
  def task(self):
    return Path.cached_tasks(task_name=self.task_name)

  @property
  def model(self):
    return Path.cached_models(query=self.config.vit_checkpoint_query,
                              **self.hparams)

  def score(self):
    return Path.scorer.score(self)

  def get_all_params(self):
    params = {}
    for c in self.components:
      params[c.name] = c.params
    return flax.core.freeze(params)

  def get_trainable_params(self):
    params = {}
    for c in self.components:
      if c.is_trainable():
        params[c.name] = c.params
    return flax.core.freeze(params)

  def get_fixed_params(self):
    params = {}
    for c in self.components:
      if not c.is_trainable():
        params[c.name] = c.params
    return flax.core.freeze(params)

  def update_trainable(self, trained_params, opt_state):
    assert len(trained_params.keys()) == len(opt_state.keys())
    trainable_count = 0
    for c in self.components:
      if c.is_trainable():
        trainable_count += 1
        assert c.name in trained_params.keys()
        assert c.name in opt_state.keys()
        c.params = trained_params[c.name]
        c.opt_state = opt_state[c.name]
    assert len(trained_params.keys()) == trainable_count, (
        f"{len(trained_params.keys())} {trainable_count}")

  def get_num_accounted_params(self):
    rtn = 0
    for c in self.components:
      tl = copy.copy(c.train_locks)
      assert type(tl) is set
      tl.add(self.agent_id)
      assert tl
      rtn += c.get_num_params() / len(tl)
    return rtn

  def get_flops(self):
    return compute_flops_hlo(
          self.model.apply,
          {"params": format_params(
              self.get_trainable_params(),
              self.get_fixed_params(),
            )},
          get_sample_images(
              self.hparams["ds_image_size"],
              batch_size=1,
          ),
          train=False)

  def clone(self, ds_hparams, policy):
    config = Path.config
    agent_id = config.agent_id
    task_name = config.task_name
    comps = []
    new_hparams = copy.deepcopy(self.hparams)
    # Overwrite dataset hparams with those sampled for the generation batch.
    new_hparams.update(ds_hparams)

    def get_component_ref(c, clone):
      if c.is_trainable() or clone:
        # Clone trainable component.
        return c.clone(agent_id=agent_id)
      # Refer to frozen component.
      return c

    for k in sorted(config.models_mutation_ranges):
      if ((k in ["_mu_", "num_layers", "adapter_dim"]
            or k.startswith(OPTIMIZER_HPARAMS_KEYS_PRERFIX)) and
          policy.do_mutate(new_hparams, f"hp:{k}")):
        new_hparams[k] = incremental_mutation(
            new_hparams[k],
            config.models_mutation_ranges[k])
    if "adapter_layers" in new_hparams or new_hparams.get("mutate_adapters", False):
      new_hparams["adapter_layers"] = mutate_adapters(
          hparams=new_hparams,
          policy=policy)

    init_params = Path.cached_init_params(
        query = config.vit_checkpoint_query,
        **new_hparams)
    new_comp_names = params2comp_names(init_params)
    for new_comp_name in new_comp_names:
      comp = None
      # Attept to reuse matching componenent from closer ancestor.
      ancestor = self
      while ancestor is not None:
        comps_lookup = {c.name:c for c in ancestor.components}
        if new_comp_name in comps_lookup:
          # Head must be trainable if no acestor is of same task will fall back
          # to random init of correct shape.
          if new_comp_name == "head" and not comps_lookup[new_comp_name].is_trainable():
            assert agent_id != ancestor.agent_id, f"{agent_id} != {ancestor.agent_id}"
            ancestor = ancestor.parent
            continue

          # Check shapes match otherwise skip.
          if (jax.tree_util.tree_map(jnp.shape, init_params[new_comp_name]) !=
              jax.tree_util.tree_map(jnp.shape, comps_lookup[new_comp_name].params)):
            if new_comp_name == "posembed_input":
              # Change of image size changed shape of position embeddings,
              # this can happend if ds_image_size is tuned,
              # continue searching through ancestors for matching size.
              assert "ds_image_size" in config.models_mutation_ranges
              assert new_hparams["ds_image_size"] != ancestor.hparams["ds_image_size"]
              ancestor = ancestor.parent
              continue
            if new_comp_name.startswith("residual_adapter_"):
              # Change of adapter inner dimension changed shape of dense layers,
              # this can happend if adapter_dim is tuned,
              # continue searching through ancestors for matching size.
              assert "adapter_dim" in config.models_mutation_ranges
              assert new_hparams["adapter_dim"] != ancestor.hparams["adapter_dim"]
              ancestor = ancestor.parent
              continue

            print(f"WARNING: Shapes do not match for component: {new_comp_name}  {ancestor.agent_id}->{agent_id}")
            print(jax.tree_util.tree_map(jnp.shape, init_params[new_comp_name]))
            print(jax.tree_util.tree_map(jnp.shape, comps_lookup[new_comp_name].params))
            assert False  # Should not happen in current configuration.

          ancestor_comp = comps_lookup[new_comp_name]
          comp = get_component_ref(
              ancestor_comp, clone=(
                  ancestor_comp.is_trainable() or policy.do_mutate(
                      new_hparams, f"clone:{new_comp_name}")))
          break
        ancestor = ancestor.parent

      # Get reshaped posembed_input.
      if comp is None and new_comp_name == "posembed_input":
        pe_comp = Path.cached_posembed_components(
            agent_id=agent_id,
            query=config.vit_checkpoint_query,
            **new_hparams)
        # Clone to make the component trainable.
        comp = get_component_ref(pe_comp, clone=True)

      # Otherwise create one from random init params.
      if comp is None:
        if VERBOSE:
          print("Init:", new_comp_name)
        # Possible rand init triggering combinations in current configurations.
        assert (
            new_comp_name == "head"
            or new_comp_name.startswith("residual_adapter_")
            or (new_comp_name.startswith("encoderblock_") and \
                config.models_default_hparams["num_layers"] < max(
                config.models_mutation_ranges.get("num_layers", [-1])))
            )
        comp = params2comps(
            init_params, train_locks=[],
            agent_id=agent_id, name=new_comp_name)[0]
      assert comp is not None
      comps.append(comp)

    rtn = Path(
        new_hparams, comps, parent=self, agent_id=agent_id, task_name=task_name)

    if agent_id == self.agent_id:
      self.metrics["offsprings"] = self.metrics.get("offsprings", 0) + 1

    return rtn

  def get_optimizer(self):
    return Path.cached_optimizers(
        num_train_batches_between_validations=
            self.task.num_train_batches_between_validations,
        num_validations_per_path_training=
            self.task.config.num_validations_per_path_training,
        **self.hparams)

In [None]:
def mutate_adapters(hparams, policy, allow_removal=False):
  num_layers = hparams["num_layers"]
  a_ids = set(ids_str2ints(hparams.get("adapter_layers", "")))
  if hparams.get("mutate_adapters", False):
    for a_id in range(num_layers):
      if a_id in a_ids:
        if allow_removal and policy.do_mutate(
            hparams, f"remove:residual_adapter_{a_id}"):
          a_ids.remove(a_id)
      else:
        if policy.do_mutate(hparams, f"add:residual_adapter_{a_id}"):
          a_ids.add(a_id)
  # Drop adapters of layers dropped by a possible mutation in num_layers.
  a_ids = [a_id for a_id in a_ids if a_id < num_layers]
  return ids_ints2str(a_ids)

In [None]:
class ScorerDecay():
  def __init__(self, scale_factor, num_params, flops=0):
    assert 0.0 < scale_factor <= 1.0
    self.scale_factor = scale_factor
    self.num_params = num_params
    self.flops = flops

  def score(self, path):
    if ("quality" not in path.metrics
        or math.isnan(path.metrics["quality"])):
      return None
    assert path.metrics["quality"] >= 0, (
        f"{path.task_name} {path.metrics['quality']}")
    score = path.metrics["quality"]
    if self.num_params > 0:
      # Accounted params needs to be updated since it depends on the
      # changing structure of the system.
      path.metrics["accounted_params"] = path.get_num_accounted_params()
      score *= self.scale_factor ** (
          path.metrics["accounted_params"] / self.num_params)
    if self.flops > 0:
      if "flops" not in path.metrics:
        path.metrics["flops"] = path.get_flops()
      score *= self.scale_factor ** (path.metrics["flops"] / self.flops)
    assert score >= 0
    path.metrics["score"] = score
    return score

In [None]:
class PPDecay():
  def __init__(self, config):
    self.config = config

  def do_mutate(self, hparams, mutation_name):
    """Returns True if mutation is sampled to be applied."""
    if mutation_name in self.config.get("force_mutations", []):
      return True
    mutation_prob_k = f"_mu_|{mutation_name}"
    # Fallback is used for batch shared sampling.
    mu = hparams.get("_mu_", self.config.models_default_hparams["_mu_"])
    mutation_prob = hparams.get(mutation_prob_k, mu)
    if "_mu_" in self.config.models_mutation_ranges:
      if mu > np.random.uniform():
        mutation_prob = incremental_mutation(
            mutation_prob, self.config.models_mutation_ranges["_mu_"])
      hparams[mutation_prob_k] = mutation_prob
    return mutation_prob > np.random.uniform()

  def sample_path(self, pop, ds_hparams):
    parent = None
    for path in sorted(pop.paths[self.config.agent_id],
                       key=lambda p: p.score(),
                       reverse=True):
      offsprings = path.metrics.get("offsprings", 0)
      print(" ", path.id, int(offsprings))
      assert not math.isnan(offsprings)
      if np.random.uniform() < 0.5 ** offsprings:
        parent = path
        break

    if not parent:  # Random sample.
      parent = random.choice([p for paths in pop.paths.values() for p in paths])
      print("  random", parent.id, parent.agent_id)

    child = parent.clone(ds_hparams, self)

    gc.collect()

    # Store record of mutations.
    mutations = {}
    for k in child.hparams:
      if parent.hparams.get(k) != child.hparams[k]:
        mutations[k] = (parent.hparams.get(k), child.hparams[k])
    child.metrics["mutations"] = json.dumps(mutations)
    print(child.id, child.metrics["mutations"])
    return child

  def sample_ds_hparams(self, pop):
    """Sample hparams that need to be shared across each paths generation."""
    assert pop.config is self.config
    ds_hparams = {}

    # Initialize shared hparams with defaults.
    for key in self.config.models_default_hparams:
      if key.startswith(DATASET_HPARAMS_KEYS_PRERFIX):
        ds_hparams[key] = self.config.models_default_hparams[key]

    # Overwrite with values from best path if available.
    best_path = pop.get_best_path()
    if best_path:
      ds_hparams.update(
          {k : best_path.hparams[k] for k in ds_hparams if k in best_path.hparams})
      ds_hparams.update(
          {k : best_path.hparams[k] for k in best_path.hparams if k.startswith(
              f"_mu_|hp:{DATASET_HPARAMS_KEYS_PRERFIX}")})

      # Sample mutations.
      for k in list(ds_hparams):
        if (k in self.config.models_mutation_ranges
            and pop.policy.do_mutate(ds_hparams, f"hp:{k}")):
          ds_hparams[k] = incremental_mutation(
              ds_hparams[k],
              self.config.models_mutation_ranges[k])

    for k in ds_hparams:
      assert (k.startswith(DATASET_HPARAMS_KEYS_PRERFIX) or
              k.startswith(f"_mu_|hp:{DATASET_HPARAMS_KEYS_PRERFIX}"))
    return ds_hparams

In [None]:
class Population():
  def __init__(self, config):
    Path.reset_globals(config)
    Component.reset_globals()
    self.paths = defaultdict(list)
    self.config = config
    self.paths_df = pd.DataFrame()
    self.comps_df = pd.DataFrame()
    self.policy = globals()[config.policy_class](
        **config.policy_kwargs,
        config=config)

  def get_best_path(self):
    if len(self.paths[self.config.agent_id]) == 0:
      return None
    # Most recent path achieving max score.
    return max(sorted(
        self.paths[self.config.agent_id], key=lambda p: p.id, reverse=True),
        key=lambda p: p.score())

  def prune_population(self):
    if self.config.get("max_task_population_size", None) and (
        len(self.paths[self.config.agent_id]) > self.config.max_task_population_size):
      best_path = self.get_best_path()
      self.paths[self.config.agent_id] = sorted(
          self.paths[self.config.agent_id], key=lambda p: p.score(), reverse=True
          )[:self.config.max_task_population_size]
      assert best_path in self.paths[self.config.agent_id] or best_path.score() == self.get_best_path().score()

  def sample_path(self, ds_hparams):
    return self.policy.sample_path(pop=self, ds_hparams=ds_hparams)

  def sample_ds_hparams(self):
    return self.policy.sample_ds_hparams(pop=self)

  def add_train_locks(self):
    # Check.
    for ps in self.paths.values():
      for p in ps:
        for c in p.components:
          assert self.config.agent_id not in c.train_locks
    # Add locks.
    paths = self.paths[self.config.agent_id]
    for p in paths:
      for c in p.components:
        c.train_locks.add(self.config.agent_id)

  def rm_train_locks(self):
    # Remove locks.
    paths = self.paths[self.config.agent_id]
    for p in paths:
      for c in p.components:
        if self.config.agent_id in c.train_locks:
          c.train_locks.remove(self.config.agent_id)
    # Check.
    for ps in self.paths.values():
      for p in ps:
        for c in p.components:
          assert self.config.agent_id not in c.train_locks

  def start_cycle(self):
    self.rm_train_locks()

  def end_cycle(self):
    # Keep only best one.
    best_path = self.get_best_path()
    assert best_path is not None
    best_path.metrics["num_cycles"] = best_path.metrics.get("num_cycles", 0) + 1
    self.paths[self.config.agent_id] = [best_path]

    self.add_train_locks()
    self.garbage_collect_paths()

  def garbage_collect_paths(self):
    # Store history before dropping references to unused paths to trigger
    # garbage collection of components and parameters.
    self.paths_df = self.paths_df.append(
        paths_to_df(Path.paths), ignore_index=True
        ).query(f'agent_id=="{self.config.agent_id}" and id>{Path.last_saved}')
    self.comps_df = self.comps_df.append(
        components_to_df(Path.paths), ignore_index=True
        ).query(f'agent_id=="{self.config.agent_id}" and id>{Component.last_saved}'
        ).drop_duplicates()

    # Drop unused paths generated in this task iteration for garbage collection.
    Path.paths = []
    # Simplify ancestor tree to contain only live paths.
    live_paths_ids = [p.id for paths in self.paths.values() for p in paths]
    # Notice that the simplification is done also for paths of other tasks,
    # since they may be pointing to a path of this task that was just pruned.
    for path in [path for paths in self.paths.values() for path in paths]:
      ancestor = path.parent
      if ancestor is None:
        continue
      while True:
        if ancestor.id in live_paths_ids:
          path.parent = ancestor
          break
        ancestor = ancestor.parent

In [None]:
pd.set_option("display.expand_frame_repr", False)
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 100)

def pop_to_df(pop):
  return paths_to_df([p for paths in pop.paths.values() for p in paths])

def paths_to_df(paths):
  # Collect all metrics names.
  metrics_keys = set()
  hparams_keys = set()
  for path in paths:
    path.score()  # Update scores.
    metrics_keys.update(path.metrics)
    hparams_keys.update(path.hparams)

  data = defaultdict(list)
  for path in paths:
    data["agent_id"].append(path.agent_id)
    data["task_name"].append(path.task_name)
    data["id"].append(path.id)
    data["parent_id"].append(path.parent.id if path.parent else -1)
    data["parent_agent_id"].append(path.parent.agent_id if path.parent else None)
    data["components"].append(",".join(
        [f"{c.agent_id}:{c.id}" for c in path.components]))
    for k in hparams_keys:
      data[f"hparams.{k}"].append(path.hparams[k] if k in path.hparams else None)
    for k in metrics_keys:
      data[f"metrics.{k}"].append(path.metrics[k] if k in path.metrics else None)
  return pd.DataFrame(data)

def components_to_df(paths):
  # Collect all components.
  comps = set()
  for p in paths:
    comps.update(p.components)

  data = defaultdict(list)
  for c in comps:
    data["id"].append(c.id)
    data["name"].append(c.name)
    data["agent_id"].append(c.agent_id)
    data["num_params"].append(c.get_num_params())
  return pd.DataFrame(data)

def print_df_segments(df, segment_length:int = 10):
  tot_length = df.shape[0]
  # Pad column title with spaces to keep alignment across segments.
  def prepend_spaces(original_str, pad_to_len):
    return " " * (pad_to_len-len(original_str)) + original_str
  pad_to_len = max([len(tn) for tn in set(df["agent_id"].to_list())])+1
  df = df.rename(columns={
    "agent_id": prepend_spaces("agent_id", pad_to_len),
    "task_name": prepend_spaces("task_name", pad_to_len),
    "parent_agent_id": prepend_spaces("parent_agent_id", pad_to_len),
    })
  for x in range(0, tot_length, segment_length):
    print(df[x:min(x+segment_length, tot_length)])

def df_leaderboard(df):
  # Place columns on the left for readability.
  all_keys = sorted(df.columns.tolist())
  first_keys = ["agent_id", "task_name",
                "metrics.test_quality", "metrics.score", "metrics.quality",
                "metrics.accounted_params", "metrics.flops",
                "id", "parent_id", "parent_agent_id"]
  first_keys = [k for k in first_keys if k in all_keys]
  sorted_keys = first_keys + [k for k in all_keys if k not in first_keys]
  # Filter mu function parameters.
  sorted_keys = [k for k in sorted_keys if "_mu_|" not in k]
  df = df[sorted_keys]
  if "metrics.score" in df:
    df = df.sort_values(["agent_id", "metrics.score"], ascending=[True, False],
                        ignore_index=True)
  else:
    df = df.sort_values("agent_id", ignore_index=True)
  print_df_segments(df)
  for k in ["metrics.score", "metrics.quality", "metrics.test_quality",
            "metrics.accounted_params", "metrics.flops"]:
    if k in df:
      print(f"Avg {k}: {df[k].mean():.6f}")

In [None]:
def prp(path):
  rtn = []
  if VERBOSE:
    rtn.append(str(path))
    for c in path.components:
      rtn.append(str(c))
  else:
    rtn.append(str(path.id))
  return "\n".join(rtn)

# Checkpointing

In [None]:
def df_write_to_csv(df, dir_path, df_name):
  filename_df = os.path.join(dir_path, f"{df_name}.csv")
  with gfile.GFile(filename_df, "w") as outfile:
    df.to_csv(outfile, index=False)

def df_read_from_csv(dir_path, df_name):
  filename_df = os.path.join(dir_path, f"{df_name}.csv")
  with gfile.GFile(filename_df, "r") as infile:
    df = pd.read_csv(infile)
  # Pandas read_csv() reads empty stings as NaNs. Set NaNs to empty strings in
  # columns with type strings/object.
  for c in df.columns:
    if df[c].dtype == np.object_:
        df[c].fillna("", inplace=True)
  return df

def get_comps_params_to_save(pop: Population):
  """Returns a dictionary containing the parameters of the used components."""
  comps_params = {}
  # All components generated by this agent.
  all_comps = set(
      [c for p in pop.paths[pop.config.agent_id] for c in p.components if c.agent_id == pop.config.agent_id])
  # Check that there are not duplicate ids.
  assert len(all_comps) == len(set([c.id for c in all_comps])), (
      [f"{c.name}:{c.agent_id}:{c.id}" for c in all_comps])
  for c in all_comps:
    if c.id <= Component.last_saved:
      continue
    assert c.agent_id == pop.config.agent_id
    c_id_string = f"{c.name}:{c.agent_id}:{c.id}"
    comps_params[c_id_string] = c.params
    if c.opt_state is not None:
      comps_params[f"opt_state:{c_id_string}"] = c.opt_state
  return comps_params

In [None]:
def latest_checkpoint(ckpt_dir: str, prefix: str = "checkpoint_"):
  """Returns the latest checkpoint under the dir pattern or None if missing."""
  ckpt_dir = os.fspath(ckpt_dir)
  glob_path = os.path.join(ckpt_dir, f"{prefix}*")
  checkpoint_files = flax_checkpoints.natural_sort(gfile.glob(glob_path))
  checkpoint_files = [f for f in checkpoint_files if not f.endswith("_tmp")]
  return checkpoint_files[-1] if checkpoint_files else None

In [None]:
LAST_CHECKPOINT_TIME = time.time()

def save_checkpoint(
    ckpt_dir: str,
    comps_params,
    cycle_id: int,
    generation_id: int):
  print("SAVING", cycle_id, generation_id, comps_params.keys())
  # Write checkpoint.
  flax_checkpoints.save_checkpoint(
      ckpt_dir=ckpt_dir,
      target=comps_params,
      step=generation_id,
      prefix=f"checkpoint_{cycle_id}_",
      overwrite=True)
  # Update time of last checkpoint save.
  global LAST_CHECKPOINT_TIME
  LAST_CHECKPOINT_TIME = time.time()
  # Delete intermediate checkpoint directories.
  if generation_id == 0:
    intermediate_ckpt_dirs = gfile.glob(
        os.path.join(os.path.dirname(ckpt_dir), "state_*_[^0]*"))
    for d in intermediate_ckpt_dirs:
      print("Deleting intermediate checkpoint:", d)
      gfile.rmtree(d)

def save_state(
    pop: Population,
    cycle_id: int,
    generation_id: int,
    config: FrozenConfigDict):
  """Save checkpoint and data needed to resume exp."""
  intermediate = (generation_id != 0)
  write_threads = []
  # Skip intermediate state save if last state was written recently.
  if intermediate and (
      (time.time() - LAST_CHECKPOINT_TIME) < SKIP_INTERMEDIATE_STATE_SECS):
    print("Skip checkpointing, seconds since last save:",
          f"{time.time() - LAST_CHECKPOINT_TIME:.0f}")
    return write_threads

  # Save data needed to resume exp.
  write_start = time.time()
  df_leaderboard(pop_to_df(pop))
  pop.garbage_collect_paths()
  state_dir = os.path.join(
      config.agent_dir, f"state_{cycle_id}_{generation_id}")
  gfile.makedirs(state_dir)

  if latest_checkpoint(state_dir):
    assert False, f"checkpoint already present in forlder: {state_dir}"
  print("WRITING CHECKPOINT:", cycle_id, generation_id)

  # Write state in background threads.
  write_threads = []

  write_threads.append(Thread(
      target=df_write_to_csv,
      args=(paths_to_df([p for p in pop.paths[config.agent_id]]),
            state_dir,
            "published")))
  write_threads[-1].start()

  write_threads.append(Thread(
      target=df_write_to_csv,
      args=(paths_to_df([p for paths in pop.paths.values() for p in paths]),
            state_dir,
            "population")))
  write_threads[-1].start()

  write_threads.append(Thread(
      target=df_write_to_csv,
      args=(pop.paths_df, state_dir, "paths")))
  write_threads[-1].start()

  write_threads.append(Thread(
      target=df_write_to_csv,
      args=(pop.comps_df, state_dir, "components")))
  write_threads[-1].start()

  write_threads.append(Thread(
      target=json.dump,
      args=(config.as_configdict().to_dict(),
            gfile.GFile(os.path.join(state_dir, "config.json"), "w")),
      kwargs=dict(indent=2)))
  write_threads[-1].start()

  write_threads.append(Thread(
      target=save_checkpoint,
      args=(state_dir, get_comps_params_to_save(pop), cycle_id, generation_id)))
  write_threads[-1].start()

  # Update last saved.
  if not intermediate:
    Path.last_saved = pop.paths_df.id.max()
    Component.last_saved = pop.comps_df.id.max()

  print(f"WRITE START TIME: {time.time() - write_start:.2f} s")
  return write_threads

In [None]:
def load_paths(
    pop: Population,
    state_dir: str,
    all_agents_dirs: str,
    ):
  if state_dir:
    state_dir = state_dir.rstrip("/")

  load_start = time.time()

  # Load system state info.
  population_df = pd.DataFrame()
  skip_agent_dir = None
  if state_dir:
    # Load active agent state, possibly intermediate.
    population_df = population_df.append(
        df_read_from_csv(state_dir, "published"))
    skip_agent_dir = os.path.dirname(state_dir)
  for agent_dir in gfile.glob(all_agents_dirs):
    if agent_dir == skip_agent_dir:
      continue
    agent_checkpoint = latest_checkpoint(os.path.join(agent_dir, "state_*_0/"))
    if agent_checkpoint is None:
      continue
    population_df = population_df.append(
        df_read_from_csv(os.path.dirname(agent_checkpoint), "published"))

  # Load parameters from sharded system checkpoint.
  loaded_params = {}  # Dictionary to accumlate loaded parameters.
  lock = Lock()
  duplicate_keys = set()
  def append_loaded_params(add_chkp_dir: str):
    # Skip folders without a completed checkpoint.
    if latest_checkpoint(add_chkp_dir) is None:
      return
    lp_add = flax_checkpoints.restore_checkpoint(
        ckpt_dir=add_chkp_dir,
        target=None)
    if lp_add:
      lock.acquire()
      print("LOADED COMPONENTS", add_chkp_dir, lp_add.keys())
      duplicate_keys.update(loaded_params.keys() & lp_add.keys())
      loaded_params.update(lp_add)
      lock.release()
  all_state_dirs = []
  if state_dir:
    # Include active agent state, possibly intermediate.
    all_state_dirs.append(state_dir)
    all_state_dirs.extend(
        gfile.glob(os.path.join(os.path.dirname(state_dir), "state_*_0"))
    )
  for agent_dir in gfile.glob(all_agents_dirs):
    all_state_dirs.extend(
        gfile.glob(os.path.join(agent_dir, "state_*_0"))
    )
  threads = []
  for s_dir in set(all_state_dirs):
    threads.append(Thread(target=append_loaded_params, args=(s_dir,)))
    threads[-1].start()
  for t in threads:
    t.join()
  assert not duplicate_keys, duplicate_keys

  print(f"LOAD TIME: {time.time() - load_start:.2f} s")

  frozen_params = flax.core.freeze(loaded_params)
  sid_2_comp = {}
  for k in frozen_params.keys():
    if k.startswith("opt_state:"):
      continue
    assert len(k.split(":")) == 3, k
    name, agent_id, id = k.split(":")
    if "opt_state:" + k in frozen_params.keys():
      opt_state = frozen_params["opt_state:" + k]
    else:
      opt_state = None
    c = Component(
        name=name, agent_id=agent_id,
        params=frozen_params[k], train_locks=[], opt_state=opt_state)
    c.id = int(id)
    source_id = f"{agent_id}:{id}"
    assert source_id not in sid_2_comp, source_id
    sid_2_comp[source_id] = c
  # For parent assignemt.
  sid_2_path = {}
  path_2_parent_sid = {}
  for index, row in population_df.iterrows():
    comps_sids = row["components"].split(",")
    comps = []
    for sid in comps_sids:
      comps.append(sid_2_comp[sid])
    task_name = row["task_name"]
    agent_id = row["agent_id"]

    # Retrieve hparams and metrics.
    hparams = {}
    metrics = {}
    for k in row.keys():
      if type(row[k]) is float and math.isnan(row[k]):
        continue
      if k.startswith("hparams."):
        hparams[k[len("hparams."):]] = row[k]
      if k.startswith("metrics."):
        metrics[k[len("metrics."):]] = row[k]
    # Fix adapter_layers string.
    if "adapter_layers" in hparams and type(hparams["adapter_layers"]) is float:
      hparams["adapter_layers"] = str(int(hparams["adapter_layers"]))
    # Create path.
    path = Path(
        hparams=hparams,
        components=comps,
        parent=None,
        agent_id=agent_id,
        task_name=task_name,
        )
    path.metrics = metrics
    path.id = int(row["id"])
    # Add train locks.
    for c in path.components:
      c.train_locks.add(agent_id)
    pop.paths[agent_id].append(path)
    path_sid = f"{path.agent_id}/{path.id}"
    assert path_sid not in sid_2_path
    sid_2_path[path_sid] = path
    if row["parent_id"] >= 0:
      parent_sid = f'{row["parent_agent_id"]}/{row["parent_id"]}'
      path_2_parent_sid[path] = parent_sid

  # Set parents.
  for path, parent_sid in path_2_parent_sid.items():
    if parent_sid not in sid_2_path:
      # This can happen if parent is retired by a parallel agent.
      # In this case fall back to root model.
      for k in sid_2_path.keys():
        if "root_model"in k:
          parent_sid = k
      print(f"{path.agent_id}:{path.id} orphaned, fallback: {parent_sid}")
    path.parent = sid_2_path[parent_sid]

# Training

In [None]:
@partial(jax.jit, static_argnames="model")
def eval_step(params, images, labels, model):
  logits = model.apply({"params": params}, images,
                       train=False)  # Disable dropout.
  # Avg accuracy on the batch.
  return (logits.argmax(axis=-1) == labels).mean()

In [None]:
@partial(jax.jit, static_argnames=["model", "optimizer"], donate_argnums=[0, 2])
def train_step(params, fixed_params, opt_state, images, labels, model, optimizer):
  def loss_fn(params, fixed_params, images, labels):
    logits = model.apply(
        {"params": format_params(params, fixed_params)}, images,
        train=False)  # Disable dropout.
    labels = jax.nn.one_hot(labels, logits.shape[-1])
    return -jnp.mean(jnp.sum(labels * nn.log_softmax(logits), axis=-1))
  grads = jax.grad(loss_fn)(params, fixed_params, images, labels)
  updates, opt_state = optimizer.update(grads, opt_state, params=params)
  params = optax.apply_updates(params, updates)
  return params, opt_state

In [None]:
LOOP_START = time.time()

def train_loop(paths, ds_train, ds_validation, devices, config):
  global LOOP_START
  timing = {"start_time": time.time(),
            "start_time_loop": LOOP_START}
  task = paths[0].task
  # The following values should be shared by all paths in this generation batch.
  for path in paths:
    assert task.name == path.task_name
    assert paths[0].hparams["ds_image_size"] == path.hparams["ds_image_size"]

  gc.collect()

  # Compile.
  compile_train_batches_arr = jax.device_put_replicated(
      get_sample_batch(
        paths[0].hparams["ds_image_size"],
        task.train_batch_size),
      devices)
  compile_eval_batches_arr = jax.device_put_replicated(
      get_sample_batch(
          paths[0].hparams["ds_image_size"],
          task.validation_batch_size),
      devices)

  for p_id, path in enumerate(paths):
    if VERBOSE:
      print("Parent")
      print(prp(path.parent))
      print(prp(path))
    path.device_id = p_id % len(devices)
    path.device = devices[path.device_id]
    path.optimizer = path.get_optimizer()
    path.optimizer_init_fn = jax.jit(path.optimizer.init, device=path.device)
    path.best_params_local = None
    path.best_opt_state_local = None
    path.best_quality = None
    path.best_score = path.parent.score() if path.task_name == path.parent.task_name else -np.inf
    path.evals = []

    # Launch parallel compilation of eval and train step functions.
    params_local = path.get_trainable_params()
    path.compile_params_device = jax.device_put(params_local, path.device)
    path.compile_fixed_params_device = jax.device_put(
        path.get_fixed_params(),
        path.device)
    path.compile_train = Thread(
        target=train_step,
        args=(path.compile_params_device,
              path.compile_fixed_params_device,
              path.optimizer_init_fn(params_local),
              compile_train_batches_arr["image"][path.device_id],
              compile_train_batches_arr["label"][path.device_id],
              path.model,
              path.optimizer))
    path.compile_eval = Thread(
        target=eval_step,
        args=(format_params(
                  path.compile_params_device,
                  path.compile_fixed_params_device),
              compile_eval_batches_arr["image"][path.device_id],
              compile_eval_batches_arr["label"][path.device_id],
              path.model))
    path.compile_eval.start()

  for path in paths:
    path.compile_eval.join()
    del path.compile_eval
    timing["end_compile_eval"] = time.time()
    path.compile_train.start()
  del compile_eval_batches_arr

  for path in paths:
    path.compile_train.join()
    del path.compile_train
    del path.compile_params_device
    del path.compile_fixed_params_device
    timing["end_compile"] = time.time()
  del compile_train_batches_arr

  gc.collect()

  # Parameter tranfer.
  for path in paths:
    path.params_device = jax.device_put(
        path.get_trainable_params(),
        path.device)
    path.fixed_params_device = jax.device_put(
        path.get_fixed_params(),
        path.device)
    path.opt_state_device = path.optimizer_init_fn(path.params_device)
    # Set opt state.
    for c in path.components:
      if c.is_trainable():
        assert c.name in path.opt_state_device[1][0].trace.keys()
        if c.opt_state is not None:
          path.opt_state_device = (
              path.opt_state_device[0],
              (optax.TraceState(
                  trace=path.opt_state_device[1][0].trace.copy(
                      {c.name: jax.device_put(c.opt_state,
                                              path.device)})),
               path.opt_state_device[1][1]
               )
          )

  iter_ds_validation = iter(ds_validation)
  # TRAIN
  for t_step, train_batch in zip(
      range(config.num_validations_per_path_training
            * task.num_train_batches_between_validations),
      ds_train,
  ):
    train_batch_arr = jax.device_put_replicated(train_batch, devices)
    for p_id, path in enumerate(paths):
      if t_step == 0:
        timing["end_prep"] = time.time()
        t_step_0_time = time.time()
      path.params_device, path.opt_state_device = train_step(
          path.params_device,
          path.fixed_params_device,
          path.opt_state_device,
          train_batch_arr["image"][path.device_id],
          train_batch_arr["label"][path.device_id],
          path.model,
          path.optimizer)
      if t_step == 0 and time.time() - t_step_0_time > 1:
        print(f"WARNING: First train step took: {time.time()-t_step_0_time:.2f} s")
    del train_batch, train_batch_arr

    # EVAL
    if (t_step+1) % task.num_train_batches_between_validations == 0:
      first_eval = ((t_step+1) == task.num_train_batches_between_validations)
      if first_eval:
        timing["start_eval"] = time.time()
      for path in paths:
        path.accs = []
      for e_step, eval_batch in zip(
          range(task.num_validation_batches),
          iter_ds_validation,
          ):
        eval_batch_arr = jax.device_put_replicated(eval_batch, devices)
        for p_id, path in enumerate(paths):
          if first_eval and e_step == 0:
            e_step_0_time = time.time()
          path.accs.append(
              eval_step(
                  format_params(path.params_device, path.fixed_params_device),
                  eval_batch_arr["image"][path.device_id],
                  eval_batch_arr["label"][path.device_id],
                  path.model))
          if first_eval and e_step == 0 and time.time() - e_step_0_time > 1:
            print(f"WARNING: First eval step took: {time.time()-e_step_0_time:.2f} s")
      del eval_batch, eval_batch_arr

      # Get params of best models.
      qs = []
      eval_idx = (t_step+1) // task.num_train_batches_between_validations
      for path in paths:
        quality = np.mean(path.accs)
        del path.accs
        qs.append(f"{quality:.4f}")
        path.evals.append(quality)
        # Set quality in metrics for current score computation.
        path.metrics["quality"] = quality
        path_score = path.score()
        if path_score >= path.best_score:
          path.best_params_local = jax.device_get(path.params_device)
          path.best_opt_state_local = jax.device_get(path.opt_state_device[1][0].trace)
          path.best_score = path_score
          path.best_quality = quality
          qs[-1] += "*"
      train_time = time.time() - timing["end_compile"]
      avg_path_time = (train_time / eval_idx) / len(paths)
      print(("\t".join(qs) + f"\t< Eval {eval_idx}").expandtabs(8),
            f"tot:{train_time:.1f}s", f"avg/path:{avg_path_time:.1f}s")

      if first_eval:
        timing["end_eval"] = time.time()

  for path in paths:
    del path.params_device
    del path.fixed_params_device
    del path.opt_state_device
    del path.optimizer
    del path.optimizer_init_fn
  gc.collect()

  timing["end_train"] = time.time()

  loop_time = timing["start_time"] - LOOP_START
  compile_time = timing["end_compile"] - timing["start_time"]
  compile_eval_time = timing["end_compile_eval"] - timing["start_time"]
  compile_train_time = timing["end_compile"] - timing["end_compile_eval"]
  prep_time = timing["end_prep"] - timing["end_compile"]
  train_time = timing["end_train"] - timing["end_prep"]
  eval_time = timing["end_eval"] - timing["start_eval"]
  LOOP_START = time.time()

  for path in paths:
    path.metrics["loop_time"] = loop_time
    path.metrics["compile_time"] = compile_time
    path.metrics["prep_time"] = prep_time
    path.metrics["train_time"] = train_time
    path.metrics["eval_time"] = eval_time
    path.metrics["start_time"] = timing["start_time"]
    path.metrics["start_time_loop"] = timing["start_time_loop"]
    path.metrics["end_time"] = time.time()
    path.metrics["num_params"] = get_num_params(path.get_all_params())
    path.metrics["num_trainable_params"] = get_num_params(path.get_trainable_params())
    path.metrics["quality"] = max(path.evals)
    path.metrics["evals"] = json.dumps([float(v) for v in path.evals])

    if path.best_params_local:
      path.metrics["improved"] = True
      path.update_trainable(path.best_params_local,
                            path.best_opt_state_local)
      assert path.best_quality == path.metrics["quality"]
      assert path.best_score == path.score()
    else:
      path.metrics["improved"] = False
      # Sampled path will be dropped if not improved, so skip paramter update.
      assert path.best_params_local == None
      assert path.best_opt_state_local == None
      assert path.best_quality == None

    del path.best_params_local
    del path.best_opt_state_local
    del path.best_score
    del path.best_quality
    del path.evals

    if VERBOSE:
      print("UPDATED:")
      print(prp(path))

  pqs = []
  qs = []
  psc = []
  sc = []
  for path in paths:
    if path.task_name == path.parent.task_name:
      pqs.append(f"{path.parent.metrics['quality']:.4f}")
      psc.append(f"{path.parent.score():.4f}")
    else:
      pqs.append("NEW")
      psc.append("NEW")
    qs.append(f"{path.metrics['quality']:.4f}")
    sc.append(f"{path.score():.4f}")
    if path.metrics["improved"]:
      sc[-1] += "+"

  print(("\t".join([f"{path.parent.id}" for path in paths]) +
        "\t< Parent id").expandtabs(8))
  print(("\t".join([f"{path.id}" for path in paths]) +
        "\t< Path id").expandtabs(8))
  print(("\t".join(pqs) + "\t< Parent best quality").expandtabs(8))
  print(("\t".join(qs) + "\t< Path best quality").expandtabs(8))
  print(("\t".join(psc) + "\t< Parent score").expandtabs(8))
  print(("\t".join(sc) + "\t< Path score").expandtabs(8))

  print("time\tINIT\tCOMPevl\tCOMPtrn\tPREP\tTRN+EVL\t1stEVAL".expandtabs(8))
  print(f"(s)\t{loop_time:.1f}\t{compile_eval_time:.1f}\t{compile_train_time:.1f}\t{prep_time:.1f}\t{train_time:.1f}\t{eval_time:.1f}".expandtabs(8))

In [None]:
# Run a full paths sampling iteration for a task.
def agent_cycle(
    task, devices, pop:Population, generation_id:int, cycle_id:int,
    config:FrozenConfigDict):
  num_devices = len(devices)
  # Track best path.
  best_path = pop.get_best_path()
  num_gen_batches = math.ceil(config.num_samples_per_cycle/num_devices)
  for _ in range(num_gen_batches):
    if generation_id >= num_gen_batches:
      break
    print("----")
    print(f"GENERATION: [{generation_id+1}/{num_gen_batches}]")
    ds_hparams = pop.sample_ds_hparams()
    ds_hparams["num_classes"] = task.num_classes
    ds_train = task.get_ds("train", ds_hparams)
    ds_validation = task.get_ds("validation", ds_hparams)
    paths = []
    for i in range(num_devices):
      paths.append(pop.sample_path(ds_hparams))
    train_loop(paths, ds_train, ds_validation, devices, config)
    for path in paths:
      if path.metrics["improved"]:
        assert path not in pop.paths[config.agent_id]
        pop.paths[config.agent_id].append(path)
    pop.prune_population()
    # Track best path.
    curr_best_path = pop.get_best_path()
    if curr_best_path != best_path:
      if best_path:
        assert curr_best_path.score() >= best_path.score()
      best_path = curr_best_path
      best_path.metrics["new_best"] = True
      print(f"Best id:{best_path.id}",
            f"score:{best_path.score():.4f}",
            f"quality:{best_path.metrics['quality']:.4f}",
            f"gen:{generation_id}",
            f"\n{best_path.hparams}")
    generation_id += 1
    if generation_id < num_gen_batches:
      save_state(pop, cycle_id, generation_id, config)
  assert best_path in pop.paths[config.agent_id]
  run_test_eval(best_path)

In [None]:
def has_test_quality(path):
  return ("test_quality" in path.metrics
          and not math.isnan(path.metrics["test_quality"]))

# Run final eval on test set.
def run_test_eval(path, test_immutability=False):
  # Skip if test_quality already computed and no immutability test required.
  if not test_immutability and has_test_quality(path):
    return
  eval_st = time.time()
  ds_test = path.task.get_ds("test", path.hparams)
  # Running on same device should allow to reuse the fn compiled for validation
  # if batch size matches.
  params = path.get_all_params()
  if not hasattr(path, "device"):
    path.device = random.choice(jax.local_devices())
  params_device = jax.device_put(params_comps_to_model(params), path.device)
  acc_sum = []
  tot_num_samples = 0
  # Warning: if repeat() is called on this dataset, then this loop never ends.
  for batch in ds_test:
    acc_avg = eval_step(
        params_device,
        batch["image"],
        batch["label"],
        path.model)
    batch_size = batch["image"].shape[0]
    # Need to recompute sum because last batch can have different size to allow
    # for exact eval on the test set.
    acc_sum.append(acc_avg * batch_size)
    tot_num_samples += batch_size
  del params_device
  acc_avg = np.sum(acc_sum) / tot_num_samples
  # Assert test quality equivalence to test immutability.
  if has_test_quality(path):
    print(f"Testing immutability of path {path.id} : {acc_avg} ~= {path.metrics['test_quality']}")
    assert test_immutability
    if not np.isclose(path.metrics["test_quality"], acc_avg):
      print("WARNING IMMUTABILITY TEST FAILED, delta:",
            path.metrics["test_quality"]-acc_avg)
  path.metrics["test_quality"] = acc_avg
  print(f"TEST EVAL TIME: {time.time() - eval_st:.2f} s")

# Main

In [None]:
# Main loop over tasks.

print("Devices type:", jax.local_devices()[0].device_kind)
print("Devices count:", len(jax.local_devices()))

agent = globals()[AGENT](
    task_name=TASK_NAME,
    experiment_dir=os.path.join(
        EXPERIMENTS_ROOT_DIR,
        EXPERIMENT_NAME),
    auto_tune=AUTO_TUNE,
    scale_factor=SCALE_FACTOR,
)
agent.run()