In [1]:
# imports
import pandas as pd
import os
import warnings
warnings.filterwarnings("ignore")
import os.path
import re
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
directory = os.getcwd()

# from dust_extinction.parameter_averages import F99, O94
# from dust_extinction.averages import GCC09_MWAvg
# from dust_extinction.shapes import P92
import math

# from getdist.mcsamples import MCSamplesFromCobaya
# import getdist.plots as gdplt
# from cobaya.run import run
# from scipy.optimize import minimize
# from collections import OrderedDict as odict

In [6]:
# 15 February 2024 (Biagio): This version of the SED fitting code contains the Pei model written by N. Butler and the SED fitting function
# shown in Zaninoni PhD thesis (https://www.research.unipd.it/handle/11577/3422978)

path="D:/naoj-grb/sample/mag-30-05-2023/"
grblist = []
filelist = []
for root, dirs, files in os.walk(path):
    for file in files:
        filelist.append(os.path.join(root,file))
        names = file.split('_')
        grblist.append(names[0])
filedf=pd.DataFrame()
filedf['grb'] = grblist
filedf['path'] = filelist
# filedf.to_csv('all_file.csv')
#print(filelist)
#print(grblist)

In [8]:
grblist[0]

'000131A'

In [None]:
# Files with telescopes and filters to identify the wavelengths in the spectrum
filters = pd.read_csv("filters.txt", sep="\t", header=0, index_col=0, engine='python', encoding='ISO-8859-1')
adps = pd.read_csv("reddening.txt", sep='\t', header=0, index_col=0, engine='python', encoding='ISO-8859-1')
schafly = pd.read_csv("SF11_conversions.txt", sep='\t+', header=0, index_col='lambda_eff')

In [None]:
def count(str1, str2):
    
    '''
    Counts how much percentage two strings match.
    ------
    INPUT:
    ------
    str1, str2: strings
    ------
    OUTPUT:
    ------
    match: percentage match
    '''
    
    diff = len(str1) - len(str2)
    if diff < 0:
      temp = str1
      str1 = str2
      str2 = temp
    else:
      pass

    c, j = 0,0
    for i in str1:
        if str2.find(i)>= 0 and j == str1.find(i):
            c += 1
        j+=1

    if len(str1)>0:
        match = c/len(str1)*100
    else:
        match = 0

    return match

def stripcount(str1, str2):
    '''
    Counts how much percentage two strings match excluding special characters
    ------
    INPUT:
    ------
    str1, str2: strings
    ------
    OUTPUT:
    ------
    match: percentage match excluding special characters
    '''

    str1 = ''.join(i for i in str1 if i.isalnum()).lower() ## removes special characters
    str2 = ''.join(i for i in str2 if i.isalnum()).lower()

    diff = len(str1) - len(str2)
    if diff < 0:
      str1 = str1 + ("-"*diff)
    elif diff > 0:
      str2 = str2 + ("-"*diff)
    else:
      pass

    c, j = 0,0
    for i in str1:
        if str2.find(i)>= 0 and j == str1.find(i):
            c += 1
        j+=1

    if len(str1)>0:
        match = c/len(str1)*100
    else:
        match = 0

    return match

def count_hst(str1, str2):
    wave1 = re.findall("\d+", str1)
    wave2 = re.findall("\d+", str2)
    if wave1 == wave2:
        str1="".join(re.findall("[a-zA-Z]+", str1))
        str2="".join(re.findall("[a-zA-Z]+", str2))
        match = count(str1, str2)
    else:
        match = 0

    return match

In [None]:
def calibration(band: str, telescope: str):

    ## Step 1: initialising data input

    ## finding the bandpass and filter in data

    sep = ['-','_','.',',']

    j = 0
    for i in sep:
        if i in band:
            bandpass, filter = band.split(i)
            if filter.lower in ['unfiltered', 'clear']:
                filter = 'clear'
            else:
                if len(filter) > len(bandpass):
                    filter, bandpass = band.split(i)
            j += 1
    if j == 0:
        filter = band
        bandpass = ''

    ## assumptions

    assume_R = ['-', '—', '|', '\\', '/', '35', '145', 'P-', 'P—', 'P|', 'P\\', 'P/', 'P', 'polarised', 'polarized', 'unfiltered', 'clear', 'CR', 'lum', 'N', 'IR-cut', 'TR-rgb', 'RM'] # q

    for i in assume_R:
      if filter.casefold() == i:
        filter = 'Rc' ## Gendre's suggestion

    if filter == 'CV': ## CV clear calibrated as V
      filter ='V'

    if filter == 'BJ':
      filter = 'B'

    if filter == 'VJ':
      filter = 'V'

    if filter == 'UJ':
      filter = 'U'

    if filter == 'BM':
      filter = 'B'

    if filter == 'UM':
      filter = 'U'

    ## formatting

    if filter == 'KS':
      filter = 'Ks'

    if filter == 'IC':
      filter = 'Ic'

    if filter == 'RC':
      filter = 'Rc'

    if filter == 'UB':
      filter = 'U'

    if "'" in filter:
        filter = filter.replace("'","p")

    if "*" in filter:
        filter = filter.replace("*","p")

    if "+" in telescope:
      telescope=telescope.split("+")[0]

    if "'" in filter:
        filter = filter.replace("'","p")

    if "*" in filter:
        filter = filter.replace("*","p")

    if "+" in telescope:
      telescope=telescope.split("+")[0]

    ## finding the observatory, telescope, instrument in data
    observatory, telescope, instrument = telescope.split('/')

    if instrument == 'CCD':
      instrument = 'None'

    if "." in telescope:
      telescope=telescope.replace(".", ",")


    ## Step 2: checking if the data filter exists in grblc filters

    for id in filters.index:
      grblc_fil = str(id).split(".")[-1]
      if len(filter) >= 5:
        filters.loc[id, 'match_fil'] = count_hst(grblc_fil, filter)
      else:
        if grblc_fil == filter:
          filters.loc[id, 'match_fil'] = 100
        else:
          filters.loc[id, 'match_fil'] = count(grblc_fil, filter)

    matched_fil = filters.loc[filters['match_fil'] == 100]
    if len(matched_fil) == 0 and len(filter) <= 2:
      matched_fil = filters.loc[filters['match_fil'] >= 50]
      if len(matched_fil) == 0:
        raise KeyError(
            f"No matching filters.")
    elif len(matched_fil) == 0 and len(filter) > 2:
       raise KeyError(
            f"No matching filters.")

    ## Step 3: finding exact matches in observatory, telescope, instrument

    probablefilters = []

    for id in matched_fil.index:

      grblc_obs, grblc_tel, grblc_ins, *__ = str(id).split(".")

      if grblc_obs.casefold() == observatory.casefold():

        matched_fil.loc[id, 'match_obs'] = 'found'

        match_tel =  count(grblc_tel.casefold(), telescope.casefold())
        match_ins =  count(grblc_ins.casefold(), instrument.casefold())

        if match_tel == 100:
          matched_fil.loc[id, 'match_tel'] = 1
          if match_ins == 100:
            matched_fil.loc[id, 'match_ins'] = match_ins
          elif match_ins >= 50:
            matched_fil.loc[id, 'match_ins'] = match_ins
          else:
            matched_fil.loc[id, 'match_ins'] = match_ins

        elif match_tel >= 50:
            matched_fil.loc[id, 'match_tel'] = 2
            matched_fil.loc[id, 'match_ins'] = match_ins

        else:
          matched_fil.loc[id, 'match_tel'] = 3
          matched_fil.loc[id, 'match_ins'] = None

      else:
         matched_fil.loc[id, 'match_obs'] = None
         matched_fil.loc[id, 'match_tel'] = None
         matched_fil.loc[id, 'match_ins'] = None

    matched_obs =  matched_fil.loc[matched_fil['match_obs'] == 'found']
    if len(matched_obs) != 0:
      matched_tel = matched_obs.loc[matched_obs['match_tel'] == np.max(matched_obs['match_tel'])]
      matched_tel =  matched_tel.sort_values(by=['match_ins'])
      probablefilters = list(matched_tel.index)

    ## Step 4: in case of no match, resort to generics

    #standard = ['Johnson', 'Cousins', 'Bessel', 'Special', 'Tyson', 'SDSS', 'SuperSDSS', 'Stromgren', 'MKO', 'UKIRT', 'UKIDSS', 'PS1']

    if len(probablefilters) == 0:
      for id in matched_fil.index:

        grblc_obs = str(id).split(".")[0]

        if grblc_obs.casefold()=='average':
          matched_fil.loc[id, 'match_status'] = 1

        elif grblc_obs.casefold()=='generic':
          matched_fil.loc[id, 'match_status'] = 2

        elif grblc_obs.casefold()=='gcpd':
          matched_fil.loc[id, 'match_status'] = 3

        elif grblc_obs.casefold()=='catalog':
          matched_fil.loc[id, 'match_status'] = 4

        else:
          matched_fil.loc[id, 'match_status'] = None

      matched_gen =  matched_fil.sort_values(by=['match_status'], na_position='last')
      probablefilters = list(matched_gen.index)

    correctfilter = probablefilters[0]

    try:
      lam = float(matched_fil.loc[correctfilter,'lambda_eff'])
    except TypeError:
      lam = float(matched_fil.loc[correctfilter,'lambda_eff'][0])

    lam_round = round(lam, -1)

    shift_toAB = matched_fil.loc[correctfilter,'mag_fromVega_toAB']

    try:
        coeff = schafly.loc[lam_round, '3.1']
        coeff_source = "Schafly+11"
    except KeyError:
        coeff = adps.loc[lam_round, 'Rv']
        coeff_source = "APDS+02"

    return lam, shift_toAB, coeff, correctfilter, coeff_source

In [None]:
def maketable(filename):
    dtype = {
        "time_sec": np.float64,
        "mag": np.float64,
        "mag_err": np.float64,
        "band": str,
        "system": str,
        "telescope": str,
        "extcorr": str,
        "source": str,
        "flag": str
    }
    names = list(dtype.keys())
    #import raw data file
    mag_table = pd.read_csv(
        filename,
        delimiter=r"\t+|\s+",
        names=names,
        dtype=dtype,
        index_col=None,
        header=0,
        engine="python",
        encoding="ISO-8859-1"
    )
#    df = {k: [] for k in ("time_sec", "mag_corr", "mag_err", "lambda", "band", "source")}
    df = {k: [] for k in ("time_sec", "mag_corr", "mag_err", "lambda", "band", "source", "telescope", "flag")}
    for __, row in mag_table.iterrows():
        time_sec = row["time_sec"]
        mag = row["mag"]
        mag_err = row["mag_err"]
        band = row["band"]
        system = row["system"]
        telescope = row["telescope"]
        extcorr = row["extcorr"]
        source = row["source"]
        telescope = row["telescope"]
        flag = row["flag"]
        #applying calibration without coefficient for galactic extinction correction
        lambda_x, zp_f, shiftAB, coeff, correctfilter = calibration(band, telescope)
        A_x = 0 # again assuming correction has ALREADY BEEN APPLIED
        mag_corr = mag - A_x

        if mag_err/mag_corr < 0:
            continue
        else:
            df["time_sec"].append(np.log10(time_sec))
            df["mag_corr"].append(mag_corr)
            df["mag_err"].append(mag_err)
            df["lambda"].append(lambda_x)
            df["band"].append(band)
            df["source"].append(source)
            df["telescope"].append(telescope)
            df["flag"].append(flag)
    df = pd.DataFrame(df)
    df = df.sort_values(by=['time_sec'], ascending=True)
    
    df = df[df["flag"]=="no"]
    df = df[df["mag_err"]<=0.5] # after the 01 November 2023 NAOJ colloquium, we put a cut on the tail of the magerr distribution
    
    spectral = pd.DataFrame()
    skips=[]
    for i,t in enumerate(df['time_sec']):
      flg=0 # flg is used for removing exactly the same datapoints
      df2 = df[np.abs(10**df['time_sec']-10**t) <= (10**t)*0.025]  # 2.5% method
      if len(set(df2["band"].values))>= 3: # use only if there are more than 2 filters
        for j in range(len(skips)):
          #print("df2")
          #print(df2.values)
          #print("skips")
          #print(skips[j])
          if len(df2.values)==len(skips[j]):
              if np.all(df2.values==skips[j]):
                  flg = 1
        if flg == 0:
          skips.append(df2.values)
          df_sub = df2
          df_sub.insert(0,'time_index', i)
          spectral = pd.concat([spectral, df_sub])
    return spectral

In [None]:
#maketable("/content/drive//MyDrive/gamma-ray-gang/sample/533GRBS_mag_with_telescope&systeminfo/Ready_for_Ridha(completefiles)/mag-sys-extcorr-19-04-2023/970228A_magAB_extcorr.txt")

# Marquardt fitting

In [None]:
# NATANIEL BUTLER DUST MAPS

from numpy import loadtxt,log,hstack
from scipy.interpolate import interp1d

pfile='pei_extinct.txt'
l1,x1,l2,x2,l3,x3 = loadtxt(pfile,unpack=True)

def pei_av(lam,A_V=1.0,gal=3,R_V=0.0):
    """
      lam in units of Angstroms
    """
    if (gal==1):
        # Milky Way
        if (R_V==0): R_V=3.08
        ll=1.*l1[::-1]
        xx=1.*x1[::-1]
    elif (gal==2):
        # LMC
        if (R_V==0): R_V=3.16
        ll=1.*l2[::-1]
        xx=1.*x2[::-1]
    else:
        # SMC, gal=3
        if (R_V==0): R_V=2.93
        ll=1.*l3[::-1]
        xx=1.*x3[::-1]


    ll_minus = 1.e4
    xx_minus = (xx[1]-xx[0])/log(ll[1]/ll[0])*log(ll_minus/ll[0]) + xx[0]
    ll_plus = 0.1
    xx_plus = (xx[-1]-xx[-2])/log(ll[-1]/ll[-2])*log(ll_plus/ll[-2]) + xx[-2]

    xx = hstack((xx_minus,xx,xx_plus))
    ll = hstack((ll_minus,ll,ll_plus))

    # in angstroms
    lambda0 = 1.e4/ll

    A_lam = A_V*( 1+xx/R_V )
    res = interp1d(log(lambda0),A_lam,bounds_error=False,fill_value=0)

    return res(log(lam))

In [None]:
# Output test for a single value
pei_av(3800,A_V=1.0,gal=3,R_V=2.93)

In [None]:
def beta_calc_marquardt_dustext(grb, z, grblist, filelist):
    
    print("GRB = ", grb, ", at z =", z, ", host galaxy model (Pei 1992)")
    plotnumber = 0

    for ll in range(len(grblist)):
        if grblist[ll] == grb:
            filename = filelist[ll]
            gamma = []
            beta_bf_marquardt = []
            beta_bf_marquardt_err = []
            beta_af_marquardt = []
            beta_af_marquardt_err = []

            dataframegrb = []
            dataframebetas = []
            dataframebetaerrors = []
            dataframeAV = []
            dataframeAVerrors = []
            dataframegalmodel = []
            dataframetimes = []
            dataframefilters = []
            dataframeredchi2 = []
            dataframeintercept = []
            dataframeintercepterrors = []
            dataframersquared = []
            dataframeprobability = []
            dataframeoutliersources = []
            dataframeplotnumber = []

            log_lam_marquardt_outlier = []
            mag_marquardt_outlier = []
            mag_err_marquardt_outlier = []
            filters_list_outlier = []
            timebands_list_outlier = []
            source_list_outlier = []
            telescope_list_outlier = []

            gooddata_source = []
            gooddata_telescope = []
            warningslist = []
            
            # create table of matching times
            df = maketable(filename=filename)

            print(df)
            if len(df) != 0:
                df = df[df["mag_err"] != 0]
                iters = [*set(df["time_index"].values)]
                bands = [*set(df["band"].values)]
                sources = [*set(df["source"].values)]
                finaltimeslist = []
                finalspectralist = []
                
                for i in iters:
                    plotnumber = plotnumber + 1
                    spectral = df.loc[df["time_index"] == i]
                    timespectra = spectral["time_sec"].tolist()
                    timebands = spectral["band"].tolist()
                    sources = spectral["source"].tolist()
                    timespectra = list(set(timespectra))
                    timebands = list(set(timebands))
                    finaltimeslist.append(timespectra)
                    finalspectralist.append(timebands)
                    plotnumbersublist = []
                    mag = []
                    mag_err = []
                    log_lam = []
                    filters_list = []
                    timebands_list = []
                    sources_list = []
                    telescope_list = []
                    for band in bands:
                        band_slice = spectral.loc[spectral["band"] == band]
                        log_lams = np.log10(band_slice["lambda"].values)
                        mags = band_slice["mag_corr"].values
                        magerrs = band_slice["mag_err"].values
                        sourcelabels = band_slice["source"].values
                        telescopelabels = band_slice["telescope"].values
                        filters_list.extend(band_slice["band"].values)
                        timebands_list.extend(band_slice["time_sec"].values)
                        for jj in range(len(magerrs)):
                            if magerrs[jj] != 0.0:
                                log_lam.append(log_lams[jj])
                                mag.append(mags[jj])
                                mag_err.append(magerrs[jj])
                                sources_list.append(sourcelabels[jj])
                                telescope_list.append(telescopelabels[jj])
                    
                    if len(set(filters_list)) > 3:

                        X = log_lam
                        y = mag
                        weights_list = [1 / err for err in mag_err] 

                        # Weights are 1/err since lmfit makes residuals=weights*(data-model) and then minimizes the square of the residuals https://stackoverflow.com/questions/58251958/take-errors-on-data-into-account-when-using-lmfit
                        # Look also in https://lmfit.github.io/lmfit-py/model.html
                        
                        def SEDmodelSMC(beta,intercept,AV,x):
                            lam = 10**x
                            return intercept - 2.5*beta*x -2.5*(pei_av(lam,A_V=AV,gal=3,R_V=2.93)-pei_av(lam/(1+z),A_V=AV,gal=3,R_V=2.93))
                        
                        modelSMC = Model(SEDmodelSMC, independent_vars=['x'])
                        parsSMC = Parameters()
                        
                        parsSMC.add('intercept', value=40, min=0, max=100)
                        parsSMC.add('beta', value=0.8, min=-10, max=10)
                        parsSMC.add('AV', value=1, min=0, max=10)

                        print("Fitting SMC model...")
                        resultsSMC = modelSMC.fit(y, parsSMC, x=X, weights=weights_list)                       
                        chisquareSMC = resultsSMC.chisqr
                        rsquaredSMC = resultsSMC.rsquared
                        redchisquareSMC = resultsSMC.redchi
                        slopeSMC = resultsSMC.params.get("beta").value
                        slopeerrSMC = resultsSMC.params.get("beta").stderr
                        avfitSMC = resultsSMC.params.get("AV").value
                        avfiterrSMC = resultsSMC.params.get("AV").stderr
                        interceptfitSMC = resultsSMC.params.get("intercept").value
                        interceptfiterrSMC = resultsSMC.params.get("intercept").stderr

                        nuSMC=len(y)
                        xxSMC=redchisquareSMC*nuSMC
                        probSMC=(2**(-nuSMC/2)/math.gamma(nuSMC/2))*scipy.integrate.quad(lambda x: math.exp(-x/2)*x**(-1+(nuSMC/2)),xxSMC,np.inf)[0]

                        def SEDmodelLMC(beta,intercept,AV,x):
                            lam = 10**x
                            return intercept - 2.5*beta*x -2.5*(pei_av(lam,A_V=AV,gal=2,R_V=3.16)-pei_av(lam/(1+z),A_V=AV,gal=2,R_V=3.16))
                        
                        modelLMC = Model(SEDmodelLMC, independent_vars=['x'])
                        parsLMC = Parameters()
                        
                        parsLMC.add('intercept', value=40, min=0, max=100)
                        parsLMC.add('beta', value=0.8, min=-10, max=10)
                        parsLMC.add('AV', value=1, min=0, max=10)

                        print("Fitting LMC model...")
                        resultsLMC = modelLMC.fit(y, parsLMC, x=X, weights=weights_list)                       
                        chisquareLMC = resultsLMC.chisqr
                        rsquaredLMC = resultsLMC.rsquared
                        redchisquareLMC = resultsLMC.redchi
                        slopeLMC = resultsLMC.params.get("beta").value
                        slopeerrLMC = resultsLMC.params.get("beta").stderr
                        avfitLMC = resultsSMC.params.get("AV").value
                        avfiterrLMC = resultsLMC.params.get("AV").stderr
                        interceptfitLMC = resultsLMC.params.get("intercept").value
                        interceptfiterrLMC = resultsLMC.params.get("intercept").stderr

                        nuLMC=len(y)
                        xxLMC=redchisquareLMC*nuLMC
                        probLMC=(2**(-nuLMC/2)/math.gamma(nuLMC/2))*scipy.integrate.quad(lambda x: math.exp(-x/2)*x**(-1+(nuLMC/2)),xxLMC,np.inf)[0]

                        def SEDmodelMW(beta,intercept,AV,x):
                            lam = 10**x
                            return intercept - 2.5*beta*x -2.5*(pei_av(lam,A_V=AV,gal=1,R_V=3.08)-pei_av(lam/(1+z),A_V=AV,gal=1,R_V=3.08))
                        
                        modelMW = Model(SEDmodelMW, independent_vars=['x'])
                        parsMW = Parameters()
                        
                        parsMW.add('intercept', value=40, min=0, max=100)
                        parsMW.add('beta', value=0.8, min=-10, max=10)
                        parsMW.add('AV', value=1, min=0, max=10)

                        print("Fitting MW model...")
                        resultsMW = modelMW.fit(y, parsMW, x=X, weights=weights_list)                       
                        chisquareMW = resultsMW.chisqr
                        rsquaredMW = resultsMW.rsquared
                        redchisquareMW = resultsMW.redchi
                        slopeMW = resultsMW.params.get("beta").value
                        slopeerrMW = resultsMW.params.get("beta").stderr
                        avfitMW = resultsMW.params.get("AV").value
                        avfiterrMW = resultsMW.params.get("AV").stderr
                        interceptfitMW = resultsMW.params.get("intercept").value
                        interceptfiterrMW = resultsMW.params.get("intercept").stderr

                        nuMW=len(y)
                        xxMW=redchisquareMW*nuMW
                        probMW=(2**(-nuMW/2)/math.gamma(nuMW/2))*scipy.integrate.quad(lambda x: math.exp(-x/2)*x**(-1+(nuMW/2)),xxMW,np.inf)[0]

                        probhostmodels=[1-probMW, 1-probLMC, 1-probSMC]

                        pivot=np.where(probhostmodels == np.min(np.abs(probhostmodels)))[0][0]

                        if pivot==0:
                            hostmodel="MW"
                            galnumber=1
                            RVnumber=3.08
                      
                            chisquare = chisquareMW
                            rsquared = rsquaredMW
                            redchisquare = redchisquareMW
                            slope = slopeMW
                            slopeerr = slopeerrMW
                            avfit = avfitMW
                            avfiterr = avfiterrMW
                            interceptfit = interceptfitMW
                            interceptfiterr = interceptfiterrMW
    
                            prob = probMW

                        if pivot==1:
                            hostmodel="LMC"
                            galnumber=2
                            RVnumber=3.16
                      
                            chisquare = chisquareLMC
                            rsquared = rsquaredLMC
                            redchisquare = redchisquareLMC
                            slope = slopeLMC
                            slopeerr = slopeerrLMC
                            avfit = avfitLMC
                            avfiterr = avfiterrLMC
                            interceptfit = interceptfitLMC
                            interceptfiterr = interceptfiterrLMC
    
                            prob = probMW
                        
                        if pivot==2:
                            hostmodel="SMC"
                            galnumber=3
                            RVnumber=2.93
                      
                            chisquare = chisquareSMC
                            rsquared = rsquaredSMC
                            redchisquare = redchisquareSMC
                            slope = slopeSMC
                            slopeerr = slopeerrSMC
                            avfit = avfitSMC
                            avfiterr = avfiterrSMC
                            interceptfit = interceptfitSMC
                            interceptfiterr = interceptfiterrSMC
    
                            prob = probSMC

                        print("The selected model according to its probability is ",hostmodel)
                        
                        reject_plot = False

                        if slope < 0:
                            reject_plot = False

                        print(plotnumber,prob)
                        
                        # 26feb commented
                        if abs(slopeerr) > abs(slope):
                           reject_plot = True

                        # if abs(avfiterr) > abs(avfit):
                        #     reject_plot = True
                        
                        if avfit < 0:
                            reject_plot = True

                        # if abs(interceptfiterr) > abs(interceptfit):
                        #     reject_plot = True

                        y_marquardtforoutliers = [interceptfit - 2.5*slope*xi -2.5*(pei_av(10**xi,A_V=avfit,gal=galnumber,R_V=RVnumber)-pei_av((10**xi)/(1+z),A_V=avfit,gal=galnumber,R_V=RVnumber))  for xi in X]
                        
                        for i in range(len(mag)):

                            if abs(y[i] - y_marquardtforoutliers[i]) <= 3 * mag_err[i]:
                                gooddata_source.append(sources_list[i])
                                gooddata_telescope.append(telescope_list[i])
                        
                        for i in range(len(mag)):
                            
                            if abs(y[i] - y_marquardtforoutliers[i]) > 3 * mag_err[i]:
                                log_lam_marquardt_outlier.append(log_lam[i])
                                mag_marquardt_outlier.append(mag[i])
                                mag_err_marquardt_outlier.append(mag_err[i])
                                filters_list_outlier.append(filters_list[i])
                                timebands_list_outlier.append(timebands_list[i])
                                source_list_outlier.append(sources_list[i])
                                telescope_list_outlier.append(telescope_list[i])

                                print("Single outlier telescope")
                                print(telescope_list[i])
                                
                                if (telescope_list[i] in gooddata_telescope) and (np.char.isnumeric(sources_list[i])):
                                    warningstring=str(grb)+" "+str(telescope_list[i])+" "+str(sources_list[i])
                                    print("WARNING ",warningstring)

                                    with open(str(grb)+'_warnings.txt', 'a') as warnfile:
                                        warnfile.write(warningstring)
                                        warnfile.write('\n')
                                        warnfile.close()
                                      
                        if len(log_lam_marquardt_outlier)==0:
                            outliersprint="NaN"
                            outlierlabelfilename=""
                        else:
                            outliersprint="_".join(set(source_list_outlier))
                            outlierlabelfilename="_outliers"
                        
                        for band_now in set(bands):
                            if reject_plot:
                                break
                            
                            main_points_log_lam = log_lam[timebands_list == band_now]
                            
                            try:
                                #If there is only 1 point or No points corresponding to band_now, then proceed to the next band
                                if len(main_points_log_lam) <= 1:
                                    continue
                                else:
                                    pass
                                
                            except:
                                continue
                                         
                            main_points_mag = mag[timebands_list == band_now]
                            main_points_mag_err = mag_err[timebands_list == band_now]
                            sorted_indices = np.argsort(main_points_mag)
                            point_set_sorted = main_points_mag[sorted_indices]
                            point_set_err_sorted = main_points_mag_err[sorted_indices]
                            
                            for iii in range(len(point_set_sorted) - 1):
                                point_x = point_set_sorted[iii]
                                point_x_plus_1 = point_set_sorted[iii+1]
                                point_x_err = point_set_err_sorted[iii]
                                point_x_plus_1_err = point_set_err_sorted[iii+1]
                                
                                if abs(point_x - point_x_plus_1) > abs(point_x_err) + abs(point_x_plus_1_err):
                                    # Reject plot and go to next fit
                                    reject_plot = True
                                    break
                            
                            if reject_plot:
                                # Even if one band has points separated beyond 1 - sigma, this fit IS REJECTED
                                break
                            
                        
                        if reject_plot:
                            continue       
                        
                        betaoldmarquardt = slope
                        betaoldmarquardterr = slopeerr if slopeerr is not None else np.inf
                        gamma.append(grb)
                        beta_bf_marquardt.append(np.negative(betaoldmarquardt))
                        beta_bf_marquardt_err.append(betaoldmarquardterr)

                        print(
                            "Marquardt Levenberg Beta:",
                            slope,
                            "Marquardt Levenberg Beta error:",
                            betaoldmarquardterr,
                        )

                        # y_marquardtforoutliers = [interceptfit - 2.5*slope*xi -2.5*(pei_av(10**xi,A_V=avfit,gal=galnumber,R_V=RVnumber)-pei_av((10**xi)/(1+z),A_V=avfit,gal=galnumber,R_V=RVnumber))  for xi in X]
                        
                        # for i in range(len(mag)):

                        #     if abs(y[i] - y_marquardtforoutliers[i]) <= 3 * mag_err[i]:
                        #         gooddata_source.append(sources_list[i])
                        #         gooddata_telescope.append(telescope_list[i])
                        
                        # for i in range(len(mag)):
                            
                        #     if abs(y[i] - y_marquardtforoutliers[i]) > 3 * mag_err[i]:
                        #         log_lam_marquardt_outlier.append(log_lam[i])
                        #         mag_marquardt_outlier.append(mag[i])
                        #         mag_err_marquardt_outlier.append(mag_err[i])
                        #         filters_list_outlier.append(filters_list[i])
                        #         timebands_list_outlier.append(timebands_list[i])
                        #         source_list_outlier.append(sources_list[i])
                        #         telescope_list_outlier.append(telescope_list[i])

                        #         print("Single outlier telescope")
                        #         print(telescope_list[i])
                                
                        #         if (telescope_list[i] in gooddata_telescope) and (np.char.isnumeric(sources_list[i])):
                        #             print("WARNING ",telescope_list[i]," ",sources_list[i])
                        
                        # if len(log_lam_marquardt_outlier)==0:
                        #     outliersprint="NaN"
                        #     outlierlabelfilename=""
                        # else:
                        #     outliersprint="_".join(set(source_list_outlier))
                        #     outlierlabelfilename="_outliers"
                        
                        X = np.sort(X)
                        y_marquardt = [interceptfit - 2.5*slope*xi -2.5*(pei_av(10**xi,A_V=avfit,gal=galnumber,R_V=RVnumber)-pei_av((10**xi)/(1+z),A_V=avfit,gal=galnumber,R_V=RVnumber))  for xi in X]                        
                        fig, ax = plt.subplots()
                        ax.invert_yaxis()
                        plt.plot(X, y_marquardt, label=r"$\beta_{opt}:\,$"+str(round(slope,3))+"+/-"+str(round(betaoldmarquardterr,3))+'\n'+
                                "$A_{V}:\,$"+str(round(avfit,3))+"+/-"+str(round(avfiterr,3))+'\n'+str(hostmodel))
                        plt.errorbar(log_lam, mag, yerr=mag_err, fmt="o")
                        plotnumbersublist.append(plotnumber)
                        plt.xlabel(r'$\log_{10}\lambda (\AA)$')
                        plt.ylabel("Magnitude")
                        
                        if min(timebands_list) == max(timebands_list):
                            plt.title(
                                'Bands=' + str(set(filters_list))
                                + '\n' + 'log10(t)_min-max='
                                + str(round(min(timebands_list), 3))
                                + '-' + str(round(max(timebands_list), 3))
                                + '\n' + r'$\chi^{2}=$' + str(round(chisquare, 2))
                                + ',Red. ' + r'$\chi^{2}=$' + str(round(redchisquare, 2))
                                + ',Prob. ' + str(round(prob, 3)))
                            plt.legend()
                            fig.tight_layout()
                            fig.text(0.1, 0.1, "GRB " + str(grb), fontsize=18, fontweight='bold', horizontalalignment='left', verticalalignment='bottom', transform=ax.transAxes)
                            plt.savefig(
                                str(grb) + "_beta"
                                + str(round(slope, 3))
                                + "_betaerr"
                                + str(round(betaoldmarquardterr, 3))
                                + "_plotn"
                                + str(plotnumber)
                                + outlierlabelfilename
                                + ".png")
                            plt.clf()

                            dataframegrb.append(str(grb))
                            dataframebetas.append(str(round(slope,3)))
                            dataframebetaerrors.append(str(round(betaoldmarquardterr, 3)))
                            dataframeAV.append(str(round(avfit,3)))
                            dataframeAVerrors.append(str(round(avfiterr,3)))
                            dataframegalmodel.append(str(hostmodel))
                            dataframeintercept.append(str(round(interceptfit,3)))
                            dataframeintercepterrors.append(str(round(interceptfiterr,3)))
                            dataframetimes.append(str(round(min(timebands_list), 4)))
                            dataframefilters.append(str(set(filters_list)))
                            dataframeredchi2.append(str(round(redchisquare, 2)))
                            dataframersquared.append(rsquared)
                            dataframeprobability.append(prob)
                            dataframeoutliersources.append(outliersprint)
                            dataframeplotnumber.append(str(plotnumber))
    
                            # print(str(grb)+" "+str(round(slope,3))+" "+str(round(betaoldmarquardterr, 3))+" "+str(round(avfit,3))+" "+str(round(avfiterr,3))
                            #      +" "+str(hostmodel)+" "+str(round(interceptfit,3))+" "+str(round(interceptfiterr,3))+" "+str(round(min(timebands_list), 4))
                            #      +" "+str(set(filters_list))+" "+str(round(redchisquare, 2))+" "+rsquared+" "+prob+" "+outliersprint+" "+str(plotnumber))
                        
                        else:
                            plt.title(
                                'Bands=' + str(set(filters_list))
                                + '\n' + 'log10(t)_min-max='
                                + str(round(min(timebands_list), 3))
                                + '-' + str(round(max(timebands_list), 3))
                                + '\n' + r'$\chi^{2}=$' + str(round(chisquare, 2))
                                + ',Red. ' + r'$\chi^{2}=$' + str(round(redchisquare, 2))
                                + ',Prob. ' + str(round(prob, 3)))
                            plt.legend()
                            fig.tight_layout()
                            fig.text(0.1, 0.1, "GRB " + str(grb), fontsize=18, fontweight='bold', horizontalalignment='left', verticalalignment='bottom', transform=ax.transAxes)
                            plt.savefig(
                                str(grb) + "_beta"
                                + str(round(slope, 3))
                                + "_betaerr"
                                + str(round(betaoldmarquardterr, 3))
                                + "_plotn"
                                + str(plotnumber)
                                + outlierlabelfilename
                                + ".png")
                            plt.clf()

                            dataframegrb.append(str(grb))
                            dataframebetas.append(str(round(slope,3)))
                            dataframebetaerrors.append(str(round(betaoldmarquardterr, 3)))
                            dataframeAV.append(str(round(avfit,3)))
                            dataframeAVerrors.append(str(round(avfiterr,3)))
                            dataframegalmodel.append(str(hostmodel))
                            dataframeintercept.append(str(round(interceptfit,3)))
                            dataframeintercepterrors.append(str(round(interceptfiterr,3)))
                            dataframetimes.append(str(round(min(timebands_list), 4))+"-"+str(round(max(timebands_list), 4)))
                            dataframefilters.append(str(set(filters_list)))
                            dataframeredchi2.append(str(round(redchisquare, 2)))
                            dataframersquared.append(str(rsquared))
                            dataframeprobability.append(prob)
                            dataframeoutliersources.append(outliersprint)
                            dataframeplotnumber.append(str(plotnumber))
    
                            # print(str(grb)+" "+str(round(slope,3))+" "+str(round(betaoldmarquardterr, 3))+" "+str(round(avfit,3))+" "+str(round(avfiterr,3))
                            #      +" "+str(hostmodel)+" "+str(round(interceptfit,3))+" "+str(round(interceptfiterr,3))+" "+str(round(min(timebands_list), 4))
                            #      +" "+str(set(filters_list))+" "+str(round(redchisquare, 2))+" "+rsquared+" "+prob+" "+outliersprint+" "+str(plotnumber))
                    
                    else:
                        print("Not 4 different bands at least for the fitting")

            else:
                print("No matching GRB found in the dataset.")


            dict = {"GRB": dataframegrb, "beta": dataframebetas, "beta_err": dataframebetaerrors, "AV": dataframeAV, "AV_err": dataframeAVerrors,
                    "Gal.model": dataframegalmodel, "intercept": dataframeintercept, "intercept_err": dataframeintercepterrors, "log10t": dataframetimes, 
                        "filters": dataframefilters, "Red.Chi2": dataframeredchi2, "R-squared": dataframersquared, "probability": dataframeprobability, "outliers": dataframeoutliersources, "plotnumb": dataframeplotnumber}
                            
            dfexp = pd.DataFrame(dict, columns=["GRB","beta","beta_err","AV","AV_err","Gal.model","intercept","intercept_err","log10t","filters","Red.Chi2","R-squared","probability","outliers","plotnumb"]) # "redchi2"

            dfexp.to_csv(str(grb)+'-results.csv')
    
    return

In [None]:
lit_betas_all_info = pd.read_csv("betas-to-compute_10.txt", sep=',', index_col=None, header=0) # substitute here the filename of your batch
GRBs = lit_betas_all_info['GRB'].values # the file betas-to-compute has 2 columns: the first with GRB label and the second with the redshift
redshift = lit_betas_all_info['z'].values

In [None]:
def singleGRBsSED(grb):
    return beta_calc_marquardt_dustext(grb, float(lit_betas_all_info[lit_betas_all_info["GRB"]==grb]["z"]), grblist, filelist)

In [None]:
# Test the function on a single GRB
# singleGRBsSED("080928A")

In [None]:
# Loop for multiple GRBs

import multiprocessing
from joblib import Parallel, delayed
from tqdm import tqdm

num_cores = multiprocessing.cpu_count()
inputs = tqdm(GRBs)

if __name__ == "__main__":
    processed_list = Parallel(n_jobs=num_cores)(delayed(singleGRBsSED)(i) for i in GRBs)