Imports

In [None]:
from nb201 import NB201Benchmark
import numpy as np
from warmstart.utils_templates import FullTemplate
import ConfigSpace as CS
from ConfigSpace import Configuration
from transformers import AutoTokenizer, AutoModelForCausalLM
import ollama
import torchvision
from exp_baselines.bayesmark.data import ProblemType
import ast
from llambo.llambo import LLAMBO


Load NB201 Benchmark

In [None]:
b = NB201Benchmark(path="./nb201.pkl", dataset='cifar10')
cs = b.get_configuration_space()
config = cs.sample_configuration()  # samples a configuration uniformly at random
print(cs)
print("Numpy representation: ", config.get_array())
print("Dict representation: ", config.get_dictionary())

#configuration from a dict
new_config = Configuration(cs, values=config.get_dictionary())
print(new_config)

y, cost = b.objective_function(config)
print("Test error: %f %%" % y)
print("Runtime %f s" % cost)

Arguments for LLAMBO

In [3]:
task_context = {
    'model': 'CNN',
    'task': 'classification',
    'tot_feats': 32 * 32 * 3,
    'cat_feats': 0,
    'num_feat': 32 * 32 * 3,
    'n_classes': 10,
    'metric': 'loss',
    'lower_is_better': True,
    'num_samples': 50000,
    'hyperparameter_constraints': {
        'op_0_to_1': ['int', 'linear', [1, 5]],  # [type, transform, [min_value, max_value]]
        'op_0_to_2': ['int', 'linear', [1, 5]],
        'op_0_to_3': ['int', 'linear', [1, 5]],
        'op_1_to_2': ['int', 'linear', [1, 5]],
        'op_1_to_3': ['int', 'linear', [1, 5]],
        'op_2_to_3': ['int', 'linear', [1, 5]]
    }
}


def init_f():
    return


def eval_point(config):
    new_config = Configuration(b.get_configuration_space(), values=config)
    res = b.objective_function(new_config)
    res_dict = {
        "score": res[0],
        "train_time": res[1]
    }
    return config, res_dict

Huggingface

In [None]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it").to("cuda")

In [None]:
input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_length=512)
print(tokenizer.decode(outputs[0]))

Ollama

In [None]:
chat_engine = "llama3"
model = ollama.pull(chat_engine)
response = ollama.chat(model="llama3", messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
print(response)
ollama.list()

In [4]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True)

Files already downloaded and verified


In [5]:
def fetch_statistics(dict, dataset):
    images = dataset.data
    labels = dataset.targets

    images_np = np.array(images)
    labels_np = np.array(labels)

    pixel_mean = np.mean(images_np / 255.)
    pixel_std = np.std(images_np / 255.)

    class_counts = np.bincount(labels_np)
    class_distribution = class_counts / len(labels_np)

    dict['pixel_mean'] = pixel_mean
    dict['pixel_std'] = pixel_std
    dict['class_distribution'] = class_distribution.tolist()

    return dict


task_context = fetch_statistics(task_context, trainset)

Warmstart

In [6]:
config = "No_Context"
metric = "acc"
NUM_SEEDS = 10
problem_type = ProblemType.clf


def extract_configs_from_response(response):
    content = response['message']['content']
    start = content.find("[")
    end = content.rfind("]") + 1
    list_str = content[start:end]
    configurations = ast.literal_eval(list_str)
    return configurations


def is_dict_valid_in_config_space(d, config_space):
    try:
        # Attempt to create a Configuration object with the given dictionary and config space
        config = CS.Configuration(config_space, values=d)
        return True
    except:
        # Return False if the dictionary is not valid
        return False
    # Function to check if all dictionaries in a list are valid in the given configuration space


def check_all_list(parsed_dicts, config_space):
    for idx, d in enumerate(parsed_dicts):
        if not is_dict_valid_in_config_space(d, config_space):
            return False
    return True


def obtain_all_list_valid(resp, config_space):
    if check_all_list(resp, config_space):
        return resp
    print("fail")


def generate_init_conf(n_samples):
    template_object = FullTemplate(context=config, provide_ranges=True)
    input_prompt = template_object.add_context(config_space=cs, num_recommendation=n_samples, task_dict=task_context)
    response = ollama.chat(model="llama3", messages=[{'role': 'user', 'content': input_prompt}])
    configs = extract_configs_from_response(response)
    return obtain_all_list_valid(configs, cs)

#print(generate_init_conf(3))

Llambo

In [7]:
llambo = LLAMBO(task_context, sm_mode='discriminative', n_candidates=10, n_templates=2, n_gens=10,
                alpha=0.1, n_initial_samples=5, n_trials=25,
                init_f=generate_init_conf,
                bbox_eval_f=eval_point,
                chat_engine="llama3")
llambo.seed = 0

# run optimization
configs, fvals = llambo.optimize(test_metric="score")

[Search settings]: 
	n_candidates: 10, n_templates: 2, n_gens: 10, 
	alpha: 0.1, n_initial_samples: 5, n_trials: 25, 
	using warping: False, ablation: None, shuffle_features: False
[Task]: 
	task type: classification, sm: discriminative, lower is better: True
Hyperparameter search space: 
{'op_0_to_1': ['categorical',
               None,
               ['none',
                'skip_connect',
                'avg_pool_3x3',
                'nor_conv_1x1',
                'nor_conv_3x3']],
 'op_0_to_2': ['categorical',
               None,
               ['none',
                'skip_connect',
                'avg_pool_3x3',
                'nor_conv_1x1',
                'nor_conv_3x3']],
 'op_0_to_3': ['categorical',
               None,
               ['none',
                'skip_connect',
                'avg_pool_3x3',
                'nor_conv_1x1',
                'nor_conv_3x3']],
 'op_1_to_2': ['categorical',
               None,
               ['none',
                'ski

ValueError: Unknown format code 'f' for object of type 'str'