In [14]:
import sys
sys.path.insert(0, '..')

import os
import datasets
from datasets import load_dataset
import pandas as pd
from astropy.io import fits
import numpy as np
from io import BytesIO
from util.parallelzipfile import ParallelZipFile as ZipFile
import csv
import json

In [7]:
dataset = load_dataset('MeriDK/AstroM3')

In [11]:
len(dataset['train'])

In [None]:
# TODO: Add BibTeX citation
# Find for instance the citation on arxiv or on the dataset repo/website
_CITATION = """\
@inproceedings{rizhko2024self,
  title={Self-supervised Multimodal Model for Astronomy},
  author={Rizhko, Mariia and Bloom, Joshua S},
  booktitle={Neurips 2024 Workshop Foundation Models for Science: Progress, Opportunities, and Challenges}
}
"""

# TODO: Add description of the dataset here
# You can copy an official description
_DESCRIPTION = """\
This dataset includes 21,440 objects with time-series photometry, spectra, and metadata. \
It is designed for building and testing next-generation multi-modal self-supervised models for astronomy.
"""

_HOMEPAGE = "https://github.com/MeriDK/AstroM3/"
_LICENSE = "CC BY 4.0"

# resolve ? tree
_URL = "https://huggingface.co/datasets/MeriDK/AstroM3/tree/main"

_URLS = {
    "full": {
        "train": "./splits/spectra_and_v_train.csv",
        "val": "./splits/spectra_and_v_val.csv",
        "test": "./splits/spectra_and_v_test.csv"
    },
    "50": {},
    "25": {},
    "10": {}
}


# TODO: Name of the dataset usually matches the script name with CamelCase instead of snake_case
class AstroM3Dataset(datasets.GeneratorBasedBuilder):
    """TODO: Short description of my dataset."""

    VERSION = datasets.Version("1.1.0")

    # This is an example of a dataset with multiple configurations.
    # If you don't want/need to define several sub-sets in your dataset,
    # just remove the BUILDER_CONFIG_CLASS and the BUILDER_CONFIGS attributes.

    # If you need to make complex sub-parts in the datasets with configurable options
    # You can create your own builder configuration class to store attribute, inheriting from datasets.BuilderConfig
    # BUILDER_CONFIG_CLASS = MyBuilderConfig

    # You will be able to load one or the other configurations in the following list with
    # data = datasets.load_dataset('my_dataset', 'first_domain')
    # data = datasets.load_dataset('my_dataset', 'second_domain')
    BUILDER_CONFIGS = [
        datasets.BuilderConfig(name="full", version=VERSION, description="The full dataset"),
        datasets.BuilderConfig(name="50", version=VERSION, description="Subsample of the dataset, contains 50% of all data"),
        datasets.BuilderConfig(name="25", version=VERSION, description="Subsample of the dataset, contains 25% of all data"),
        datasets.BuilderConfig(name="10", version=VERSION, description="Subsample of the dataset, contains 10% of all data"),
    ]

    DEFAULT_CONFIG_NAME = "full"  # It's not mandatory to have a default configuration. Just use one if it make sense.

    def _info(self):
        return datasets.DatasetInfo(
            # This is the description that will appear on the datasets page.
            description=_DESCRIPTION,
            # This defines the different columns of the dataset and their types
            features=datasets.Features(
                {
                    "photometry": datasets.Array2D(shape=(200, 7), dtype="float32"),
                    "photometry_mask": datasets.Sequence(datasets.Value("float32")),
                    "spectra": datasets.Array2D(shape=(200, 7), dtype="float32"),
                    "metadata": datasets.Sequence(datasets.Value("float32")),
                    "label": datasets.Value("int32")
                }
            ),
            supervised_keys=None,
            # Homepage of the dataset for documentation
            homepage=_HOMEPAGE,
            # License for the dataset if available
            license=_LICENSE,
            # Citation for the dataset
            citation=_CITATION,
        )

    def _split_generators(self, dl_manager: DownloadManager):
        # TODO: This method is tasked with downloading/extracting the data and defining the splits depending on the configuration
        # If several configurations are possible (listed in BUILDER_CONFIGS), the configuration selected by the user is in self.config.name

        # dl_manager is a datasets.download.DownloadManager that can be used to download and extract URLS
        # It can accept any type or nested list/dict and will give back the same structure with the url replaced with path to local files.
        # By default the archives will be extracted and a path to a cached folder where they are extracted is returned instead of the archive
        urls = _URLS[self.config.name]
        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN,
                # These kwargs will be passed to _generate_examples
                gen_kwargs={
                    "split": "train",
                },
            ),
            datasets.SplitGenerator(
                name=datasets.Split.VALIDATION,
                # These kwargs will be passed to _generate_examples
                gen_kwargs={
                    "split": "val",
                },
            ),
            datasets.SplitGenerator(
                name=datasets.Split.TEST,
                # These kwargs will be passed to _generate_examples
                gen_kwargs={
                    "split": "test"
                },
            ),
        ]

    # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
    def _generate_examples(self, split):
        csv_file = os.path.join(self.config.data_dir, f'spectra_and_v_{split}.csv')
        vband_file = os.path.join(self.config.data_dir, 'asassnvarlc_vband_complete.zip')
        
        df = pd.read_csv(csv_file)
        reader_v = ZipFile(vband_file)

        # TODO: This method handles input defined in _split_generators to yield (key, example) tuples from the dataset.
        # The `key` is for legacy reasons (tfds) and is not important in itself, but must be unique for each example.
        for idx, row in df.iterrows():
            label = row['target']
            photometry, photometry_mask = self.get_vlc(reader_v, row['file_name'])
            spectra = self.readLRSFits(f'{split}/{label}/{row["spec_filename"]}')
            metadata = np.array(row[self.config["meta_cols"]], dtype=np.float32)

            
        with open(filepath, encoding="utf-8") as f:
            for key, row in enumerate(f):
                data = json.loads(row)
                if self.config.name == "first_domain":
                    # Yields examples as (key, example) tuples
                    yield key, {
                        "sentence": data["sentence"],
                        "option1": data["option1"],
                        "answer": "" if split == "test" else data["answer"],
                    }
                else:
                    yield key, {
                        "sentence": data["sentence"],
                        "option2": data["option2"],
                        "second_domain_answer": "" if split == "test" else data["second_domain_answer"],
                    }

In [13]:

        
        for idx, row in df.iterrows():
            photometry, photometry_mask = self.get_vlc(reader_v, data_dir, row['file_name'])
            spectra = self.readLRSFits(data_dir, row["spec_filename"])
            metadata = np.array(row[self.config["meta_cols"]], dtype=np.float32)
            label = row["label"]

            yield idx, {
                "photometry": photometry,
                "photometry_mask": photometry_mask,
                "spectra": spectra,
                "metadata": metadata,
                "label": label
            }

    def get_vlc(self, reader_v, data_dir, file_name):
        csv = BytesIO()
        file_name = file_name.replace(' ', '')
        data_path = f'vardb_files/{file_name}.dat'

        csv.write(reader_v.read(data_path))
        csv.seek(0)

        lc = pd.read_csv(csv, sep='\s+', skiprows=2, names=['HJD', 'MAG', 'MAG_ERR', 'FLUX', 'FLUX_ERR'],
                         dtype={'HJD': float, 'MAG': float, 'MAG_ERR': float, 'FLUX': float, 'FLUX_ERR': float})

        return lc[['HJD', 'FLUX', 'FLUX_ERR']].values

    def readLRSFits(self, data_dir, file_name):
        path = os.path.join(data_dir, file_name)
        hdulist = fits.open(path)
        len_list = len(hdulist)

        if len_list == 1:
            head = hdulist[0].header
            scidata = hdulist[0].data
            coeff0 = head['COEFF0']
            coeff1 = head['COEFF1']
            pixel_num = head['NAXIS1']
            specflux = scidata[0,]
            ivar = scidata[1,]
            wavelength = np.linspace(0, pixel_num - 1, pixel_num)
            wavelength = np.power(10, (coeff0 + wavelength * coeff1))
            hdulist.close()
        elif len_list == 2:
            head = hdulist[0].header
            scidata = hdulist[1].data
            wavelength = scidata[0][2]
            ivar = scidata[0][1]
            specflux = scidata[0][0]
        else:
            raise ValueError(f'Wrong number of fits files. {len_list} should be 1 or 2')

        return np.vstack((wavelength, specflux, ivar)).T
