In [None]:
import copy
%load_ext autoreload
%autoreload 2
%env PYTORCH_ENABLE_MPS_FALLBACK=1

In [None]:
from star_analysis.model.types import ModelTypes
from star_analysis.runner.sdss_runner import SdssRunner
from star_analysis.data.augmentations import Augmentations
from star_analysis.runner.sdss_runner import SdssRunConfig, SdssModelConfig
from star_analysis.model.neural_networks.losses.types import LossType
from star_analysis.runner.run import Run

In [None]:
runner = SdssRunner()

In [None]:
run_config0 = SdssRunConfig(
    model_config=SdssModelConfig(
        learning_rate=1e-4,
        batch_size=32,
        model_type=ModelTypes.UNET,
        loss_type=LossType.DA_MSE
    ),
    augmentation=Augmentations.NONE,
    shuffle_train=True
)

run_config1 = copy.deepcopy(run_config0)
run_config1.model_config.loss_type = LossType.DICE

run_config2 = copy.deepcopy(run_config0)
run_config2.model_config.loss_type = LossType.DA_DICE

run_config3 = copy.deepcopy(run_config0)
run_config3.model_config.loss_type = LossType.FOCAL

run_config4 = copy.deepcopy(run_config0)
run_config4.model_config.loss_type = LossType.DA_FOCAL

configs = [run_config0, run_config1, run_config2, run_config3, run_config4]
runs = [Run(config) for config in configs]

In [None]:
for run in runs:
    runner.add_run(run)

In [None]:
from star_analysis.runner.run import TrainerConfig

results = runner.tune(
    runs=runs,
    trainer_config=TrainerConfig(
        logger=None,
        max_epochs=10,
    )
)
run_results = zip(runs, results)

In [None]:
best_run, result_best = max(run_results, key=lambda x: x[1]['test_f1'])
print(f"Best run, {best_run[0].name}, achieved {best_run[1]['test_f1']} test_f1")

for i, (run, result) in enumerate(run_results):
    print(f"Run i, {run.name}, achieved {result['test_f1']} test_f1")

runner.save_model(best_run.model)