# IMPORTS

In [1]:
from google.cloud import storage
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions, StandardOptions
from apache_beam.runners.interactive.interactive_runner import InteractiveRunner
import apache_beam.runners.interactive.interactive_beam as ibeam
import os
import subprocess
import glob
import shutil
from typing import Union
from pathlib import Path
import pandas as pd
import numpy as np
from collections import defaultdict
from typing import List, Tuple, Any, Dict
from astropy.stats import sigma_clip
import itertools
from gcsfs import GCSFileSystem
import tensorflow as tf

2024-10-10 10:58:34.313078: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-10-10 10:58:34.315301: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-10-10 10:58:34.346901: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-10 10:58:34.346930: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-10 10:58:34.347754: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to

# CONTENT

## FUNCTIONS

In [2]:
def list_blobs(bucket_name: str, folder: str | None = None) -> list[str]:
    """List the object in the bucket (folder).

    Args:
        bucket_name (str): Bucket to consider.
        folder (str | None, optional): Folder to whose elements to list. Defaults to None.

    Returns:
        list[str]: Listed objects' uris.
    """
    storage_client = storage.Client()
    blobs = storage_client.list_blobs(bucket_name)
    if folder:
        folder_objects = []
        for blob in blobs:
            if blob.name.startswith(folder):
                folder_objects.append("gs://" + bucket_name + "/" + blob.name)
        return folder_objects
    else:
        return ["gs://" + bucket_name + "/" + blob.name for blob in blobs]

In [3]:
def group_by_key(key_value: List[Tuple[Any, Any]]) -> List[List[Any]]:
    """
    Group a list of key-value pairs by key.

    Args:
        key_value (List[Tuple[Any, Any]]): List of tuples where each tuple contains a key and a value.

    Returns:
        List[List[Any]]: A list of lists where each sublist contains the key and a list of associated values.
    """
    grouped_data = defaultdict(list)
    for key, value in key_value:
        grouped_data[key].append(value)
    return [[key, values] for key, values in grouped_data.items()]

In [4]:
def split_key_value(path: str):
    l = path.split("/")
    if len(l) <= 5:
        return (l[3], "/".join(l) if l[4] != "" else None)
    else:
        return (l[5], "/".join(l))

In [5]:
def get_raw_data_uris(
    bucket_name: str, folder: str | None = None
) -> List[Tuple[str, List[str]]]:
    """
    Get the data URIs in the Google Cloud Storage bucket.

    Args:
        bucket_name (str): Bucket to consider.
        floder (str): Folder to whose elements to list. Defaults to None.

    Returns:
        List[Tuple[str, List[str]]]: List of tuples where each tuple has the planet ID as the key and its data as the value.
    """
    paths = list_blobs(bucket_name, folder)
    key_value_pairs = [split_key_value(path=path) for path in paths]
    raw_data = [
        item[1] for item in key_value_pairs if item[0] == "raw" and item[1] is not None
    ]
    non_raw_data = [item for item in key_value_pairs if item[0] != "raw"]
    grouped_data = group_by_key(non_raw_data)
    result = [(key, value + raw_data) for key, value in grouped_data]
    return result

In [62]:
def save_to_tfrecords(dataset: tf.data.Dataset, path: str):
    with tf.io.TFRecordWriter(path) as file_writer:
        for record in dataset:
            try:
                file_writer.write(record.numpy())
            except AttributeError:
                file_writer.write(record)

In [7]:
def decode_fn(proto):
    return tf.io.parse_tensor(proto, out_type=tf.float64)

In [8]:
def read_tfrecord(path: str):
    return tf.data.TFRecordDataset(path)

In [9]:
def print_records(dataset: tf.data.Dataset):
    for record in dataset:
        print(record.numpy())

In [10]:
def deserialize_example(ex: tf.train.Example, formats: dict, decoders: dict | None):
    example = tf.io.parse_single_example(example, formats)
    return example

In [160]:
def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "id": tf.io.FixedLenFeature([], tf.int64),
        "airs": tf.io.FixedLenFeature([], tf.string),
        "fgs": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    id = example["id"]
    airs = decode_fn(example["airs"])
    fgs = decode_fn(example["fgs"])
    target = decode_fn(example["target"])
    return id, airs, fgs, target

## CLASSES

In [153]:
class CalibrationFn(beam.DoFn):
    """Perfroms the raw data calibration."""

    def __init__(
        self,
        cut_inf: int,
        cut_sup: int,
        mask: bool,
        corr: bool,
        dark: bool,
        flat: bool,
        binning: int | None = None,
    ):
        """_summary_

        Args:
            cut_inf (int): Images cropping range lower limit.
            cut_sup (int): Images cropping range upper limit.
            binning (int | None): Number of images to bin together.
            mask (bool): Wheteher to mask dead and hot pixels.
            corr (bool): Wheteher to apply linear correction.
            dark (bool): Whether to apply current dark correction.
            flat (bool): Whether to apply flat pixels correction.
        """

        self.CUT_INF = cut_inf
        self.CUT_SUP = cut_sup
        self.BINNING = binning
        self.MASK = mask
        self.CORR = corr
        self.DARK = dark
        self.FLAT = flat

    def process(self, element):
        id, uris = element
        airs_data, fgs_data, info_data = self._load_data(id, uris)
        airs_signal = self._calibrate_airs_data(airs_data, info_data)
        airs_signal = tf.convert_to_tensor(airs_signal.reshape(1, *airs_signal.shape))
        fgs_signal = self._calibrate_fgs_data(fgs_data, info_data)
        fgs_signal = tf.convert_to_tensor(fgs_signal.reshape(1, *fgs_signal.shape))
        labels = (
            np.array([np.nan]) if info_data["labels"] is None else info_data["labels"]
        )
        labels = tf.convert_to_tensor(labels.reshape(1, *labels.shape))
        return [(int(id), airs_signal, fgs_signal, labels)]

    def _calibrate_airs_data(self, data: dict[pd.DataFrame], info: dict[pd.DataFrame]):
        signal = self._adc_revert(
            data["signal"], info["airs_gain"], info["airs_offset"]
        )
        dt = info["airs_it"]
        dt[1::2] += 0.1
        signal = signal[:, :, self.CUT_INF : self.CUT_SUP]
        if self.MASK:
            signal = self._mask_hot_dead(signal, data["dead"], data["dark"])
        if self.CORR:
            signal = self._apply_linear_corr(data["linear_corr"], signal)
        if self.DARK:
            signal = self._clean_dark_current(signal, data["dead"], data["dark"], dt)
        signal = self._get_cds(signal)
        if self.BINNING:
            signal = self._bin_obs(signal, self.BINNING)
        else:
            signal = signal.transpose(0, 2, 1)
        if self.FLAT:
            signal = self._correct_flat_field(data["flat"], data["dead"], signal)
        return signal

    def _calibrate_fgs_data(self, data: dict[pd.DataFrame], info: dict[pd.DataFrame]):
        signal = self._adc_revert(data["signal"], info["fgs_gain"], info["fgs_offset"])
        dt = np.ones(len(signal)) * 0.1
        dt[1::2] += 0.1
        if self.MASK:
            signal = self._mask_hot_dead(signal, data["dead"], data["dark"])
        if self.CORR:
            signal = self._apply_linear_corr(data["linear_corr"], signal)
        if self.DARK:
            signal = self._clean_dark_current(signal, data["dead"], data["dark"], dt)
        signal = self._get_cds(signal)
        if self.BINNING:
            signal = self._bin_obs(signal, self.BINNING * 12)
        else:
            signal = signal.transpose(0, 2, 1)
        if self.FLAT:
            signal = self._correct_flat_field(data["flat"], data["dead"], signal)
        return signal

    def _adc_revert(self, signal: np.ndarray, gain: float, offset: float) -> np.ndarray:
        """Revert pixel voltage from ADC.

        Args:
            signal (np.ndarray): ADC converted signal integer.
            gain (float): ADC gain error.
            offset (float): ADC offset error.

        Returns:
            np.ndarray: Pixel voltages.
        """
        signal = signal.astype(np.float64)
        signal /= gain
        signal += offset
        return signal

    def _mask_hot_dead(
        self, signal: np.ndarray, dead: np.ndarray, dark: np.ndarray
    ) -> np.ndarray:
        """Mask dead and hot pixels so that they won't be take in account in corrections.

        Args:
            signal (np.ndarray): Pixel voltage signal.
            dead (np.ndarray): Dead pixels.
            dark (np.ndarray): Dark pixels.

        Returns:
            np.ndarray: Pixel voltages with dead pixels masked.
        """
        hot = sigma_clip(dark, sigma=5, maxiters=5).mask
        hot = np.tile(hot, (signal.shape[0], 1, 1))
        dead = np.tile(dead, (signal.shape[0], 1, 1))
        signal = np.ma.masked_where(dead, signal)
        signal = np.ma.masked_where(hot, signal)
        return signal

    def _apply_linear_corr(self, corr: np.ndarray, signal: np.ndarray) -> np.ndarray:
        """Fix non-linearity due to capacity leakage in the detector.

        Args:
            corr (np.ndarray): Correction coefficients
            signal (np.ndarray): Signal to correct.

        Returns:
            np.ndarray: Corrected signal
        """
        linear_corr = np.flip(corr, axis=0)
        for x, y in itertools.product(range(signal.shape[1]), range(signal.shape[2])):
            poli = np.poly1d(linear_corr[:, x, y])
            signal[:, x, y] = poli(signal[:, x, y])
        return signal

    def _clean_dark_current(
        self, signal: np.ndarray, dead: np.ndarray, dark: np.ndarray, dt: np.array
    ) -> np.ndarray:
        """Remove the accumulated charge due to dark current.

        Args:
            signal (np.ndarray): Signal to clean
            dead (np.ndarray): Dead pixels.
            dark (np.ndarray): Dark pixels.
            dt (np.array): Short frames delay.

        Returns:
            np.ndarray: Cleaned signal.
        """
        dark = np.ma.masked_where(dead, dark)
        dark = np.tile(dark, (signal.shape[0], 1, 1))

        signal -= dark * dt[:, np.newaxis, np.newaxis]
        return signal

    def _get_cds(self, signal: np.ndarray) -> np.ndarray:
        """Return the actual accumulated charge (a delta) due to the transit.

        Args:
            signal (np.ndarray): Signal.

        Returns:
            np.ndarray: An image for one observation in the time (Time series observations).
        """
        cds = signal[1::2, :, :] - signal[::2, :, :]
        return cds

    def _bin_obs(self, cds: np.ndarray, binning: int) -> np.ndarray:
        """Binnes cds time series together at the specified frequency.

        Args:
            cds (np.ndarray): CDS signal.
            binning (int): Binning frequency.

        Returns:
            np.ndarray: _description_
        """
        cds_transposed = cds.transpose(0, 2, 1)
        cds_binned = np.zeros(
            (
                cds_transposed.shape[0] // binning,
                cds_transposed.shape[1],
                cds_transposed.shape[2],
            )
        )
        for i in range(cds_transposed.shape[0] // binning):
            cds_binned[i, :, :] = np.sum(
                cds_transposed[i * binning : (i + 1) * binning, :, :], axis=0
            )
        return cds_binned

    def _correct_flat_field(
        self, flat: np.ndarray, dead: np.ndarray, signal: np.ndarray
    ):
        """Correction by calibrating on an uniform signal.

        Args:
            flat (np.ndarray): Flat signal
            dead (np.ndarray): Dead pixels
            signal (np.ndarray): CDS signal
        """

        flat = flat.transpose(1, 0)
        dead = dead.transpose(1, 0)
        flat = np.ma.masked_where(dead, flat)
        flat = np.tile(flat, (signal.shape[0], 1, 1))
        signal = signal / flat
        return signal

    def _load_data(self, id: int, uris: List[str]):
        """
        Load data from a list of URIs, reading the files as CSV or Parquet,
        and organize them by type ('dark', 'dead', etc.) for AIRS and FGS.

        Args:
            id (int): Planet's id.
            uris (List[str]): List of URIs to load data from.
        """
        airs_data = {}
        fgs_data = {}
        info_data = {}
        calib_data_types = [
            "dark",
            "dead",
            "flat",
            "linear_corr",
            "read",
            "signal",
            "labels",
            "axis_info",
            "adc_info",
        ]

        def read_data(uri: str):
            try:
                return pd.read_csv(uri)
            except UnicodeDecodeError:
                return pd.read_parquet(uri)

        for uri in uris:
            df = read_data(uri)
            for data_type in calib_data_types:
                if data_type in uri:
                    if "AIRS" in uri:
                        if "signal" in uri:
                            airs_data[data_type] = df.values.astype(np.float64).reshape(
                                (df.shape[0], 32, 356)
                            )
                        elif "linear_corr" in uri:
                            airs_data[data_type] = df.values.astype(np.float64).reshape(
                                (6, 32, 356)
                            )[:, :, self.CUT_INF : self.CUT_SUP]
                        else:
                            airs_data[data_type] = df.values.astype(np.float64).reshape(
                                (32, 356)
                            )[:, self.CUT_INF : self.CUT_SUP]
                    elif "FGS" in uri:
                        if "signal" in uri:
                            fgs_data[data_type] = df.values.astype(np.float64).reshape(
                                (df.shape[0], 32, 32)
                            )
                        elif "linear_corr" in uri:
                            fgs_data[data_type] = df.values.astype(np.float64).reshape(
                                (6, 32, 32)
                            )
                        else:
                            fgs_data[data_type] = df.values.astype(np.float64).reshape(
                                (32, 32)
                            )
                    else:
                        try:
                            info_data[data_type] = df.set_index("planet_id").loc[
                                int(id)
                            ]
                        except KeyError:
                            if "labels" in uri:
                                info_data[data_type] = None
                            elif "axis_info" in uri:
                                info_data[data_type] = df
                            else:
                                pass
        info_data["fgs_gain"] = info_data["adc_info"]["FGS1_adc_gain"]
        info_data["fgs_offset"] = info_data["adc_info"]["FGS1_adc_offset"]
        info_data["airs_gain"] = info_data["adc_info"]["AIRS-CH0_adc_gain"]
        info_data["airs_offset"] = info_data["adc_info"]["AIRS-CH0_adc_offset"]
        if info_data["labels"] is not None:
            info_data["labels"] = info_data["labels"].values
        info_data["airs_it"] = (
            info_data["axis_info"]["AIRS-CH0-integration_time"].dropna().values
        )
        del info_data["adc_info"], info_data["axis_info"]
        return airs_data, fgs_data, info_data

In [154]:
class CombineDataFn(beam.CombineFn):
    def create_accumulator(self):
        return (None, [])

    def add_input(self, accumulator, input):
        _, bag = accumulator
        if bag is None:
            bag = []
        bag.append(input)
        return (None, bag)

    def merge_accumulators(self, accumulators):
        merge = []
        for _, item in accumulators:
            merge.extend(item)
        return (None, merge)

    def extract_output(self, merge):
        _, merge = merge
        return [merge]

In [155]:
def make_example(id: int, airs: tf.Tensor, fgs: tf.Tensor, target: tf.Tensor):
    id_ft = tf.train.Feature(int64_list=tf.train.Int64List(value=[id]))
    airs_ft = tf.train.Feature(
        bytes_list=tf.train.BytesList(
            value=[
                tf.io.serialize_tensor(airs).numpy(),
            ]
        )
    )
    fgs_ft = tf.train.Feature(
        bytes_list=tf.train.BytesList(
            value=[
                tf.io.serialize_tensor(fgs).numpy(),
            ]
        )
    )
    target_ft = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(target).numpy()])
    )
    features = tf.train.Features(
        feature={"id": id_ft, "airs": airs_ft, "fgs": fgs_ft, "target": target_ft}
    )
    example = tf.train.Example(features=features)
    return example.SerializeToString()

In [157]:
def save_dataset_to_tfrecords(element: list, path: str):
    examples = element[0]
    dataset = tf.data.Dataset.from_generator(
        lambda: iter(examples),
        # output signature important car sans ça TF essaye de concatener tous
        # les éléments du tuple dans un unique tensor ce qui crée des erreurs.
        output_signature=(
            tf.TensorSpec(shape=None, dtype=tf.int64),
            tf.TensorSpec(shape=None, dtype=tf.float64),
            tf.TensorSpec(shape=None, dtype=tf.float64),
            tf.TensorSpec(shape=None, dtype=tf.float64),
        ),
    )
    dataset = dataset.map(
        lambda id, airs, fgs, target: tf.py_function(
            func=make_example, inp=[id, airs, fgs, target], Tout=tf.string
        )
    )
    save_to_tfrecords(dataset, path)

## PIPELINE

In [159]:
with beam.Pipeline() as pipeline:
    uris = get_raw_data_uris("neurips-adc-bucket", "raw")[:2]
    outputs = (
        pipeline
        | "Create initial values" >> beam.Create(uris)
        | "Calibrate data"
        >> beam.ParDo(CalibrationFn(39, 321, False, False, False, False, 30))
        | "Create a data set from accumulators" >> beam.CombineGlobally(CombineDataFn())
        | "Save dataset"
        >> beam.Map(
            lambda x: save_dataset_to_tfrecords(
                x, "gs://neurips-adc-bucket/primary/beam_output.tfrecords"
            )
        )
    )

In [161]:
loaded_ds = read_tfrecord("gs://neurips-adc-bucket/primary/beam_output.tfrecords")

In [162]:
loaded_ds = loaded_ds.map(read_labeled_tfrecord)

In [125]:
LABELED_TFREC_FORMAT = {
    "id": tf.io.FixedLenFeature([], tf.int64),
    "arr1": tf.io.FixedLenFeature([], tf.string),
    "arr2": tf.io.FixedLenFeature([], tf.string),
    "arr3": tf.io.FixedLenFeature([], tf.string),
}

## PIPELINE OPTIONS