# Using raytune

This notebook shows an example of how to use Raytune with this package to train several configurations of DelightCnn models

In [1]:
import os
import sys
import logging
import datetime
from functools import partial

import gdown
import torch
import numpy as np
import numpy.typing as npt
from ray import tune, train
from ray.tune.schedulers import ASHAScheduler

from delightcnn.dataset import DelightDataset, DelightDatasetOptions
from delightcnn.training import (
    ray_wrapper_training_function,
    TrainingOptions,
)

logging.basicConfig(
    format="[%(asctime)s %(levelname)s]: %(message)s",
    level=logging.INFO,
    stream=sys.stderr,
)

### Downloading dataset used in [Delight Paper](https://arxiv.org/pdf/2208.04310)

In [None]:
url = "https://drive.google.com/drive/u/2/folders/1UkHvXq2oNySMN2Hv2K1H9ptygvi2KgdM"
source = os.path.join(os.getcwd(), "data")
gdown.download_folder(url, output=source, quiet=False)

In [2]:
class TrainingSetProcessor:
    def __init__(self, source: str, balance: bool = False):
        self._source = source
        self._balanced_indexes: npt.NDArray[np.int32] | None = None
        if balance:
            self._balanced_indexes = np.random.shuffle(self._get_balanced_indexes())

    def _get_balanced_indexes(self) -> npt.NDArray[np.int32]:
        id_train_filepath = os.path.join(self._source, "id_train.npy")
        id_train: npt.NDArray[np.str_] = np.load(id_train_filepath, allow_pickle=True)
        idxAsiago = np.array(
            [i for i in range(id_train.shape[0]) if id_train[i][:2] == "SN"]
        )
        idxZTF = np.array(
            [i for i in range(id_train.shape[0]) if id_train[i][:3] == "ZTF"]
        )
        nimb = int(idxZTF.shape[0] / idxAsiago.shape[0])

        idxbal = np.array([], dtype=int)
        for i in range(nimb + 1):
            idxbal = np.concatenate([idxbal, idxAsiago])
            idxbal = np.concatenate(
                [
                    idxbal,
                    idxZTF[
                        i * idxAsiago.shape[0] : min(
                            idxZTF.shape[0], (i + 1) * idxAsiago.shape[0]
                        )
                    ],
                ]
            )

        return idxbal

    @property
    def X(self) -> npt.NDArray[np.float32]:
        x_train_filepath = os.path.join(self._source, "X_train.npy")
        X_train: npt.NDArray[np.float32] = np.load(x_train_filepath)

        if self._balanced_indexes is not None:
            X_train = X_train[self._balanced_indexes]

        return X_train.swapaxes(3, 1).swapaxes(2, 3)

    @property
    def y(self) -> npt.NDArray[np.float32]:
        y_train_filepath = os.path.join(self._source, "y_train.npy")
        y_train: npt.NDArray[np.float32] = np.load(y_train_filepath)

        if self._balanced_indexes is not None:
            y_train = y_train[self._balanced_indexes]

        return y_train


class ValidationSetProcessor:
    def __init__(self, source: str, pixscale_mask_value: float | None = None):
        self._source = source
        self._pixscale_mask: npt.NDArray[np.int32] | None = None
        if pixscale_mask_value is not None:
            self._pixscale_mask = self._get_distance_mask(pixscale_mask_value)

    def _get_distance_mask(self, pixscale: float) -> npt.NDArray[np.int32]:
        y_validation_filepath = os.path.join(self._source, "y_validation.npy")
        y_validation: npt.NDArray[np.float32] = np.load(y_validation_filepath)

        distance = np.sqrt(np.sum(y_validation**2, axis=1))
        return (distance * pixscale) < 60

    @property
    def X(self) -> npt.NDArray[np.float32]:
        x_validation_filepath = os.path.join(self._source, "X_validation.npy")
        X_validation: npt.NDArray[np.float32] = np.load(x_validation_filepath)

        if self._pixscale_mask is not None:
            X_validation = X_validation[self._pixscale_mask]

        return X_validation.swapaxes(3, 1).swapaxes(2, 3)

    @property
    def y(self) -> npt.NDArray[np.float32]:
        y_validation_filepath = os.path.join(self._source, "y_validation.npy")
        y_validation: npt.NDArray[np.float32] = np.load(y_validation_filepath)

        if self._pixscale_mask is not None:
            y_validation = y_validation[self._pixscale_mask]

        return y_validation


class TestingSetProcessor:
    def __init__(self, source: str):
        self._source = source

    @property
    def X(self) -> npt.NDArray[np.float32]:
        x_test_filepath = os.path.join(self._source, "X_test.npy")
        x_test: npt.NDArray[np.float32] = np.load(x_test_filepath)
        return x_test.swapaxes(3, 1).swapaxes(2, 1)

    @property
    def y(self) -> npt.NDArray[np.float32]:
        y_test_filepath = os.path.join(self._source, "y_test.npy")
        return np.load(y_test_filepath)


class ProductionTrainingSetProcessor:
    def __init__(self, source: str):
        self._source = source
        self._training_set = TrainingSetProcessor(source)
        self._validation_set = ValidationSetProcessor(source)

    @property
    def X(self) -> npt.NDArray[np.float32]:
        return np.concatenate((self._training_set.X, self._validation_set.X))

    @property
    def y(self) -> npt.NDArray[np.float32]:
        return np.concatenate((self._training_set.y, self._validation_set.y))

In [3]:
# Dataset settigns
source = os.path.join(os.getcwd(), "data")
dataset_options = DelightDatasetOptions(channels=1, levels=5, rot=True, flip=True)
balance_training_set = True
validation_pixscale_mask_value = 0.25

# Training settings
device: torch.device = torch.device("mps")
epochs = 10
batch_size = 32
adam_learning_rate = 0.0014
adam_weight_decay = 1e-4
criterion = torch.nn.MSELoss()
optimizer = partial(
    torch.optim.Adam,  # type: ignore
    lr=adam_learning_rate,
    weight_decay=adam_weight_decay,
)

train_dataset = DelightDataset(
    processor=TrainingSetProcessor(source, balance=balance_training_set),
    options=dataset_options,
)
val_dataset = DelightDataset(
    processor=ValidationSetProcessor(
        source, pixscale_mask_value=validation_pixscale_mask_value
    ),
    options=dataset_options,
)

# We create a `delightcnn.training.TrainingOptions` objects
# To separate Hyperparameters (ones where Raytune will create a grid)
# From training specifications
training_options = TrainingOptions(
    criterion=criterion,
    dataset_options=dataset_options,
    optimizer=optimizer,  # type: ignore
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    epochs=epochs,
    device=device,
)

# Defining `run_ray_tune` function

We have to define `run_ray_tune` with hyperparams defines in `delightcnn.training.HyperParameters`

This package comes with `ray_wrapper_traning_function`, that is responsable to converts HyperParameters into DelightCnn parameters and uses `TrainingOptions` to create all the required objects to start a training session

In [4]:
def run_ray_tune(
    *,
    name: str,
    num_samples: int,
    gpus_per_trial: float,
    training_options: TrainingOptions,
):
    param_space = {
        "nconv1": tune.lograndint(16, 64 + 1),
        "nconv2": tune.lograndint(16, 64 + 1),
        "nconv3": tune.lograndint(16, 64 + 1),
        "ndense": tune.lograndint(256, 2048 + 1),
        "dropout": tune.uniform(0, 0.4),
        "batch_size": tune.lograndint(16, 64 + 1),
    }

    scheduler = ASHAScheduler(
        grace_period=20,  # epochs before evaluate early stop
        reduction_factor=3,  # the worst 1/3 trials will be terminated
        brackets=1,  # we don't want to decrease resources
    )

    train_fn = partial(ray_wrapper_training_function, training_options=training_options)

    tuner = tune.Tuner(
        tune.with_resources(train_fn, resources={"gpu": gpus_per_trial}),  # type: ignore
        tune_config=tune.TuneConfig(
            metric="val_loss", mode="min", scheduler=scheduler, num_samples=num_samples
        ),
        run_config=train.RunConfig(name=name),
        param_space=param_space,
    )
    return tuner.fit()

In [5]:
now = datetime.datetime.now()
name = f"ray_experiment_{now.strftime('%d_%m_%Y-%H_%M_%S')}"
num_samples = 200

result = run_ray_tune(
    name=name,
    num_samples=num_samples,
    gpus_per_trial=0.2,
    training_options=training_options,
)

