In [49]:
from dataclasses import dataclass, asdict
from tango.common import Params
from tango.common import Registrable
from tango.common.testing import run_experiment
from collections import OrderedDict
import itertools

In [51]:
sweeps_config_path = "/Users/sabhyac/Desktop/sabhya/tango/test_fixtures/sweeps/basic_test/sweeps-config.jsonnet"
main_config_path = "/Users/sabhyac/Desktop/sabhya/tango/test_fixtures/sweeps/basic_test/config.jsonnet"
components = "/Users/sabhyac/Desktop/sabhya/tango/test_fixtures/sweeps/basic_test/basic_arithmetic.py"

In [82]:
class Sweeper(Registrable):
    def __init__(self, main_config_path: str , sweeps_config_path: str, components: str):
        super(Registrable, self).__init__()
        self.config_path = main_config_path
        self.sweep_config = load_config(sweeps_config_path)
        self.main_config_path = main_config_path
        self.components = components
    
    # returns all the combinations of hyperparameters in the form of a list of lists
    def get_combinations(self):
        hyperparams = self.sweep_config.config["hyperparameters"]
        hyperparams_lsts = []
        for key, val in hyperparams.items():
            hyperparams_lsts.append(val)            
        hyperparam_combos = list(itertools.product(*hyperparams_lsts))
        return hyperparam_combos
    
    # TODO: trying to figure the best path forward? should i use tests?
    def run_experiments(self):
        hyperparam_combos = self.get_combinations()
        for combination in hyperparam_combos:
            main_config = self.override_hyperparameters(combination)
            with run_experiment(main_config, include_package=[self.components]) as run_dir:
                # TODO: fill in something here?
                pass
    
    # TODO: wondering if this function should be here or in a test_file?
    def override_hyperparameters(self, experiment_tuple: dict):
        # Override all the hyperparameters in the current experiment_config
        overrides = {}
        for (i, key) in enumerate(self.sweep_config.config["hyperparameters"].keys()):
            overrides[key] = experiment_tuple[i]
        # load the config & override it
        main_config = Params.from_file(self.main_config_path, params_overrides=overrides)
        return main_config
        

# function that loads the config from a specified yaml or jasonnet file
# TODO: how do I read "wandb" form config and call appropriate class
def load_config(config_path: str):
    return SweepConfig.from_file(config_path)

# data class that loads the parameters
# TODO: unsure about how to specify a default here?
@dataclass(frozen=True)
class SweepConfig(Params):
    config: OrderedDict

In [78]:
sw = Sweeper(main_config_path=main_config_path, sweeps_config_path=config_path, components=components)

In [79]:
sw.sweep_config.config["hyperparameters"]

{'steps.add_numbers.num1': [8, 16],
 'steps.add_numbers.num2': [2, 4],
 'steps.add_x.num2': [1, 2],
 'steps.divide_result.factor': [5, 10],
 'steps.multiply_result.factor': [1, 10]}

In [80]:
sw.get_combinations()[0]

(8, 2, 1, 5, 1)

In [81]:
sw.run_experiments()

Starting new run proper-osprey
● Starting step "add_numbers" ...
✓ Finished step "add_numbers"
✓ Found output for step "add_numbers" in cache (needed by "multiply_result") ...
● Starting step "multiply_result" (needed by "divide_result") ...
✓ Finished step "multiply_result"
● Starting step "divide_result" (needed by "add_x") ...
✓ Finished step "divide_result"
● Starting step "add_x" ...
✓ Finished step "add_x"
✓ Found output for step "divide_result" in cache ...
✓ Found output for step "multiply_result" in cache ...
✓ Found output for step "add_x" in cache (needed by "print") ...
● Starting step "print" ...
3.0
✓ Finished step "print"
✓ The output for "add_numbers" is in /var/folders/st/xl2yk49d6v50p42_9spxgtlr0000gp/T/tango_testsngn5yg16/workspace/cache/AdditionStep-5SnWM66M3F7uD2UpCqGAbwLfzHMxmkqA
✓ The output for "add_x" is in /var/folders/st/xl2yk49d6v50p42_9spxgtlr0000gp/T/tango_testsngn5yg16/workspace/cache/AdditionStep-3jJeYfWuajwTjmC2LRLmvQKkg7dDgbVD
✓ The output for "divide_

✓ The output for "add_x" is in /var/folders/st/xl2yk49d6v50p42_9spxgtlr0000gp/T/tango_testss0_bni68/workspace/cache/AdditionStep-3WmejhakQcjZ6s1Lpd9Pju2E3keswSr5
✓ The output for "divide_result" is in /var/folders/st/xl2yk49d6v50p42_9spxgtlr0000gp/T/tango_testss0_bni68/workspace/cache/ScaleDown-3FuQAfg28T9ooWBNyza3UKLJhmMizmKk
✓ The output for "multiply_result" is in /var/folders/st/xl2yk49d6v50p42_9spxgtlr0000gp/T/tango_testss0_bni68/workspace/cache/ScaleUp-5Fsy9AbNHQZTREHfunhfwbpt6AjFP7T4
✓ The output for "print" is in /var/folders/st/xl2yk49d6v50p42_9spxgtlr0000gp/T/tango_testss0_bni68/workspace/cache/Print-3cu2d5W4bFu6PkxbNC3cCQrCo8LAdwTb
Finished run on-joey
Starting new run deep-stud
● Starting step "add_numbers" ...
✓ Finished step "add_numbers"
✓ Found output for step "add_numbers" in cache (needed by "multiply_result") ...
● Starting step "multiply_result" (needed by "divide_result") ...
✓ Finished step "multiply_result"
● Starting step "divide_result" (needed by "add_x") ...


KeyboardInterrupt: 