# Introducción

En el presente notebook, buscaremos reproducir los resultados obtenidos en el paper de [DELIGHT](https://arxiv.org/pdf/2208.04310).

Para esto, trabajaremos adaptando el modelo propuesto a PyTorch, y mediremos su rendimiento bajo las métricas propuestas por el paper.

# Métricas

Evaluaremos el desempeño de la red mediante 6 métricas sobre el conjunto de test

- $$ RMSE = \sqrt{MSE} = \sqrt{\frac{1}{N} \sum_{i=1}^{N} (Y_i - \hat{Y_i})^2} $$
- $$ Mean Deviaton = \frac{1}{N} \sum_{i=1}^{N} \lVert Y_i - \hat{Y_i} \rVert $$
- $$ Median Deviaton = mediana(\lVert Y_i - \hat{Y_i} \rVert) $$
- $$ Mode Deviaton = moda(\lVert Y_i - \hat{Y_i} \rVert) $$

Donde:
- $N$: Es el tamaño total de los datos del conjunto de test.
- $Y_i$: Es un vector 2d que representa la posición real de la galaxia host. $Y_i = (x_i, y_i)$
- $\hat{Y_i }$: Es un vector 2d que representa la posición predicha de la galaxia host. $\hat{Y_i} = (\hat{x_i}, \hat{y_i})$

# Baseline

Los valores obtenidos para cada métrica en el paper son:
- RMSE: 1.836 ± 0.05100
- Mean Deviation: 0.783 ± 0.00900
- Median Deviation: 0.468 ± 0.00800
- Mode Deviation: 0.427 ± 0.05100


In [1]:
%load_ext autoreload
%load_ext tensorboard
%autoreload 2

In [2]:
import numpy as np
import numpy.typing as npt
from typing import Callable
from scipy import stats  # type: ignore
from sklearn.utils import resample  # type: ignore

StatisticFunction = Callable[[npt.NDArray[np.float32]], float]


def bootstrap_statistic(
    data: npt.NDArray[np.float32],
    statistic: StatisticFunction,
    n_iterations: int = 1000,
) -> float:
    stats = np.zeros(n_iterations)
    for i in range(n_iterations):
        sample: npt.NDArray[np.float32] = resample(data)  # type: ignore
        stats[i] = statistic(sample)
    return np.std(stats).item()


def rmse(
    y_true: npt.NDArray[np.float32], y_pred: npt.NDArray[np.float32]
) -> tuple[float, float]:
    has_shape_2 = len(y_true.shape) == len(y_pred.shape) == 2
    are_points = y_true.shape[1] == y_pred.shape[1] == 2
    assert (
        has_shape_2 and are_points
    ), f"Expected vectors of dim (N, 2): y_true={y_true.shape} y_pred={y_pred.shape}"

    sum_distance_squared: npt.NDArray[np.float32] = np.sum(
        (y_true - y_pred) ** 2, axis=1
    )
    value = np.sqrt(np.mean(sum_distance_squared))  # type: ignore
    assert isinstance(value, float), f"Expected float result: {value}"
    return value, bootstrap_statistic(
        sum_distance_squared, lambda x: np.sqrt(np.mean(x))
    )


def mean_deviation(
    y_true: npt.NDArray[np.float32], y_pred: npt.NDArray[np.float32]
) -> tuple[float, float]:
    has_shape_2 = len(y_true.shape) == len(y_pred.shape) == 2
    are_points = y_true.shape[1] == y_pred.shape[1] == 2
    assert (
        has_shape_2 and are_points
    ), f"Expected vectors of dim (N, 2): y_true={y_true.shape} y_pred={y_pred.shape}"

    deviation: npt.NDArray[np.float32] = np.linalg.norm(y_true - y_pred, axis=1)  # type: ignore
    return np.mean(deviation).item(), bootstrap_statistic(deviation, np.mean)


def median_deviation(
    y_true: npt.NDArray[np.float32], y_pred: npt.NDArray[np.float32]
) -> tuple[float, float]:
    has_shape_2 = len(y_true.shape) == len(y_pred.shape) == 2
    are_points = y_true.shape[1] == y_pred.shape[1] == 2
    assert (
        has_shape_2 and are_points
    ), f"Expected vectors of dim (N, 2): y_true={y_true.shape} y_pred={y_pred.shape}"

    deviation: npt.NDArray[np.float32] = np.linalg.norm(y_true - y_pred, axis=1)  # type: ignore
    return np.median(deviation).item(), bootstrap_statistic(deviation, np.median)


def mode_deviation(
    y_true: npt.NDArray[np.float32], y_pred: npt.NDArray[np.float32]
) -> tuple[float, float]:
    has_shape_2 = len(y_true.shape) == len(y_pred.shape) == 2
    are_points = y_true.shape[1] == y_pred.shape[1] == 2
    assert (
        has_shape_2 and are_points
    ), f"Expected vectors of dim (N, 2): y_true={y_true.shape} y_pred={y_pred.shape}"

    deviation: npt.NDArray[np.float32] = np.linalg.norm(y_true - y_pred, axis=1)  # type: ignore
    mode = stats.mode(deviation, axis=None).mode  # type: ignore
    return mode, bootstrap_statistic(
        deviation,
        lambda x: stats.mode(x).mode,  # type: ignore
    )

In [3]:
import math
from collections import OrderedDict
from typing import TypedDict
from functools import reduce

import torch


class DelightCnnParameters(TypedDict):
    nconv1: int
    nconv2: int
    nconv3: int
    ndense: int
    levels: int
    dropout: float
    rot: bool
    flip: bool


class RotationAndFlipLayer(torch.nn.Module):
    def __init__(self, rot: bool = True, flip: bool = True):
        super().__init__()  # type: ignore
        self.rot = rot
        self.flip = flip
        self.n_transforms = (int(flip) + 1) * (3 * int(rot) + 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        stacked = reduce(lambda x, y: x * y, x.shape[:-3], 1)

        if self.rot is False and self.flip is False:
            x = x.reshape(stacked, x.shape[-3], x.shape[-2], x.shape[-1])
            return x

        w_dim = len(x.shape) - 2
        h_dim = len(x.shape) - 1
        transforms: tuple[torch.Tensor, ...]

        if self.rot is False:
            flipped = x.flip(dims=(h_dim,))
            transforms = (x, flipped)

        elif self.flip is False:
            rot90 = x.rot90(k=1, dims=(w_dim, h_dim))
            rot180 = x.rot90(k=2, dims=(w_dim, h_dim))
            rot270 = x.rot90(k=3, dims=(w_dim, h_dim))
            transforms = (x, rot90, rot180, rot270)

        else:
            rot90 = x.rot90(k=1, dims=(w_dim, h_dim))
            rot180 = x.rot90(k=2, dims=(w_dim, h_dim))
            rot270 = x.rot90(k=3, dims=(w_dim, h_dim))
            flipped = x.flip(dims=(h_dim,))
            flipped_rot90 = flipped.rot90(k=1, dims=(w_dim, h_dim))
            flipped_rot180 = flipped.rot90(k=2, dims=(w_dim, h_dim))
            flipped_rot270 = flipped.rot90(k=3, dims=(w_dim, h_dim))
            transforms = (
                x,
                rot90,
                rot180,
                rot270,
                flipped,
                flipped_rot90,
                flipped_rot180,
                flipped_rot270,
            )

        x = torch.cat(transforms, dim=1)
        return x.reshape(
            stacked * self.n_transforms, x.shape[-3], x.shape[-2], x.shape[-1]
        )


class DelightCnn(torch.nn.Module):
    def __init__(self, options: DelightCnnParameters):
        super().__init__()  # type: ignore
        bottleneck: OrderedDict[str, torch.nn.Module] = OrderedDict(
            [
                ("conv1", torch.nn.Conv2d(1, options["nconv1"], 3)),
                ("relu1", torch.nn.ReLU()),
                ("mp1", torch.nn.MaxPool2d(2)),
                ("conv2", torch.nn.Conv2d(options["nconv1"], options["nconv2"], 3)),
                ("relu2", torch.nn.ReLU()),
                ("mp2", torch.nn.MaxPool2d(2)),
                ("conv3", torch.nn.Conv2d(options["nconv2"], options["nconv3"], 3)),
                ("relu3", torch.nn.ReLU()),
                ("flatten", torch.nn.Flatten()),
            ]
        )
        linear_in = self._compute_dense_features(
            levels=options["levels"], bottleneck=bottleneck
        )
        self.fc1 = torch.nn.Linear(
            in_features=linear_in, out_features=options["ndense"]
        )
        self.tanh = torch.nn.Tanh()
        self.dropout = torch.nn.Dropout(p=options["dropout"])
        self.fc2 = torch.nn.Linear(in_features=options["ndense"], out_features=2)
        self.rot_and_flip = RotationAndFlipLayer(
            rot=options["rot"], flip=options["flip"]
        )
        self.bottleneck = torch.nn.Sequential(bottleneck)

    def _compute_dense_features(
        self,
        *,
        bottleneck: OrderedDict[str, torch.nn.Module],
        levels: int,
    ) -> int:
        w = 30
        h = 30
        conv_out = 0
        for layer in bottleneck.values():
            k: int
            if isinstance(layer, torch.nn.Conv2d):
                k = layer.kernel_size[0]
                w = w - k + 1
                h = h - k + 1
                conv_out = layer.out_channels
            if isinstance(layer, torch.nn.MaxPool2d):
                k = layer.kernel_size  # type: ignore
                w = math.floor((w - k) / 2 + 1)
                h = math.floor((h - k) / 2 + 1)

        return w * h * conv_out * levels

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch = x.shape[0]  # TODO: Remove batch dependency

        # Apply flips and rotations over level (L) dimension
        x = self.rot_and_flip(x)

        # Bottleneck
        x = self.bottleneck(x)

        # Undo transformations
        x = x.reshape(batch, self.rot_and_flip.n_transforms, -1)

        # Linear
        x = self.fc1(x)
        x = self.tanh(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x.reshape(batch, self.rot_and_flip.n_transforms * 2)

    def derotate(self, y_pred: torch.Tensor) -> npt.NDArray[np.float32]:
        y_pred_numpy: npt.NDArray[np.float32] = y_pred.cpu().numpy()
        return (
            np.dstack(
                [
                    y_pred_numpy.reshape((y_pred_numpy.shape[0], 8, 2))[:, 0],
                    y_pred_numpy.reshape((y_pred_numpy.shape[0], 8, 2))[:, 1, ::-1]
                    * [1, -1],
                    y_pred_numpy.reshape((y_pred_numpy.shape[0], 8, 2))[:, 2, :]
                    * [-1, -1],
                    y_pred_numpy.reshape((y_pred_numpy.shape[0], 8, 2))[:, 3, ::-1]
                    * [-1, 1],
                    y_pred_numpy.reshape((y_pred_numpy.shape[0], 8, 2))[:, 4, :]
                    * [1, -1],
                    y_pred_numpy.reshape((y_pred_numpy.shape[0], 8, 2))[:, 5, ::-1],
                    y_pred_numpy.reshape((y_pred_numpy.shape[0], 8, 2))[:, 6, :]
                    * [-1, 1],
                    y_pred_numpy.reshape((y_pred_numpy.shape[0], 8, 2))[:, 7, ::-1]
                    * [-1, -1],
                ]
            )
            .reshape((y_pred_numpy.shape[0], 2, 8))
            .swapaxes(1, 2)
        )

In [4]:
import os
from dataclasses import dataclass
from enum import Enum
from typing import Any, cast

import numpy as np
import tensorflow as tf
import tensorflow.experimental.numpy as tnp  # type: ignore
import torch
from torch.utils.data import Dataset


class DelightDatasetType(Enum):
    TRAIN = "TRAIN"
    TEST = "TEST"
    VALIDATION = "VALIDATION"


@dataclass
class DelightDatasetOptions:
    source: str
    n_levels: int
    fold: int
    mask: bool
    object: bool
    rot: bool
    flip: bool
    balance: bool = True


class DelightDataset(Dataset[tuple[torch.Tensor, torch.Tensor]]):
    def __init__(
        self,
        options: DelightDatasetOptions,
        datatype: DelightDatasetType,
        transform_y: bool = True,
    ):
        X, y = self.get_data(options, datatype)

        self.X = torch.Tensor(X).permute(0, 3, 1, 2)

        self.y = (
            self.transform(
                y,
                options.rot,
                options.flip,
            )
            if transform_y
            else torch.from_numpy(y)  # type: ignore
        )

    def get_data(
        self, options: DelightDatasetOptions, datatype: DelightDatasetType
    ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
        enum = {
            DelightDatasetType.TRAIN: self.get_train_data,
            DelightDatasetType.VALIDATION: self.get_val_data,
            DelightDatasetType.TEST: self.get_test_data,
        }

        return enum[datatype](options)

    def get_train_data(
        self, options: DelightDatasetOptions
    ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
        nlevels = options.n_levels
        ifold = options.fold
        domask = options.mask
        doobject = options.object
        source = options.source
        balance = options.balance

        oid_train: npt.NDArray[np.str_] = np.load(
            os.path.join(
                source,
                f"oid_train_nlevels{nlevels}_fold{ifold}_mask{domask}_objects{doobject}.npy",
            ),
            allow_pickle=True,
        )
        y_train: npt.NDArray[np.float32] = np.load(
            os.path.join(
                source,
                f"y_train_nlevels{nlevels}_fold{ifold}_mask{domask}_objects{doobject}.npy",
            )
        )
        X_train: npt.NDArray[np.float32] = np.load(
            os.path.join(
                source,
                f"X_train_nlevels{nlevels}_fold{ifold}_mask{domask}_objects{doobject}.npy",
            )
        )

        if balance is False:
            return X_train, y_train

        # create balanced training set
        idxAsiago = np.array(
            [i for i in range(oid_train.shape[0]) if oid_train[i][:2] == "SN"]
        )
        idxZTF = np.array(
            [i for i in range(oid_train.shape[0]) if oid_train[i][:3] == "ZTF"]
        )
        nimb = int(idxZTF.shape[0] / idxAsiago.shape[0])

        idxbal = np.array([], dtype=int)
        for i in range(nimb + 1):
            idxbal = np.concatenate([idxbal, idxAsiago])
            idxbal = np.concatenate(
                [
                    idxbal,
                    idxZTF[
                        i * idxAsiago.shape[0] : min(
                            idxZTF.shape[0], (i + 1) * idxAsiago.shape[0]
                        )
                    ],
                ]
            )
        # shuffle inplace
        np.random.shuffle(idxbal)

        oid_train = oid_train[idxbal]
        X_train = X_train[idxbal]
        y_train = y_train[idxbal]

        return X_train, y_train

    def get_val_data(
        self, options: DelightDatasetOptions
    ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
        nlevels = options.n_levels
        ifold = options.fold
        domask = options.mask
        doobject = options.object
        source = options.source
        pixscale = 0.25

        oid_val: npt.NDArray[np.str_] = np.load(
            os.path.join(
                source,
                f"oid_val_nlevels{nlevels}_fold{ifold}_mask{domask}_objects{doobject}.npy",
            ),
            allow_pickle=True,
        )
        y_val: npt.NDArray[np.float32] = np.load(
            os.path.join(
                source,
                f"y_val_nlevels{nlevels}_fold{ifold}_mask{domask}_objects{doobject}.npy",
            )
        )
        X_val: npt.NDArray[np.float32] = np.load(
            os.path.join(
                source,
                f"X_val_nlevels{nlevels}_fold{ifold}_mask{domask}_objects{doobject}.npy",
            )
        )

        # mask only the validation set (having difficult cases in the training set helps the validation)
        distance = np.sqrt(np.sum(y_val**2, axis=1))
        mask = (distance * pixscale) < 60
        X_val = X_val[mask]
        y_val = y_val[mask]
        oid_val = oid_val[mask]

        return X_val, y_val

    def get_test_data(
        self, options: DelightDatasetOptions
    ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
        nlevels = options.n_levels
        domask = options.mask
        doobject = options.object
        source = options.source

        # oid_test = np.load(os.path.join(source, f"oid_test_nlevels{nlevels}_mask{domask}_objects{doobject}.npy"), allow_pickle=True)
        y_test = np.load(
            os.path.join(
                source, f"y_test_nlevels{nlevels}_mask{domask}_objects{doobject}.npy"
            )
        )
        X_test = np.load(
            os.path.join(
                source, f"X_test_nlevels{nlevels}_mask{domask}_objects{doobject}.npy"
            )
        )

        return X_test, y_test

    @staticmethod
    def transform(
        y: np.ndarray[Any, np.dtype[np.float32]], rot: bool, flip: bool
    ) -> torch.Tensor:
        transformed: tuple[np.ndarray[Any, np.dtype[np.float32]], ...]

        if rot is False and flip is False:
            return torch.Tensor(y)

        yflip = cast(np.ndarray[Any, np.dtype[np.float32]], [1, -1] * y)
        if rot is False:
            transformed = (y, yflip)

        y90 = cast(np.ndarray[Any, np.dtype[np.float32]], [-1, 1] * y[:, ::-1])
        y180 = cast(np.ndarray[Any, np.dtype[np.float32]], [-1, 1] * y90[:, ::-1])
        y270 = cast(np.ndarray[Any, np.dtype[np.float32]], [-1, 1] * y180[:, ::-1])
        yflip90 = cast(np.ndarray[Any, np.dtype[np.float32]], [-1, 1] * yflip[:, ::-1])
        yflip180 = cast(
            np.ndarray[Any, np.dtype[np.float32]], [-1, 1] * yflip90[:, ::-1]
        )
        yflip270 = cast(
            np.ndarray[Any, np.dtype[np.float32]], [-1, 1] * yflip180[:, ::-1]
        )

        if flip is False:
            transformed = (y, y90, y180, y270)
        else:
            transformed = (y, y90, y180, y270, yflip, yflip90, yflip180, yflip270)

        return torch.Tensor(np.concatenate(transformed, axis=1))

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx: int):
        x = self.X[idx]
        y = self.y[idx]

        if len(x.shape) == 3:  # has no channel information
            levels, width, height = x.shape
            x = x.reshape(levels, 1, width, height)  # asume 1 channel information
        return x, y

    def to_tf_dataset(self) -> tuple[tf.Tensor, tf.Tensor]:
        X = cast(np.ndarray[Any, np.dtype[np.float32]], self.X.numpy())
        y = cast(np.ndarray[Any, np.dtype[np.float32]], self.y.numpy())

        return tnp.copy(X.transpose((0, 2, 3, 1))), tnp.copy(y)  # type: ignore

2024-06-15 01:50:36.331266: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-15 01:50:36.331320: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-15 01:50:36.332143: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-15 01:50:36.496189: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
import tempfile
from typing import TypedDict
import datetime

import torch
from ray import train
from ray.train import Checkpoint
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
import logging


class HyperParameters(TypedDict):
    lr: float
    batch_size: int | float
    nconv1: int | float
    nconv2: int | float
    nconv3: int | float
    ndense: int | float
    dropout: float
    epochs: int


class EvaluationResult(TypedDict):
    rmse: tuple[float, float]
    mean_deviation: tuple[float, float]
    median_deviation: tuple[float, float]
    mode_deviation: tuple[float, float]


class EarlyStopper:
    def __init__(self, patience: int = 1, min_delta: float = 0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter: int = 0
        self.min_validation_loss = float("inf")

    def early_stop(self, validation_loss: float, logger: logging.Logger) -> bool:
        if validation_loss < self.min_validation_loss:
            logger.info(
                f"Validation loss has been improved from {self.min_validation_loss} -> {validation_loss}"
            )
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            logger.info(
                f"Validation loss is not improving. Best val loss={self.min_validation_loss}"
            )
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False


def _get_value_from_parameter(parameter: int | float, base: int = 2) -> int:
    return int(base**parameter) if isinstance(parameter, float) else parameter


def get_delight_cnn_parameters(
    params: HyperParameters, options: DelightDatasetOptions
) -> DelightCnnParameters:
    return {
        "nconv1": _get_value_from_parameter(params["nconv1"]),
        "nconv2": _get_value_from_parameter(params["nconv2"]),
        "nconv3": _get_value_from_parameter(params["nconv3"]),
        "ndense": _get_value_from_parameter(params["ndense"]),
        "levels": options.n_levels,
        "dropout": params["dropout"],
        "rot": options.rot,
        "flip": options.flip,
    }


def _train_one_epoch(
    *,
    device: str,
    train_dl: DataLoader[tuple[torch.Tensor, torch.Tensor]],
    optimizer: torch.optim.Optimizer,
    model: DelightCnn,
    criterion: torch.nn.MSELoss,
    writer: SummaryWriter,
    is_ray: bool,
    epoch: int,
):
    running_loss = 0.0
    inputs: torch.Tensor
    positions: torch.Tensor
    outputs: torch.Tensor
    loss: torch.Tensor

    model.train()
    for i, (inputs, positions) in enumerate(train_dl):
        inputs, positions = inputs.to(device), positions.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)

        loss = criterion(outputs, positions)
        loss.backward()  # type: ignore

        optimizer.step()
        loss_value = loss.item()

        if is_ray is False:
            t = epoch * len(train_dl) + i  # type: ignore
            writer.add_scalar("[MSE Loss]: Train", loss_value, t)  # type: ignore

        running_loss += loss_value * inputs.size(0)

    return running_loss / len(train_dl.dataset)  # type: ignore


def _validate_train(
    *,
    device: str,
    val_dl: DataLoader[tuple[torch.Tensor, torch.Tensor]],
    model: DelightCnn,
    criterion: torch.nn.MSELoss,
):
    running_loss = 0.0
    data: tuple[torch.Tensor, torch.Tensor]
    outputs: torch.Tensor
    loss: torch.Tensor

    model.eval()
    with torch.no_grad():
        for _, data in enumerate(val_dl):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)

            loss = criterion(outputs, labels)
            running_loss += loss.item() * labels.size(0)

    return running_loss / len(val_dl.dataset)  # type: ignore


def _train(
    *,
    start_epoch: int,
    num_epochs: int,
    device: str,
    train_dl: DataLoader[tuple[torch.Tensor, torch.Tensor]],
    val_dl: DataLoader[tuple[torch.Tensor, torch.Tensor]],
    optimizer: torch.optim.Optimizer,
    model: DelightCnn,
    criterion: torch.nn.MSELoss,
    logger: logging.Logger,
    early_stop: bool,
    is_ray: bool = False,
):
    model.to(device)
    writer = SummaryWriter(
        comment=datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%SZ")
    )
    early_stopper = EarlyStopper(patience=3, min_delta=0)
    for epoch in range(start_epoch, num_epochs):
        train_loss = _train_one_epoch(
            device=device,
            train_dl=train_dl,
            optimizer=optimizer,
            model=model,
            criterion=criterion,
            is_ray=is_ray,
            writer=writer,
            epoch=epoch,
        )

        val_loss = _validate_train(
            device=device, val_dl=val_dl, model=model, criterion=criterion
        )

        logger.info(
            f"[EPOCH {epoch+1}] train loss = {train_loss} | val_loss = {val_loss}"
        )
        metrics = {"val_loss": val_loss, "train_loss": train_loss}

        if is_ray is False:
            writer.add_scalars("[MSE Loss]: Train / Validation", metrics, epoch)  # type: ignore
        else:
            with tempfile.TemporaryDirectory() as tempdir:
                torch.save(  # type: ignore
                    {
                        "epoch": epoch,
                        "net_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                    },
                    os.path.join(tempdir, "checkpoint.pt"),
                )
                train.report(
                    metrics=metrics, checkpoint=Checkpoint.from_directory(tempdir)
                )  # type: ignore

        if early_stop and early_stopper.early_stop(
            validation_loss=val_loss, logger=logger
        ):
            logger.info(f"Stopped due Early Stop condition, last epoch: {epoch}")
            break

    writer.close()


def train_delight_cnn_model(
    params: HyperParameters, options: DelightDatasetOptions, early_stop: bool = True
) -> DelightCnn:
    device = "cpu" if torch.cuda.is_available() is False else "cuda"
    batch_size = _get_value_from_parameter(params["batch_size"])
    net_options = get_delight_cnn_parameters(params, options)
    net = DelightCnn(net_options)

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=params["lr"], weight_decay=1e-4)
    checkpoint = cast(Checkpoint | None, train.get_checkpoint())  # type: ignore
    start_epoch = 0

    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            checkpoint_dict = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))  # type: ignore
            start_epoch = int(checkpoint_dict["epoch"]) + 1
            net.load_state_dict(checkpoint_dict["net_state_dict"])
            optimizer.load_state_dict(checkpoint_dict["optimizer_state_dict"])

    train_dataset = DelightDataset(options=options, datatype=DelightDatasetType.TRAIN)
    val_dataset = DelightDataset(
        options=options, datatype=DelightDatasetType.VALIDATION
    )
    train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    logging.basicConfig()
    logger = logging.getLogger("Training")
    logger.setLevel(logging.DEBUG)

    logger.info(
        "Starting: epochs=%s,batch_size=%s,lr=%s,nconv1=%s,nconv2=%s,nconv3=%s,ndense=%s,dropout=%s"
        % (
            params["epochs"],
            batch_size,
            params["lr"],
            net_options["nconv1"],
            net_options["nconv2"],
            net_options["nconv3"],
            net_options["ndense"],
            net_options["dropout"],
        )
    )

    _train(
        start_epoch=start_epoch,
        num_epochs=params["epochs"],
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        optimizer=optimizer,
        model=net,
        criterion=criterion,
        is_ray=checkpoint is not None,
        logger=logger,
        early_stop=early_stop,
    )

    return net


def evaluate_delight_cnn_model(
    model: DelightCnn, options: DelightDatasetOptions
) -> EvaluationResult:
    device = "cpu" if torch.cuda.is_available() is False else "cuda"
    dataset = DelightDataset(
        options=options, datatype=DelightDatasetType.TEST, transform_y=False
    )
    dl = DataLoader(dataset, batch_size=16, shuffle=False)

    print("Evaluating model...")
    predictions: list[tuple[float, float]] = []

    model.to(device)
    model.eval()

    inputs: torch.Tensor
    outputs: torch.Tensor
    with torch.no_grad():
        for _, data in enumerate(dl):
            inputs, _ = data
            inputs = inputs.to(device)
            outputs = model(inputs)
            derotated = model.derotate(outputs)
            y_hat: npt.NDArray[np.float32] = np.mean(derotated, axis=1)
            predictions.extend(y_hat.tolist())

    y_true: npt.NDArray[np.float32] = dataset.y.cpu().numpy()
    y_pred: npt.NDArray[np.float32] = np.array(predictions)

    return {
        "rmse": rmse(y_true, y_pred),
        "mean_deviation": mean_deviation(y_true, y_pred),
        "median_deviation": median_deviation(y_true, y_pred),
        "mode_deviation": mode_deviation(y_true, y_pred),
    }

In [None]:
%tensorboard --logdir=runs

options = DelightDatasetOptions(
    source=os.path.join(os.getcwd(), "data"),
    n_levels=5,
    fold=0,
    mask=False,
    object=True,
    rot=True,
    flip=True,
    balance=True,
)

params_paper: HyperParameters = {
    "nconv1": 52,
    "nconv2": 57,
    "nconv3": 41,
    "ndense": 685,
    "dropout": 0.06,
    "epochs": 50,
    "batch_size": 40,
    "lr": 0.0014,
}

params: HyperParameters = {
    "nconv1": 16,
    "nconv2": 32,
    "nconv3": 32,
    "ndense": 128,
    "dropout": 0,
    "epochs": 50,
    "batch_size": 64,
    "lr": 0.0014,
}

model = train_delight_cnn_model(params_paper, options)

In [None]:
evaluation = evaluate_delight_cnn_model(model, options)

evaluation

## Busqueda de hiperparámetros

In [6]:
import asyncio
import os
import socket

from ray import tune
from ray.tune.schedulers import ASHAScheduler
from telegram import Bot


class TelegramNotifier:
    def __init__(self, token: str, chat_id: int):
        self._bot = Bot(token)
        self._chat_id = chat_id

    def notify(self, message: str) -> None:
        async def notify(bot: Bot, message: str, chat_id: int):
            async with bot:
                await bot.send_message(
                    text=message, parse_mode="MarkDown", chat_id=chat_id
                )  # type: ignore

        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:  # 'RuntimeError: There is no current event loop...'
            loop = None

        if loop and loop.is_running():
            print(
                "Async event loop already running. Adding coroutine to the event loop."
            )
            tsk = loop.create_task(notify(self._bot, message, self._chat_id))
            # ^-- https://docs.python.org/3/library/asyncio-task.html#task-object
            # Optionally, a callback function can be executed when the coroutine completes
            tsk.add_done_callback(lambda t: print("Message sent"))
        else:
            print("Starting new event loop")
            asyncio.run(notify(self._bot, message, self._chat_id))


def run_ray_tune(*, name: str, num_samples: int, gpus_per_trial: float, source: str):
    params = {
        "nconv1": tune.lograndint(16, 64 + 1),
        "nconv2": tune.lograndint(16, 64 + 1),
        "nconv3": tune.lograndint(16, 64 + 1),
        "ndense": tune.lograndint(256, 2048 + 1),
        "dropout": tune.uniform(0, 0.4),
        "batch_size": tune.lograndint(16, 64 + 1),
        "lr": tune.loguniform(1e-4, 1e-2),
        "epochs": 100,
    }

    options = DelightDatasetOptions(
        source=source, n_levels=5, fold=0, mask=False, object=True, rot=True, flip=True
    )

    scheduler = ASHAScheduler(
        grace_period=20,  # epochs before evaluate early stop
        reduction_factor=3,  # the worst 1/3 trials will be terminated
        brackets=1,  # we don't want to decrease resources
    )

    def train_fn(params: HyperParameters) -> None:
        train_delight_cnn_model(params, options, early_stop=False)

    tuner = tune.Tuner(
        tune.with_resources(train_fn, resources={"gpu": gpus_per_trial}),  # type: ignore
        tune_config=tune.TuneConfig(
            metric="val_loss", mode="min", scheduler=scheduler, num_samples=num_samples
        ),
        run_config=train.RunConfig(name=name),
        param_space=params,
    )
    return tuner.fit()

In [7]:
import sys

now = datetime.datetime.now()
name = f"ray_experiment_{now.strftime('%d_%m_%Y-%H_%M_%S')}"
num_samples = 200
machine = socket.gethostname()
chat_id = -4049822363
token = "6333721085:AAGbLdRmJsn8TU-gTrSu8npgXNgOaNmBwcs"
notifier = TelegramNotifier(token=token, chat_id=chat_id)
sources = {
    "quimal-gpu.alerce.online": "/home/kpinochet/delight/data",
    "LAPTOP-CUH9J3SR": "/home/keviinplz/universidad/tesis/astro-delight/data",
}

default_source = "/home/keviinplz/universidad/tesis/astro-delight/data"

message = f"""
**Experimento `{name}` iniciado el día {now.strftime('%d-%m-%Y a las %H:%M:%S')} UTC**

Información del experimento:

```
Pruebas: {num_samples}
Máquina: {socket.gethostname()}
```
"""

notifier.notify(message)

try:
    result = run_ray_tune(
        name=name,
        num_samples=num_samples,
        gpus_per_trial=1,
        source=sources.get(machine, default_source),
    )
except Exception as e:
    finish = datetime.datetime.now()
    message = f"""
    **El experimento `{name}` ha fallado**

    Razon: {str(e)}
 
    Este experimento se ha ejecutado en la máquina {socket.gethostname()}
    Y fue iniciado el día {now.strftime('%d-%m-%Y a las %H:%M:%S')} UTC
    Finalizando el día {finish.strftime('%d-%m-%Y a las %H:%M:%S')} UTC

    """
    notifier.notify(message)
    sys.exit(1)

finish = datetime.datetime.now()

df = result.get_dataframe()

df_folder = os.path.join(os.getcwd(), "ray_results_df")
os.makedirs(df_folder, exist_ok=True)

df_filename = os.path.join(df_folder, name + ".pkl")
result_path = "/".join(result.get_best_result().path.split("/")[:-1])
df.to_pickle(df_filename)

best_quantity = 10
data = (
    df.sort_values(by=["val_loss"])[["val_loss", "train_loss"]]
    .head(best_quantity)
    .to_dict(orient="records")
)  # type: ignore

rows = ""
for i, d in enumerate(data):
    rows += f'    |   {i+1}  |  {round(d["val_loss"],3)}  |    {round(d["train_loss"],3)}   |\n'

message = f"""
**El experimento `{name}` ha finalizado**

Mejores {len(data)} resultados:

```
| Rank | val_loss | train_loss |
|------|:--------:|:----------:|
{rows}
```

Este experimento se ha ejecutado en la máquina {socket.gethostname()}
Y fue iniciado el día {now.strftime('%d-%m-%Y a las %H:%M:%S')} UTC
Finalizando el día {finish.strftime('%d-%m-%Y a las %H:%M:%S')} UTC

Se ha guardado un dataframe con los resultados en `{df_filename}`

A su vez, el experimento se encuentra en `{result_path}`
"""
notifier.notify(message)

0,1
Current time:,2024-06-15 02:01:00
Running for:,00:10:06.74
Memory:,2.6/7.5 GiB

Trial name,status,loc,batch_size,dropout,lr,nconv1,nconv2,nconv3,ndense
train_fn_3a650_00000,RUNNING,172.17.177.163:39754,20,0.235285,0.00233303,57,55,27,668
train_fn_3a650_00001,PENDING,,22,0.227047,0.000148812,22,44,16,1218
train_fn_3a650_00002,PENDING,,27,0.377357,0.000384833,46,40,24,794
train_fn_3a650_00003,PENDING,,17,0.0631622,0.00981663,59,44,17,828
train_fn_3a650_00004,PENDING,,19,0.0746885,0.00205198,20,31,21,416
train_fn_3a650_00005,PENDING,,21,0.0522794,0.000291489,34,46,21,628
train_fn_3a650_00006,PENDING,,24,0.149872,0.00030532,24,20,16,1233
train_fn_3a650_00007,PENDING,,24,0.290728,0.00537451,40,57,41,568
train_fn_3a650_00008,PENDING,,16,0.150494,0.00281034,35,56,34,709
train_fn_3a650_00009,PENDING,,35,0.0554766,0.00619164,43,23,35,955


[36m(pid=39754)[0m 2024-06-15 01:50:56.267668: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=39754)[0m 2024-06-15 01:50:56.267716: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=39754)[0m 2024-06-15 01:50:56.267741: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[36m(pid=39754)[0m 2024-06-15 01:50:56.273533: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
[36m(pid=39754)[0m To enable the following instructions: AVX2 FMA, in other operations, r

Async event loop already running. Adding coroutine to the event loop.


AttributeError: 'tuple' object has no attribute 'tb_frame'

Message sent
Message sent


[33m(raylet)[0m [2024-06-15 02:01:52,691 E 39205 39205] (raylet) node_manager.cc:3035: 2 Workers (tasks / actors) killed due to memory pressure (OOM), 0 Workers crashed due to other reasons at node (ID: 227f6473ba966ed8118157c67cf731153a2b7d9b1e6a97d25c210079, IP: 172.17.177.163) over the last time period. To see more information about the Workers killed on this node, use `ray logs raylet.out -ip 172.17.177.163`
[33m(raylet)[0m 
[33m(raylet)[0m Refer to the documentation on how to address the out of memory issue: https://docs.ray.io/en/latest/ray-core/scheduling/ray-oom-prevention.html. Consider provisioning more memory on this node or reducing task parallelism by requesting more CPUs per task. To adjust the kill threshold, set the environment variable `RAY_memory_usage_threshold` when starting Ray. To disable worker killing, set the environment variable `RAY_memory_monitor_refresh_ms` to zero.
Task exception was never retrieved
future: <Task finished name='Task-6' coro=<Telegram