In [1]:
%load_ext tensorboard

In [2]:
import os
import sys
import logging
from functools import partial

import torch
import numpy as np
import numpy.typing as npt
from torch.utils.tensorboard.writer import SummaryWriter

from utils.stoppers import EarlyStopper
from models.delightcnn.model import DelightCnnParameters
from models.delightcnn.dataset import DelightDataset, DelightDatasetOptions
from models.delightcnn.training import execute_train_model

logging.basicConfig(
    format="[%(asctime)s %(levelname)s]: %(message)s",
    level=logging.INFO,
    stream=sys.stderr,
)

## Creating Dataset to be used for training

In [3]:
class TrainingSetProcessor:
    def __init__(self, source: str, balance: bool = False):
        self._source = source
        self._balanced_indexes: npt.NDArray[np.int32] | None = None
        if balance:
            self._balanced_indexes = np.random.shuffle(self._get_balanced_indexes())

    def _get_balanced_indexes(self) -> npt.NDArray[np.int32]:
        id_train_filepath = os.path.join(self._source, "id_train.npy")
        id_train: npt.NDArray[np.str_] = np.load(id_train_filepath, allow_pickle=True)
        idxAsiago = np.array(
            [i for i in range(id_train.shape[0]) if id_train[i][:2] == "SN"]
        )
        idxZTF = np.array(
            [i for i in range(id_train.shape[0]) if id_train[i][:3] == "ZTF"]
        )
        nimb = int(idxZTF.shape[0] / idxAsiago.shape[0])

        idxbal = np.array([], dtype=int)
        for i in range(nimb + 1):
            idxbal = np.concatenate([idxbal, idxAsiago])
            idxbal = np.concatenate(
                [
                    idxbal,
                    idxZTF[
                        i * idxAsiago.shape[0] : min(
                            idxZTF.shape[0], (i + 1) * idxAsiago.shape[0]
                        )
                    ],
                ]
            )

        return idxbal

    @property
    def X(self) -> npt.NDArray[np.float32]:
        x_train_filepath = os.path.join(self._source, "X_train.npy")
        X_train: npt.NDArray[np.float32] = np.load(x_train_filepath)

        if self._balanced_indexes is not None:
            X_train = X_train[self._balanced_indexes]

        return X_train.swapaxes(3, 1).swapaxes(2, 3)

    @property
    def y(self) -> npt.NDArray[np.float32]:
        y_train_filepath = os.path.join(self._source, "y_train.npy")
        y_train: npt.NDArray[np.float32] = np.load(y_train_filepath)

        if self._balanced_indexes is not None:
            y_train = y_train[self._balanced_indexes]

        return y_train


class ValidationSetProcessor:
    def __init__(self, source: str, pixscale_mask_value: float | None = None):
        self._source = source
        self._pixscale_mask: npt.NDArray[np.int32] | None = None
        if pixscale_mask_value is not None:
            self._pixscale_mask = self._get_distance_mask(pixscale_mask_value)

    def _get_distance_mask(self, pixscale: float) -> npt.NDArray[np.int32]:
        y_validation_filepath = os.path.join(self._source, "y_validation.npy")
        y_validation: npt.NDArray[np.float32] = np.load(y_validation_filepath)

        distance = np.sqrt(np.sum(y_validation**2, axis=1))
        return (distance * pixscale) < 60

    @property
    def X(self) -> npt.NDArray[np.float32]:
        x_validation_filepath = os.path.join(self._source, "X_validation.npy")
        X_validation: npt.NDArray[np.float32] = np.load(x_validation_filepath)

        if self._pixscale_mask is not None:
            X_validation = X_validation[self._pixscale_mask]

        return X_validation.swapaxes(3, 1).swapaxes(2, 3)

    @property
    def y(self) -> npt.NDArray[np.float32]:
        y_validation_filepath = os.path.join(self._source, "y_validation.npy")
        y_validation: npt.NDArray[np.float32] = np.load(y_validation_filepath)

        if self._pixscale_mask is not None:
            y_validation = y_validation[self._pixscale_mask]

        return y_validation


class TestingSetProcessor:
    def __init__(self, source: str):
        self._source = source

    @property
    def X(self) -> npt.NDArray[np.float32]:
        x_test_filepath = os.path.join(self._source, "X_test.npy")
        x_test: npt.NDArray[np.float32] = np.load(x_test_filepath)
        return x_test.swapaxes(3, 1).swapaxes(2, 1)

    @property
    def y(self) -> npt.NDArray[np.float32]:
        y_test_filepath = os.path.join(self._source, "y_test.npy")
        return np.load(y_test_filepath)


class ProductionTrainingSetProcessor:
    def __init__(self, source: str):
        self._source = source
        self._training_set = TrainingSetProcessor(source)
        self._validation_set = ValidationSetProcessor(source)

    @property
    def X(self) -> npt.NDArray[np.float32]:
        return np.concatenate((self._training_set.X, self._validation_set.X))

    @property
    def y(self) -> npt.NDArray[np.float32]:
        return np.concatenate((self._training_set.y, self._validation_set.y))

In [4]:
# Dataset settigns
source = os.path.join(os.getcwd(), "data")
dataset_options = DelightDatasetOptions(channels=1, levels=5, rot=True, flip=True)
balance_training_set = True
validation_pixscale_mask_value = 0.25

# Model settings
model_parameters = DelightCnnParameters(
    nconv1=16,
    nconv2=16,
    nconv3=32,
    ndense=128,
    dropout=0.06,
    channels=dataset_options.channels,
    levels=dataset_options.levels,
    rot=dataset_options.rot,
    flip=dataset_options.flip,
)

# Training settings
device: torch.device = torch.device("mps")
epochs = 10
batch_size = 32
adam_learning_rate = 0.0014
adam_weight_decay = 1e-4
criterion = torch.nn.MSELoss()
optimizer = partial(
    torch.optim.Adam,  # type: ignore
    lr=adam_learning_rate,
    weight_decay=adam_weight_decay,
)
stopper = EarlyStopper(patience=3, min_delta=0)
writter = SummaryWriter()

train_dataset = DelightDataset(
    processor=TrainingSetProcessor(source, balance=balance_training_set),
    options=dataset_options,
)
val_dataset = DelightDataset(
    processor=ValidationSetProcessor(
        source, pixscale_mask_value=validation_pixscale_mask_value
    ),
    options=dataset_options,
)

In [5]:
model = execute_train_model(
    model_parameters=model_parameters,
    criterion=criterion,
    optimizer=optimizer,  # type: ignore
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    stopper=stopper,
    device=device,
    epochs=epochs,
    batch_size=batch_size,
)

[2024-10-28 17:36:31,300 INFO]: [EPOCH 1] train loss = 2415.5647735595703 | val_loss = 309.1363386027018
[2024-10-28 17:36:31,300 INFO]: Validation loss has been improved from inf -> 309.1363386027018
[2024-10-28 17:36:41,528 INFO]: [EPOCH 2] train loss = 1727.4650440216064 | val_loss = 216.06668968200682
[2024-10-28 17:36:41,530 INFO]: Validation loss has been improved from 309.1363386027018 -> 216.06668968200682
[2024-10-28 17:36:51,558 INFO]: [EPOCH 3] train loss = 1347.4812803268433 | val_loss = 184.02650428771972
[2024-10-28 17:36:51,558 INFO]: Validation loss has been improved from 216.06668968200682 -> 184.02650428771972
[2024-10-28 17:37:01,588 INFO]: [EPOCH 4] train loss = 1110.7901306152344 | val_loss = 161.72652524312338
[2024-10-28 17:37:01,588 INFO]: Validation loss has been improved from 184.02650428771972 -> 161.72652524312338
[2024-10-28 17:37:11,490 INFO]: [EPOCH 5] train loss = 938.5985555648804 | val_loss = 145.32816317240398
[2024-10-28 17:37:11,490 INFO]: Validatio