In [20]:
cd data

/Users/jbloom/Projects/AstroML/data


In [33]:
import os
import glob
import numpy as np
import logging
import joblib
import zipfile
import tarfile
from pathlib import Path
logging.captureWarnings(True)
from scipy.signal import medfilt
import pandas as pd
pd.options.mode.chained_assignment = None
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

%load_ext autoreload
%autoreload 2

import dask
import dask.dataframe as dd
import torch
from torch.utils.data import Dataset, DataLoader

from astropy.wcs import WCS
from astropy.io import fits
from astroML.time_series import search_frequencies, lomb_scargle, MultiTermFit

from astropy import units as u
import astroML

# plotting
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns
sns.set_theme(style="ticks")

import pandas as pd
import dask.dataframe as dd
from dask.delayed import delayed

from preprocess_data import clip_outliers
print(f"sns version: {sns.__version__}")
print(f"astroML version: {astroML.__version__}")
print(f"torch version: {torch.__version__}")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
sns version: 0.13.0
astroML version: 1.0.2
torch version: 2.1.0


# ASASSN Data

## Preparation:

   - Create a folder to hold the data (`data/asaasn`) 
   - Download the V-band data (https://drive.google.com/drive/folders/1IAtztpddDeh5XOiuxmLWdLUaT_quXkug)
     Make sure you have the files in the data directory `asassn_catalog_full.csv` and `asassnvarlc_vband_complete.zip`
     (Do not unzip the light curve file!)
   - Download the g-band data (https://drive.google.com/drive/folders/1gxcIokRsw1eyPmbPZ0-C8blfRGItSOAu)
     Make sure you have the files `asassn_variables_x.csv` and `g_band_lcs-001.tar.gz`
     (Do not unzip the light curve file!)


In [148]:
asassn_dir = Path("data/asaasn")
raw_data_files={"v": {"tab": "asassn_catalog_full.csv", "lcs": "asassnvarlc_vband_complete.zip", "prefix": "vardb_files/", 
                      "filekey": "asassn_name", "keyfill": "", "flux_headers": ["HJD", "FLUX", "FLUX_ERR"], "mag_headers": ["HJD","MAG","MAG_ERR"]},
                "g": {"tab": "asassn_variables_x.csv", "lcs": "g_band_lcs-001.tar", "prefix": "g_band_lcs/", 
                      "filekey": "ID", "keyfill": "_", "flux_headers": ["HJD","flux","flux_err"],"mag_headers": ["HJD","mag","mag_err"]}}
metadata_cols={"g": ['Mean_gmag', 'Amplitude', 'Period', 'parallax', 'parallax_error', 'parallax_over_error', 'pm', 'pmra', 'pmra_error', 'pmdec', 'pmdec_error', 'ruwe', 'phot_g_mean_mag', 'e_phot_g_mean_mag', 'phot_bp_mean_mag', 'e_phot_bp_mean_mag', 'phot_rp_mean_mag', 'e_phot_rp_mean_mag', 'bp_rp', 'FUVmag', 'e_FUVmag', 'NUVmag', 'e_NUVmag', 'W1mag', 'W2mag', 'W3mag', 'W4mag', 'Jmag', 'Hmag', 'Kmag', 'e_W1mag', 'e_W2mag', 'e_W3mag', 'e_W4mag', 'e_Jmag', 'e_Hmag', 'e_Kmag'],
                 "v": ['mean_vmag', 'amplitude', 'period', 'phot_g_mean_mag', 'e_phot_g_mean_mag', "lksl_statistic","rfr_score", 'phot_bp_mean_mag', 'e_phot_bp_mean_mag', 'phot_rp_mean_mag', 'e_phot_rp_mean_mag', 'bp_rp', 'parallax', 'parallax_error', 'parallax_over_error', 'pmra', 'pmra_error', 'pmdec', 'pmdec_error', 'j_mag', 'e_j_mag', 'h_mag', 'e_h_mag', 'k_mag', 'e_k_mag', 'w1_mag', 'e_w1_mag', 'w2_mag', 'e_w2_mag', 'w3_mag', 'e_w3_mag', 'w4_mag', 'e_w4_mag', 'j_k', 'w1_w2', 'w3_w4', 'apass_vmag', 'e_apass_vmag', 'apass_bmag', 'e_apass_bmag', 'apass_gpmag', 
                       'e_apass_gpmag', 'apass_rpmag', 'e_apass_rpmag', 'apass_ipmag', 'e_apass_ipmag', 'FUVmag', 'e_FUVmag', 'NUVmag', 'e_NUVmag', 'pm', 'ruwe']              }
bookkeeping_cols={"v":  ['id', 'source_id', 'asassn_name', 'other_names', 'raj2000', 'dej2000', 'l', 'b', 'epoch_hjd', 'gdr2_id', 'allwise_id', 'apass_dr9_id','edr3_source_id', 'galex_id','tic_id'],
                    "g":  ['ID', 'RAJ2000', 'DEJ2000', 'l', 'b', 'EpochHJD', 'EDR3_source_id', 'GALEX_ID',  'TIC_ID', 'AllWISE_ID', "ML_probability",'class_probability']}
target_cols={"g":  ['ML_classification'],
             "v":  ['variable_type']}
period_col = {"g": 'Period', "v": "period"}
merge_key={"g": "EDR3_source_id", "v": "edr3_source_id"}


class ASASSNVarStarDataset(Dataset):
    
    def __init__(self, data_root, prediction_length, mode='train', use_errors=True,
                 use_bands=["v","g"],merge_type="inner",lc_type="flux",
                 window_length=None, 
                 rng=None, return_phased = True, lock_phase=None, clean=True, recalc_period=False, 
                 verbose=False, lamost_spec_file="Spectra/lamost_spec.csv",
                 lamost_spec_dir="Spectra/v2",only_sources_with_spectra=True, prime=True,
                 initial_clean_clip=[20,5],only_periodic=True, period_cache = "periods.csv",
                 return_items_as_list=True
                ):
        
        """
        Multi-modal ASASSN dataset of variable stars
        
        rng = random state. If rng is None then set the state
        return_phased = return the phased light curve as well as the original light curves
        recalc_period = refit the L-S to find the best period and rewrites logP?
        return_ls = refit the L-S 
        cache_ls_dir = cache dir of the L-S
        meta_data_columns = columns to return as metadata
        """
        self.data_root = data_root
        self.prediction_length = prediction_length
        self.use_errors = use_errors
        self.window_length = window_length
        self.return_phased = return_phased
        self.recalc_period = recalc_period
        self.clean=clean
        self.verbose = verbose
        self.use_bands = use_bands
        if not isinstance(use_bands, list):
            raise Exception("`use_bands` must be a list like ['v', 'g']")
        self.merge_type = merge_type
        self.lc_type = lc_type
        self.lamost_spec_file = lamost_spec_file
        self.lamost_spec_dir = lamost_spec_dir
        self.only_sources_with_spectra = only_sources_with_spectra
        self.lock_phase = lock_phase
        self.initial_clean_clip = initial_clean_clip
        self.only_periodic = only_periodic
        self.period_cache = period_cache
        self.return_items_as_list = return_items_as_list
        
        # set the random seed if need be
        if rng is None:
            self.rng = np.random.default_rng(42)
        else:
            self.rng = rng

        self._check_and_open_data_files()
        self._merge_bands()
        
        if self.recalc_period and self.period_cache is not None:
            fname = self.data_root / self.period_cache
            try:
                self.period_recalc_df = pd.read_csv(fname)
                if self.verbose:
                    print("Opened period cache file")
            except:
                self.period_recalc_df  = pd.DataFrame(columns=["id","p", "band"])

        if prime:
            self._prime()

        # shuffle 
        self.df = self.df.sample(frac=1, random_state=self.rng)


    def __del__(self):
        if self.recalc_period and self.period_cache is not None:
            # read in what's on disk just in case it changed by another process.
            if (self.data_root / self.period_cache).exists():
                tmp = pd.read_csv(self.data_root / self.period_cache)
                merged = pd.concat([tmp,self.period_recalc_df], ignore_index=True).drop_duplicates(keep="last",ignore_index=True)
            else:
                merged =  self.period_recalc_df
            merged.to_csv(self.data_root / self.period_cache, index=False)
            if self.verbose:
                print("Wrote period cache file.")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if not isinstance(idx, list):
            idx = [idx]

        sources = self.df.iloc[idx]

        return_dict = {}
        
        # bookkeeping - not to be used in learning
        return_dict["bookkeeping_data"] = sources[self.bookkeeping_all].values.tolist()
        
        # classes
        targets = sources[self.target_all].values
        for k, v in self.target_lookup.items(): targets[targets==v] = k
        
        return_dict["classes"] = targets.astype(np.int32)
        
        # light curves
        return_dict["lcs"] = self.get_light_curves(sources)

        # spectra
        return_dict["spectra"] = self.get_spectra(sources)

        if self.return_phased:
            phased_all = []
            i = 0
            Ps = {band: [] for band in self.use_bands} 
            for ind, source in sources.iterrows():
                source_id = source[merge_key[self.use_bands[0]]]
                band_periods = {band: source[period_col[band]] for band in self.use_bands}
                used_cache = {band: False for band in self.use_bands}
                for band_num, band in enumerate(self.use_bands):
                    recalc = self.recalc_period
                    if len(self.use_bands) > 1:
                        other_band = list(filter(lambda x: x != band, self.use_bands.copy()))[0]
                    else:
                        other_band = band
                    if self.lock_phase is not None:
                        if band == self.lock_phase:
                            if np.isnan(band_periods[band]):
                                # switch to the other band
                                band_periods[band] = band_periods[other_band]
                        else:
                            band_periods[band] = band_periods[other_band]
                            used_cache[band] = True
                            continue
                            
                    # failsafe: if the period is nan then recalc
                    if np.isnan(band_periods[band]):
                        recalc = True

                    if recalc:
                        # print(f"{band} P={band_periods[band]} ",flush=True, end="")
                        lcs = return_dict["lcs"][i][band_num]
                        t, y, dy = lcs[:,0], lcs[:,1], lcs[:,2]
                        do_calc = True
                        if self.period_cache is not None and self.recalc_period:
                            tmp = self.period_recalc_df.query(f"(id == '{source_id}') & (band == '{band}')")["p"].values
                            if len(tmp) > 0:
                                P_best = tmp[-1]
                                do_calc = False
                                used_cache[band] = True
                                if self.verbose:
                                    print(" got period from cache ")
                            else:
                                do_calc = True
                        if do_calc and len(t) > 20:
                            t, y, dy, _, P_best, _, _, _ = \
                                clip_outliers(t, y, dy, measurements_in_flux_units = self.lc_type == "flux", 
                                          initial_clip=self.initial_clean_clip, clean_only=False, max_iter=2)
                        # print(f" ({P_best}) ", flush=True, end="")
                        # we may have found a harmonic. Fix it.
                        if abs((P_best - band_periods[band])/band_periods[band]) > 0.01 and P_best < band_periods[band]:
                            P_best *= 2
                        elif abs(P_best - 1) < 0.01 or abs(P_best - 2) < 0.01 or abs(P_best - 3) < 0.01:
                            # keep the original P if we're too close to 1 or 2 or 3 day Periods
                            P_best = band_periods[band]
                            if np.isnan(P_best):
                                P_best = band_periods[other_band]

                        band_periods[band] = P_best
                        # print(f"P={band_periods[band]} ",flush=True)

                # now that we have the period, fold
                phased = []
                for band_num, band in enumerate(self.use_bands):
                    if len(self.use_bands) > 1:
                        other_band = list(filter(lambda x: x != band, self.use_bands.copy()))[0]
                    else:
                        other_band = band
                    P_best = band_periods[band] if band == self.lock_phase else band_periods[other_band]   
                    lcs = return_dict["lcs"][i][band_num]
                    t, y, dy = lcs[:,0], lcs[:,1], lcs[:,2]
                    phased.append(np.vstack(( (t % P_best)/P_best, y, dy)).T)
                    Ps[band].append(P_best)
                    if self.period_cache is not None and self.recalc_period and not used_cache[band]:
                        # append the recalc periods to the cache
                        self.period_recalc_df.loc[len(self.period_recalc_df)] = [source_id, P_best, band]
                            
                phased_all.append(phased)
                i+=1
                
            # update the sources table with the new periods
            sources[period_col[band]] = Ps[band]
            return_dict["phased"] = phased_all

        # metadata
        return_dict["metadata"] = sources[self.metadata_all].values

        if not self.return_items_as_list:
            return return_dict
        else:
            if self.return_phased:
                lcs = return_dict["phased"]
            else:
                lcs = return_dict["lcs"]
            return lcs, return_dict["metadata"], return_dict["spectra"], return_dict["classes"], return_dict["bookkeeping_data"]
            
    def _prime(self):
        """ This takes about 1 minute. After that getting light curves is fast """
        print("Priming tarballs by doing initial scan...", flush=True, end="")
        self.get_light_curves(self.df.sample(random_state=self.rng))
        print("done.", flush=True)
        
    def get_light_curves(self, rows):
        """ Given df row(s), return the light curves as numpy arrays """
        
        light_curves = []
        for ind, row in rows.iterrows():
            row_lc = []
            for band in self.use_bands:
                name = row[raw_data_files[band]["filekey"]].replace(" ", raw_data_files[band]["keyfill"])
                if self.lcs[band][1] == "tar" and band == "g":
                    try:
                        f = self.lcs[band][0].getmember(f"{raw_data_files[band]['prefix']}{name}.dat")
                    except KeyError:
                        print(f"Cannot find {raw_data_files[band]['prefix']}{name}")
                        continue
                    row_lc.append(pd.read_csv(self.lcs[band][0].extractfile(f), sep="\t")[raw_data_files[band][f"{self.lc_type}_headers"]].values)
                elif self.lcs[band][1] == "zip" and band == "v":
                    try:
                        f = self.lcs[band][0].open(f"{raw_data_files[band]['prefix']}{name}.dat")
                    except KeyError:
                        print(f"Cannot find {raw_data_files[band]['prefix']}{name}")
                        continue
                    row_lc.append(pd.read_csv(f, sep=" ", skiprows=1)[raw_data_files[band][f"{self.lc_type}_headers"]].values)                
                else:
                    raise Exception("Dont know how to get data from such files")

                if self.clean and len(row_lc[-1][:,0]) > 20:
                    t, y, yerr = row_lc[-1][:,0], row_lc[-1][:,1], row_lc[-1][:,2]
                    t, y, yerr, _, _, _, _, _ = clip_outliers(t, y, yerr, 
                                                              measurements_in_flux_units = self.lc_type == "flux", 
                                                              initial_clip=self.initial_clean_clip, clean_only=True)
                    row_lc[-1] = np.vstack((t, y, yerr)).T
                                  
            light_curves.append(row_lc)
        return light_curves

    def get_spectra(self, rows):
        """ Given df row(s), return the spectra as numpy arrays """
        spectra = []
        for ind, row in rows.iterrows():
            row_spectra = []
            rowid = row[merge_key[self.use_bands[0]]]
            rez = a.spec_df.query(f"edr3_source_id == '{rowid}'")
            if len(rez) == 0:
                spectra.append([])
                continue
            for si, spect in rez.iterrows():
                filename = self.data_root /  self.lamost_spec_dir / spect["spec_filename"]
                if os.path.exists(filename):
                    row_spectra.append(self._readLRSFits(filename))
            spectra.append(row_spectra)
        return spectra
    
    def _merge_bands(self):
        if len(self.use_bands) == 1:
            self.df = self.dfs[self.use_bands[0]]
        elif len(self.use_bands) == 2 and "v" in self.use_bands and "g" in self.use_bands:
            if self.verbose:
                print("Merging bands...", flush=True,end="")
            self.df = self.dfs["v"].merge(self.dfs["g"], how=self.merge_type, 
                                          left_on=merge_key["v"], right_on=merge_key["g"], suffixes=('_vband', '_gband'))
            if self.verbose:
                print("done.", flush=True)
            
        else:
            raise Exception("Dont know how to merge these bands")

        if self.only_periodic:
            try:
                self.df = self.df[self.df["periodic"]]
            except:
                print("No `periodic` column in the dataframe. Proceeding.")

        target_cols
        # make a list of columns to save for bookkeeping
        df_cols = self.df.columns
        self.bookkeeping_all = []
        for band in self.use_bands:
            for col in bookkeeping_cols[band]:
                if col in df_cols:
                    self.bookkeeping_all.append(col)
                elif f"{col}_{band}band" in df_cols:
                    self.bookkeeping_all.append(f"{col}_{band}band")

        self.metadata_all = []
        for band in self.use_bands:
            for col in metadata_cols[band]:
                if col in df_cols:
                    self.metadata_all.append(col)
                elif f"{col}_{band}band" in df_cols:
                    self.metadata_all.append(f"{col}_{band}band")

        self.target_all = []
        for band in self.use_bands:
            self.target_all += target_cols[band]

        self.target_lookup = {i: x for i,x in enumerate(np.unique(a.df[[target_cols[band][0] for band in a.use_bands]].values.ravel()))}
        
                    
    def _readLRSFits(self, filename, z_corr=True):
        # from https://github.com/fandongwei/pylamost
        
        hdulist = fits.open(filename)
        len_list=len(hdulist)
        if len_list == 1:
            head = hdulist[0].header
            scidata = hdulist[0].data
            coeff0 = head['COEFF0']
            coeff1 = head['COEFF1']
            pixel_num = head['NAXIS1'] 
            specflux = scidata[0,]
            ivar = scidata[1,]
            spec_noconti = scidata[2,]
            wavelength=np.linspace(0,pixel_num-1,pixel_num)
            wavelength=np.power(10,(coeff0+wavelength*coeff1))
            hdulist.close()
        elif len_list == 2:
            head = hdulist[0].header
            scidata = hdulist[1].data
            wavelength = scidata[0][2]
            ivar = scidata[0][1]
            specflux = scidata[0][0]

        if z_corr:
            try:
                # correct for radial velocity of star
                redshift = head['Z']
            except:
                redshift = 0.0
            wavelength = wavelength - redshift * wavelength
            
        return np.vstack((wavelength, specflux, ivar)).T
    
    def _check_and_open_data_files(self):
        
        if len(self.use_bands) == 0:
            raise Exception("Need a least one bandpass to use")
        if not os.path.isdir(self.data_root):
            raise Exception(f"{self.data_root} is not a valid data directory.")
        
        dfs = {}
        lcs = {}
        for band in self.use_bands:
            if not os.path.exists(self.data_root / raw_data_files[band]["tab"]):
                raise Exception(f"Missing tabular data for {band}.")
            if not os.path.exists(self.data_root / raw_data_files[band]["lcs"]):
                raise Exception(f"Missing light curve data for {band}.")

            if self.verbose:
                print(f"Opening {band} data files...", flush=True, end="")

            dfs[band] = pd.read_csv(self.data_root / raw_data_files[band]["tab"])
            if self.verbose:
                print(f" Found {len(dfs[band])} sources. ", end="")
                
            lcs_file_type = "".join((self.data_root / raw_data_files[band]["lcs"]).suffixes)
            
            if lcs_file_type == ".zip":
                lcs[band] = (zipfile.ZipFile(self.data_root / raw_data_files[band]["lcs"]), "zip")
            elif lcs_file_type in [".tar.gz", ".tgz", ".tar"]:
                lcs[band] = (tarfile.open(self.data_root / raw_data_files[band]["lcs"], "r"), "tar")
            else:
                raise Exception(f"Dont know how to open {self.data_root / raw_data_files[band]['lcs']}")

            if self.verbose:
                print("done.", flush=True)

            self.lcs = lcs
            self.dfs = dfs

        if self.lamost_spec_file is not None and (self.data_root / self.lamost_spec_file).exists():
            if self.verbose:
                print("Opening spectra csv...", flush=True, end="")
            self.spec_df = pd.read_csv(self.data_root / self.lamost_spec_file)
            if self.verbose:
                print("done.", flush=True)
            if self.only_sources_with_spectra:
                sources_with_spectra = pd.unique(self.spec_df["edr3_source_id"])
                for band in self.use_bands:
                    if self.verbose:
                        print(f"Keeping only {band} band sources with spectra...", flush=True, end="")
                    dfs[band] = dfs[band][dfs[band][merge_key[band]].isin(sources_with_spectra)]
                    if self.verbose:
                        print(f" Left with {len(dfs[band])} sources. ", end="")
                        print(f"done.", flush=True)
            

In [172]:
a=ASASSNVarStarDataset(Path("data/asaasn"),10,verbose=True,only_periodic=True,
                       recalc_period=False,prime=True,use_bands=["v","g"])

Opening v data files... Found 687695 sources. done.
Opening g data files... Found 378861 sources. done.
Opening spectra csv...done.
Keeping only v band sources with spectra... Left with 26412 sources. done.
Keeping only g band sources with spectra... Left with 25965 sources. done.
Merging bands...done.
Priming tarballs by doing initial scan...done.


In [173]:
a.return_items_as_list = False
## what's the structure of what we just made?
for k, v in a[0].items():
    s = v[0]
    if isinstance(s, (np.int64, int, float)):
        rez = (1,)
    elif isinstance(s, np.ndarray):
        rez = s.shape
    elif isinstance(s, list):
        if len(s) == 0:
            rez = "None"
        else:
            if isinstance(s[0], (tuple)):
                rez = ", ".join(str(x.shape) for x in s[0]) 
            elif isinstance(s[0], (str, float, int)):
                rez = f"[{len(s)}]"
            else:
                rez = ", ".join(str(x.shape) for x in s)
                    
    else:
        rez = "?"
    print(k, rez)
a.return_items_as_list = True

bookkeeping_data [27]
classes (2,)
lcs (167, 3), (209, 3)
spectra (3879, 3), (3909, 3)
phased (167, 3), (209, 3)
metadata (89,)


In [215]:
rez = a[[0,1,3,4]]
len(rez[0])

4

In [154]:
lcs, metadata, spectra, targets, bookkeeping = a[[0,1]]

In [171]:
from torch.nn.utils.rnn import pad_sequence
len(batch[0][0])

1

In [224]:
def collate_fn(batch):
    return len(batch)

In [225]:
train_dataloader = DataLoader(a, batch_size=7, shuffle=True,collate_fn=collate_fn )

In [226]:
for idx, batch in enumerate(train_dataloader):
    break

In [227]:
batch

7

In [204]:
batch[3][0][0].shape

torch.Size([1])

## Spectra merge

In [260]:
merge_id_df = pd.read_csv("merge_id.csv")
merge_id_df["index"] = merge_id_df.index
lamost_spec_df = pd.read_csv("data/asaasn/Spectra/27976.csv", sep="|")
lamost_spec_df["inputobjs_input_id"] = lamost_spec_df["inputobjs_input_id"] - 1 # this is the iloc for the merge_id_df

In [261]:
lamost_spec_df = lamost_spec_df.merge(merge_id_df,left_on="inputobjs_input_id",right_on="index",how="left")

def make_spec_name(r):
    return f'spec-{r["combined_lmjd"]}-{r["combined_planid"]}_sp{r["combined_spid"]:02}-{r["combined_fiberid"]:03}.fits.gz'

lamost_spec_df['spec_filename'] = lamost_spec_df.apply(make_spec_name, axis=1)
lamost_spec_df.to_csv("data/asaasn/Spectra/lamost_spec.csv")


In [262]:
!head data/asaasn/Spectra/lamost_spec.csv

,inputobjs_input_id,inputobjs_input_ra,inputobjs_input_dec,inputobjs_dist,combined_obsid,combined_obsdate,combined_lmjd,combined_mjd,combined_planid,combined_spid,combined_fiberid,combined_class,combined_subclass,combined_z,combined_ra,combined_dec,combined_feh,combined_logg,combined_rv,combined_teff,raj2000,dej2000,edr3_source_id,index,spec_filename
0,8,280.33355,44.91092,1.6667680014096964,457215016,2016-05-07,57516,57515,HD184435N434959B01,15,16,STAR,F6,1.44433e-05,280.3341642,44.9110786,-0.221,4.223,4.33,6113.43,280.33355,44.91092,EDR3 2117581746386552576,8,spec-57516-HD184435N434959B01_sp15-016.fits.gz
1,25,352.58827,21.71119,0.2740173706497962,260211063,2014-11-05,56967,56966,EG232823N195308V01,11,63,STAR,G3,1.93467e-05,352.588347,21.711164,0.132,4.169,5.8,5849.0,352.58827,21.71119,EDR3 2827427930745553152,25,spec-56967-EG232823N195308V01_sp11-063.fits.gz
2,43,119.68567,15.35703,1.9002163538280932,864113050,2020-12-08,59192,59191,HD074946N150623V01,13,50,STAR,A6,0.0001126112,119.

In [231]:
for idx, chunk in enumerate(np.array_split(a.df[["raj2000","dej2000","edr3_source_id"]], 5)):
    chunk.to_csv(f'merge_a.{idx}.csv',index=False)

In [235]:
tmp = a.df[["raj2000","dej2000","edr3_source_id"]]
# tmp["sep"] = 2.0
tmp.to_csv(f'merge_id.csv',index=False)

In [236]:
!head merge_id.csv

raj2000,dej2000,edr3_source_id
258.64478,30.29155,EDR3 1333085560085564544
293.81352,23.63857,EDR3 2021010643745700608
310.79987,67.24633,EDR3 2246124001521354368
311.80428,19.52859,EDR3 1814028193934473728
278.76235,-36.44085,EDR3 6733359322482079744
145.10705,-41.66644,EDR3 5425390701060256640
328.50332,45.58807,EDR3 1973594032944176640
128.78229,-46.19009,EDR3 5521855975772074240
280.33355,44.91092,EDR3 2117581746386552576
