Imports

In [None]:
from nb201 import NB201Benchmark
import numpy as np
from warmstart.utils_templates import FullTemplate
import ConfigSpace as CS
from ConfigSpace import Configuration
import ollama
import torchvision
from exp_baselines.bayesmark.data import ProblemType
import ast
from llambo.llambo import LLAMBO
from utils import convert_LLAMBO_df_to_synetune_dict
from utils import convert_synetune_dict_to_LLAMBO_compatible_format

from syne_tune_local.experiments.benchmark_definitions.nas201 import nas201_benchmark
from syne_tune_local.blackbox_repository import BlackboxRepositoryBackend
from syne_tune_local.backend.simulator_backend.simulator_callback import SimulatorCallback
from syne_tune_local import Tuner, StoppingCriterion

from typing import Optional, Dict, Any, List, Union
import logging
from syne_tune_local.optimizer.schedulers import FIFOScheduler
from syne_tune_local.optimizer.schedulers.searchers import StochasticAndFilterDuplicatesSearcher

Load NB201 Benchmark

In [None]:
b = NB201Benchmark(path="./nb201.pkl", dataset='cifar10')
cs = b.get_configuration_space()
config = cs.sample_configuration()  # samples a configuration uniformly at random

print(cs)
print("Numpy representation: ", config.get_array())
print("Dict representation: ", config.get_dictionary())

#configuration from a dict
new_config = Configuration(cs, values=config.get_dictionary())
print(new_config)

y, cost = b.objective_function(config)
print("Test error: %f %%" % y)
print("Runtime %f s" % cost)

Arguments for LLAMBO

In [None]:
task_context = {
    'model': 'CNN',
    'task': 'classification',
    'tot_feats': 32 * 32 * 3,
    'cat_feats': 0,
    'num_feat': 32 * 32 * 3,
    'n_classes': 10,
    'metric': 'loss',
    'lower_is_better': True,
    'num_samples': 50000,
    'hyperparameter_constraints': {
        'op_0_to_1': ['categorical', None, ["none", "skip_connect", "avg_pool_3x3", "nor_conv_1x1", "nor_conv_3x3"]],
        # [type, transform, [min_value, max_value]]
        'op_0_to_2': ['categorical', None, ["none", "skip_connect", "avg_pool_3x3", "nor_conv_1x1", "nor_conv_3x3"]],
        'op_0_to_3': ['categorical', None, ["none", "skip_connect", "avg_pool_3x3", "nor_conv_1x1", "nor_conv_3x3"]],
        'op_1_to_2': ['categorical', None, ["none", "skip_connect", "avg_pool_3x3", "nor_conv_1x1", "nor_conv_3x3"]],
        'op_1_to_3': ['categorical', None, ["none", "skip_connect", "avg_pool_3x3", "nor_conv_1x1", "nor_conv_3x3"]],
        'op_2_to_3': ['categorical', None, ["none", "skip_connect", "avg_pool_3x3", "nor_conv_1x1", "nor_conv_3x3"]]
    }
}


def init_f():
    return


def eval_point(config):
    new_config = Configuration(b.get_configuration_space(), values=config)
    res = b.objective_function(new_config)
    res_dict = {
        "score": res[0],
        "train_time": res[1]
    }
    return config, res_dict

Ollama

In [None]:
# chat_engine = "llama3"
# model = ollama.pull(chat_engine)
# response = ollama.chat(model="llama3", messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
# print(response)
# ollama.list()

In [None]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True)


def fetch_statistics(dict, dataset):
    images = dataset.data
    labels = dataset.targets

    images_np = np.array(images)
    labels_np = np.array(labels)

    pixel_mean = np.mean(images_np / 255.)
    pixel_std = np.std(images_np / 255.)

    class_counts = np.bincount(labels_np)
    class_distribution = class_counts / len(labels_np)

    dict['pixel_mean'] = pixel_mean
    dict['pixel_std'] = pixel_std
    dict['class_distribution'] = class_distribution.tolist()

    return dict


task_context = fetch_statistics(task_context, trainset)

Warmstart

In [None]:
config = "No_Context"
metric = "acc"
NUM_SEEDS = 10
problem_type = ProblemType.clf


def extract_configs_from_response(response):
    content = response['message']['content']
    start = content.find("[")
    end = content.rfind("]") + 1
    list_str = content[start:end]
    configurations = ast.literal_eval(list_str)
    return configurations


def is_dict_valid_in_config_space(d, config_space):
    try:
        # Attempt to create a Configuration object with the given dictionary and config space
        config = CS.Configuration(config_space, values=d)
        return True
    except:
        # Return False if the dictionary is not valid
        return False
    # Function to check if all dictionaries in a list are valid in the given configuration space


def check_all_list(parsed_dicts, config_space):
    for idx, d in enumerate(parsed_dicts):
        if not is_dict_valid_in_config_space(d, config_space):
            return False
    return True


def obtain_all_list_valid(resp, config_space):
    if check_all_list(resp, config_space):
        return resp
    print("fail")


def generate_init_conf(n_samples):
    template_object = FullTemplate(context=config, provide_ranges=True)
    input_prompt = template_object.add_context(config_space=cs, num_recommendation=n_samples, task_dict=task_context)
    response = ollama.chat(model="llama3", messages=[{'role': 'user', 'content': input_prompt}])
    configs = extract_configs_from_response(response)
    return obtain_all_list_valid(configs, cs)

#print(generate_init_conf(3))

Llambo

In [None]:
llambo = LLAMBO(task_context, sm_mode='discriminative', n_candidates=10, n_templates=2, n_gens=10,
                alpha=0.1, n_initial_samples=5, n_trials=4,
                init_f=generate_init_conf,
                bbox_eval_f=eval_point,
                chat_engine="llama3")
llambo.seed = 0

# run optimization
#configs, fvals = llambo.optimize(test_metric="score")

Searcher

In [None]:

logger = logging.getLogger(__name__)

MAX_RETRIES = 100


class LlamboSearcher(StochasticAndFilterDuplicatesSearcher):

    def __init__(
            self,
            config_space: Dict[str, Any],
            metric: Union[List[str], str],
            points_to_evaluate: Optional[List[dict]] = None,
            **kwargs,
    ):
        super().__init__(
            config_space,
            metric=metric,
            points_to_evaluate=points_to_evaluate,
            **kwargs,
        )
        self.X = []
        self.y = []

    def configure_scheduler(self, scheduler):
        from syne_tune_local.optimizer.schedulers.scheduler_searcher import (
            TrialSchedulerWithSearcher,
        )

        assert isinstance(
            scheduler, TrialSchedulerWithSearcher
        ), "This searcher requires TrialSchedulerWithSearcher scheduler"
        super().configure_scheduler(scheduler)

    def _train_model(self, train_data: np.ndarray, train_targets: np.ndarray) -> bool:
        """
        :param train_data: Training input feature matrix X
        :param train_targets: Training targets y
        :return: Was training successful?
        """
        llambo._update_observations(train_data, train_targets)
        return True
    
    def get_state(self) -> Dict[str, Any]:
        return dict(
            super().get_state(),
        )

    def _restore_from_state(self, state: Dict[str, Any]):
        super()._restore_from_state(state)

    def get_config(self, **kwargs) -> Optional[Dict[str, Any]]:
        suggestion = self._next_initial_config()
        if suggestion is None:
            if self.y:
                if self._train_model(np.array(self.X), np.array(self.y)):
                    suggestion = convert_LLAMBO_df_to_synetune_dict(llambo.get_config())
            
        return suggestion
        
    def _update(self, trial_id: str, config: Dict[str, Any], result: Dict[str, Any]):
        self.X.append(convert_synetune_dict_to_LLAMBO_compatible_format(config))
        self.y.append(result)
        
    def clone_from_state(self, state: Dict[str, Any]):
        raise NotImplementedError


In [None]:

logger = logging.getLogger(__name__)

MAX_RETRIES = 100


class MultiFidelityLLamboSearcher(LlamboSearcher):

    def __init__(
            self,
            config_space: Dict[str, Any],
            metric: Union[List[str], str],
            points_to_evaluate: Optional[List[dict]] = None,
            resource_attr: Optional[str] = None,
            **kwargs,
    ):
        super().__init__(
            config_space,
            metric=metric,
            points_to_evaluate=points_to_evaluate,
            **kwargs,
        )
        self.resource_attr = resource_attr
        self.resource_levels = []

    def configure_scheduler(self, scheduler):
        from syne_tune_local.optimizer.schedulers.multi_fidelity import (
            MultiFidelitySchedulerMixin,
        )

        super().configure_scheduler(scheduler)
        assert isinstance(
            scheduler, MultiFidelitySchedulerMixin
        ), "This searcher requires MultiFidelitySchedulerMixin scheduler"
        self.resource_attr = scheduler.resource_attr
       
    def _train_model(self, train_data: np.ndarray, train_targets: np.ndarray) -> bool:
        highest_resource_level = self._highest_resource_model_can_fit()
        if highest_resource_level is None:
            return False
        else:
            indices = np.where(self.resource_levels == highest_resource_level)
            sub_data = train_data[indices]
            sub_targets = train_targets[indices]
        return super()._train_model(sub_data, sub_targets)
     
    def _highest_resource_model_can_fit(self) -> Optional[int]:
        # find the highest resource level we have at least one data points of the positive class
        min_data_points = 4
        unique_resource_levels, counts = np.unique(
            self.resource_levels, return_counts=True
        )
        idx = np.where(counts >= min_data_points)[0]

        if len(idx) == 0:
            return None

        # collect data on the highest resource level
        return unique_resource_levels[idx[-1]]
    
    def get_state(self) -> Dict[str, Any]:
        return dict(
            super().get_state(),
        )

    def _restore_from_state(self, state: Dict[str, Any]):
        super()._restore_from_state(state)

    def _update(self, trial_id: str, config: Dict, result: Dict):
        super()._update(trial_id=trial_id, config=config, result=result)
        resource_level = int(result[self.resource_attr])
        self.resource_levels.append(resource_level)


Simple searcher combined with LLAMBO

In [None]:
from syne_tune_local.optimizer.schedulers.synchronous import SynchronousGeometricHyperbandScheduler

logging.getLogger().setLevel(logging.WARNING)

random_seed = 1
nb201_random_seed = 0
n_workers = 1
dataset_name = "cifar10"
benchmark = nas201_benchmark(dataset_name)

max_resource_attr = benchmark.max_resource_attr
trial_backend = BlackboxRepositoryBackend(
    blackbox_name=benchmark.blackbox_name,
    elapsed_time_attr=benchmark.elapsed_time_attr,
    max_resource_attr=max_resource_attr,
    dataset=dataset_name,
    seed=nb201_random_seed,
)

blackbox = trial_backend.blackbox
nas_configuration_space = blackbox.configuration_space_with_max_resource_attr(
    max_resource_attr
)

points_to_evaluate = llambo.initialize_configs(5)
points_to_evaluate = convert_LLAMBO_df_to_synetune_dict(points_to_evaluate)
scheduler = SynchronousGeometricHyperbandScheduler(
    config_space=nas_configuration_space,
    max_resource_attr=max_resource_attr,
    mode=benchmark.mode,
    metric=benchmark.metric,
    random_seed=random_seed,
    searcher=MultiFidelityLLamboSearcher,
    resource_attr=blackbox.fidelity_name(),
    points_to_evaluate=points_to_evaluate,
)

max_num_trials_started = 5
stop_criterion = StoppingCriterion(max_num_trials_started=max_num_trials_started)
print_update_interval = 700
results_update_interval = 300
tuner = Tuner(
    trial_backend=trial_backend,
    scheduler=scheduler,
    stop_criterion=stop_criterion,
    n_workers=n_workers,
    sleep_time=0,
    results_update_interval=results_update_interval,
    print_update_interval=print_update_interval,
    callbacks=[SimulatorCallback()],
)

tuner.run()