# Training DelightCnn Model

This notebook shows an example of how to use DelightCnn package to train and use the model

In [11]:
import os
import sys
import logging
from functools import partial

import gdown
import torch
import numpy as np
import numpy.typing as npt
from torch.utils.tensorboard.writer import SummaryWriter

from delightcnn.utils.stoppers import EarlyStopper
from delightcnn.model import DelightCnnParameters
from delightcnn.dataset import DelightDataset, DelightDatasetOptions
from delightcnn.training import execute_train_model

logging.basicConfig(
    format="[%(asctime)s %(levelname)s]: %(message)s",
    level=logging.INFO,
    stream=sys.stderr,
)

### Downloading dataset used in [Delight Paper](https://arxiv.org/pdf/2208.04310)


In [15]:
url = "https://drive.google.com/drive/u/2/folders/1UkHvXq2oNySMN2Hv2K1H9ptygvi2KgdM"
source = os.path.join(os.getcwd(), "data")
gdown.download_folder(url, output=source, quiet=False)

Retrieving folder contents


Processing file 1TQDnl-cmb5rfW3VGKKTwXnTkX6LqVK_c id_train.npy
Processing file 1IIqPqScappd8_TLEBY-v0iv4UNPGyFFk id_validation.npy
Processing file 1zYShTB3llnZI5DTvsApUyYk4LZR-x_5m X_test.npy
Processing file 1BFzBRClUJH9xqtijm3-N3pDzCzdX3q2R X_train.npy
Processing file 1VCjSuwFxTgJTOHe4RvbWmqqNJolL0BcV X_validation.npy
Processing file 15j0HXm_bPcnR6GuZuG2Jf2gm2CRsjPqK y_test.npy
Processing file 1EuJLDgyXH8lbrxi0UjdVku2HOq52P5QQ y_train.npy
Processing file 1Zezs_TxyFgaaLUZ57tleFV1h_R3nPl7G y_validation.npy


Retrieving folder contents completed
Building directory structure
Building directory structure completed
Downloading...
From: https://drive.google.com/uc?id=1TQDnl-cmb5rfW3VGKKTwXnTkX6LqVK_c
To: /Users/keviinplz/thesis/data/id_train.npy
100%|██████████| 206k/206k [00:00<00:00, 6.02MB/s]
Downloading...
From: https://drive.google.com/uc?id=1IIqPqScappd8_TLEBY-v0iv4UNPGyFFk
To: /Users/keviinplz/thesis/data/id_validation.npy
100%|██████████| 51.3k/51.3k [00:00<00:00, 5.10MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1zYShTB3llnZI5DTvsApUyYk4LZR-x_5m
From (redirected): https://drive.google.com/uc?id=1zYShTB3llnZI5DTvsApUyYk4LZR-x_5m&confirm=t&uuid=64fd7d71-8213-4790-89c2-1f0a2f44022b
To: /Users/keviinplz/thesis/data/X_test.npy
100%|██████████| 172M/172M [00:05<00:00, 28.9MB/s] 
Downloading...
From (original): https://drive.google.com/uc?id=1BFzBRClUJH9xqtijm3-N3pDzCzdX3q2R
From (redirected): https://drive.google.com/uc?id=1BFzBRClUJH9xqtijm3-N3pDzCzdX3q2R&confirm=t&uu

['/Users/keviinplz/thesis/data/id_train.npy',
 '/Users/keviinplz/thesis/data/id_validation.npy',
 '/Users/keviinplz/thesis/data/X_test.npy',
 '/Users/keviinplz/thesis/data/X_train.npy',
 '/Users/keviinplz/thesis/data/X_validation.npy',
 '/Users/keviinplz/thesis/data/y_test.npy',
 '/Users/keviinplz/thesis/data/y_train.npy',
 '/Users/keviinplz/thesis/data/y_validation.npy']

# Processors

A processor is an object that defines where and how data has to be retrieved from a given `source`.
It has to follow the `delightcnn.dataset.Processor` protocol defining `X` and `y` properties.

This approach allows user to be flexible in where and how data has to be processed.

In [16]:
class TrainingSetProcessor:
    def __init__(self, source: str, balance: bool = False):
        self._source = source
        self._balanced_indexes: npt.NDArray[np.int32] | None = None
        if balance:
            self._balanced_indexes = np.random.shuffle(self._get_balanced_indexes())

    def _get_balanced_indexes(self) -> npt.NDArray[np.int32]:
        id_train_filepath = os.path.join(self._source, "id_train.npy")
        id_train: npt.NDArray[np.str_] = np.load(id_train_filepath, allow_pickle=True)
        idxAsiago = np.array(
            [i for i in range(id_train.shape[0]) if id_train[i][:2] == "SN"]
        )
        idxZTF = np.array(
            [i for i in range(id_train.shape[0]) if id_train[i][:3] == "ZTF"]
        )
        nimb = int(idxZTF.shape[0] / idxAsiago.shape[0])

        idxbal = np.array([], dtype=int)
        for i in range(nimb + 1):
            idxbal = np.concatenate([idxbal, idxAsiago])
            idxbal = np.concatenate(
                [
                    idxbal,
                    idxZTF[
                        i * idxAsiago.shape[0] : min(
                            idxZTF.shape[0], (i + 1) * idxAsiago.shape[0]
                        )
                    ],
                ]
            )

        return idxbal

    @property
    def X(self) -> npt.NDArray[np.float32]:
        x_train_filepath = os.path.join(self._source, "X_train.npy")
        X_train: npt.NDArray[np.float32] = np.load(x_train_filepath)

        if self._balanced_indexes is not None:
            X_train = X_train[self._balanced_indexes]

        return X_train.swapaxes(3, 1).swapaxes(2, 3)

    @property
    def y(self) -> npt.NDArray[np.float32]:
        y_train_filepath = os.path.join(self._source, "y_train.npy")
        y_train: npt.NDArray[np.float32] = np.load(y_train_filepath)

        if self._balanced_indexes is not None:
            y_train = y_train[self._balanced_indexes]

        return y_train


class ValidationSetProcessor:
    def __init__(self, source: str, pixscale_mask_value: float | None = None):
        self._source = source
        self._pixscale_mask: npt.NDArray[np.int32] | None = None
        if pixscale_mask_value is not None:
            self._pixscale_mask = self._get_distance_mask(pixscale_mask_value)

    def _get_distance_mask(self, pixscale: float) -> npt.NDArray[np.int32]:
        y_validation_filepath = os.path.join(self._source, "y_validation.npy")
        y_validation: npt.NDArray[np.float32] = np.load(y_validation_filepath)

        distance = np.sqrt(np.sum(y_validation**2, axis=1))
        return (distance * pixscale) < 60

    @property
    def X(self) -> npt.NDArray[np.float32]:
        x_validation_filepath = os.path.join(self._source, "X_validation.npy")
        X_validation: npt.NDArray[np.float32] = np.load(x_validation_filepath)

        if self._pixscale_mask is not None:
            X_validation = X_validation[self._pixscale_mask]

        return X_validation.swapaxes(3, 1).swapaxes(2, 3)

    @property
    def y(self) -> npt.NDArray[np.float32]:
        y_validation_filepath = os.path.join(self._source, "y_validation.npy")
        y_validation: npt.NDArray[np.float32] = np.load(y_validation_filepath)

        if self._pixscale_mask is not None:
            y_validation = y_validation[self._pixscale_mask]

        return y_validation


class TestingSetProcessor:
    def __init__(self, source: str):
        self._source = source

    @property
    def X(self) -> npt.NDArray[np.float32]:
        x_test_filepath = os.path.join(self._source, "X_test.npy")
        x_test: npt.NDArray[np.float32] = np.load(x_test_filepath)
        return x_test.swapaxes(3, 1).swapaxes(2, 1)

    @property
    def y(self) -> npt.NDArray[np.float32]:
        y_test_filepath = os.path.join(self._source, "y_test.npy")
        return np.load(y_test_filepath)


class ProductionTrainingSetProcessor:
    def __init__(self, source: str):
        self._source = source
        self._training_set = TrainingSetProcessor(source)
        self._validation_set = ValidationSetProcessor(source)

    @property
    def X(self) -> npt.NDArray[np.float32]:
        return np.concatenate((self._training_set.X, self._validation_set.X))

    @property
    def y(self) -> npt.NDArray[np.float32]:
        return np.concatenate((self._training_set.y, self._validation_set.y))

# Training

The follows suggests some settings to train a DelightCnn model.

It uses `delightcnn.training.execute_train_model` function to train it.

In [17]:
# Dataset settigns
source = os.path.join(os.getcwd(), "data")
dataset_options = DelightDatasetOptions(channels=1, levels=5, rot=True, flip=True)
balance_training_set = True
validation_pixscale_mask_value = 0.25

# Model settings
model_parameters = DelightCnnParameters(
    nconv1=16,
    nconv2=16,
    nconv3=32,
    ndense=128,
    dropout=0.06,
    channels=dataset_options.channels,
    levels=dataset_options.levels,
    rot=dataset_options.rot,
    flip=dataset_options.flip,
)

# Training settings
device: torch.device = torch.device("mps")
epochs = 10
batch_size = 32
adam_learning_rate = 0.0014
adam_weight_decay = 1e-4
criterion = torch.nn.MSELoss()
optimizer = partial(
    torch.optim.Adam,  # type: ignore
    lr=adam_learning_rate,
    weight_decay=adam_weight_decay,
)

# You can use other Stopper strategy
# Following `delightcnn.utils.stoppers.Stopper` protocol
stopper = EarlyStopper(patience=3, min_delta=0)
writter = SummaryWriter()

train_dataset = DelightDataset(
    processor=TrainingSetProcessor(source, balance=balance_training_set),
    options=dataset_options,
)
val_dataset = DelightDataset(
    processor=ValidationSetProcessor(
        source, pixscale_mask_value=validation_pixscale_mask_value
    ),
    options=dataset_options,
)

In [18]:
model = execute_train_model(
    model_parameters=model_parameters,
    criterion=criterion,
    optimizer=optimizer,  # type: ignore
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    stopper=stopper,
    device=device,
    epochs=epochs,
    batch_size=batch_size,
)

[2024-10-29 00:23:00,844 INFO]: [EPOCH 1] train loss = 2408.1670989990234 | val_loss = 296.975034383138
[2024-10-29 00:23:00,845 INFO]: Validation loss has been improved from inf -> 296.975034383138
[2024-10-29 00:23:10,895 INFO]: [EPOCH 2] train loss = 1716.3822231292725 | val_loss = 213.41258761088054
[2024-10-29 00:23:10,896 INFO]: Validation loss has been improved from 296.975034383138 -> 213.41258761088054
[2024-10-29 00:23:20,803 INFO]: [EPOCH 3] train loss = 1336.0316133499146 | val_loss = 186.77573875427245
[2024-10-29 00:23:20,804 INFO]: Validation loss has been improved from 213.41258761088054 -> 186.77573875427245
[2024-10-29 00:23:30,832 INFO]: [EPOCH 4] train loss = 1082.6466007232666 | val_loss = 154.44075614929199
[2024-10-29 00:23:30,833 INFO]: Validation loss has been improved from 186.77573875427245 -> 154.44075614929199
[2024-10-29 00:23:40,934 INFO]: [EPOCH 5] train loss = 914.9907398223877 | val_loss = 152.44388626098632
[2024-10-29 00:23:40,935 INFO]: Validation l