In [13]:
import os
import logging
import math
from filelock import FileLock

# __import_lightning_begin__
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
# __import_lightning_end__




# __import_tune_begin__
from pytorch_lightning.loggers import TensorBoardLogger
from ray import tune
from ray.tune import CLIReporter, JupyterNotebookReporter
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.integration.pytorch_lightning import TuneReportCallback, TuneReportCheckpointCallback
# __import_tune_end__


from LightningMNISTClassifier import LightningMNISTClassifier


In [14]:
log = logging.getLogger('App')
logging.basicConfig(level=logging.INFO)

In [15]:
def train_mnist_tune_checkpoint(config,
                                checkpoint_dir=None,
                                num_epochs=10,
                                num_gpus=0):
    data_dir = os.path.expanduser("~/data")

    trainer = pl.Trainer(
        max_epochs=num_epochs,
        # If fractional GPUs passed in, convert to int.
        gpus=math.ceil(num_gpus),
        logger=TensorBoardLogger(save_dir=tune.get_trial_dir(), name="", version="."),
        progress_bar_refresh_rate=config["progress_bar_refresh_rate"],
        num_sanity_val_steps=0,
        callbacks=[
            TuneReportCheckpointCallback(
                metrics={
                    "loss": "ptl/val_loss",
                    "mean_accuracy": "ptl/val_accuracy"
                },
                filename="checkpoint",
                on="validation_end"
            )
        ]
    )

    if checkpoint_dir:
        model = LightningMNISTClassifier.load_from_checkpoint(os.path.join(checkpoint_dir, "checkpoint"), config=config,
                                                              data_dir=data_dir)
        log.info('Lightning loaded from checkpoint')
    else:
        model = LightningMNISTClassifier(config=config, data_dir=data_dir)
        log.info('Lightning initialized')

    trainer.fit(model)

In [16]:
def tune_mnist_pbt(num_samples=20, num_epochs=10, gpus_per_trial=0):
    config = {
        "layer_1_size": tune.choice([32, 64, 128, 256, 512, 1024]),
        "layer_2_size": tune.choice([32, 64, 128, 256, 512, 1024]),
        "lr": 1e-3,
        "batch_size": tune.choice([32, 64, 128, 256, 512, 1024]),
    }

    config = {
        "layer_1_size": 512,
        "layer_2_size": 512,
        "lr": 1e-3,
        "batch_size": 64,
    }

    def explore(config):
        log.info("======================================= EXPLORE =========================================")
        log.info(config)
        config['batch_size'] = config['batch_size'] + 10
        return config

    def generate_batch_sizes():
        res = []
        for _ in range(random.randint(1, 10)):
            res.append(random.randint(8, 129))
        print(res)
        return res

    """
    hyperparam_mutations={
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": [32, 64, 128]
    }
    """
    scheduler = PopulationBasedTraining(
        time_attr="training_iteration",
        perturbation_interval=1,
        # Models will be considered for perturbation at this interval of time_attr="time_total_s"
        hyperparam_mutations={
            "batch_size": tune.choice([32, 64, 128, 256, 512, 1024, 2048]),
        },
        custom_explore_fn=explore,
        log_config=True
    )

    """
    reporter_cli = CLIReporter(
        parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
        metric_columns=["loss", "mean_accuracy", "training_iteration"]
    )
    """

    reporter_jupyter = JupyterNotebookReporter(
        overwrite=True,
        parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
        metric_columns=["loss", "mean_accuracy", "training_iteration"]
    )

    analysis = tune.run(
        tune.with_parameters(
            train_mnist_tune_checkpoint,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial),
        resources_per_trial={
            "cpu": 1,
            "gpu": gpus_per_trial
        },
        metric="loss",
        mode="min",
        config={
            "progress_bar_refresh_rate": 0,
            "layer_1_size": tune.choice([32, 64, 128, 256, 512, 1024]),
            "layer_2_size": tune.choice([32, 64, 128, 256, 512, 1024]),
            "lr": tune.choice([1e-2, 1e-3, 1e-4, 1e-5, 1e-6]),
            "batch_size": tune.choice([32, 64, 128, 256, 512, 1024, 2048]),
        },
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter_jupyter,
        verbose=1,
        name="MNIST",
        stop={  # Stop a single trial if one of the conditions are met
            "mean_accuracy": 0.99,
            "training_iteration": 15},
        local_dir="./data",
    )

    print("Best hyperparameters found were: ", analysis.best_config)
    return analysis

In [17]:
analysis = tune_mnist_pbt(num_samples=5, num_epochs=5, gpus_per_trial=1/20.0)
analysis.best_config
analysis.results

2021-09-26 21:44:17,054	INFO tune.py:550 -- Total run time: 137.61 seconds (137.49 seconds for the tuning loop).


Best hyperparameters found were:  {'progress_bar_refresh_rate': 0, 'layer_1_size': 1024, 'layer_2_size': 1024, 'lr': 0.0001, 'batch_size': 266}


{'6eb22_00000': {'loss': 0.0715656504034996,
  'mean_accuracy': 0.9789683222770691,
  'time_this_iter_s': 7.12268853187561,
  'should_checkpoint': True,
  'done': True,
  'timesteps_total': None,
  'episodes_total': None,
  'training_iteration': 7,
  'experiment_id': '2f56f0029d1a425fbf3a05df2270f04d',
  'date': '2021-09-26_21-44-16',
  'timestamp': 1632681856,
  'time_total_s': 113.75413274765015,
  'pid': 21753,
  'hostname': 'ml-linux',
  'node_ip': '192.168.1.23',
  'config': {'progress_bar_refresh_rate': 0,
   'layer_1_size': 1024,
   'layer_2_size': 1024,
   'lr': 0.0001,
   'batch_size': 48},
  'time_since_restore': 48.44677019119263,
  'timesteps_since_restore': 0,
  'iterations_since_restore': 5,
  'trial_id': '6eb22_00000',
  'experiment_tag': '0_batch_size=1024,layer_1_size=1024,layer_2_size=32,lr=0.0001@perturbed[batch_size=48]'},
 '6eb22_00001': {'loss': 0.11761678755283356,
  'mean_accuracy': 0.9657266736030579,
  'time_this_iter_s': 9.429386854171753,
  'should_checkpoin