In [1]:
#!gsutil -m cp -r gs://neurips-adc-bucket/raw/train/100468857 ../data/raw/train
#!gsutil -m cp -r gs://neurips-adc-bucket/raw/train/1005054328 ../data/raw/train
#!gsutil cp gs://neurips-adc-bucket/raw/* ../data/raw/

# IMPORTS

In [2]:
from google.cloud import storage
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions, StandardOptions
from apache_beam.runners.interactive.interactive_runner import InteractiveRunner
import apache_beam.runners.interactive.interactive_beam as ibeam
import os
import subprocess
import glob
import shutil
from typing import Union
from pathlib import Path
import pandas as pd
import numpy as np
from collections import defaultdict
from typing import List, Tuple, Any, Dict
from astropy.stats import sigma_clip
import itertools
from gcsfs import GCSFileSystem
import tensorflow as tf

2024-10-01 15:20:27.552411: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-10-01 15:20:27.555811: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-10-01 15:20:27.599560: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-01 15:20:27.599602: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-01 15:20:27.601022: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to

# CONTENTS

## EVALUATE WAYS TO LIST THE DATA TO PROCESS

### Using storage client

In [3]:
storage_client = storage.Client()

In [4]:
blobs = storage_client.list_blobs(
    bucket_or_name="neurips-adc-bucket", match_glob="raw/**/*", delimiter="/"
)

In [5]:
[blob for blob in blobs.prefixes]

[]

In [6]:
[blob.name for blob in blobs]

['raw/',
 'raw/axis_info.parquet',
 'raw/sample_submission.csv',
 'raw/test_adc_info.csv',
 'raw/train_adc_info.csv',
 'raw/train_labels.csv',
 'raw/wavelengths.csv']

In [7]:
def list_blobs(bucket_name: str, folder: str | None = None) -> list[str]:
    """List the object in the bucket (folder).

    Args:
        bucket_name (str): Bucket to consider.
        folder (str | None, optional): Folder to whose elements to list. Defaults to None.

    Returns:
        list[str]: Listed objects' uris.
    """
    storage_client = storage.Client()
    blobs = storage_client.list_blobs(bucket_name)
    if folder:
        folder_objects = []
        for blob in blobs:
            if blob.name.startswith(folder):
                folder_objects.append("gs://" + bucket_name + "/" + blob.name)
        return folder_objects
    else:
        return ["gs://" + bucket_name + "/" + blob.name for blob in blobs]

In [8]:
%%timeit
list_blobs("neurips-adc-bucket", folder="raw/test/499191466/")

1.99 s ± 55.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Using GCFS python API

In [9]:
fs = GCSFileSystem()
fs.ls("neurips-adc-bucket/raw")

['neurips-adc-bucket/raw/',
 'neurips-adc-bucket/raw/axis_info.parquet',
 'neurips-adc-bucket/raw/sample_submission.csv',
 'neurips-adc-bucket/raw/test',
 'neurips-adc-bucket/raw/test_adc_info.csv',
 'neurips-adc-bucket/raw/train',
 'neurips-adc-bucket/raw/train_adc_info.csv',
 'neurips-adc-bucket/raw/train_labels.csv',
 'neurips-adc-bucket/raw/wavelengths.csv']

In [10]:
%%timeit
fs.glob("neurips-adc-bucket/raw/**")

673 ms ± 44.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Using gsutil cli tool

In [11]:
def ls_gcp_bucket(args: list[str]) -> list[str]:
    """List providers, buckets, or objects

    Args:
        command (list[str]): args and options supproted by `gsutil ls`

    Returns:
        list[str]: Informations returned by `gsutil ls` command as list for each line.
    """
    command = ["gsutil", "ls"] + args
    result = subprocess.run(args=command, capture_output=True, text=True)
    if result.returncode != 0:
        # TODO: logging
        print(f"{result.stderr}")
    else:
        output = result.stdout.splitlines()
        return output

In [12]:
%%timeit
ls_gcp_bucket(["gs://neurips-adc-bucket/**"])[:5]

2.66 s ± 55.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
def ls_filesystem(
    path: Union[str, Path], files_only: bool = False, folders_only: bool = False
) -> list[str]:
    """List files in path.

    Args:
        path (Union[str, Path]): The path whose objects to list.
        files_only (bool, optional): Whether to return files only. Defaults to False.
        folders_only (bool, optional): Whether to return folders only. Defaults to False.

    Returns:
        list[str]: list of element in path.
    """
    if files_only and folders_only:
        folders_only = False
    ls = [os.path.join(path, item) for item in os.listdir(path)]
    if files_only:
        files = []
        for item in ls:
            if os.path.isfile(item):
                files.append(item)
        ls = files
    if folders_only:
        folders = []
        for item in ls:
            if os.path.isdir(item):
                folders.append(item)
        ls = folders
    return ls

### Conclusion
**For bucket ls** It's better to use the storage client as it is more portable, customisable, straightforward and provide good balance for efficiency.

gsutil need to be installed a system level.

With GCFS one need to deal with glob patterns.

## CREATE PLANET_ID, DATA PAIRS

In [14]:
l = list_blobs("neurips-adc-bucket", folder="raw")

In [15]:
l

['gs://neurips-adc-bucket/raw/',
 'gs://neurips-adc-bucket/raw/axis_info.parquet',
 'gs://neurips-adc-bucket/raw/sample_submission.csv',
 'gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/dark.parquet',
 'gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/dead.parquet',
 'gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/flat.parquet',
 'gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/linear_corr.parquet',
 'gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/read.parquet',
 'gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_signal.parquet',
 'gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/dark.parquet',
 'gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/dead.parquet',
 'gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/flat.parquet',
 'gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/linear_corr.parquet',
 'gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/rea

In [16]:
def group_by_key(key_value: List[Tuple[Any, Any]]) -> List[List[Any]]:
    """
    Group a list of key-value pairs by key.

    Args:
        key_value (List[Tuple[Any, Any]]): List of tuples where each tuple contains a key and a value.

    Returns:
        List[List[Any]]: A list of lists where each sublist contains the key and a list of associated values.
    """
    grouped_data = defaultdict(list)
    for key, value in key_value:
        grouped_data[key].append(value)
    return [[key, values] for key, values in grouped_data.items()]

In [17]:
def split_key_value(path: str):
    l = path.split("/")
    if len(l) <= 5:
        return (l[3], "/".join(l) if l[4] != "" else None)
    else:
        return (l[5], "/".join(l))

In [18]:
def get_raw_data_uris(
    bucket_name: str, folder: str | None = None
) -> List[Tuple[str, List[str]]]:
    """
    Get the data URIs in the Google Cloud Storage bucket.

    Args:
        bucket_name (str): Bucket to consider.
        floder (str): Folder to whose elements to list. Defaults to None.

    Returns:
        List[Tuple[str, List[str]]]: List of tuples where each tuple has the planet ID as the key and its data as the value.
    """
    paths = list_blobs(bucket_name, folder)
    key_value_pairs = [split_key_value(path=path) for path in paths]
    raw_data = [
        item[1] for item in key_value_pairs if item[0] == "raw" and item[1] is not None
    ]
    non_raw_data = [item for item in key_value_pairs if item[0] != "raw"]
    grouped_data = group_by_key(non_raw_data)
    result = [(key, value + raw_data) for key, value in grouped_data]
    return result

In [19]:
def copy_to(element, dest: str):
    sufix = "/".join(element.split("/")[-2:])
    output = os.path.join(dest, sufix)
    shutil.copytree(src=element, dst=output)

In [20]:
l = get_raw_data_uris("neurips-adc-bucket", folder="raw")

In [21]:
l

[('499191466',
  ['gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/dark.parquet',
   'gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/dead.parquet',
   'gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/flat.parquet',
   'gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/linear_corr.parquet',
   'gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/read.parquet',
   'gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_signal.parquet',
   'gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/dark.parquet',
   'gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/dead.parquet',
   'gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/flat.parquet',
   'gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/linear_corr.parquet',
   'gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/read.parquet',
   'gs://neurips-adc-bucket/raw/test/499191466/FGS1_signal.parquet',
   'gs://neurips-ad

## EVALUATE COLLECTION CREATION AND BEAM.MAP WITH LAMBDA

In [22]:
with beam.Pipeline() as pipeline:
    outputs = (
        pipeline
        | "Create initial values"
        >> beam.Create(ls_filesystem(path="../data/raw", folders_only=True))
        | "Ls filesystem" >> beam.Map(lambda x: copy_to(x, "../data/output"))
    )
    outputs | beam.Map(print)



None


## BUILD A DOFN LOGIC FOR DATA CALIBRATION

In [23]:
class CalibrationFn(beam.DoFn):
    """Perfroms the raw data calibration."""

    def __init__(
        self,
        cut_inf: int,
        cut_sup: int,
        mask: bool,
        corr: bool,
        dark: bool,
        flat: bool,
        binning: int | None = None,
    ):
        """_summary_

        Args:
            cut_inf (int): Images cropping range lower limit.
            cut_sup (int): Images cropping range upper limit.
            binning (int | None): Number of images to bin together.
            mask (bool): Wheteher to mask dead and hot pixels.
            corr (bool): Wheteher to apply linear correction.
            dark (bool): Whether to apply current dark correction.
            flat (bool): Whether to apply flat pixels correction.
        """

        self.CUT_INF = cut_inf
        self.CUT_SUP = cut_sup
        self.BINNING = binning
        self.MASK = mask
        self.CORR = corr
        self.DARK = dark
        self.FLAT = flat

    def process(self, element):
        id, uris = element
        airs_data, fgs_data, info_data = self._load_data(uris)
        airs_signal = self._calibrate_airs_data(id, airs_data, info_data)
        fgs_signal = self._calibrate_fgs_data(id, fgs_data, info_data)
        return [(id, [airs_signal, fgs_signal])]

    def _calibrate_airs_data(
        self, id: str, data: dict[pd.DataFrame], info: dict[pd.DataFrame]
    ):
        signal = data["signal"]
        gain = info["adc_info"].set_index("planet_id")["AIRS-CH0_adc_gain"].loc[int(id)]
        offset = (
            info["adc_info"].set_index("planet_id")["AIRS-CH0_adc_offset"].loc[int(id)]
        )
        signal = self._adc_revert(signal, gain, offset)
        dt = info["axis_info"]["AIRS-CH0-integration_time"].dropna().values
        dt[1::2] += 0.1
        signal = signal[:, :, self.CUT_INF : self.CUT_SUP]
        if self.MASK:
            signal = self._mask_hot_dead(signal, data["dead"], data["dark"])
        if self.CORR:
            signal = self._apply_linear_corr(data["linear_corr"], signal)
        if self.DARK:
            signal = self._clean_dark_current(signal, data["dead"], data["dark"], dt)
        signal = self._get_cds(signal)
        if self.BINNING:
            signal = self._bin_obs(signal, self.BINNING)
        else:
            signal = signal.transpose(0, 2, 1)
        if self.FLAT:
            signal = self._correct_flat_field(data["flat"], data["dead"], signal)
        return signal

    def _calibrate_fgs_data(
        self, id: str, data: dict[pd.DataFrame], info: dict[pd.DataFrame]
    ):
        signal = data["signal"]
        gain = info["adc_info"].set_index("planet_id")["FGS1_adc_gain"].loc[int(id)]
        offset = info["adc_info"].set_index("planet_id")["FGS1_adc_offset"].loc[int(id)]
        signal = self._adc_revert(signal, gain, offset)
        dt = np.ones(len(signal)) * 0.1
        dt[1::2] += 0.1
        if self.MASK:
            signal = self._mask_hot_dead(signal, data["dead"], data["dark"])
        if self.CORR:
            signal = self._apply_linear_corr(data["linear_corr"], signal)
        if self.DARK:
            signal = self._clean_dark_current(signal, data["dead"], data["dark"], dt)
        signal = self._get_cds(signal)
        if self.BINNING:
            signal = self._bin_obs(signal, self.BINNING * 12)
        else:
            signal = signal.transpose(0, 2, 1)
        if self.FLAT:
            signal = self._correct_flat_field(data["flat"], data["dead"], signal)
        return signal

    def _adc_revert(self, signal: np.ndarray, gain: float, offset: float) -> np.ndarray:
        """Revert pixel voltage from ADC.

        Args:
            signal (np.ndarray): ADC converted signal integer.
            gain (float): ADC gain error.
            offset (float): ADC offset error.

        Returns:
            np.ndarray: Pixel voltages.
        """
        signal = signal.astype(np.float64)
        signal /= gain
        signal += offset
        return signal

    def _mask_hot_dead(
        self, signal: np.ndarray, dead: np.ndarray, dark: np.ndarray
    ) -> np.ndarray:
        """Mask dead and hot pixels so that they won't be take in account in corrections.

        Args:
            signal (np.ndarray): Pixel voltage signal.
            dead (np.ndarray): Dead pixels.
            dark (np.ndarray): Dark pixels.

        Returns:
            np.ndarray: Pixel voltages with dead pixels masked.
        """
        hot = sigma_clip(dark, sigma=5, maxiters=5).mask
        hot = np.tile(hot, (signal.shape[0], 1, 1))
        dead = np.tile(dead, (signal.shape[0], 1, 1))
        signal = np.ma.masked_where(dead, signal)
        signal = np.ma.masked_where(hot, signal)
        return signal

    def _apply_linear_corr(self, corr: np.ndarray, signal: np.ndarray) -> np.ndarray:
        """Fix non-linearity due to capacity leakage in the detector.

        Args:
            corr (np.ndarray): Correction coefficients
            signal (np.ndarray): Signal to correct.

        Returns:
            np.ndarray: Corrected signal
        """
        linear_corr = np.flip(corr, axis=0)
        for x, y in itertools.product(range(signal.shape[1]), range(signal.shape[2])):
            poli = np.poly1d(linear_corr[:, x, y])
            signal[:, x, y] = poli(signal[:, x, y])
        return signal

    def _clean_dark_current(
        self, signal: np.ndarray, dead: np.ndarray, dark: np.ndarray, dt: np.array
    ) -> np.ndarray:
        """Remove the accumulated charge due to dark current.

        Args:
            signal (np.ndarray): Signal to clean
            dead (np.ndarray): Dead pixels.
            dark (np.ndarray): Dark pixels.
            dt (np.array): Short frames delay.

        Returns:
            np.ndarray: Cleaned signal.
        """
        dark = np.ma.masked_where(dead, dark)
        dark = np.tile(dark, (signal.shape[0], 1, 1))

        signal -= dark * dt[:, np.newaxis, np.newaxis]
        return signal

    def _get_cds(self, signal: np.ndarray) -> np.ndarray:
        """Return the actual accumulated charge (a delta) due to the transit.

        Args:
            signal (np.ndarray): Signal.

        Returns:
            np.ndarray: An image for one observation in the time (Time series observations).
        """
        cds = signal[1::2, :, :] - signal[::2, :, :]
        return cds

    def _bin_obs(self, cds: np.ndarray, binning: int) -> np.ndarray:
        """Binnes cds time series together at the specified frequency.

        Args:
            cds (np.ndarray): CDS signal.
            binning (int): Binning frequency.

        Returns:
            np.ndarray: _description_
        """
        cds_transposed = cds.transpose(0, 2, 1)
        cds_binned = np.zeros(
            (
                cds_transposed.shape[0] // binning,
                cds_transposed.shape[1],
                cds_transposed.shape[2],
            )
        )
        for i in range(cds_transposed.shape[0] // binning):
            cds_binned[i, :, :] = np.sum(
                cds_transposed[i * binning : (i + 1) * binning, :, :], axis=0
            )
        return cds_binned

    def _correct_flat_field(
        self, flat: np.ndarray, dead: np.ndarray, signal: np.ndarray
    ):
        """Correction by calibrating on an uniform signal.

        Args:
            flat (np.ndarray): Flat signal
            dead (np.ndarray): Dead pixels
            signal (np.ndarray): CDS signal
        """

        flat = flat.transpose(1, 0)
        dead = dead.transpose(1, 0)
        flat = np.ma.masked_where(dead, flat)
        flat = np.tile(flat, (signal.shape[0], 1, 1))
        signal = signal / flat
        return signal

    def _load_data(self, uris: List[str]):
        """
        Load data from a list of URIs, reading the files as CSV or Parquet,
        and organize them by type ('dark', 'dead', etc.) for AIRS and FGS.

        Args:
            uris (List[str]): List of URIs to load data from.
        """
        airs_data = {}
        fgs_data = {}
        info_data = {}
        calib_data_types = [
            "dark",
            "dead",
            "flat",
            "linear_corr",
            "read",
            "signal",
            "axis_info",
            "test_adc_info",
            "train_adc_info",
        ]

        def read_data(uri: str):
            try:
                return pd.read_csv(uri)
            except UnicodeDecodeError:
                return pd.read_parquet(uri)

        for uri in uris:
            df = read_data(uri)
            for data_type in calib_data_types:
                if data_type in uri:
                    if "AIRS" in uri:
                        if "signal" in uri:
                            airs_data[data_type] = df.values.astype(np.float64).reshape(
                                (df.shape[0], 32, 356)
                            )
                        elif "linear_corr" in uri:
                            airs_data[data_type] = df.values.astype(np.float64).reshape(
                                (6, 32, 356)
                            )[:, :, self.CUT_INF : self.CUT_SUP]
                        else:
                            airs_data[data_type] = df.values.astype(np.float64).reshape(
                                (32, 356)
                            )[:, self.CUT_INF : self.CUT_SUP]
                    elif "FGS" in uri:
                        if "signal" in uri:
                            fgs_data[data_type] = df.values.astype(np.float64).reshape(
                                (df.shape[0], 32, 32)
                            )
                        elif "linear_corr" in uri:
                            fgs_data[data_type] = df.values.astype(np.float64).reshape(
                                (6, 32, 32)
                            )
                        else:
                            fgs_data[data_type] = df.values.astype(np.float64).reshape(
                                (32, 32)
                            )
                    else:
                        info_data[data_type] = df
        info_data["adc_info"] = pd.concat(
            [info_data["train_adc_info"], info_data["test_adc_info"]]
        )
        del info_data["train_adc_info"], info_data["test_adc_info"]
        return airs_data, fgs_data, info_data

In [24]:
processor = CalibrationFn(39, 321, None, False, False, False, False)

In [25]:
result = processor.process(
    (
        "499191466",
        [
            "gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/dark.parquet",
            "gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/dead.parquet",
            "gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/flat.parquet",
            "gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/linear_corr.parquet",
            "gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_calibration/read.parquet",
            "gs://neurips-adc-bucket/raw/test/499191466/AIRS-CH0_signal.parquet",
            "gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/dark.parquet",
            "gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/dead.parquet",
            "gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/flat.parquet",
            "gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/linear_corr.parquet",
            "gs://neurips-adc-bucket/raw/test/499191466/FGS1_calibration/read.parquet",
            "gs://neurips-adc-bucket/raw/test/499191466/FGS1_signal.parquet",
            "gs://neurips-adc-bucket/raw/axis_info.parquet",
            "gs://neurips-adc-bucket/raw/sample_submission.csv",
            "gs://neurips-adc-bucket/raw/test_adc_info.csv",
            "gs://neurips-adc-bucket/raw/train_adc_info.csv",
            "gs://neurips-adc-bucket/raw/train_labels.csv",
            "gs://neurips-adc-bucket/raw/wavelengths.csv",
        ],
    )
)

In [26]:
result[0][1][0].shape

(5625, 282, 32)

In [27]:
with beam.Pipeline() as pipeline:
    uris = get_raw_data_uris("neurips-adc-bucket", "raw")[:2]
    outputs = (
        pipeline
        | "Create initial values" >> beam.Create(uris)
        | "Calibrate data"
        >> beam.ParDo(CalibrationFn(39, 321, None, False, False, False, False))
    )
    outputs | beam.Map(print)

('499191466', [array([[[ 28.78636262,  29.8525242 ,   0.        , ...,  26.65403946,
           3.19848474,  -5.33080789],
        [  0.        ,   6.39696947,  23.45555473, ...,   7.46313105,
           4.26464631,  14.9262621 ],
        [  4.26464631,  -8.52929263,  28.78636262, ...,  54.3742405 ,
           2.13232316,  14.9262621 ],
        ...,
        [ 31.98484735,  57.57272524,  51.17575576, ...,  63.96969471,
          51.17575576,  49.04343261],
        [ 37.31565525,  37.31565525,  21.32323157, ...,  47.97727103,
          10.66161578,  28.78636262],
        [ 27.72020104,   1.06616158,  50.10959419, ..., -24.5217163 ,
          27.72020104,  33.05100893]],

       [[ 27.72020104, -33.05100893,  26.65403946, ...,  23.45555473,
           8.52929263,  19.19090841],
        [ 11.72777736,  10.66161578,  26.65403946, ...,   0.        ,
           3.19848474,  35.18333209],
        [ -5.33080789,  -2.13232316, 107.68231942, ...,  50.10959419,
         -11.72777736, -13.86010052]

In [28]:
c = [result[0][1][0], result[0][1][0]]

In [29]:
np.array(c).shape

(2, 5625, 282, 32)

## BUILD THE COMBINEFN LOGIC THAT WILL EXPORT THE RESULT AS A TFRECORD

### UNDERSTAND TFRECORD AND EXAMPLES 

In [30]:
dataset = tf.data.Dataset.from_tensor_slices(np.array(c))
ds_bytes = dataset.map(tf.io.serialize_tensor)
with tf.io.TFRecordWriter("../data/output/test.tfrecords") as file_writer:
    for record in ds_bytes:
        file_writer.write(record.numpy())

In [31]:
def decode_fn(proto):
    return tf.io.parse_tensor(proto, out_type=tf.float64)


tfrecord_file = "../data/output/test.tfrecords"
raw_dataset = tf.data.TFRecordDataset(tfrecord_file)

parsed_dataset = raw_dataset.map(decode_fn)


for parsed_record in parsed_dataset:
    print(parsed_record.numpy())  # Convert tensor back to NumPy array for inspection

[[[ 28.78636262  29.8525242    0.         ...  26.65403946   3.19848474
    -5.33080789]
  [  0.           6.39696947  23.45555473 ...   7.46313105   4.26464631
    14.9262621 ]
  [  4.26464631  -8.52929263  28.78636262 ...  54.3742405    2.13232316
    14.9262621 ]
  ...
  [ 31.98484735  57.57272524  51.17575576 ...  63.96969471  51.17575576
    49.04343261]
  [ 37.31565525  37.31565525  21.32323157 ...  47.97727103  10.66161578
    28.78636262]
  [ 27.72020104   1.06616158  50.10959419 ... -24.5217163   27.72020104
    33.05100893]]

 [[ 27.72020104 -33.05100893  26.65403946 ...  23.45555473   8.52929263
    19.19090841]
  [ 11.72777736  10.66161578  26.65403946 ...   0.           3.19848474
    35.18333209]
  [ -5.33080789  -2.13232316 107.68231942 ...  50.10959419 -11.72777736
   -13.86010052]
  ...
  [  1.06616158  41.58030156  42.64646314 ...  50.10959419  -1.06616158
    39.4479784 ]
  [  7.46313105  21.32323157  14.9262621  ...  45.84494787  54.3742405
    30.91868577]
  [ 25.5

In [36]:
airs = tf.convert_to_tensor(result[0][1][0])
fgs = tf.convert_to_tensor(result[0][1][1])
target = tf.convert_to_tensor(
    pd.read_csv("/home/drxc/neurips_adc/data/raw/train_labels.csv")
    .set_index("planet_id")
    .loc[785834]
    .values
)

In [38]:
airs_ft = tf.train.Feature(
    bytes_list=tf.train.BytesList(
        value=[
            tf.io.serialize_tensor(airs).numpy(),
        ]
    )
)

fgs_ft = tf.train.Feature(
    bytes_list=tf.train.BytesList(
        value=[
            tf.io.serialize_tensor(fgs).numpy(),
        ]
    )
)

target_ft = tf.train.Feature(
    bytes_list=tf.train.BytesList(
        value=[
            tf.io.serialize_tensor(target).numpy(),
        ]
    )
)

In [39]:
features = tf.train.Features(
    feature={
        "airs": airs_ft,
        "fgs": fgs_ft,
        "target": target_ft,
    }
)

In [40]:
example = tf.train.Example(features=features)

In [None]:
class CombineDataFn(beam.CombineFn):
    def create_accumulator(self):
        return ([], [], [])