In [1]:
import re
from astropy.table import Table, vstack
import os
import lcdata
import numpy as np



In [11]:
"""
def extract_data(kind):
    if kind == 'ps':
        root_directory = './data/Pan-STARRS/'
    elif kind == 'sdss': 
        root_directory = './data/SDSS/'
    elif kind == 'uvot':
        root_directory = './data/SOUSA'
    elif kind == 'csp':
        root_directory = './data/CSP'
    else: 
        print('Kind not defined, please enter one of ps, sdss, uvot, or csp')
    light_curves = []
    # Loop through each directory and its subdirectories
    for dirpath, dirnames, filenames in os.walk(root_directory):
        # Process files based on their names

        # Save object_id from dirpath
        object_id = dirpath.split("/")[-1]


        lc_tot = []
        
        for filename in filenames:
            telescope_type = None
            # Check if the file has a specific name pattern
            if filename.startswith(f'lc2fit') and filename.endswith('.dat'):
                if 'SWOPE' in filename:
                    telescope_type = 'cps'
                elif 'UVOT' in filename or kind == 'uvot':
                    telescope_type = 'uvot'

                # Construct the full path to the file
                file_path = os.path.join(dirpath, filename)

                phase = []
                flux = []
                fluxerr = []
                zp = []
                csp_bands = []

                with open(file_path, 'r') as file:
                    for line in file:
                        band_match = re.search(r'@BAND\s+(.+)', line)
                        if band_match:
                            band_name = band_match.group(1)
                            if kind == 'ps':
                                band = f'ps1::{band_name[-1]}'


                            elif kind == 'sdss':
                                if band_name.split("_")[0].lower() =='sdss':
                                    band = f'sdss::{band_name[-1]}'
                                elif band_name.split("_")[0].lower() == 'swope':
                                    if 'LC' in band_name:
                                        # Handle case like 'SWOPE_V-LC-3014'
                                        formatted_band_name = band_name.replace('SWOPE_', '').replace('-LC', '').replace('-', '')
                                        band = f'csp{formatted_band_name.lower()}'
                                    else:
                                        # Handle case like 'SWOPE_r'
                                        formatted_band_name = band_name.replace('SWOPE_', '').lower()
                                        band = f'csp{formatted_band_name}'

                            elif kind == 'uvot':
                                band = f'uvot::{band_name.split("_")[-1]}'


                            elif kind =='csp':
                                if 'SWOPE' in band_name:
                                    if 'LC' in band_name:
                                        # Handle case like 'SWOPE_V-LC-3014'
                                        formatted_band_name = band_name.replace('SWOPE_', '').replace('-LC', '').replace('-', '')
                                        band = f'csp{formatted_band_name.lower()}'
                                    else:
                                        # Handle case like 'SWOPE_r'
                                        formatted_band_name = band_name.replace('SWOPE_', '').lower()
                                        band = f'csp{formatted_band_name}'
                                    csp_bands.append(band)
                                if 'UVOT' in band_name:
                                    band = f'uvot::{band_name.split("_")[-1]}'
                                

                                
                        match = re.match(r'\s*([\d.]+)\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)', line)
                        if match:
                            phase.append(float(match.group(1)))
                            flux.append(float(match.group(2)))
                            fluxerr.append(float(match.group(3)))
                            zp.append(float(match.group(4)))
                    lc = Table({
                        'time': phase,
                        'flux': flux,
                        'fluxerr': fluxerr,
                        'zp': zp,
                        'zpsys': [str('ab').lower()]*len(phase),
                        'band': [str(band).lower()]*len(phase)
                    })
                   
                    lc_tot.append(lc)
        if lc_tot:
            # Ensure 'band' column is always a string
            for lc in lc_tot:
                lc['band'] = lc['band'].astype(str)
                lc['zpsys'] = lc['zpsys'].astype(str)

            # Now try to stack them
            try:
                combined_lc = vstack(lc_tot)
            except Exception as e:
                print(f"Error in stacking tables: {e}")
                continue  # or handle the error as needed
            combined_lc = vstack(lc_tot)
            light_curves.append(combined_lc)

            lightfile_path = os.path.join(dirpath, 'lightfile')
            metadata_dict = {}

            with open(lightfile_path, 'r') as lightfile:
                for line in lightfile:
                    # Split the line and check the length
                    parts = line.strip().split()
                    if len(parts) == 2:
                        key, value = parts

                        if key == 'z_heliocentric':
                            metadata_dict['redshift'] = float(value)
                        elif key != 'z_CMB':
                            metadata_dict['object_id'] = object_id
                            metadata_dict[key] = float(value)
                            
                        
            # Add metadata to the combined table's meta
            combined_lc.meta.update(metadata_dict)

    return light_curves
"""

In [5]:
import os
import re
from astropy.table import Table, vstack

# Helper function to set the root directory
def set_root_directory(kind):
    root_directories = {
        'ps': './data/Pan-STARRS/',
        'sdss': './data/SDSS/',
        'uvot': './data/SOUSA/',
        'csp': './data/CSP/',
        'roman': './data/RomanLCs/'
    }
    return root_directories.get(kind, None)

# Helper function to parse metadata and apply the label map
def parse_metadata(file_path, label_map):
    metadata = {}
    with open(file_path, 'r') as file:
        for line in file:
            parts = line.split()
            key = parts[0]
            values = parts[1:]

            if key in ['Mass', 'z_cmb']:
                continue
            if key == 'z_heliocentric':
                key = 'redshift'
            if key == 'SNTYPE':
                key = 'type'
                sntype_value = int(values[0])
                metadata[key] = label_map.get(sntype_value, "Unknown")
            else:
                metadata[key] = float(values[0]) if len(values) == 1 else [float(v) for v in values]
    return metadata

# Function to create a label map from SNTYPE to the corresponding transient type
def create_label_map():
    return {
        10: "SNIa",
        11: "91bg-like",
        12: "SNIax",
        30: "CCSN",
        40: "SLSN",
        42: "TDE",
        45: "ILOT",
        50: "KNa",
        59: "PISNb"
    }

def determine_band(band_name, kind):
    band = None
    if kind == 'ps':
        band = f'ps1::{band_name[-1]}'
    elif kind == 'sdss':
        if band_name.split("_")[0].lower() == 'sdss':
            band = f'sdss::{band_name[-1]}'
        elif band_name.split("_")[0].lower() == 'swope':
            if 'LC' in band_name:
                # Handle case like 'SWOPE_V-LC-3014'
                formatted_band_name = band_name.replace('SWOPE_', '').replace('-LC', '').replace('-', '')
                band = f'csp{formatted_band_name.lower()}'
            else:
                # Handle case like 'SWOPE_r'
                formatted_band_name = band_name.replace('SWOPE_', '').lower()
                band = f'csp{formatted_band_name}'
    elif kind == 'uvot':
        band = f'uvot::{band_name.split("_")[-1]}'
    elif kind == 'csp':
        if 'SWOPE' in band_name:
            if 'LC' in band_name:
                # Handle case like 'SWOPE_V-LC-3014'
                formatted_band_name = band_name.replace('SWOPE_', '').replace('-LC', '').replace('-', '')
                band = f'csp{formatted_band_name.lower()}'
            else:
                # Handle case like 'SWOPE_r'
                formatted_band_name = band_name.replace('SWOPE_', '').lower()
                band = f'csp{formatted_band_name}'
    elif kind == 'roman':
        band = f'f{band_name[1:]}'
    
    return band

# Main function for extracting light curves and their metadata

def extract_data(kind):
    root_directory = set_root_directory(kind)
    if not root_directory:
        print('Kind not defined, please enter one of ps, sdss, uvot, csp, or roman')
        return []

    label_map = create_label_map()
    light_curves = []
    #dropped_types = ['ILOT', 'PISNb', 'TDE', 'SLSN']

    #Debugging
    #type_exclusions = {type_: 0 for type_ in dropped_types}
    phase_exclusions = 0


    for dirpath, dirnames, filenames in os.walk(root_directory):
        lc_tot = []
        valid_light_curve = True

        for filename in filenames:
            if filename.startswith('lc2fit') and filename.endswith('.dat'):
                file_path = os.path.join(dirpath, filename)
                with open(file_path, 'r') as file:
                    phase, flux, fluxerr, zp = [], [], [], []
                    band = None
                    for line in file:
                        band_match = re.search(r'@BAND\s+(.+)', line)
                        if band_match:
                            band_name = band_match.group(1)
                            band = determine_band(band_name, kind)
                        match = re.match(r'\s*([\d.]+)\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)', line)
                        if match:
                            phase.append(float(match.group(1)))
                            flux.append(float(match.group(2)))
                            fluxerr.append(float(match.group(3)))
                            zp.append(float(match.group(4)))

                    # If any band has less than 4 phases, mark the entire light curve as invalid
                    if len(phase) < 4:
                        valid_light_curve = False
                        phase_exclusions += 1
                        break
                    else:
                        lc = Table({
                            'time': phase,
                            'flux': flux,
                            'fluxerr': fluxerr,
                            'zp': zp,
                            'zpsys': ['ab' for _ in phase],
                            'band': [str(band).lower() for _ in phase]
                        })
                        lc_tot.append(lc)

        if valid_light_curve and lc_tot:
            try:
                combined_lc = vstack(lc_tot)
                lightfile_path = os.path.join(dirpath, 'lightfile')
                metadata = parse_metadata(lightfile_path, label_map)
                """if kind == 'roman' and metadata.get('type') in dropped_types:
                    type_exclusions[metadata.get('type')] += 1
                    continue"""
                combined_lc.meta.update(metadata)
                light_curves.append(combined_lc)
                # Check the type for roman telescopes
                
            except Exception as e:
                print(f"Error in stacking tables: {e}")
                continue
    """print(f"Light curves excluded due to type:")
    for type_, count in type_exclusions.items():
        print(f"{type_}: {count}")"""
    print(f"Light curves excluded due to insufficient phases: {phase_exclusions}")

    return light_curves





In [6]:
#Usage Example
roman_lcs = extract_data('roman')

Light curves excluded due to insufficient phases: 254


In [3]:
len(roman_lcs)

65833

In [7]:
import lcdata
import numpy as np
roman_train = lcdata.from_light_curves(roman_lcs[:60000])
roman_test = lcdata.from_light_curves(roman_lcs[60000:])


In [15]:
roman_test.write_hdf5('roman_test.h5', overwrite='True')
roman_train.write_hdf5('roman_train.h5', overwrite='True')

In [10]:
ps_lcs = extract_data(kind = 'ps')
sdss_lcs = extract_data(kind='sdss')
sousa_lcs = extract_data(kind='uvot')
csp_lcs = extract_data(kind = 'csp')

Error in stacking tables: The 'band' columns have incompatible types: ['str128', 'object', 'str128', 'str256', 'object', 'str128', 'str128', 'object', 'object', 'object', 'str128', 'object']
Error in stacking tables: The 'band' columns have incompatible types: ['str128', 'object', 'str128', 'str256', 'object', 'str128', 'str128', 'object', 'object', 'object', 'str128', 'object']
Error in stacking tables: The 'band' columns have incompatible types: ['str128', 'str128', 'str256', 'str128', 'object', 'object', 'object', 'object', 'str128', 'object', 'str128']
Error in stacking tables: The 'band' columns have incompatible types: ['str128', 'object', 'str128', 'str256', 'object', 'str128', 'str128', 'object', 'object', 'object', 'str128', 'object']
Error in stacking tables: The 'band' columns have incompatible types: ['str128', 'str128', 'str128', 'str256', 'str128', 'object', 'object', 'str128', 'object']
Error in stacking tables: The 'band' columns have incompatible types: ['str128', 'obj

In [18]:
roman_lcs = extract_data('roman')

Error in stacking tables: The 'zpsys' columns have incompatible types: ['str64', 'float64', 'str64']


In [19]:
len(roman_lcs)

66238

In [7]:
ps_dataset = lcdata.from_light_curves(ps_lcs)
sdss_dataset = lcdata.from_light_curves(sdss_lcs)
swift_dataset = lcdata.from_light_curves(sousa_lcs)
csp_dataset = lcdata.from_light_curves(csp_lcs)

In [59]:
ps_dataset.write_hdf5('ps_data.h5', overwrite=True)
sdss_dataset.write_hdf5('sdss_data.h5', overwrite=True)
swift_dataset.write_hdf5('swift_data.h5', overwrite=True)
csp_dataset.write_hdf5('csp_data.h5', overwrite=True)