# SNe "Simulations" from NIR data

In [1]:
import os
import glob
import itertools

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import snpy
import scipy
import george

from multiprocessing import Pool

%config InlineBackend.figure_format = 'retina'
print(f'SNooPy version: v{snpy.__version__}')

SNOOPY_MODEL = 'max_model'

SNooPy version: v2.6.0


### Auxiliar Functions

In [2]:
def combinations(iterable, r):
    '''Returns all possible combinations without permutation.
    
    **Note:** the output should be inserted in a list
    
    Parameters
    ==========
    iterable: list-like
        List of items from which the combination is computed.
    r: int
        Numbers of items for the combination.
        
    Returns
    =======
    comb_list: list
        List with all combinations.
    '''
    
    comb_list = list(itertools.combinations(iterable, r))
    
    return comb_list            
            
def mag2flux(mag, zp, mag_err=0.0):
    """Converts magnitudes to fluxes, propagating errors if given.
    
    Parameters
    ----------
    mag : array
        Array of magnitudes.
    zp : float or array
        Zero points.
    mag_err : array, default ``0.0``
        Array of magnitude errors.
        
    Returns
    -------
    flux : array
        Magnitudes converted to fluxes.
    flux_err : array
        Magnitude errors converted to errors in fluxes.
    """

    flux = 10**( -0.4*(mag-zp) )
    flux_err =  np.abs( flux*0.4*np.log(10)*mag_err )

    return flux, flux_err

### "Simulations" functions

In [3]:
def update_sn(sn, update_dict):
    """Updates the data of a SN (mjd, flux, mag, etc.).
    
    Parameters
    ==========
    sn: SNooPy object
        SN in a SNooPy object.
    update_dict: dict
        Dictionary with new SN data with bands as keys and sub-dictionaries
        as values. Each sub-dictionary must have `mjd`, `mag` and `mag_err` 
        as keys and their respective values.
        
    Returns
    =======
    sn: SNooPy object
        Updated SN.
    """
    
    for band, data_dict in update_dict.items():

        mjd = data_dict['mjd']
        mag, mag_err = data_dict['mag'], data_dict['mag_err']
        
        sn.data[band].MJD = mjd
        sn.data[band].magnitude = mag
        sn.data[band].mag = mag
        sn.data[band].e_mag = mag_err

        zp = sn.data[band].filter.zp
        flux, flux_err = mag2flux(mag, zp, mag_err)

        # there are some hidden parameters like "_flux" and "_eflux"
        # which are used by other internal functions
        sn.data[band].flux = flux
        sn.data[band]._flux = flux
        sn.data[band].e_flux = flux_err
        sn.data[band]._eflux = flux_err

        # not sure if these are completely necessary, but just in case...
        sn.data[band].sids = np.zeros_like(flux)
        sn.data[band].__dict__['mask'] = np.array([True]*len(flux))
        
    return sn

def filter_coeval_NIRdata(sn):
    """Filters out NIR data of a SNooPy object that are not
    coeval in J/Jrc2 and H bands.
    
    Parameters
    ==========
    sn: SNooPy object
        SN in a SNooPy object.
        
    Returns
    =======
    The updated sn.
    """
    
    # NIR bands labels
    NIR_bands = [band for band in sn.allbands() if band in ['J', 'Jrc2', 'H']]
    NIR_ebands = [f'e_{band}' for band in NIR_bands]  # for errors

    # Find epochs with coeval J/Jrc2 and H bands
    sn_df = pd.DataFrame(sn.get_mag_table())
    sn_df.replace(99.900000, np.nan, inplace=True)

    if ('J' in NIR_bands) & ('Jrc2' in NIR_bands):
        JH_df = sn_df[['MJD', 'J', 'e_J', 'H', 'e_H']].dropna()
        Jrc2H_df = sn_df[['MJD', 'Jrc2', 'e_Jrc2', 'H', 'e_H']].dropna()
        sn_df = pd.concat([JH_df, Jrc2H_df], axis=0, join="outer")
    else:
        sn_df = sn_df[['MJD'] + NIR_bands + NIR_ebands].dropna()

    # update SN J/Jrc2 and H bands data
    NIR_dict = {band:None for band in NIR_bands}
    for band, eband in zip(NIR_bands, NIR_ebands):

        indeces = sn_df[band].dropna().index
        mjd = sn_df['MJD'].loc[indeces].values
        mag = sn_df[band].loc[indeces].values
        mag_err = sn_df[eband].loc[indeces].values
        NIR_dict[band] = {'mjd':mjd, 'mag':mag, 'mag_err':mag_err}

    sn = update_sn(sn, NIR_dict)
    
    return sn

def extract_combinatories(values, n_epochs):
    """Extracts a list of all the combinatories of `values` with 
    `n_epochs` number of items each.
    
    Parameters
    ==========
    values: list-like
        List of values from which the combinatories are extracted.
    n_epochs: int
        Number of items per combinatory.
        
    Returns
    =======
    list_alt_values: list
        List of all the combinatories, where each combinatory is a numpy array.
    """
    
    comb_values_tuples = combinations(values, n_epochs)
    list_comb_values = [np.array(tup) for tup in comb_values_tuples]

    return list_comb_values

def extract_combinatories_dict(sn, n_epochs):
    """Extracts a dictionary with all the combinatories of `mjd`, `mag` and `mag_err`
    for a SN in J/Jrc2 and H bands.
    
    Parameters
    ==========
    sn: SNooPy object
        SN in a SNooPy object.
    n_epochs: int
        Number of items per combinatory.
        
    Returns
    =======
    NIR_dict: dict
        Dictionary with the NIR bands as keys and sub-dictionaries as values.
        Each sub-dictionary has `mjd`, `mag` and `mag_err` as keys and a list
        of combinatories (the results of `alternate_values()`) as values.
    """
    
    # NIR bands labels
    NIR_bands = [band for band in sn.allbands() if band in ['J', 'Jrc2', 'H']]
    NIR_dict = {band:None for band in NIR_bands}
    
    for band in NIR_bands:      
        if ('J' in NIR_bands) & ('Jrc2' in NIR_bands) & ('J' in band):
            if band=='Jrc2':
                # skip this band as I join J and Jrc2
                continue
            else:
                mjd = np.r_[sn.data['J'].MJD, sn.data['Jrc2'].MJD]
                mag = np.r_[sn.data['J'].magnitude, sn.data['Jrc2'].magnitude]
                mag_err = np.r_[sn.data['J'].e_mag, sn.data['Jrc2'].e_mag]
                bands = ['J']*len(sn.data['J'].MJD) + ['Jrc2']*len(sn.data['Jrc2'].MJD)
        else:
            mjd = sn.data[band].MJD.copy()
            mag = sn.data[band].magnitude.copy()
            mag_err = sn.data[band].e_mag.copy()
            bands = [band]*len(sn.data[band].MJD)
        
        mjd_list = extract_combinatories(mjd, n_epochs)
        mag_list = extract_combinatories(mag, n_epochs)
        mag_err_list = extract_combinatories(mag_err, n_epochs)
        band_list = extract_combinatories(bands, n_epochs)
        
        NIR_dict[band] = {'mjd':mjd_list, 
                          'mag':mag_list, 
                          'mag_err':mag_err_list,
                          'band':band_list}
            
    return NIR_dict

def calculate_metrics(sn, n_epochs):
    """Calculates the metrics for the NIR bands of a SN.
    
    Parameters
    ==========
    sn: SNooPy object
        SN in a SNooPy object.
    n_epochs: int
        Number of epochs per NIR band.
        
    Returns
    =======
    m1, m2, m3: floats
        Metrics:- m1 is the phase of the closest epoch to T.max.
                - m2 is the phase of the phases.
                - m3 is the difference between the first and last epoch.
        **Note:** m2 is only calculated if n_epochs >= 2, and m3 
        if n_epochs == 3. Otherwise, these are NaNs.
    """

    assert n_epochs<=3, 'Metrics are only calculated for `n_epochs` <= 3'
    phases = sn.data['H'].MJD - sn.Tmax

    if n_epochs==1:
        m1 = phases[0]
        m2 = m3 = np.nan

    elif n_epochs==2:
        m1_ind = np.argmin(np.abs(phases))
        m1 = phases[m1_ind]
        m2 = np.mean(phases)
        m3 = np.nan

    elif n_epochs==3:
        m1_ind = np.argmin(np.abs(phases))
        m1 = phases[m1_ind]
        m2 = np.mean(phases)
        m3 = phases[-1] - phases[0]
        
    return m1, m2, m3

def get_parameter(sn, parameter, include_sys=False):
    """Obtains the value of the given parameter and
    the total (systematics + statistical) uncertainty
    of a fitted SN.
    
    **Note:** NaN is returned if the parameter is not
    found.
    
    Parameters
    ==========
    sn: SNooPy object
        SN in a SNooPy object.
    parameter: str
        Parameter to extract.
    include_sys: bool
        If `True`, systematic uncertainties are 
        included.
        
    Returns
    =======
    value: float
        Value of the parameter.
    err: float
        Total uncertainty in the parameter. 
    """
    
    if parameter in sn.parameters.keys():
        value = sn.parameters[parameter]
        stats_err = sn.errors[parameter]
        if include_sys:
            sys_err = sn.systematics()[parameter]
        else:
            sys_err = 0.0
        err = np.sqrt(stats_err**2 + sys_err**2)
    else:
        value = err = np.nan
    
    return value, err

def extract_lc_params(sn):
    """Extracts the estimated light-curve parameters.
    
    Parameters
    ==========
    sn: SNooPy object
        SN in a SNooPy object.
        
    Returns
    =======
    lc_dict: dict
        Dictionary with the light-curve parameter names as keys
        and parameter values as values.
    """

    parameters = ['Tmax', 'st', 'gmax', 'rmax', 'Jmax', 'Hmax']
    lc_dict = {}
    
    for parameter in parameters:
        value, value_err = get_parameter(sn, parameter)
        lc_dict[parameter] = value
        lc_dict[parameter+'_err'] = value_err
        
    return lc_dict
    
def fit_combinatory_OLD(sn, n_epochs):
    """Fit a SN multiple times for all the combinatories given by `n_epochs`.
    
    **Note:** THIS FUNCTION IS NOT BEING USED, BUT i AM SAVING IT JUST IN CASE
    
    Parameters
    ==========
    sn: SNooPy object
        SN in a SNooPy object.
    n_epochs: int
        Number of items per combinatory.
    """
    
    assert n_epochs<=3, 'Combinatories are only calculated for `n_epochs` <= 3'
    results_dict = {'comb':[], 'm1':[], 'm2':[], 'm3':[],
                   'Tmax':[], 'Tmax_err':[], 'st':[], 'st_err':[],
                   'gmax':[], 'gmax_err':[], 'rmax':[], 'rmax_err':[],
                   'Jmax':[], 'Jmax_err':[], 'Hmax':[], 'Hmax_err':[]}
    
    # outputs directories
    if not os.path.isdir('sim_fits'):
        os.mkdir('sim_fits')
        
    sn_dir = os.path.join('sim_fits', sn.name)
    if not os.path.isdir(sn_dir):
        os.mkdir(sn_dir)
    
    sn = filter_coeval_NIRdata(sn)
    # extract dictionary with all the combinatory for the NIR bands
    comb_dict = extract_combinatories_dict(sn, n_epochs)
    
    NIR_bands = list(comb_dict.keys())
    num_comb = len(comb_dict['H']['mjd'])  # number of combinations
    
    # update SN with each combinatory
    for i in range(num_comb):  
        NIR_dict = {band:{'mjd':[], 'mag':[], 'mag_err':[]} 
                                        for band in NIR_bands}
        
        # Jrc2 is already combined with J at this stage, 
        # but splitted below
        for band in ['J', 'H']:
            band_info = comb_dict[band]
            
            mjd = band_info['mjd'][i]
            mag = band_info['mag'][i]
            mag_err = band_info['mag_err'][i]
            bands = band_info['band'][i]
        
            # this sub_bands parts is due to having two J bands,
            # i.e., J and Jrc2, which makes things complicated.
            # Here J and Jrc2 are splitted again as they were combined.
            for j, sub_band in enumerate(bands):
                NIR_dict[sub_band]['mjd'].append(mjd[j])
                NIR_dict[sub_band]['mag'].append(mag[j])
                NIR_dict[sub_band]['mag_err'].append(mag_err[j])
            
        # turn lists into arrays for SNooPy
        for band in NIR_bands:
            for key in NIR_dict[band].keys():
                NIR_dict[band][key] = np.array(NIR_dict[band][key])
            
        sn = update_sn(sn, NIR_dict)
        
        # As there is J and Jrc2 bands, sometimes one of them can be empty
        # and needs to be removed, but then added again below
        empty_lc = None
        for band in sn.allbands():
            if len(sn.data[band].MJD)==0:
                empty_lc = sn.data.pop(band)
                popped_band = band
                
        sn.filter_order = None  # this is necessary as we remove bands

        try:
            bands2fit = [band for band in sn.allbands() 
                                 if band in ['g', 'r', 'J', 'Jrc2', 'H']]
            sn.choose_model(SNOOPY_MODEL)
            sn.fit(bands2fit)
            
            # save plot with fits
            outfile = os.path.join(sn_dir, 
                                   f'{sn.name}_{n_epochs}epochs_comb{i}.jpg')
            sn.plot(outfile=outfile)
            
            # get the empty light curve back
            if empty_lc is not None:
                sn.data[popped_band] = empty_lc
                
            # time metrics
            m1, m2, m3 = calculate_metrics(sn, n_epochs)
            results_dict['comb'].append(i)
            results_dict['m1'].append(m1)
            results_dict['m2'].append(m2)
            results_dict['m3'].append(m3)
            
            # light-curve parameters
            lc_dict = extract_lc_params(sn)
            for key, value in lc_dict.items():
                results_dict[key].append(value)
            
        except Exception as message:
            print(f'Combinatory number {i} failed for {sn.name}: {message}!')
            if empty_lc is not None:
                sn.data[popped_band] = empty_lc
                
    # save results
    results_df = pd.DataFrame(results_dict)
    results_file = os.path.join(sn_dir, f'{sn.name}_{n_epochs}epochs_results.csv')
    results_df.to_csv(results_file, index=False)

In [22]:
def fit_combinatory(sn, n_epochs):
    """Fits a SN multiple times for all the combinatories given by `n_epochs`.
    The fits are performed with grJ, grH and grJH bands.
    
    Parameters
    ==========
    sn: SNooPy object
        SN in a SNooPy object.
    n_epochs: int
        Number of items per combinatory.
    """
    
    assert n_epochs<=3, 'Combinatories are only calculated for `n_epochs` <= 3'

    # output parameters
    parameters = ['comb', 'm1', 'm2', 'm3', 
                  'Tmax', 'Tmax_err', 'st', 'st_err',
                  'gmax', 'gmax_err', 'rmax', 'rmax_err',
                  'Jmax', 'Jmax_err', 'Hmax', 'Hmax_err']
    suffixes = ['J', 'H', 'JH']  # NIR bands combinations/suffixes for outputs
    
    all_results_dict = {suffix:{param:[] for param in parameters}
                                                for suffix in suffixes}
    # set of bands to fit
    set_bands2fit = [['g', 'r', 'J', 'Jrc2'],  # NIR J only
                     ['g', 'r', 'H'],  # NIR H only
                     ['g', 'r', 'J', 'Jrc2', 'H']]  # NIR J and H
    
    # outputs directories
    if not os.path.isdir('sim_fits'):
        os.mkdir('sim_fits')
        
    sn_dir = os.path.join('sim_fits', sn.name)
    if not os.path.isdir(sn_dir):
        os.mkdir(sn_dir)
    
    sn = filter_coeval_NIRdata(sn)
    # extract dictionary with all the combinatory for the NIR bands
    comb_dict = extract_combinatories_dict(sn, n_epochs)
    
    NIR_bands = list(comb_dict.keys())
    num_comb = len(comb_dict['H']['mjd'])  # number of combinations
    
    # update SN with each combinatory
    for i in range(num_comb):  
        NIR_dict = {band:{'mjd':[], 'mag':[], 'mag_err':[]} 
                                        for band in NIR_bands}
        
        # Jrc2 is already combined with J at this stage (if there is J and 
        # Jrc2), but splitted below
        for band, band_info in comb_dict.items(): 
            if band_info is not None:
                mjd = band_info['mjd'][i]
                mag = band_info['mag'][i]
                mag_err = band_info['mag_err'][i]
                bands = band_info['band'][i]

                # this sub_bands parts is due to having two J bands,
                # i.e., J and Jrc2, which makes things complicated.
                # Here J and Jrc2 are splitted again as they were combined.
                for j, sub_band in enumerate(bands):
                    NIR_dict[sub_band]['mjd'].append(mjd[j])
                    NIR_dict[sub_band]['mag'].append(mag[j])
                    NIR_dict[sub_band]['mag_err'].append(mag_err[j])
            
        # turn lists into arrays for SNooPy
        for band in NIR_bands:
            for key in NIR_dict[band].keys():
                NIR_dict[band][key] = np.array(NIR_dict[band][key])
            
        sn = update_sn(sn, NIR_dict)
        
        # As there is J and Jrc2 bands, sometimes one of them can be empty
        # and needs to be removed, but then added again below
        empty_lc = None
        for band in sn.allbands():
            if len(sn.data[band].MJD)==0:
                empty_lc = sn.data.pop(band)
                popped_band = band
                
        sn.filter_order = None  # this is necessary as we remove bands

        try:            
            for bands, suffix in zip(set_bands2fit, suffixes):
                bands2fit = [band for band in sn.allbands() 
                                     if band in bands]
                sn.choose_model(SNOOPY_MODEL)
                sn.fit(bands2fit)

                # save plot with fits
                outfile = f'{sn.name}_{n_epochs}epochs_comb{i}_{suffix}.jpg'
                outfile = os.path.join(sn_dir, outfile)
                sn.plot(outfile=outfile) 

                # time metrics
                m1, m2, m3 = calculate_metrics(sn, n_epochs)
                all_results_dict[suffix]['comb'].append(i)
                all_results_dict[suffix]['m1'].append(m1)
                all_results_dict[suffix]['m2'].append(m2)
                all_results_dict[suffix]['m3'].append(m3)

                # light-curve parameters
                lc_dict = extract_lc_params(sn)
                for key, value in lc_dict.items():
                    all_results_dict[suffix][key].append(value)

            # get the empty light curve back
            if empty_lc is not None:
                sn.data[popped_band] = empty_lc
            
        except Exception as message:
            print(f'Combinatory number {i} failed for {sn.name}: {message}!')
            if empty_lc is not None:
                sn.data[popped_band] = empty_lc
                
    # save results for J, H and JH
    for keys, results_dict in all_results_dict.items():
        results_df = pd.DataFrame(results_dict)
        output_file = f'{sn.name}_{n_epochs}epochs_results_{keys}.csv'
        results_file = os.path.join(sn_dir, output_file)
        results_df.to_csv(results_file, index=False)

### Parallelisation

In [14]:
def run_fits(input_pair):
    """Function to parallelise `fit_combinatory`.
    
    Parameters
    ==========
    input_pair: list
        List with the first item a SN file (str) 
        and the second item `n_epochs` (int).
        
    Example
    =======
    processes = 8
    Pool(processes).map(run_fits, 
                        ([sn_name, n_epochs] 
                            for sn_file in sn_files))
    """
    
    sn_file, n_epoch = input_pair
    sn = snpy.import_lc(sn_file)
    try:
        fit_combinatory(sn, n_epochs)
    except Exception as message:
        print(f'{sn.name} failed with n_epochs={n_epochs}: {message}')

In [None]:
processes = 4
with open('reference_files.txt') as ref_file:
    sn_files = ref_file.read().splitlines()

for n_epochs in range(1, 4):
    print(f'Fitting SNe with n_epochs={n_epochs}')
    %time Pool(processes).map(run_fits, ([sn_file, n_epochs] for sn_file in sn_files))