In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import seaborn as sns
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
    prepare_trainer,
)
from ray.train.torch import TorchTrainer
from ray.train import RunConfig, ScalingConfig, CheckpointConfig


from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix
)

from data_loader import Cifar10DataModule
from model import SimpleCIFAR10Classifier

In [None]:
DATA_DIR = '~/autodl-tmp/A3_CNN Image Classifier/data'

SEARCH_SPACE = {
    "batch_size": tune.choice([32, 48, 64]),
    "lr": tune.uniform(1e-3, 1e-1),
    "weight_decay": tune.uniform(1e-7, 1e-6)
}


def train_func(config: dict):
    dm = Cifar10DataModule(data_dir=DATA_DIR, batch_size=config['batch_size'])
    model = SimpleCIFAR10Classifier(config)

    trainer = pl.Trainer(
        devices='auto',
        accelerator='auto',
        strategy=RayDDPStrategy(),
        callbacks=[RayTrainReportCallback()],
        plugins=[RayLightningEnvironment()],
        enable_progress_bar=False,
    )
    trainer = prepare_trainer(trainer)
    trainer.fit(model, datamodule=dm)


def tune_cifar_asha(num_epochs: int = 5 ,num_samples: int = 10):
    scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

    scaling_config = ScalingConfig(
        num_workers=1, use_gpu=True, resources_per_worker={'CPU': 4, 'GPU': 1}
    )

    run_config = RunConfig(
        checkpoint_config=CheckpointConfig(
            num_to_keep=2,
            checkpoint_score_attribute='logs/val_accuracy',
            checkpoint_score_order='max',
        ),
    )

    ray_trainer = TorchTrainer(
        train_func,
        scaling_config=scaling_config,
        run_config=run_config,
    )

    tuner = tune.Tuner(
        ray_trainer,
        param_space={'train_loop_config': SEARCH_SPACE},
        tune_config=tune.TuneConfig(
            metric='logs/val_accuracy',
            mode='max',
            num_samples=num_samples,
            scheduler=scheduler,
        ),
    )
    return tuner.fit()

In [None]:
results = tune_cifar_asha(num_epochs=10, num_samples=15)

In [None]:
results.get_best_result(metric="logs/val_accuracy", mode="max")