# Notes

I am cutting out/excluding:

- parts of isochrones that cannot be fitted with a polynomial
- cluster stars that do not fall within isochrone range
- cluster stars that do not fall within isochrone range when doing interpolations
- I am not counting in empty bins when fitting functions to the IMFs
- I am not taking Kroopa intervals into account if they contain one or none bins


Other notes:

- The slopes are most reliable for $0.5$ to $1.0 \,$M$_{\odot}$ because there, both isochrone models are more complete

**MIST**

Challenges with MIST:

- Contains giant branch, which makes isochrone polynomial fit impossible $\rightarrow$ I am cutting off the isochrone at magnitude 0
- Does not include high enough colours (only goes slightly past colour value of 3 whereas Baraffe goes past 5)
- The lack of colours is because is does not contain masses below 0.1 solar masses whereas Baraffe goes down to 0.01
- This limitation in colour removes a lot of the low-mass stars, which means that the age is underestimated and the IMF is more limited

Benefits with MIST:

- Contains higher stellar masses
- Contains lower magnitudes

**Baraffe**

Challenges with Baraffe:

- Only contains few isochrones. This can somewhat be removed as a problem by interpolating between isochrones
- Only has low magnitudes
- Has a short colour range
- Only contains stellar masses below $1.4 \,$M$_{\odot}$
- The upper mass limit makes all high-mass stars get an interpolated mass of $1.4 \,$M$_{\odot}$, which is mis-leading and gives a bad IMF slope


Benefits with Baraffe:

- So far gives similar ages to the pre-determined ones
- Contains low-mass stars
- Contains very red stars

# Imports

In [117]:
# General imports
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize as scopt
from scipy import stats
import sympy as sp


# Astropy imports
import astropy.units as u
import astropy.constants as c
from astropy.table import QTable
from astropy.coordinates import Distance
from astropy.coordinates import SkyCoord
from astropy.io import votable

# Special import
from sklearn.cluster import KMeans


np.set_printoptions(threshold=5000)

In [119]:
plt.rcParams.update({'xtick.labelsize':15, 'ytick.labelsize':15, 'axes.titlesize':18, 
                     'axes.grid':True, 'axes.labelsize':14, 'legend.fontsize':14})

# Functions

## Importing the data

The distance modulus is calculated through \textit{distance.distmod} and it is mathematically given by

\begin{equation}
    \mu = 5 \log_{10} d - 5 = m - M \iff M = m - \mu.
\end{equation}

In [2]:
def cluster_import(file_name):
    """
    Imports the cluster data file and adds a distance column, an absolute magnitude column
    and a G_BP-G_RP column
    ---------------------------------------------------------------------------------------
    Parameters:
    
    file_name: str
        The name of the data file that contains the data
        
    ---------------------------------------------------------------------------------------
    Output:
    
    data: astropy QTable
        Returns an astropy QTable with all the data including the units
    """
    
    data = QTable.read(f'Data/Clusters/{file_name}.txt', 
                         names=['GaiaID', 'gal_long', 'gal_lat', 'parallax', 'e_parallax', 'RA_pm', 
                                'e_RA_pm', 'DE_pm', 'e_DE_pm', 'M_apparent', 'G_BP', 'G_RP', 
                                'Flag', 'Cluster_id', 'RA_icrs', 'DE_icrs'], 
                         units = [u.m/u.m, u.deg, u.deg, u.mas, u.mas, u.mas/u.yr, u.mas/u.yr, 
                                  u.mas/u.yr, u.mas/u.yr, u.mag, u.mag, u.mag, u.m/u.m, u.m/u.m, 
                                  u.deg, u.deg], delimiter=' ', format = 'ascii')
    
    # Adds a distance column and converts the parallaxes to distance in pc
    data['dist'] = Distance(parallax=data['parallax'])


    # Calculates and adds an absolute magnitude column calculated from the photometric mean magnitude 
    # and the distance modulus, described above
    data['M_V'] = data['M_apparent'] - data['dist'].distmod

    # Calculating and adding G_bp - G_rp
    data['bp_rp'] = data['G_BP'] - data['G_RP']

    
    return data

## Importing and separating isochrones from one file

In [3]:
def separate_isochrones(model, metallicity, survey):
    """
    Parameters:
    -----------
    model: str
        Name of isochrone model
        
    metallicity: str
        String with metallicity value
        
    survey: str
        Which survey was used, which affects the magnitude units
    """
    all_iso = [] # List to gather all data for all isochrones in
    models = ['Baraffe', 'MIST', 'Marigo'] # List of models
    # List of filenames
    file_names = [f'Baraffe_{survey}.txt', 
                  f'MIST_v1.2_feh_{metallicity}_vvcrit0.00.txt']
    
    list_position = models.index(model) # Finding list position for model
    file_name = file_names[list_position] # Extracting filename
    
    # Opens file to read content
    with open(f'Data/Isochrones/{model}/{file_name}', "r") as iso_file: 
        age = [] # List to put isochrone data divided by age (final output of function)
        lines = [] # List to put all the line information into
        
        # Loops over every line and stores the data for all lines (all isochrones) in a list
        for line in iso_file: # loop through each line
            if line.startswith('#'): # do not add lines that don't contain any data
                continue
            elif line.startswith('!'):
                continue
            elif line.startswith(' !'):
                continue
            elif len(line)<2:
                continue
            else: # If it contains data, add it to list
                val = line.split() # split each line and store them in a list
                lines.append(val) # Store list of line values in another list
        
        
        if model=='Baraffe':
            iso_data_list = [] # List to fill with isochrone data points
            iso_ages_list= []
            # Looping over all lines and adding interesting ones to data lists
            for i, l_vals in enumerate(lines, start=1): # l_vals = list of line/column values
                #print(l_vals)
                # if equal to 1 means new age
                
                if len(l_vals)==1: # means old list is finished
                    if iso_data_list != []:
                        # Want to add finished list to list with all isochrones
                        # converting to array with shape:(n_points, parameters)
                        iso_data_array = np.array(iso_data_list) 
                    
                        # adding isochrone's data to isochrone list
                        all_iso.append(iso_data_array.astype(float))
            
                    # making new list
                    iso_data_list = []
                    iso_ages_list.append(l_vals)
                    
            
                else: # if the age is the same 
                    iso_data_list.append(l_vals) # filling old list
            
                if i==len(lines): # appends the last isochrone
                    iso_data_list.append(l_vals)
                    # converting to array with shape:(n_points, parameters)
                    iso_data_array = np.array(iso_data_list)
                    # adding isochrone's data to isochrone list
                    all_iso.append(iso_data_array.astype(float)) 
            
            iso_ages_array = np.array(iso_ages_list, dtype=float)
            
              
        if model=='MIST':
            iso_data_list = [] # List to fill with isichrone data points
            iso_ages_list = []
            iso_age = lines[0][1]
            iso_ages_list.append(iso_age) # setting first iso_age
            iso_data_list.append(lines[0]) # Adding first data line to list
            # Looping over all lines and adding interesting ones to data lists
            for i, l_vals in enumerate(lines, start=1): # l_vals = list of line/column values
                
                line_age = l_vals[1]
                
                # checking if isochrone age is same as before, is same
                # If True, then this means that the old list is finished
                if line_age!=iso_age: 
                    # Want to add finished list to list with all isochrones
                    # converting to array with shape:(n_points, parameters)
                    iso_data_array = np.array(iso_data_list) 
                    # adding isochrone's data to isochrone list
                    all_iso.append(iso_data_array.astype(float)) 
            
                    # Starting new list with new age
                    iso_age = l_vals[1] # Otherwise replacing it with the new value
                    iso_ages_list.append(iso_age)
                    
                    # making new list
                    iso_data_list = []
                    iso_data_list.append(l_vals)
            
                else: # if the age is the same   
                    iso_data_list.append(l_vals) # filling old list
            
                if i==len(lines): # appends the last isochrone
                    iso_data_list.append(l_vals)
                    # converting to array with shape:(n_points, parameters)
                    iso_data_array = np.array(iso_data_list) 
                    # adding isochrone's data to isochrone list
                    all_iso.append(iso_data_array.astype(float)) 
                    
                iso_ages_arr = np.array(iso_ages_list, dtype=float) # in MIST unit
                iso_ages_array = 10**(iso_ages_arr)*10**(-9) # Converts to Gyr
                
    iso_file.close()   
    
    # Only keeping the relevant information from the isochrones
    for i, iso in enumerate(all_iso):
        
        if survey=='gaia':
            new_iso_array = np.empty((len(iso), 4)) # Contains: time, mass, M_V, bp_rp
            
            if model=='Baraffe':
                # Adding the age to the new array
                new_iso_array[:, 0] = iso_ages_array[i]
                new_iso_array[:, 1] = iso[:, 0] # Adding the mass to the new array 
                new_iso_array[:, 2] = iso[:, 22] # Adding M_V
                    
                bp = iso[:, 23]
                rp = iso[:, 24]
                new_iso_array[:, 3] = bp-rp # Adding bp_rp
                
            elif model=='MIST':
                # Adding the age to the new array and converting to Gyr
                new_iso_array[:, 0] = 10**(iso[:, 1])*10**(-9) 
                new_iso_array[:, 1] = iso[:, 3] # Adding the mass to the new array (initial mass vs star mass???)
                new_iso_array[:, 2] = iso[:, 30] # Adding M_V
                
                bp = iso[:, 31]
                rp = iso[:, 32]
                new_iso_array[:, 3] = bp-rp # Adding bp_rp to array
                #new_iso_array = new_iso_array[new_iso_array[:, 3].argsort()] # sort after colour
                magnitude_mask = 0<new_iso_array[:, 2]
                new_iso_array = new_iso_array[magnitude_mask]
                
                
        elif survey=='2mass':
            new_iso_array = np.empty((len(iso), 5)) # Contains: time, mass, J, H, K
            
            if model=='Baraffe':
                # Adding the age to the new array
                new_iso_array[:, 0] = iso_ages_array[i]
                new_iso_array[:, 1] = iso[:, 0] # Adding the mass to the new array 
                new_iso_array[:, 2] = iso[:, 6] # Adding J
                new_iso_array[:, 3] = iso[:, 7] # Adding H
                new_iso_array[:, 4] = iso[:, 8] # Adding K
                
            elif model=='MIST':
                # Adding the age to the new array and converting to Gyr
                new_iso_array[:, 0] = 10**(iso[:, 1])*10**(-9) 
                new_iso_array[:, 1] = iso[:, 3] # Adding the mass to the new array (initial mass vs star mass???)
                new_iso_array[:, 2] = iso[:, 14] # Adding J
                new_iso_array[:, 3] = iso[:, 15] # Adding H
                new_iso_array[:, 4] = iso[:, 16] # Adding K
                
                #magnitude_mask = 0<new_iso_array[:, 2]
                #new_iso_array = new_iso_array[magnitude_mask]
                
        all_iso[i] = new_iso_array
    
    return all_iso, iso_ages_array
            
            
            

## Importing the isochrone data and extracting fitting parameters

In [4]:
def isochrone_params(isos_data, model, survey):
    """
    Imports the isochrone data file
    ---------------------------------------------------------------------------------------
    Parameters:
    
    file_names: lst
        List of the file names of the data files that contains each isochrone's data
        
    model: str or list
        The name of the model isochrone to use
        
    ---------------------------------------------------------------------------------------
    Output:
    
    data: array
        Returns an array with all the isochrone data for plotting with [G, G_BP-G_RP]
    """
    # Useful lists
    #models = ['Baraffe'] # List of model names
    
    #column_names = [['Mass', 'Teff', 'Luminosity', 'g', 'Radius', 'Li', 'F33', 'F33B', 'F41', 'F45B', 
    #                 'F47', 'F51', 'FHa', 'F57', 'F63B', 'F67', 'F75', 'F78', 'F82', 'F82B', 'F89', 
    #                 'G_RSV', 'M_V', 'G_BP', 'G_RP']] # List of column names for each model
    
    #units = [[u.Msun, u.K, u.Lsun, u.m/u.m, u.Rsun, u.m/u.m, u.m/u.m, u.m/u.m, u.m/u.m, u.m/u.m, 
    #          u.m/u.m, u.m/u.m, u.m/u.m, u.m/u.m, u.m/u.m, u.m/u.m, u.m/u.m, u.m/u.m, u.m/u.m, 
    #          u.m/u.m, u.m/u.m, u.m/u.m, u.mag, u.mag, u.mag]] # List of units for the columns for each model
    
    #Useful parameter for the loop
    #i = models.index(model)
    
    #all_iso_data = []
    
    fit_params = np.empty((6, len(isos_data)))
    
    for j, isochrone in enumerate(isos_data):
        
        #iso_data = QTable.read(f'Data/Isochrones/{model}/Isochrone_{model}_{isochrone}.txt', 
        #                       names=column_names[i], units = units[i], delimiter=' ', 
        #                       format = 'ascii')
        
        #iso_data['bp_rp'] = iso_data['G_BP'] - iso_data['G_RP'] # Calculating colour
        
        #data = np.empty((len(iso_data), 3))
        #data[:, 0] = iso_data['M_V']
        #data[:, 1] = iso_data['bp_rp']
        #data[:, 2] = iso_data['Mass']
        
        #all_iso_data.append(data)
        
        if survey=='gaia':
        
            magn_mask = (isochrone[:, 2]<16) & (isochrone[:, 2]>0)# Magnitude mask
        
            iso_x = isochrone[:, 3][magn_mask] # Colour
            iso_y = isochrone[:, 2][magn_mask] # Magnitude
        
            k5, k4, k3, k2, k1, c = np.polyfit(iso_x, iso_y, 5)
            
            
        elif survey=='2mass':
        
            magn_mask = (isochrone[:, 2]<16) & (isochrone[:, 2]>0)# J mask
        
            iso_x = isochrone[:, 2][magn_mask] - isochrone[:, 4][magn_mask] # Colour, J-K
            iso_y = isochrone[:, 2][magn_mask] # J
        
            k5, k4, k3, k2, k1, c = np.polyfit(iso_x, iso_y, 5)
        
        
        fit_params[0, j] = k5
        fit_params[1, j] = k4 
        fit_params[2, j] = k3 
        fit_params[3, j] = k2 
        fit_params[4, j] = k1 
        fit_params[5, j] = c
    
    return fit_params
    

## Plotting isochrones with data

In [5]:
def plotting_iso_and_data(all_data, all_iso_data, cluster_names, model, survey, data_alpha=0.5,
                          CMD_or_mass='CMD'): # n_plots=1
    """
    Plots the data and/or isochrones
    ----------------------------------
    Parameters:
    
    all_data: list of arrays
        List of data for different clusters
        
    all_iso_data: list of arrays
        List of data from one or several different isochrones
        
    cluster_names: list
        List or string with names of the used clusters
        
    n_plots: int
        The number of plots to plot
        
    Output:
    --------
    Plot
    """
    # 16 different colours to plot
    colours = ['b', 'r', 'deepskyblue', 'firebrick', 'cyan', 'crimson', 'teal', 'peru', 'orange',  
               'blueviolet', 'silver', 'purple', 'dimgray', 'magenta', 'g', 'lawngreen']
    
    rev_colours = colours[::-1]
    
    if all_data==None: # Only plotting isochrones
        fig, ax = plt.subplots(figsize=(5, 4))
        if survey=='gaia':
            if CMD_or_mass == 'CMD':
                for i, iso in enumerate(all_iso_data):
                        ax.plot(iso[:, 3], iso[:, 2], color=colours[i], label=f't = {iso[i, 0]:.1e} Gyr')
            
                ax.set_xlabel(r'G$_{BP}$ - G$_{RP}$ [mag]')
                ax.set_ylabel(r'M$_{V}$ [mag]')
                ax.set_title(f'CMD for {len(all_iso_data)} {model} isochrones')
                ax.invert_yaxis()
            
            elif CMD_or_mass == 'mass':
                for i, iso in enumerate(all_iso):
                    ax.plot(iso[:, 3], iso[:, 1], color=colours[i], label=f't = {iso[i, 0]:.1e} Gyr')
            
                ax.set_xlabel(r'G$_{BP}$ - G$_{RP}$ [mag]')
                ax.set_ylabel(r'Mass [M$_{\odot}$]')
                ax.set_title(f'Mass vs colour for {len(all_iso_data)} {model} isochrones')
                
        if survey=='2mass':
            if CMD_or_mass == 'CMD':
                for i, iso in enumerate(all_iso_data):
                    iso_colour = iso[:, 2] - iso[:, 4]
                    ax.plot(iso_colour, iso[:, 2], color=colours[i], label=f't = {iso[i, 0]:.1e} Gyr')
            
                ax.set_xlabel(r'J - K [mag]')
                ax.set_ylabel(r'J [mag]')
                ax.set_title(f'CMD for {len(all_iso_data)} {model} isochrones')
                ax.invert_yaxis()
            
            elif CMD_or_mass == 'mass':
                for i, iso in enumerate(all_iso):
                    iso_colour = iso[:, 2] - iso[:, 4]
                    ax.plot(iso_colour, iso[:, 1], color=colours[i], label=f't = {iso[i, 0]:.1e} Gyr')
            
                ax.set_xlabel(r'J - K [mag]')
                ax.set_ylabel(r'Mass [M$_{\odot}$]')
                ax.set_title(f'Mass vs colour for {len(all_iso_data)} {model} isochrones')
            
        ax.legend()
            
        ax.grid(True)
            
        plt.show()
        
    elif all_iso_data==None: # Only plotting data
        fig, ax = plt.subplots(figsize=(5, 4))
        for i, data in enumerate(all_data):
            ax.scatter(data['bp_rp'].value, data['M_V'].value, c=rev_colours[i], 
                       alpha=data_alpha, s=10, label=f'Cluster {cluster_names[i]}')
        
        ax.invert_yaxis()
        ax.set_xlabel(r'G$_{BP}$ - G$_{RP}$ [mag]')
        ax.set_ylabel(r'M$_{V}$ [mag]')
        ax.set_title(f'Cluster data for {len(all_data)} clusters')
        ax.legend()
            
        ax.grid(True)
            
        plt.show()
        
        
    elif (all_data!=None) and (all_iso_data!=None): # Plotting both
        fig, ax = plt.subplots(figsize=(5, 4))
            
        # Plotting data
        for i, data in enumerate(all_data):
            ax.scatter(data['bp_rp'].value, data['M_V'].value, c=rev_colours[i], 
                   alpha=data_alpha, s=10, label=f'Cluster {cluster_names[i]}')
            
        # Plotting isochrones
        for i, iso in enumerate(all_iso_data):
            ax.plot(iso[:, 3], iso[:, 2], color=colours[i], label=f't = {iso[i, 0]} Gyr')
                
        ax.invert_yaxis()
        ax.set_xlabel(r'G$_{BP}$ - G$_{RP}$ [mag]')
        ax.set_ylabel(r'M$_{V}$ [mag]')
        ax.set_title(f'Cluster data for {len(all_data)} clusters and isochrones')
        ax.legend()
            
        ax.grid(True)
            
        plt.show()
            

## Fitting function for isochrones

In [6]:
def fit_fcn(x, args):
    k5, k4, k3, k2, k1, c = args
    return k5*x**5 + k4*x**4 + k3*x**3 + k2*x**2 + k1*x + c

## Chi-square fitting

In [7]:
def chi_fitting(model_fcn, data, iso_params, isochrones, survey):
    """
    Parameters:
    -----------
    model_fcn: fcn
        The model function for the isochrone fit
        
    data: astropy QTable
        The cluster data
        
    iso_params: array
        The fitted isochrone parameters
        
    isochrones: array
        Isochrone data
    ----------------------------------------------
    Output:
    -------
    chisq_value: array
        The sum of all differences for each star for every isochrone shape:(len(isochrones))
    """
    n_isochrones = np.shape(iso_params)[1]
    #diff = np.empty((len(data), n_isochrones)) # shape (data_points,isochrones)
    chisq_value = np.empty((n_isochrones))
    for i in range(n_isochrones):
        if survey=='gaia':
            min_col = np.min(isochrones[i][:, 3])
            max_col = np.max(isochrones[i][:, 3])
            colour_mask = (min_col<=data['bp_rp'].value)&(data['bp_rp'].value<=max_col)
            data = data[colour_mask]
            #magnitude_mask = (0<=data['M_V'].value)&(data['M_V'].value<=16)
            #data = data[magnitude_mask]
        
            sigma=np.ones((len(data)))
        
            params = iso_params[:, i]
        
            # Absolute magnitudes for data according to isochrone
            model_data = model_fcn(data['bp_rp'].value, params)
        
            diff = ((data['M_V'].value - model_data)**2 /sigma) # diff[:, i]
            chisq_value[i] = np.nansum(diff, axis=0)
            
            
        if survey=='2mass':
            # J - K
            iso_col = isochrones[i][:, 2] - isochrones[i][:, 4]
            min_col = np.min(iso_col)
            max_col = np.max(iso_col)
            colour_mask = (min_col<=data['J-K'].value)&(data['J-K'].value<=max_col)
            data = data[colour_mask]
            #magnitude_mask = (0<=data['M_V'].value)&(data['M_V'].value<=16)
            #data = data[magnitude_mask]
        
            sigma=np.ones((len(data)))
        
            params = iso_params[:, i]
        
            # Absolute magnitudes for data according to isochrone
            model_data = model_fcn(data['J-K'].value, params)
        
            diff = ((data['J'].value - model_data)**2 /sigma) # diff[:, i]
            chisq_value[i] = np.nansum(diff, axis=0)
    
    return chisq_value

## Getting IMF

Maybe make possible for many clusters at the same time?

In [163]:
def IMF(colour_data, model_data, cluster_name, model_name, nbins, time, check=True, 
        save_check=False, plot=True, save_plot=True):
    """
    Interpolates the stellar masses from isochrone colours
    
    Parameters:
    -----------
    colour_data: array
        Colour from cluster stars
        
    model_data: array
        Isochrone data containing time, mass, M_V, bp_rp
        
    cluster_name: str
        name of cluster for plots
        
    model_name: str
        Name of ispchrone models
        
    nbins:
    
    time:
        
    check: bool
        True if you want to check the interpolation
        
    save_check: bool
        True if the plot of the interpolation is supposed to be saved
        
    plot: bool
        True if the histograms should be plotted
        
    save_plot: bool
        True if the plots of the histograms should be saved
    """
    
    # Sorting all columns in data according to the colour
    sorted_model_data = model_data[model_data[:, 3].argsort()]
    
    stellar_masses = np.array(np.interp(colour_data, sorted_model_data[:, 3], sorted_model_data[:, 1]))
    
    if check:
        fig1, ax1 = plt.subplots(figsize=(5, 4))
    
        ax1.plot(sorted_model_data[:, 3], sorted_model_data[:, 1], color='b', label='Model data')
        ax1.scatter(colour_data, stellar_masses, c='r', s=5, label='Interpolation', zorder=20)
        
        ax1.set_xlabel(r'G$_{BP}$ - G$_{RP}$ [mag]')
        ax1.set_ylabel(r'Mass [M$_{\odot}$]')
        ax1.set_title(f'Interpolation check, {model_name} model, age = {time:.1e} Gyr')
        
        ax1.legend()
        
        if save_check:
            plt.savefig(f'Plots/Interpolation_check_{cluster_name}_{model_name}_model.png', bbox_inches='tight')
        
        plt.show()
    
    if plot:
        # Plotting histogram in normal scale
        min_mass = np.min(stellar_masses)
        max_mass = np.max(stellar_masses)
        
        fig2, ax2 = plt.subplots(1, 2, figsize=(12, 4))
        
        ax2[0].hist(stellar_masses, bins=nbins,
                     histtype='stepfilled', fc=(0, 0, 1, 0.25), ec='b', lw=1.5)

        ax2[0].set_xlabel(r'Mass [M$_{\odot}$]')
        ax2[0].set_ylabel('Counts')
        ax2[0].set_title(f'{model_name} IMF for {cluster_name}, age = {time:.1e} Gyr: normal scale')
        
        
        # Plotting histogram in log scale
        ax2[1].hist(stellar_masses, range=(np.log10(min_mass), np.log10(max_mass)), 
                     bins=np.logspace(np.log10(min_mass), np.log10(max_mass), nbins+1),
                     histtype='stepfilled', fc=(0, 0, 1, 0.25), ec='b', lw=1.5)

        ax2[1].set_xlabel(r'Log Mass [M$_{\odot}$]')
        ax2[1].set_ylabel('Counts')
        ax2[1].set_title(f'{model_name} IMF for {cluster_name}, age = {time:.1e} Gyr: log scale')

        ax2[1].set_xscale('log')
        
        if save_plot:
            plt.savefig(f'Plots/{cluster_name}_{model_name}_IMF_log_and_normal.png', bbox_inches='tight')
        plt.show()
    
    
    return stellar_masses

## Age interpolation from in between isochrones

In [162]:
def new_age_interpolation(chi_values, isochrone_data, iso_ages, cluster_name, plot=True):
    """
    Parameters:
    -----------
    chi_values: int
        Values of sum(chi**2) for each isochrone
        
    isochrone_data: array
        The data from all isochrones
        
    iso_ages: array/list
        list/array of isochrone ages
    """
    min_chi = np.min(chi_values)
    pos_min_chi = np.where(chi_values==min_chi)[0][0]
    
    
    if pos_min_chi==0:
        new_cluster_age = iso_ages[pos_min_chi] # Find new age based on new minimum chi value
        
        younger_isochrone = isochrone_data[pos_min_chi]
        younger_age = iso_ages[pos_min_chi]
        older_isochrone = isochrone_data[pos_min_chi]  
        older_age = iso_ages[pos_min_chi]
        
        plot=False
        
        print(f'Could not interpolate age for cluster {cluster_name}.')
        
    else:
        # Limiting data to desired range
        pnt_below = int(pos_min_chi-1)
        pnt_above = int(pos_min_chi+2)
    
        #print(pnt_below)
        #print(pnt_above)
        # data for min chi, one above and one below
        closest_isos = isochrone_data[pnt_below : pnt_above] # isochrone data for points 
        closest_ages = iso_ages[pnt_below : pnt_above] # isochrone ages for points
        closest_chis = chi_values[pnt_below : pnt_above] # chi values for points
    
        #print(closest_ages)
        #print(closest_chis)
        # Fitting a quadratic function to interval
        k2, k1, c = np.polyfit(closest_ages, closest_chis, deg=2)
    
        # Makes 1000 points within age interval
        ages = np.linspace(np.min(closest_ages), np.max(closest_ages), 10000)
        chis = k2*ages**2 + k1*ages + c # Calculates corresponding chi values from fit
        new_min_chi = np.min(chis) # Find minimum chi value in interval
        pos_new_min_chi = np.where(chis==new_min_chi)[0] # Find position of minimum chi value
        new_cluster_age = ages[pos_new_min_chi][0] # Find new age based on new minimum chi value
    
        # Finding masses
        if new_cluster_age<iso_ages[pos_min_chi]:
            younger_isochrone = isochrone_data[int(pos_min_chi-1)]
            younger_age = iso_ages[int(pos_min_chi-1)]
            older_isochrone = isochrone_data[pos_min_chi]
            older_age = iso_ages[pos_min_chi]
            
        elif new_cluster_age>iso_ages[pos_min_chi]:
            younger_isochrone = isochrone_data[pos_min_chi]
            younger_age = iso_ages[pos_min_chi]
            older_isochrone = isochrone_data[int(pos_min_chi+1)]  
            older_age = iso_ages[int(pos_min_chi+1)]
        
    if plot:
        # Fixing data to plot
        ages_fit = np.linspace(np.min(closest_ages), np.max(closest_ages), 100)
        chi_vals_fit = k2*ages_fit**2 + k1*ages_fit + c
        
        fig, ax = plt.subplots(figsize=(7, 6))
        
        # Plotting chi for each isochrone
        ax.scatter(iso_ages, chi_values, c='b', s=10, label=r'$\chi^2$ per isoc.')
        # Plotting minimum chi from isochrones
        ax.scatter(iso_ages[pos_min_chi], min_chi, c='r', s=20, label=r'Age$_{min}$ isoc.')
        # Plotting chi-age fit
        ax.plot(ages_fit, chi_vals_fit, color='orange', label=r'Age fit')
        
        # Marking the newly determined age
        ax.axvline(new_cluster_age, linestyle='dashed', color='g', 
                   label=f'Age={new_cluster_age:.3} Gyr')
        
        ax.set_xlabel('Ages [Gyr]')
        ax.set_ylabel(r'$\chi^2$')
        ax.set_title(f'Age fit {cluster_name}')
        
        ax.set_xlim(xmin=np.min(closest_ages)-0.005, xmax=np.max(closest_ages)+0.005)
        ax.set_ylim(ymin=np.min(chi_vals_fit)-50, ymax=np.max(chi_vals_fit)+50)
        
        ax.legend(loc='lower right')
        ax.set_xticks(np.linspace(np.min(closest_ages)-0.005, np.max(closest_ages)+0.005, 6))
        #plt.savefig('Plots/Age_fit_plot.png', bbox_inches='tight')
        plt.show()
        
    return new_cluster_age, younger_isochrone, younger_age, older_isochrone, older_age

## Interpolated mass from in between isochrones

In [161]:
def interpolated_mass(cluster_data, cluster_name, new_cluster_age, younger_data, older_data, 
                      younger_age, older_age, model_name, bin_width, plot=True, save_plot=False):
    """
    Parameters:
    ------------
    cluster_data:
    
    cluster_names:
    
    younger_data: array
        data for younger isochrone
        
    older_data: array
        data for older isochrone
        
    younger_age: array
        age for younger isochrone
        
    older_age: array
        age for older isochrone
        
    model_name:
    """
    #print(np.min(younger_data[:, 3]))
    #print(np.min(younger_data[:, 3]))
    colour_mask = (np.min(younger_data[:, 3])<=cluster_data['bp_rp'].value)&(cluster_data['bp_rp'].value<=np.max(older_data[:, 3]))
    
    cluster_data = cluster_data[colour_mask]
    #print(cluster_data)
    
    # Interpolating masses from model isochrones above and below new age
    young_stellar_masses = IMF(cluster_data['bp_rp'].value, younger_data, cluster_name, 
                               model_name, 15, younger_age, check=False, save_check=False, 
                               plot=False, save_plot=False)[:, np.newaxis]
    
    old_stellar_masses = IMF(cluster_data['bp_rp'].value, older_data, cluster_name, 
                               model_name, 15, older_age, check=False, save_check=False, 
                               plot=False, save_plot=False)[:, np.newaxis]
    
    
    # Putting masses into same array, shape: (n_stars, age=2)
    model_masses = np.concatenate([young_stellar_masses, old_stellar_masses], axis=1)
    # Sort for the interpolation
    #sorted_model_masses = model_masses[model_masses[:, 0].argsort()] # Sort according to young masses
    
    # Creating arrays of ages from model isochrones above and below new age
    # Fills entire array with same age as is it constant for each isochrone
    young_age_array = np.empty(len(cluster_data))
    young_age_array[:] = younger_age 
    young_age_array = young_age_array[:, np.newaxis]
    
    old_age_array = np.empty(len(cluster_data))
    old_age_array[:] = older_age # Fills entire array with same age
    old_age_array = old_age_array[:, np.newaxis]
    
    # Putting ages into same array, shape: (n_stars, models=2)
    model_ages = np.concatenate((young_age_array, old_age_array), axis=1)
    
    
    cluster_stellar_masses = np.empty(len(cluster_data))
    
    # Want to loop over each star => loop over each row. Gives 2 model points for each star
    for i in range(len(cluster_data)):
        cluster_stellar_masses[i] = np.interp(new_cluster_age, model_ages[i, :], 
                                              model_masses[i, :])
        
        
        
    if plot:
        # Plotting histogram in normal scale
        min_mass = np.min(cluster_stellar_masses)
        max_mass = np.max(cluster_stellar_masses)
        
        nbins = int(np.round(2*(np.log10(max_mass) - np.log10(min_mass))/bin_width))
        
        
        fig2, ax2 = plt.subplots(1, 1, figsize=(8, 6))
        
        #ax2[0].hist(cluster_stellar_masses, bins=nbins,
        #             histtype='stepfilled', fc=(0, 0, 1, 0.25), ec='b', lw=1.5)

        #ax2[0].set_xlabel(r'Mass [M$_{\odot}$]')
        #ax2[0].set_ylabel('Counts')
        #ax2[0].set_title(f'{model_name} IMF for {cluster_name}, age = {new_cluster_age:.2e} Gyr: normal scale')
        #
        #if model_name=='MIST':
        #    ax2[0].set_xlim(xmin=0, xmax=5)
        #        
        #elif model_name=='Baraffe':
        #    ax2[0].set_xlim(xmin=0, xmax=1.5)
        
        
        # Plotting histogram in log scale
        ax2.hist(cluster_stellar_masses, range=(np.log10(min_mass), np.log10(max_mass)), 
                     bins=np.logspace(np.log10(min_mass), np.log10(max_mass), nbins+1),
                     histtype='stepfilled', fc=(0, 0, 1, 0.25), ec='b', lw=1.5)

        ax2.set_xlabel(r'Log Mass [M$_{\odot}$]')
        ax2.set_ylabel('Counts')
        ax2.set_title(f'{model_name} IMF for {cluster_name}, age = {new_cluster_age:.2e} Gyr: log scale')

        ax2.set_xscale('log')
        ax2.set_yscale('log')
        if model_name=='MIST':
            ax2.set_xlim(xmin=0.08, xmax=5)
                
        elif model_name=='Baraffe':
            ax2.set_xlim(xmin=0.01, xmax=1.5)
        
        if save_plot:
            plt.savefig(f'Plots/{cluster_name}_{model_name}_IMF_log_and_normal.png', bbox_inches='tight')
        plt.show()
        
    return cluster_stellar_masses

## All in one

**Check** if it is better to have varying numbers of bins for each cluster. Then nbins has to be a list of number of bins which is looped over for each cluster. The counts and edges arrays also have to be converted into lists instead of arrays because they will then vary in size. 

**Fix!!!** Interpolation when data is outside interpolation interval!!!

In [160]:
def final_IMFs(cluster_data, cluster_names, model, metallicity, bin_width, age_fit_plot=True, 
               chi_plot=True, save_chi_plot=False, plot_hists=True, save_plot_hists=True): 
    #check_interp=True, save_check_interp=False,  cluster_tmass_data, 
    """
    Finds the best fitting isochrone to the data
    ---------------------------------------------
    Parameters:
    -----------
    cluster_data: list
        List of cluster data for different clusters
        
    cluster_names: list
        List of the names of all clusters
        
    iso_file_names: list
        List of possible isochrones
        
    model: str
        Isochrone model
    
    
    Output:
    -------
    
    """
    # Extracting isochrone data for the Gaia magnitudes
    iso_data, iso_ages = separate_isochrones(model, metallicity, 'gaia')
    # FLattening the isochrone ages
    iso_ages = iso_ages.flatten()
    # Extracting the isochrone parameters
    isochrone_parameters_gaia = isochrone_params(iso_data, model, 'gaia') # A (6, n_isochsones) array
    
    # Extracting the isochrone data for the 2MASS magnitudes
    #iso_tmass_data, iso_tmass_ages = separate_isochrones(model, metallicity, '2mass')
    
    
    # Initiating outputlists
    log_counts = [] #np.empty((nbins, len(cluster_data)))
    log_edges = [] #np.empty((nbins+1, len(cluster_data)))
    all_masses = []
    all_ages = []
    same_as_first_iso = []
    
    all_cluster_names = cluster_names.copy() #np.array([float(name) for name in cluster_names])
    
    for i, cl_data in enumerate(cluster_data): # Loops over all clusters in list
        print(f'i new = {i}')
        
        # Removing stars with missing colour values in both datasets
        data_mask = cl_data['bp_rp'].mask==False # 
        cl_data = cl_data[data_mask]
        #cl_tmass_data = cluster_tmass_data[i]
        
        
        # Finding common stars for Gaia and 2MASS data 
        #gaia_2mass_inters, gaia_indices, tmass_indices = np.intersect1d(cl_data['GaiaID'].value, 
        #                                                                cl_tmass_data['GaiaID'].value,
        #                                                                return_indices=True)
        #
        
       # if len(gaia_indices)!=len(cl_tmass_data):
       #     cl_member_indices = np.linspace(0, len(cl_tmass_data)-1, len(cl_tmass_data), dtype=int)
       #     cl_member_indices = np.delete(cl_member_indices, tmass_indices)
       #     cl_tmass_data.remove_rows(cl_member_indices)
            
        
        # Finding best fitting isochrone 
        chi_values = chi_fitting(fit_fcn, cl_data, isochrone_parameters_gaia, iso_data, 'gaia')
        min_chi = np.min(chi_values)
        pos_min_chi = np.where(chi_values==min_chi)[0][0]
        
        min_chis = np.array([0., pos_min_chi])
        
        first_iso = pos_min_chi
        
        ##################################################################################
        #min_col = np.min(cl_data[])
        #nbins = int(np.round(2*(np.log10(max_mass) - np.log10(min_mass))/bin_width))
        #cl_name = cluster_names[i]
        #best_isochrone = iso_tmass_data[pos_min_chi]
        #best_iso_col = best_isochrone[:, 2]-best_isochrone[:, 4]
        #best_iso_mag = best_isochrone[:, 2]
        #fig, ax = plt.subplots(1, 2, figsize=(10, 4))
        
        #ax.hist(np.array(cl_data['bp_rp'].value), bins=15,
        #        align='left', histtype='stepfilled', fc=(0, 0, 1, 0.25), ec='b', lw=1.5)
        #ax[0].plot(best_iso_col, best_iso_mag, 'r')
        #ax[0].scatter(cl_tmass_data['J-K'], cl_tmass_data['J'], c='b', s=10, alpha=0.5)
        #ax[0].set_xlim(-5, 6)
        #ax[0].set_ylim(-5, 15)
        #ax[0].set_title(f'CMD Beforecluster  {i}')
        #ax[0].invert_yaxis()
        #ax[0].grid(True)
        
        #ax.grid(True)
        #plt.show()
        ######################################################################################
        
        #while min_chis[0]!=min_chis[1]:
        #    
        #    #print(min_chis)
        #    min_chis[0] = pos_min_chi
        #    
        #    best_isochrone = iso_tmass_data[pos_min_chi]
        #    
        #    iso_jh = best_isochrone[:, 2] - best_isochrone[:, 3] # Isochrone CCD values
        #    iso_hk = best_isochrone[:, 3] - best_isochrone[:, 4] # Isochrone CCD values
        #    
        #    iso_interval_fit_mask = (0.18<=iso_hk)&(iso_hk<=0.28) 
        #    
        #    iso_hk_interval = iso_hk[iso_interval_fit_mask]
        #    iso_jh_interval = iso_jh[iso_interval_fit_mask]
        #    
        #    k_iso_fit, m_iso_fit = np.polyfit(iso_hk_interval, iso_jh_interval, deg=1)
        #    
        #    
        #    # Correcting 2MASS data
        #    cl_tmass_data = extinction(k_iso_fit, m_iso_fit, cl_tmass_data)
        #    #print(cl_tmass_data)
        #    mean_A_G = np.mean(cl_tmass_data['A_G'].value)*u.mag
        #    mean_A_BP = np.mean(cl_tmass_data['A_BP'].value)*u.mag
        #    mean_A_RP = np.mean(cl_tmass_data['A_RP'].value)*u.mag
        #    #print(mean_extinction)
        #    #print(np.min(cl_tmass_data['Extinction'].value), np.max(cl_tmass_data['Extinction'].value))
        #    
        #    # Adding extinction column to Gaia data
        #    cl_data['A_G'] = np.zeros((len(cl_data)))
        #    cl_data['A_BP'] = np.zeros((len(cl_data)))
        #    cl_data['A_RP'] = np.zeros((len(cl_data)))
        #    
        #    # Adds extracted values from 2MASS data
        #    cl_data['A_G'][gaia_indices] = cl_tmass_data['A_G']
        #    cl_data['A_BP'][gaia_indices] = cl_tmass_data['A_BP']
        #    cl_data['A_RP'][gaia_indices] = cl_tmass_data['A_RP']
        #    
        #    # Finding the non-xmatched star positions
        #    star_positions = np.linspace(0, len(cl_data)-1, len(cl_data), dtype=int)
        #    non_xmatched_pos = np.delete(star_positions, gaia_indices)
        #    
        #    # Assigning mean extintion to the non-xmatched stars in the Gaia data
        #    cl_data['A_G'][non_xmatched_pos] = mean_A_G
        #    cl_data['A_BP'][non_xmatched_pos] = mean_A_BP
        #    cl_data['A_RP'][non_xmatched_pos] = mean_A_RP
        #    
        #    ########### Change according to table in paper!!!!!!!!####################
        #    #extinction_G = 0.789*cl_data['Extinction']*u.mag #(+-0.005) #extinction_per_band('G', cl_data['bp_rp'].value, cl_data['Extinction'].value)
        #    #extinction_BP = 1.002*cl_data['Extinction']*u.mag #(+-0.007) #extinction_per_band('BP', cl_data['bp_rp'].value, cl_data['Extinction'].value)
        #    #extinction_RP = 0.589*cl_data['Extinction']*u.mag #(+-0.004) #extinction_per_band('RP', cl_data['bp_rp'].value, cl_data['Extinction'].value)
        #    
        #    extinction_bp_rp = cl_data['A_BP'].value - cl_data['A_RP'].value
        #    
        #    # Correcting the absolute magnitude in Gaia data
        #    cl_data['M_V'] = (cl_data['M_V'].value - cl_data['A_G'].value)*u.mag #(cl_data['M_apparent'].value - 5*np.log10(cl_data['dist'].value) + 5 - cl_data['Extinction'].value)*u.mag
        #    cl_data['bp_rp'] = (cl_data['bp_rp'].value - extinction_bp_rp)*u.mag
        #    #print(cl_data['M_V'])
        #
        #    
        #    chi_values = chi_fitting(fit_fcn, cl_data, isochrone_parameters_gaia, iso_data, 'gaia')
        #    #print(chi_values)
        #    min_chi = np.min(chi_values)
        #    #print(min_chi)
        #    pos_min_chi = np.where(chi_values==min_chi)[0][0]
        #    
        #    min_chis[1] =pos_min_chi
        #    #print(min_chis)
        #    #print()
        
        #best_isochrone = iso_tmass_data[pos_min_chi]
        
        #ax[1].plot(best_isochrone[:, 2] - best_isochrone[:, 4], best_isochrone[:, 2], 'r')
        #ax[1].scatter(cl_tmass_data['J-K'], cl_tmass_data['J'], c='b', s=10, alpha=0.5)
        #ax[1].set_xlim(-5, 6)
        #ax[1].set_ylim(-5, 15)
        #ax[1].set_title('CMD After')
        #ax[1].invert_yaxis()
        #ax[1].grid(True)
        
        #plt.tight_layout()
        #plt.show()
        #fig2, ax2 = plt.subplots(figsize=(5, 4))
        
        #ax2.hist(np.array(cl_data['bp_rp'].value), bins=80,
        #        align='left', histtype='stepfilled', fc=(0, 0, 1, 0.25), ec='b', lw=1.5)
            #ax.scatter(cl_data['bp_rp'], cl_data['M_V'], s=10, alpha=0.5)
        #ax2.set_xlim(xmin=0)
        #ax.set_ylim(-5, 15)
        #ax.set_title('Initially')
        #ax.set_xscale('log')
        #ax.set_yscale('log')
        #ax2.grid(True)
        #ax.invert_yaxis()
        #plt.show()
        
        
        if first_iso == pos_min_chi:
            same_as_first_iso.append(cl_data['Cluster_number'][0])
        
        cl_age, iso_young, age_young, iso_old, age_old = new_age_interpolation(chi_values, iso_data, 
                                                                            iso_ages, cluster_names[i],
                                                                            plot=age_fit_plot)
        
        all_ages.append(cl_age)
        #min_chi = np.min(chi_values)
        #min_iso_pos = int(np.where(chi_values==min_chi)[0])
        
        if chi_plot:
            fig, ax = plt.subplots(figsize=(8, 6))
            
            ax.scatter(iso_ages, chi_values, c='b', s=10, label=r'$\chi^2$ values')
            #ax.scatter(iso_ages[min_iso_pos], min_chi, c='r', s=15, label=r'Minimum $\chi^2$')
            
            ax.axvline(cl_age, linestyle='dashed', c='g')
            
            ax.set_xlabel('Age [Gyr]')
            ax.set_ylabel(r'$\chi^2$ values')
            ax.set_title(r'$\chi^2$ values for isochrone ages')
            ax.set_xscale('log')
            ax.legend()
            
            if save_chi_plot:
                plt.savefig(f'Plots/Chi_plot_{cluster_names[i]}.png', bbox_inches='tight')
            
            plt.show()
        
        
        # 
        #best_iso_params = isochrone_parameters[:, min_iso_pos]
        #best_iso_data = iso_data[min_iso_pos] # Contains (time, M_V, bp_rp, Mass)
        
        #colour_cluster_mask = (np.min(best_iso_data[:, 3])<=cl_data['bp_rp'].value)&(cl_data['bp_rp'].value<=np.max(best_iso_data[:, 3]))
        
        #cl_data = cl_data[colour_cluster_mask]
        
        #cluster_masses = IMF(cl_data['bp_rp'].value, best_iso_data, cluster_names[i], model, 
        #                     nbins=nbins, time=iso_ages[min_iso_pos], check=check_interp, 
        #                     save_check=save_check_interp, plot=plot_hists, 
        #                     save_plot=save_plot_hists)
        
        cluster_masses = interpolated_mass(cl_data, cluster_names[i], cl_age, iso_young, iso_old,
                                           age_young, age_old, model, bin_width, plot_hists, 
                                           save_plot_hists)
        
        
        mass_mask = cluster_masses>0.3
        
        cluster_masses = cluster_masses[mass_mask]
        cl_name = cluster_names[i]
        #print(f'{len(cluster_names) = }')
        #print(f'{i = }')
        
        if len(cluster_masses)<10:
            del [all_ages[-1]]
            name_pos = np.where(all_cluster_names==cl_name)[0]
            all_cluster_names = np.delete(all_cluster_names, name_pos)
            continue
            
        #print(f'{len(cluster_names) = }')
        #print(f'{i = }')
        #print()
        
        all_masses.append(cluster_masses)
        
        min_mass = np.min(cluster_masses)
        max_mass = np.max(cluster_masses)
        
        nbins = int(np.round(2*(np.log10(max_mass)-np.log10(min_mass))/(bin_width)))
        
        #normal_counts[:, i], normal_edges[:, i] = np.histogram(cluster_masses, nbins, 
        #                                                       range=(np.min(cluster_masses), 
        #                                                       np.max(cluster_masses)))
        
        l_counts, l_edges = np.histogram(cluster_masses, 
                                         bins=np.logspace(np.log10(np.min(cluster_masses)), 
                                                          np.log10(np.max(cluster_masses)), 
                                                          nbins+1),
                                         range=(np.log10(np.min(cluster_masses)), 
                                                np.log10(np.max(cluster_masses))))
        log_counts.append(l_counts)
        log_edges.append(l_edges)
    return log_counts, log_edges, all_masses, all_ages, same_as_first_iso, all_cluster_names
    

## Separating cluster data into separate arrays for each cluster

In [12]:
def cluster_list(cluster_table, N_limit, survey, check_effect, dist_cut=600):
    
    clusters = []
    names = []
    
    cluster_table = cluster_table[cluster_table['Cluster_number'].argsort()]
    cluster_table['dist'] = np.zeros((len(cluster_table))) * u.pc
    cluster_table['dist_error'] = np.zeros((len(cluster_table))) * u.pc
    
    for i in range(1, cluster_table['Cluster_number'][-1]+1):
        # Creates a mask that only leaves one cluster
        cluster_mask = cluster_table['Cluster_number']==i
        
        # Creates a table with only one cluster
        cluster = cluster_table[cluster_mask]
        
        # Checking that the cluster is not too small
        if len(cluster)<N_limit:
            continue
        
        else:        
            if survey=='gaia':
        
                cluster['bp_rp'] = cluster['G_bp'] - cluster['G_rp']
            
                # Sorting cluster according to brightest stars
                cluster = cluster[cluster['M_apparent'].argsort()]
                # 10 brightest stars in the cluster have the lowest magnitudes
                brightest_stars = cluster[:10]
                # Distances of brightest stars
                distance = 1/(brightest_stars['Parallax'].value*1e-3) * u.pc
                # Mean distance of the cluster based on the 10 brightest stars
                mean_distance = np.mean(distance)
                #print(mean_distance)
                #print(type(mean_distance))
                
                if mean_distance>dist_cut*u.pc:
                    continue
                
                cluster['dist'] = mean_distance #Distance(parallax=cluster['Parallax'])
                
                expr = -((brightest_stars['Parallax_error'].value * 1e-3)/(brightest_stars['Parallax'].value * 1e-3)**2)
                #print(expr)
                distance_error = np.sqrt((1/10)*np.sum(expr**2)) * u.pc
                #print(distance_error)
                cluster['dist_error'] = distance_error
                
                if check_effect=='+':
                    cluster['dist'] = cluster['dist'] + cluster['dist_error']
                    
                elif check_effect=='-':
                    cluster['dist'] = cluster['dist'] - cluster['dist_error']
                
                # Calculates and adds an absolute magnitude column calculated from the photometric mean magnitude 
                # and the distance modulus, described above
                cluster['M_V'] = (cluster['M_apparent'].value - 5*np.log10(cluster['dist'].value) + 5)*u.mag
            
            
            elif survey=='2mass':
                cluster['J-H'] = (cluster['j_m'].value - cluster['h_m'].value)*u.mag
                cluster['H-K'] = (cluster['h_m'].value - cluster['k_m'].value)*u.mag
            
                cluster['J'] = cluster['j_m'].value * u.mag
                cluster['H'] = cluster['h_m'].value * u.mag
                cluster['K'] = cluster['k_m'].value * u.mag
                cluster['J-K'] = (cluster['j_m'].value - cluster['k_m'].value)*u.mag
        
                # Sorting cluster according to brightest stars
                cluster = cluster[cluster['M_apparent'].argsort()]
                # 10 brightest stars in the cluster have the lowest magnitudes
                brightest_stars = cluster[:10]
                # Distances of brightest stars
                distance = 1/(brightest_stars['Parallax'].value*1e-3) * u.pc
                # Mean distance of the cluster based on the 10 brightest stars
                mean_distance = np.mean(distance)
                #print(mean_distance)
                #print(type(mean_distance))
            
                cluster['dist'] = mean_distance #Distance(parallax=cluster['Parallax'])
            
                if mean_distance>dist_cut*u.pc:
                    continue
            
            
            clusters.append(cluster)
            names.append(f"{cluster['Cluster_number'][0]}")
            
    return clusters, names

## Fitting IMF slopes function(s)

In [108]:
def IMF_fit_fcn(x, a, C):
    return C*x**a

def IMF_slopes(log_edges, log_counts, all_cluster_masses, model, intervals='Kroupa', 
               plot=True):
    # Want to plot, give a list of all slopes, chose if I want to use all intervals
    # Comparison with Kroupa
    # Plot fits with histograms
    """
    Parameters:
    -----------
    log_edges: list
        List with edges of the logarithmic bins
        
    log_counts: list
        List with logarithmic bin counts
        
    all_cluster_masses: list
        List of all clusters' stellar mass arrays
        
    Output:
    -------
    bin_widths:
    
    slopes:
    """
    
    bin_widths = np.zeros((len(all_cluster_masses)))
    
    # shape(n_clusters, 2 parameters, 4 intervals)
    all_cluster_params = np.full((len(all_cluster_masses), 2, 4), fill_value=np.nan) 
    
    kroupa_diff = np.full((len(all_cluster_masses), 4), fill_value=np.nan)
    
    it=0
    # Loops over each cluster
    for i, cluster_masses in enumerate(all_cluster_masses):
        print(i)
        #print()
        min_masses = np.min(cluster_masses)
        max_masses = np.max(cluster_masses)
        edges = log_edges[i]#.flatten()
        
        # Rolling edges array such that for-loop is not needed
        rolled_edges = np.roll(edges, shift=1)
        
        n_bins = len(edges)-1
        
        # Getting bin widths and bin positions
        bin_mid = np.exp((np.log(edges) + np.log(rolled_edges))/2)
        bin_mid = bin_mid[1:]
        
        bin_width = np.abs(np.log(edges) - np.log(rolled_edges))
        bin_width = bin_width[1:]
        bin_width_diff = np.round(np.abs(bin_width - np.roll(bin_width, shift=1)), decimals=0)
        if all(bin_width_diff==0):
            bin_widths[i] = bin_width[0]
        
        else:
            print('Not constant bin width!')
        
        
        cluster_counts = log_counts[i]
        
        if intervals=='Kroupa':
            # Creating interval mass masks
            interval_1_mask = (0.01<=bin_mid)&(bin_mid<=0.08)
            interval_2_mask = (0.3<=bin_mid)&(bin_mid<=0.5) #0.08
            interval_3_mask = (0.5<=bin_mid)&(bin_mid<=1.0)
            interval_4_mask = (1.0<=bin_mid)
            
            all_intervals = np.array([1, 2, 3, 4])
            
            interval_masks = [interval_1_mask, interval_2_mask, interval_3_mask, interval_4_mask]
            interval_edges = [(0.01, 0.08), (0.3, 0.5), (0.5, 1.0), (1.0, np.max(cluster_masses))]
            
            all_interval_masses = [bin_mid[mask] for mask in interval_masks]
            all_interval_counts = [cluster_counts[mask] for mask in interval_masks]
            
            model_slopes = np.array([0.3, 1.3, -2.3, -2.3])
            model_slope_errors = np.array([0.7, 0.5, 0.3, 0.7])
            
            lengths = np.array([len(int_mask[int_mask==True]) for int_mask in interval_masks])# Works! Use this in the if-statement!
            
            
            if any(length<=1 for length in lengths):
                short_len_positions = [k for k,length in enumerate(lengths) if length<=1]
                short_len_mask = lengths>1
                
                remove_pos = short_len_positions[0]
                    
                interval_masks = [mask for k, mask in enumerate(interval_masks) if len(mask[mask==True])>1]
                interval_edges.remove(interval_edges[remove_pos])
                
                all_intervals = all_intervals[short_len_mask]
                
                all_interval_masses = [masses for k, masses in enumerate(all_interval_masses) if len(masses)>1]
                all_interval_counts = [counts for k, counts in enumerate(all_interval_counts) if len(counts)>1]
                
                model_slopes = np.delete(model_slopes, short_len_positions)
                model_slope_errors = np.delete(model_slope_errors, short_len_positions)
            
            counter = 0
            # Ignoring intervals with only one non-zero bin
            while any([len(int_counts[int_counts>0])<=1 for int_counts in all_interval_counts]):
                counter = counter +1
                # Finds indices for intervals for which there is only one non-zero bin
                failed_int_pos = np.array([index for index, int_counts in enumerate(all_interval_counts) if len(int_counts[int_counts>0])<=1])
                
                del interval_masks[failed_int_pos[0]]                
                del interval_edges[failed_int_pos[0]]
                
                all_intervals = np.delete(all_intervals, failed_int_pos[0])
                
                del all_interval_masses[failed_int_pos[0]]
                del all_interval_counts[failed_int_pos[0]]
                
                model_slopes = np.delete(model_slopes, failed_int_pos[0])
                model_slope_errors = np.delete(model_slope_errors, failed_int_pos[0])
                
            
            # shape(fit_length, n intervals)
            x_fit_values = np.empty((500, len(interval_masks))) 
            y_fit_values = np.empty((500, len(interval_masks)))
            
            # shape(2 params, n intervals)
            #cluster_params = np.zeros((2, 4))  #len(interval_masks)
               
            it = it+1
        
        #print(it)
        #print(interval_edges)
        # Loop over mass intervals to fit parameters
        for j, interval_masses in enumerate(all_interval_masses):
            #print(j)
            interval_counts = all_interval_counts[j].flatten()
            #print(interval_edges[j])
            #print(interval_counts)
            #print(interval_masses)
            
            fitted_params, covariance = scopt.curve_fit(IMF_fit_fcn, xdata=interval_masses, 
                                                        ydata=interval_counts, p0=[model_slopes[j], 30],
                                                        maxfev=5000) # , full_output=True
            
            #print(covariance)
            # Adding to final list/array
            all_cluster_params[i, :, all_intervals[j]-1] = fitted_params
            kroupa_diff[i, all_intervals[j]-1] = (fitted_params[0] - model_slopes[j]) # all_cluster_params[i, :, intervals[j]]
            
            # Dividing up
            slope_fit, C_fit = fitted_params
            
            x_fit_values[:, j] = np.linspace(interval_edges[j][0], interval_edges[j][1], 500)
            y_fit_values[:, j] = IMF_fit_fcn(x_fit_values[:, j], slope_fit, C_fit)
        
        
        if plot:
            colours = ['r', 'orange', 'g', 'm', 'skyblue', 'limegreen']
            #min_masses = np.min(cluster_masses)
            #max_masses = np.max(cluster_masses)
            
            
            fig, ax = plt.subplots(figsize=(7, 6))
            
            ax.hist(cluster_masses, range=(np.log10(min_masses), np.log10(max_masses)), 
                    bins=np.logspace(np.log10(min_masses), np.log10(max_masses), n_bins),
                    histtype='stepfilled', fc=(0, 0, 1, 0.25), ec='b', lw=1.5, 
                    label='IMF histogram')
            
            
            
            # Loop over mass intervals to fit parameters
            for j, interval_masses in enumerate(all_interval_masses):
                ax.plot(x_fit_values[:, j], y_fit_values[:, j], color=colours[j], 
                        label=f'{interval_edges[j][0]}'+r'$\leq$m/'+r'M$_{\odot}$<'+f'{interval_edges[j][1]:.3}')
                
            ax.set_xlabel(r'Mass [M$_{\odot}$]')
            ax.set_ylabel(r'Counts')
            ax.set_title('Histogram with fitted IMF slopes')
            
            ax.legend(loc='lower right')
            ax.set_xscale('log')
            ax.set_yscale('log')
            
            if model=='MIST':
                ax.set_xlim(xmin=0.08, xmax=5)
                
            elif model=='Baraffe':
                ax.set_xlim(xmin=1e-2, xmax=2)
                ax.set_ylim(ymin=8e-1, ymax=2e2)
            
            plt.show()
            
    
    
    return bin_widths, all_cluster_params, kroupa_diff, counter

## Assigning cluster numbers to 2MASS data

In [14]:
def sorting_2mass_data(tmass_data):
    """
    Assigns cluster numbers to the 2MASS data
    """
    tmass_data_new = tmass_data
    
    star_ids_data = votable.parse_single_table('Filter_containing_tmassID.vot').to_table()
    cluster_data = votable.parse_single_table('Vizier_data_filtered.vot').to_table()
    
    cluster_data.rename_column('cluster_number', 'Cluster_number')
    cluster_data.rename_column('gaiaid', 'GaiaID')
    cluster_data.rename_column('g_bp', 'G_bp')
    cluster_data.rename_column('g_rp', 'G_rp')
    cluster_data.rename_column('parallax', 'Parallax')
    cluster_data.rename_column('parallax_error', 'Parallax_error')
    cluster_data.rename_column('m_apparent', 'M_apparent')
   
    
    # Sorting data according to Gaia ID
    cluster_order = cluster_data['GaiaID'].argsort()
    cluster_data = cluster_data[cluster_order]
    
    id_order = star_ids_data['source_id'].argsort()
    star_ids_data = star_ids_data[id_order]
    tmass_data_new = tmass_data_new[id_order]
    
    km_mask = tmass_data_new['k_m']!='NULL'
    
    tmass_data_new = tmass_data_new[km_mask]
    tmass_data_new['k_m'] = tmass_data_new['k_m'].astype(float)
    
    cluster_data = cluster_data[km_mask]
    star_ids_data = star_ids_data[km_mask]
    
    
    
    
    # Assigning cluster numbers and the Gaia IDs
    tmass_data_new['Cluster_number'] = cluster_data['Cluster_number']
    tmass_data_new['GaiaID'] = cluster_data['GaiaID']
    tmass_data_new['Parallax'] = cluster_data['Parallax']
    tmass_data_new['Parallax_error'] = cluster_data['Parallax_error']
    tmass_data_new['M_apparent'] = cluster_data['M_apparent']
    
    unique_tmass_ids, n_id_occurances = np.unique(tmass_data_new['tmass_oid'], return_counts=True)
    
    duplicates = unique_tmass_ids[n_id_occurances>1]
    
    for duplicate in duplicates:
        duplicate_pos = np.where(tmass_data_new['tmass_oid']==duplicate)
        
        tmass_data_new.remove_rows(duplicate_pos)
        np.delete(star_ids_data, duplicate_pos)
        np.delete(cluster_data, duplicate_pos)    
    
      
    return tmass_data_new, cluster_data
        

## Determining extinction function

In [15]:
def extinction(k_iso_fit, m_iso_fit, cluster_data):
    """
    Corrects the extincted stars in the cluster data
    -------------------------------------------------
    Parameters:
    -----------
    k_iso_fit: float
        Slope of the isochrone line fit
        
    m_iso_fit: float
        y-axis crossing of the isochrone line fit
        
    Cluster_data: table
        2MASS data for the cluster
    
    Output:
    -------
    The corrected cluster data
    """
    # Cardelli (1989) table 3 values
    A_J = 0.282
    A_H = 0.190
    A_K = 0.114
            
    A_jh = A_J - A_H
    A_hk = A_H - A_K
    
    # Length of Av=1 vector
    len_ext_vector = np.sqrt(A_jh**2 + A_hk**2)
    #print()
    #Getting slope for extinction vector            
    k_ext_vector, m_ext_vector = np.polyfit(np.array([0, A_hk]), np.array([0, A_jh]), deg=1)
    
    # Adding extinction column to data table
    cluster_data['A_G'] = np.zeros((len(cluster_data)))
    cluster_data['A_BP'] = np.zeros((len(cluster_data)))
    cluster_data['A_RP'] = np.zeros((len(cluster_data)))
    cluster_data['A_J'] = np.zeros((len(cluster_data)))
    cluster_data['A_H'] = np.zeros((len(cluster_data)))
    cluster_data['A_K'] = np.zeros((len(cluster_data)))
    #cluster_data['A_JK'] = np.zeros((len(cluster_data)))
    
    
    # Create x value range
    min_x_val = np.min(cluster_data['H-K'].value)
    max_x_val = np.max(cluster_data['H-K'].value)
    #print(any(np.isnan(cluster_data['H-K'].value)))
    #print(max_x_val)
    x_range = np.linspace(-5*max_x_val, max_x_val, int(1e6))
    
    # y_isochrone_line - kx_isochrone_line
    m_upper_line = (k_iso_fit*0.16 + m_iso_fit) - k_ext_vector*0.16
    m_lower_line = (k_iso_fit*0.33 + m_iso_fit) - k_ext_vector*0.33
    
    ######################################################################################################
    #fig, ax = plt.subplots(1, 2, figsize=(10, 8))
    
    #ax[0].scatter(cluster_data['H-K'], cluster_data['J-H'], c='b', s=10, alpha=0.5)
    #ax[0].plot(x_range, k_iso_fit*x_range + m_iso_fit, linestyle='dashed', color='g')
    #ax[0].plot(x_range, k_ext_vector*x_range + m_upper_line, linestyle='dashed', color='m')
    #ax[0].plot(x_range, k_ext_vector*x_range + m_lower_line, linestyle='dashed', color='m')
    
    #ax[0].grid()
    #ax[0].set_xlim(-2, 3)
    #ax[0].set_ylim(-2, 5)
    ######################################################################################################
    
    
    # Find stars above isochrone line
    iso_line_mask = cluster_data['J-H'].value>(k_iso_fit*cluster_data['H-K'].value + m_iso_fit)
    
    upper_line_mask = cluster_data['J-H'].value < (k_ext_vector*cluster_data['H-K'].value + m_upper_line)
    lower_line_mask = cluster_data['J-H'].value > (k_ext_vector*cluster_data['H-K'].value + m_lower_line)
    extincted_stars_mask = iso_line_mask*upper_line_mask*lower_line_mask #
    
    extincted_stars = cluster_data[extincted_stars_mask]
    #print(len(extincted_stars))
    
    if len(extincted_stars)==0:
        return cluster_data
    
    else:
        
        upper_line_mask_2 = cluster_data['J-H'].value > (k_ext_vector*cluster_data['H-K'].value + m_upper_line)
        
        left_stars_mask = upper_line_mask_2*iso_line_mask
        
        left_stars = cluster_data[left_stars_mask]
        #print(len(left_stars))
        
        left_stars['A_G'] = np.empty((len(left_stars)))
        left_stars['A_BP'] = np.empty((len(left_stars)))
        left_stars['A_RP'] = np.empty((len(left_stars)))
        left_stars['A_J'] = np.empty((len(left_stars)))
        left_stars['A_H'] = np.empty((len(left_stars)))
        left_stars['A_K'] = np.empty((len(left_stars)))
        left_stars['A_JK'] = np.empty((len(left_stars)))
        
        lower_line_mask_2 = cluster_data['J-H'].value < (k_ext_vector*cluster_data['H-K'].value + m_lower_line)
    
        right_stars_mask = lower_line_mask_2*iso_line_mask
    
        right_stars = cluster_data[right_stars_mask]
        #print(len(right_stars))
    
        right_stars['A_G'] = np.empty((len(right_stars)))
        right_stars['A_BP'] = np.empty((len(right_stars)))
        right_stars['A_RP'] = np.empty((len(right_stars)))
        right_stars['A_J'] = np.empty((len(right_stars)))
        right_stars['A_H'] = np.empty((len(right_stars)))
        right_stars['A_K'] = np.empty((len(right_stars)))
        right_stars['A_JK'] = np.empty((len(right_stars)))
        
        # Create new line for each point by finding its m_value
        # m = y - kx
        m_values = extincted_stars['J-H'].value - k_ext_vector*extincted_stars['H-K'].value
        
        extinctions = np.empty((len(extincted_stars))) # 
        
        # Looping over every extincted star in cluster
        for i, m_value in enumerate(m_values):
            #print(i)
            # Find where it crosses isochrone line for each star
            # |extinction_vector_star_slope - isochrone_line_fit|
            
            #x_var = sp.symbols('x')
            #to_solve = (k_ext_vector*x_var+m_value) - (k_iso_fit*x_var+m_iso_fit)
            #x_coord = float(sp.solve(to_solve)[0])
            
            diff = np.abs((k_ext_vector*x_range+m_value) - (k_iso_fit*x_range+m_iso_fit))
            
            # Crossing is where the difference is closest to zero
            crossing = np.where(diff==np.min(diff))[0][0]
            
            # Crossing coordinate
            x_coord = x_range[crossing] # H-K colour
            y_coord = k_ext_vector*x_coord+m_value # J-H colour
            
            cluster_data['J-H'].value>(k_iso_fit*cluster_data['H-K'].value + m_iso_fit)
            
            while y_coord>=k_iso_fit*x_coord + m_iso_fit:
                crossing = crossing - 1
                x_coord = x_range[crossing] # H-K colour
                y_coord = k_ext_vector*x_coord+m_value # J-H colour
            
            extincted_stars[i]['H-K'] = x_coord*u.mag
            extincted_stars[i]['J-H'] = y_coord*u.mag
        
            # Find distance between/length of isochrone line and data point
            dx = cluster_data[i]['H-K'].value - x_coord
            dy = cluster_data[i]['J-H'].value - y_coord
        
            length = np.sqrt(dx**2 + dy**2)
            #print(length)
            # Divide by A_v=1 length to get extinction
            extinctions[i] = length/len_ext_vector
        
    
        extincted_stars['A_G'] = 0.789*extinctions * u.mag
        extincted_stars['A_BP'] = 1.002*extinctions * u.mag
        extincted_stars['A_RP'] = 0.589*extinctions * u.mag
    
        #extincted_stars['A_J'] = 0.243*extinctions * u.mag
        #extincted_stars['A_H'] = 0.131*extinctions * u.mag
        #extincted_stars['A_J'] = 0.078*extinctions * u.mag
        #extincted_stars['A_JK'] = (0.243*extinctions - 0.078*extinctions)* u.mag
    
        # Mean
        #print(extinctions)
        mean_A_G = np.mean(0.789*extinctions) * u.mag
        mean_A_BP = np.mean(1.002*extinctions) * u.mag
        mean_A_RP = np.mean(0.589*extinctions) * u.mag
    
        mean_A_J = np.mean(0.243*extinctions) * u.mag
        mean_A_H = np.mean(0.131*extinctions) * u.mag
        mean_A_K = np.mean(0.078*extinctions) * u.mag
        mean_A_JK = mean_A_J-mean_A_K
    
        left_stars['A_G'] = mean_A_G
        left_stars['A_BP'] = mean_A_BP
        left_stars['A_RP'] = mean_A_RP
    
        left_stars['A_J'] = mean_A_J
        left_stars['A_H'] = mean_A_H
        left_stars['A_K'] = mean_A_K
        #left_stars['A_JK'] = mean_A_JK
    
        left_stars['J'] = (left_stars['J'].value - left_stars['A_J'].value)*u.mag
        left_stars['H'] = (left_stars['H'].value - left_stars['A_H'].value)*u.mag
        left_stars['K'] = (left_stars['K'].value - left_stars['A_K'].value)*u.mag
        #left_stars['J-K'] = (left_stars['J'].value - left_stars['K'].value)*u.mag
        left_stars['J-H'] = (left_stars['J'].value - left_stars['H'].value)*u.mag
        left_stars['H-K'] = (left_stars['H'].value - left_stars['K'].value)*u.mag

    
    
        right_stars['A_G'] = mean_A_G
        right_stars['A_BP'] = mean_A_BP
        right_stars['A_RP'] = mean_A_RP
        
        right_stars['A_J'] = mean_A_J
        right_stars['A_H'] = mean_A_H
        right_stars['A_K'] = mean_A_K
        #right_stars['A_JK'] = mean_A_JK
    
        right_stars['J'] = (right_stars['J'].value - right_stars['A_J'].value)*u.mag
        right_stars['H'] = (right_stars['H'].value - right_stars['A_H'].value)*u.mag
        right_stars['K'] = (right_stars['K'].value - right_stars['A_K'].value)*u.mag
        #right_stars['J-K'] = (right_stars['J'].value - right_stars['K'].value)*u.mag
        right_stars['J-H'] = (right_stars['J'].value - right_stars['H'].value)*u.mag
        right_stars['H-K'] = (right_stars['H'].value - right_stars['K'].value)*u.mag
        
        
        # Replacing all extincted stars' values with the corrected values
        cluster_data[extincted_stars_mask] = extincted_stars
    
        cluster_data[left_stars_mask] = left_stars
        cluster_data[right_stars_mask] = right_stars
    
        ######################################################################################################
        
        #ax[1].scatter(cluster_data['H-K'], cluster_data['J-H'], c='b', s=10, alpha=0.5)
        #ax[1].plot(x_range, k_iso_fit*x_range + m_iso_fit, linestyle='dashed', color='g')
        #ax[1].plot(x_range, k_ext_vector*x_range + m_upper_line, linestyle='dashed', color='m')
        #ax[1].plot(x_range, k_ext_vector*x_range + m_lower_line, linestyle='dashed', color='m')
        
        #ax[1].grid()
        #ax[1].set_xlim(-2, 3)
        #ax[1].set_ylim(-2, 5)
        
        #plt.tight_layout()
        #plt.show
        ######################################################################################################
    
        return cluster_data
    

## Calculating extinction values for the different passbands

In [16]:
def Av_calc(Rv):
    """
    Output:
    -------
    Array with extinction parameters for the colour filters
    """
    # J, H, K
    a = np.array([0.4008, 0.2693, 0.1615])
    b = np.array([-0.3679, -0.2473, -0.1483])
    A_Av = a + b/Rv
    return np.round(A_Av, decimals=3)

## Extinction coefficient function

In [17]:
def extinction_per_band(band, X, Av):
    coefficients = QTable.read('Data/Extinction_law_coeff/Fitz19_EDR3_MainSequence.csv', 
                               format='csv', delimiter=',', 
                               names=['a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9', 'a10', 
                                      'Xname', 'Band'])
    
    coefficients_mask = coefficients['Xname']=='BPRP'
    
    coefficients = coefficients[coefficients_mask]
    
    if band=='G':
        coeff = coefficients[coefficients['Band']=='kG']
        
    elif band=='BP':
        coeff = coefficients[coefficients['Band']=='kBP']
        
    elif band=='RP':
        coeff = coefficients[coefficients['Band']=='kRP']
        
    
    k = coeff['a1'] + coeff['a2']*X + coeff['a3']*X**2 + coeff['a4']*X**3 + \
        coeff['a5']*Av + coeff['a6']*Av**2 + coeff['a7']*Av**3 + \
        coeff['a8']*Av*X + coeff['a9']*Av*X**2 + coeff['a10']*X*Av**2
    
    extinction = k*Av
    
    return extinction*u.mag
    

## Estimating density

In [157]:
def density(ra, dec, cluster_distance, cluster_number, 
            plot_centre=True, save_plot_centre=False):
    
    ra = ra[:, np.newaxis]
    dec = dec[:, np.newaxis]
    pos = np.concatenate([ra, dec], axis=1)
    
    d_cl = cluster_distance[0]
    
    estimate = KMeans(1, n_init=100) # Only want centroid of one cluster
    estimate.fit(pos)
    col_kmeans = estimate.predict(pos) # Predicts closest cluster that each sample in pos belongs to
    centroid_x = estimate.cluster_centers_[0][0]
    centroid_y = estimate.cluster_centers_[0][1]
    centroid = SkyCoord(ra=centroid_x, dec=centroid_y, frame='icrs', unit='deg')
    pos_stars = SkyCoord(ra=pos[:, 0], dec=pos[:, 1], frame='icrs', unit='deg')
        
    
    diffs_x = np.abs(pos[:, 0] - centroid_x*u.deg)
    diffs_y = np.abs(pos[:, 1] - centroid_y*u.deg)
    r_diffs = np.sqrt(diffs_x**2 + diffs_y**2) # deg
    max_r = np.max(r_diffs) # deg
    max_r_rad = max_r.value * np.pi/180 * u.rad
    max_r_pc = cluster_distance*np.tan(max_r_rad)
    
    r_range = np.linspace(0, max_r, 1000) # deg
    
    n_stars = len(ra)
    
    n_90percent = np.round(0.90*n_stars)
    n_75percent = np.round(0.75*n_stars)
    n_50percent = np.round(0.50*n_stars)
    
    n_stars_per_radius = np.empty(len(r_range))
    
    for i, r in enumerate(r_range):
        n_stars_mask = r_diffs<r
        n_stars_per_radius[i] = len(r_diffs[n_stars_mask])
        
    
    # 90%
    _90_diff = np.abs(n_stars_per_radius - n_90percent)
    closest_90_percent = np.min(_90_diff)
    _90_percent_pos = np.where(_90_diff==closest_90_percent)[0]
    
    if len(_90_percent_pos)>1:
        _90_percent_pos = _90_percent_pos[-1]
    
    _90_n = n_stars_per_radius[_90_percent_pos]
    r_90 = r_range[_90_percent_pos] #deg
    r_90_pc = d_cl*np.tan(r_90.value * np.pi/180 * u.rad)
    
    
    # 75%
    _75_diff = np.abs(n_stars_per_radius - n_75percent)
    closest_75_percent = np.min(_75_diff)
    _75_percent_pos = np.where(_75_diff==closest_75_percent)[0]
    
    if len(_75_percent_pos)>1:
        _75_percent_pos = _75_percent_pos[-1]
    
    _75_n = n_stars_per_radius[_75_percent_pos]
    r_75 = r_range[_75_percent_pos] #deg
    r_75_pc = d_cl*np.tan(r_75.value * np.pi/180 * u.rad)
    
    
    # 50%
    _50_diff = np.abs(n_stars_per_radius - n_50percent)
    closest_50_percent = np.min(_50_diff)
    _50_percent_pos = np.where(_50_diff==closest_50_percent)[0]
    
    if len(_50_percent_pos)>1:
        _50_percent_pos = _50_percent_pos[-1]
    
    _50_n = n_stars_per_radius[_50_percent_pos]
    r_50 = r_range[_50_percent_pos] #deg
    r_50_pc = d_cl*np.tan(r_50.value * np.pi/180 * u.rad)
    
    #print(r_90_pc, r_75_pc, r_50_pc)
    
    
    
    # Whole radius
    _90percent_area = np.pi*r_90_pc**2 # pc**2
    _90percent_density = _90_n/_90percent_area
    
    
    _75percent_area = np.pi*r_75_pc**2
    _75percent_density = _75_n/_75percent_area
    
    
    # 1/2 half radius
    _50percent_area = np.pi*r_50_pc**2 # pc**2
    _50percent_density = _50_n/_50percent_area
    
    
    if plot_centre:        
        fig, ax = plt.subplots(figsize=(7, 7))
        
        plt.minorticks_on()
        
        ax.scatter(pos[:, 0], pos[:, 1], c=col_kmeans, cmap='rainbow', s=10, alpha=0.5, 
                   label='Cluster stars')
        
        ax.scatter(centroid_x, centroid_y, c='r', marker='+', s=40, label='Centroid')
        full_circle = plt.Circle((centroid_x, centroid_y), radius=r_90.value, fill=False, 
                                 lw=2, color='darkgreen', label='90% radius')
        ax.add_artist(full_circle)
        
        circle_3_4 = plt.Circle((centroid_x, centroid_y), radius=r_75.value, fill=False, 
                                 lw=2, color='limegreen', label='75% radius')
        ax.add_artist(circle_3_4)
        
        half_circle = plt.Circle((centroid_x, centroid_y), radius=r_50.value, fill=False, 
                                 lw=2, color='lime', label='50% radius')
        ax.add_artist(half_circle)
        
        ax.set_xlabel('Ra [deg]')
        ax.set_ylabel('Dec [deg]')
        ax.set_title(f'Cluster {cluster_number} with centroid')
        
        ax.set_xlim(xmin=centroid_x-max_r.value-1, xmax=centroid_x+max_r.value+1)
        ax.set_ylim(ymin=centroid_y-max_r.value-1, ymax=centroid_y+max_r.value+1)
        ax.grid(True, which='both')
        
        ax.legend()
        
        if save_plot_centre:
            plt.savefig(f'Plots/Finding_cluster_centre_{cluster_number}.png', bbox_inches='tight')
        
        plt.show()
    
    #if (tot_cluster_density.unit==1/u.pc**2)&(density_3_4.unit==1/u.pc**2)&(half_radius_cluster_density.unit==1/u.pc**2):
    #    return np.array([tot_cluster_density.value, density_3_4.value, half_radius_cluster_density.value])
    #
    #else:
    #    print('Not right and/or same unit!')
    
    densities_array = np.array([_90percent_density.value, 
                                _75percent_density.value,
                                _50percent_density.value], dtype=object)
    
    return densities_array.flatten()

--------------------------------------------------------------------------------------------------------------

# Testing

## Query testing

from astroquery.vizier import Vizier

catalog = 'J/A+A/664/A175/table4'
columns = ['Plx', 'e_Plx', 'Gmag', 'BPmag', 'RPmag']
Vizier.ROW_LIMIT = -1
catalogue = Vizier.get_catalogs(catalog='J/A+A/664/A175/table4')[0]

cluster_table = QTable([catalogue['GaiaEDR3'], catalogue['Plx'], catalogue['e_Plx'], catalogue['Gmag'], 
                        catalogue['BPmag'], catalogue['RPmag'], catalogue['Cluster'], catalogue['_RA.icrs'], 
                        catalogue['_DE.icrs']], 
                        names = ['GaiaID', 'Parallax', 'Parallax_error', 'M_apparent', 'G_bp', 'G_rp', 
                               'Cluster_number', 'RA_ICRS', 'DE_ICRS'])

cluster_table['bp_rp'] = cluster_table['G_bp'] - cluster_table['G_rp'] 

cluster_table

error_sizes = cluster_table['Parallax_error']/cluster_table['Parallax']

bigger_errors_mask = error_sizes>0.2

big_errors = error_sizes[bigger_errors_mask]

print(len(big_errors)/len(error_sizes) * 100)
print(len(big_errors))

clusters_sep, clusters_names = cluster_list(cluster_table, 1, 'gaia', None, 5000)
print(len(clusters_sep))
print(len(clusters_names))

#del [clusters_sep[499]]
#del [clusters_names[499]]

#del [clusters_sep[310]] # Problems for MIST
#del [clusters_names[310]]

#del [clusters_sep[76]] # Problems for MIST
#del [clusters_names[76]]

#del [clusters_sep[71]] # Problems for Baraffe
#del [clusters_names[71]]
clusters_names = np.array([float(name) for name in clusters_names])

print(clusters_names)

%%time

test_log_counts, test_log_edges, test_masses, test_ages, same_as_first_iso, cluster_names_new = final_IMFs(cluster_data=clusters_sep[181:182], cluster_names=clusters_names[181:182], model='Baraffe', 
                                                                                                         metallicity='0.00', bin_width=0.2, age_fit_plot=False, 
                                                                                                         chi_plot=False, save_chi_plot=False, plot_hists=False, save_plot_hists=False) 

print('Done')

#cluster_tmass_data=tmass_clusters_sep

%%time
tmass_cluster_data = QTable.read('Filtered_tmass_cluster_data.csv', format='csv', delimiter=',')

print(len(tmass_cluster_data))

tmass_sorted_data, cluster_data_short = sorting_2mass_data(tmass_cluster_data)

tmass_clusters_sep, tmass_cluster_names = cluster_list(tmass_sorted_data, 1, '2mass', None, 5000)


#del [tmass_clusters_sep[310]] # Problems for MIST

#del [tmass_clusters_sep[76]] # Problems for MIST

#del [tmass_clusters_sep[71]] # Problems for Baraffe

lupus_tmass_data = tmass_clusters_sep[181]

test_iso_sep, test_iso_ages = separate_isochrones('Baraffe', metallicity='0.00', survey='2mass')

# Best fitting isochrone for Lupus
iso_j_h = test_iso_sep[7][:, 2] - test_iso_sep[7][:, 3]
iso_h_k = test_iso_sep[7][:, 3] - test_iso_sep[7][:, 4]

iso_hk_mask = (0.18<=iso_h_k)&(iso_h_k<=0.28)

iso_hk = iso_h_k[iso_hk_mask]
iso_jh = iso_j_h[iso_hk_mask]

k_fit, m_fit = np.polyfit(iso_hk, iso_jh, deg=1)

# Crooked line

def iso_line(x, k, m):
    #m=1.085
    #k=-1.9
    return k*x+m

x_vals = np.linspace(-1, 2, 1000)
y_vals = iso_line(x_vals, k_fit, m_fit)

A_J = 0.282
A_H = 0.190
A_K = 0.114

j_h_ext = A_J - A_H
h_k_ext = A_H - A_K
#print(j_h_ext)
#print(h_k_ext)

len_ext_vector = np.sqrt(j_h_ext**2 + h_k_ext**2)

k_vec, m_vec = np.polyfit(np.array([0, h_k_ext]), np.array([0, j_h_ext]), deg=1)

m_upper_line = (k_fit*0.16 + m_fit) - k_vec*0.16
m_lower_line = (k_fit*0.33 + m_fit) - k_vec*0.33

#x_vec_vals = np.linspace(0.18, 1, 100)

#ext_vector_1magn = k_vec*x_vec_vals + m_vec

upper_line = x_vals*k_vec + m_upper_line

lower_line = x_vals*k_vec + m_lower_line

print(tmass_clusters_sep[181].columns)

lupus_corrected = extinction(k_fit, m_fit, lupus_tmass_data)

print(len(lupus_corrected))
print(len(tmass_clusters_sep[181]))

fig, ax = plt.subplots(figsize=(7, 6))

plt.minorticks_on()

#ax.scatter(lupus_corrected['H-K'], lupus_corrected['J-H'], c='b', s=10, alpha=0.5, 
#           label='Cluster data')

ax.scatter(lupus_tmass_data['H-K'], lupus_tmass_data['J-H'], c='b', s=10, alpha=0.5, 
           label='Cluster data')

ax.plot(iso_h_k, iso_j_h, color='r', label='Best isochrone', lw=2, marker='*')

ax.plot(x_vals, y_vals, color='g', label='Extinction line', lw=2, linestyle='dashed')

ax.plot(x_vals, upper_line, color='m', lw=2, linestyle='dashed')
ax.plot(x_vals, lower_line, color='m', lw=2, linestyle='dashed')

ax.arrow(0.18, 0.78, h_k_ext, j_h_ext, color='orange', width=0.01, label=r'$A_V$=1 magn')

#ax.scatter(exts_lupus['H-K'], exts_lupus['J-H'], c='m', s=20, alpha=0.5, label='Corrected stars')

ax.set_xlabel('H - K')
ax.set_ylabel('J - H')
ax.set_title('Extinction Correction After')

ax.grid(True, which='both')
ax.legend()

ax.set_xlim(-1, 2) # -1, 2
ax.set_ylim(-0.5, 2) # -0.5, 2

#plt.savefig('Plots/Extinction_correction_demonstration_After.png')
plt.show()

print(cluster_names_new)

import warnings
warnings.filterwarnings('ignore')

%%time
densities = np.empty((len(clusters_sep), 3))

for i, cluster in enumerate(clusters_sep):
    print(i)
    densities_i = density(clusters_sep[i]['RA_ICRS'], clusters_sep[i]['DE_ICRS'],
                              clusters_sep[i]['dist'], clusters_names[i], False, False)
    
    
    #print(type(densities_i))
    #print(densities_i)
    densities[i, :] = densities_i

print(densities[13, 0])

densities_test = density(clusters_sep[181]['RA_ICRS'], clusters_sep[181]['DE_ICRS'], clusters_sep[181]['dist'], 
                         clusters_names[0], plot_centre=True, save_plot_centre=False)

print(densities_test)

print(len(clusters_sep))
#print(clusters_sep[181]['dist'][0])
#print(clusters_sep[181]['dist_error'][0])

cluster1_dist = clusters_sep[603]['dist']
mean_cl1_dist = np.mean(clusters_sep[603]['dist'])


#print(cluster1_dist[cluster1_dist.argsort()])
print(mean_cl1_dist)

### Testing first function

In [16]:
#times_str = ['0_0005', '0_0010', '0_0020', '0_0030', '0_0040', '0_0050', '0_0080', '0_01', '0_02', 
#         '0_03', '0_04', '0_05', '0_08', '0_1', '0_2', '0_3', '0_4', '0_5', '0_625', '0_8', 
#         '1', '2', '3', '4', '5', '8', '10']
#
#times_num = [0.0005, 0.0010, 0.0020, 0.0030, 0.0040, 0.0050, 0.0080, 0.01, 0.02, 0.03, 0.04, 0.05, 
#             0.08, 0.1, 0.2, 0.3, 0.4, 0.5, 0.625, 0.8, 1, 2, 3, 4, 5, 8, 10]
#
#isochrones = []
#for i, time in enumerate(times_str):
#    isochrones.append(isochrone_import(f'Isochrone_Baraffe_{time}.txt', 'Baraffe'))
#    
#
#print(len(isochrones))

In [17]:
#lupus_data = cluster_import('Lupus_data')
#NGC2264_data = cluster_import('NGC2264_data.txt', 'Clusters')
#
#data_tot = [lupus_data, NGC2264_data]

In [18]:
#plotting_iso_and_data(data_tot, isochrones[5:7], times_num[5:7], ['Lupus','NGC2264'])

### Testing IMF from interpolation of colours

In [19]:
#lupus_data = cluster_import('Lupus_data')


#times_str = ['0_0005', '0_0010', '0_0020', '0_0030', '0_0040', '0_0050', '0_0080', '0_01', '0_02', 
#         '0_03', '0_04', '0_05', '0_08', '0_1', '0_2', '0_3', '0_4', '0_5', '0_625', '0_8', 
#         '1', '2', '3', '4', '5', '8', '10']

#times_num = [0.0005, 0.0010, 0.0020, 0.0030, 0.0040, 0.0050, 0.0080, 0.01, 0.02, 0.03, 0.04, 0.05, 
#             0.08, 0.1, 0.2, 0.3, 0.4, 0.5, 0.625, 0.8, 1, 2, 3, 4, 5, 8, 10]

#isochrones = []
#for i, time in enumerate(times_str):
#    isochrones.append(isochrone_import(f'Isochrone_Baraffe_{time}.txt', 'Baraffe'))
#    
#fit_params = iso_fit(isochrones)
#print(fit_params.shape)
#print(len(isochrones))
#print(isochrones[0].shape)

In [20]:
#isochrone_parameters = isochrone_params(times_str, 'Baraffe')

#print(isochrone_parameters.shape)

In [21]:
#to_interpolate = lupus_data['bp_rp'].value

#masses = IMF(to_interpolate, isochrones[6], cluster_name='Lupus', model_name='Baraffe', 
#             check=True, save_plot=False)

#min_mass = np.min(masses)
#max_mass = np.max(masses)

## Trying to fit the IMF and get parameters

In [85]:
# Lupus has number 181
#test_data = clusters_sep[0:10] #cluster_import('Lupus_data') 
#test_names = clusters_names[0:10]

#print(type(clusters_sep[0:2]))
#print(test_data)

#print(test_data[0]['Cluster_number']) # Works!

%%time
tmass_cluster_data = QTable.read('Filtered_tmass_cluster_data.csv', format='csv', delimiter=',')

print(len(tmass_cluster_data))

tmass_sorted_data, cluster_data_short = sorting_2mass_data(tmass_cluster_data)

tmass_clusters_sep, tmass_cluster_names = cluster_list(tmass_sorted_data, 1, '2mass', None, 5000)


#del [tmass_clusters_sep[310]] # Problems for MIST

#del [tmass_clusters_sep[76]] # Problems for MIST

#del [tmass_clusters_sep[71]] # Problems for Baraffe

print(len(tmass_clusters_sep))
print(len(clusters_sep))

print(len(tmass_clusters_sep[17]))
print(len(clusters_sep[17]))
#print(tmass_clusters_sep[17].columns)

inters, gids, tids = np.intersect1d(clusters_sep[17]['GaiaID'].value, tmass_clusters_sep[17]['GaiaID'].value,
                                   return_indices=True)

print(len(gids))

for i in range(len(clusters_sep)):
    #print(i)
    num_g = clusters_sep[i]['Cluster_number'][0]
    num_t = tmass_clusters_sep[i]['Cluster_number'][0]
    if num_g!=num_t:
        print('Not aligned')
        
    intersection, ginds, tinds = np.intersect1d(clusters_sep[i]['GaiaID'].value, 
                                              tmass_clusters_sep[i]['GaiaID'].value,
                                              return_indices=True)
    
    g_ids = clusters_sep[i]['GaiaID'][ginds]
    t_ids = tmass_clusters_sep[i]['GaiaID'][tinds]
    
    if all(g_ids==t_ids):
        #print(f'It is fine: cluster {i+1}')
        continue
    else:
        print(f'Something is wrong: cluster {i+1}')

print(len(clusters_sep))
print(len(clusters_names))

%%time

test_log_counts, test_log_edges, test_masses, test_ages, same_as_first_iso, cluster_names = final_IMFs(cluster_data=clusters_sep, cluster_names=clusters_names, model='MIST', 
                                                                                                         metallicity='0.00', bin_width=0.2, age_fit_plot=False, 
                                                                                                         chi_plot=False, save_chi_plot=False, plot_hists=False, save_plot_hists=False) 

print('Done')

#cluster_tmass_data=tmass_clusters_sep

print(len(test_masses))
print(len(test_log_counts))
print(len(test_log_edges))
print(len(test_ages))
print()

print(test_ages)

extinctions_data = QTable.read('Extinctions_file.csv', format='csv', delimiter=',')

extinctions_data

#print(test_ages) # Baraffe isochrone number 7
print(len(same_as_first_iso))
#print(same_as_first_iso)

## IMPORTANT 

extinctions_data = extinctions_data[extinctions_data['source_id'].argsort()]
cluster_table = cluster_table[cluster_table['GaiaID'].argsort()]

same_stars, gaia_inds, ext_inds = np.intersect1d(cluster_table['GaiaID'].value, 
                                               extinctions_data['source_id'].value, 
                                               return_indices=True)


extinctions_data['Cluster_number'] = cluster_table[gaia_inds]['Cluster_number']

unique_numbers = np.unique(extinctions_data['Cluster_number'].value)
#print(unique_numbers)
print(len(unique_numbers))


extinctions_data_sep, names = cluster_list(extinctions_data, 1, None, None)

#print(tmass_clusters_sep[405].columns)

print(len(extinctions_data_sep))

clusters = []

for extinction_stars in extinctions_data_sep:
    clusters.append(extinction_stars['Cluster_number'][0])
    
    
#print(clusters)

In [158]:
#del [tmass_clusters_sep[405]]

## IMPORTANT 

A_G_diffs = []
A_BP_diffs = []
A_RP_diffs = []

for i, cluster in enumerate(tmass_clusters_sep):
    #print(i)
    which_cluster = cluster['Cluster_number'][0]
    pos = np.where(clusters==which_cluster)[0][0]
    
    extinction_data = extinctions_data_sep[pos]
    
    #if cluster['Cluster_number'][0]==extinctions_data_sep[i]['Cluster_number'][0]:
    common_stars, gaia_inds, ext_inds = np.intersect1d(cluster['GaiaID'].value, 
                                                       extinctions_data['source_id'].value, 
                                                       return_indices=True)
    
    
    
    A_G_diffs.append(cluster[gaia_inds]['A_G'] - extinctions_data[ext_inds]['ag50'])
    A_BP_diffs.append(cluster[gaia_inds]['A_BP'] - extinctions_data[ext_inds]['abp50'])
    A_RP_diffs.append(cluster[gaia_inds]['A_RP'] - extinctions_data[ext_inds]['arp50'])
    #print(len(gaia_inds))
    #else:
     #   print('Something is wrong')

print(len(A_G_diffs[200]))

clusters_names_float = [float(x) for x in clusters_names]

#print(clusters_names_float)

fig10, ax10 = plt.subplots(3, 1, figsize=(15, 15))


ax10[0].set_xlabel('Cluster name/number')
ax10[0].set_ylabel('Extinction difference')
ax10[0].set_title(r'Extinction difference $A_G$')

ax10[0].grid(True)
ax10[0].set_xlim(-5, 680)



ax10[1].set_xlabel('Cluster name/number')
ax10[1].set_ylabel('Extinction difference')
ax10[1].set_title(r'Extinction difference $A_{BP}$')

ax10[1].grid(True)
ax10[1].set_xlim(-5, 680)



ax10[2].set_xlabel('Cluster name/number')
ax10[2].set_ylabel('Extinction difference')
ax10[2].set_title(r'Extinction difference $A_{RP}$')

ax10[2].grid(True)
ax10[2].set_xlim(-5, 680)



for i in range(672):
    #print(i)
    #print(len(A_G_diffs[i]))
    name_array = np.empty((len(A_G_diffs[i])))
    name_array[:] = clusters_names_float[i]
    #print(name_array)
    
    ax10[0].scatter(name_array, A_G_diffs[i], s=20, alpha=0.5)
    
    ax10[1].scatter(name_array, A_BP_diffs[i], s=20, alpha=0.5)
    
    ax10[2].scatter(name_array, A_RP_diffs[i], s=20, alpha=0.5)



plt.tight_layout()
plt.show()

In [105]:
#print(A_G_diffs[0])
#print(np.min(A_G_diffs[0]))
#print('-----------------------------------------------------------------------')
#print(A_BP_diffs[0])
#print(np.min(A_BP_diffs[0]))
#print('-----------------------------------------------------------------------')
#print(A_RP_diffs[0])
#print(np.min(A_RP_diffs[0]))

arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

ind = np.linspace(0, len(arr)-1, len(arr), dtype=int)

print(arr)
print(ind)

In [84]:
#for i in range(len(lupus_masses)):
#    print(len(lupus_masses[i]))

#print(test_ages)

In [443]:
#print(len(log_lupus_edges[1]))

### Saving files for testing

In [364]:
#np.save('failed_cluster_counts', log_lupus_counts[:, 2])
#np.save('failed_cluster_edges', log_lupus_edges[:, 2])

In [346]:
#print(np.shape(log_lupus_counts))

#print(type(log_lupus_counts[:, -1]))

#if any(log_lupus_counts[:, -1]==0.):
#    print('Yes')
    
#new_log_counts = [counts for counts in log_lupus_counts[:, -1] if counts!=0.]

#print(new_log_counts)

## IMPORTANT 

%%time

bin_widths_test, all_cluster_params_test, kroupa_diff_test, counter = IMF_slopes(test_log_edges, test_log_counts, 
                                                                         test_masses, model='Baraffe', intervals='Kroupa', 
                                                                         plot=True)

print(kroupa_diff_test)

In [36]:
#print(len(bin_widths_test))
#print(len(all_cluster_params_test))
#print(len(kroupa_diff_test))
#print(counter)

In [248]:
#print()
#print(f'Widths = {bin_widths_test}')
#print()
#print(f'Slopes for all clusters for all intervals = {all_cluster_params_test[:, 0, :]}')
#print()
#print(f'Difference from Kroupa = {kroupa_diff_test}')
#print()
#print(f'Number of clusters that had to cut interval bc of zero-bins = {counter}')

### Setting up middle of bins and checking bin widths

In [146]:
# Method for calculting the widths without a for-loop

#arr1 = np.array([1, 2, 3, 4, 5])
#arr2 = np.roll(arr1, shift=1)
#arr = arr2-arr1
#arr_tot = np.delete(arr, arr==arr[0])

#if all(arr_tot==-1):
#    print('yes')

#print(arr)
#print(arr_tot)



# Method for finding true values in mask:

#arr1 = np.array([1, 2, 2, 3, 4, 5, 2, 2, 6, 7])

#arr_mask = arr1==2

#if len(arr_mask[arr_mask==True])==4:
#    print('Yes')

In [19]:
#n_bins = len(log_lupus_edges) - 1 

#bin_width = np.empty((n_bins)) # In log(solar_masses)
#bin_pos = np.empty((n_bins)) # In solar_masses

#for i in range(n_bins):
#    bin_pos[i] = (log_lupus_edges[i] + log_lupus_edges[i+1])/2
#    bin_width[i] = np.abs(np.log(log_lupus_edges[i]) - np.log(log_lupus_edges[i+1]))



#print(bin_pos, 'solMass')
#print()
#print(bin_width, 'dex')


### Dividing up into mass intervals

In [343]:
#interval1_mask = (0.01<=bin_pos)&(bin_pos<=0.08) # bin_pos<=0.1
#interval2_mask = (0.08<=bin_pos)&(bin_pos<=0.5) #0.1<=bin_pos  
#interval3_mask = (0.5<=bin_pos)#&(bin_pos<1)
#interval4_mask = 1<=bin_pos

#mass_int1 = bin_pos[interval1_mask]
#mass_int2 = bin_pos[interval2_mask]
#mass_int3 = bin_pos[interval3_mask]
#mass_int4 = bin_pos[interval4_mask]

#print(mass_int2)

#count_int1 = log_lupus_counts[interval1_mask]
#count_int2 = log_lupus_counts[interval2_mask]
#count_int3 = log_lupus_counts[interval3_mask]
#count_int4 = log_lupus_counts[interval4_mask]


#int_masses = [mass_int1, mass_int2, mass_int3]#, mass_int4] #  
#int_counts = [count_int1, count_int2, count_int3]#, count_int4] # 

#print(len(bin_pos))
#print(len(mass_int1) + len(mass_int2) )
#print(len(count_int1) + len(count_int2) )

#print(int_masses[0])
#print(type(int_masses[0][0]))
#print()
#print(np.shape(int_counts[0]))

#flat_arr = int_counts[0].flatten()
#print(flat_arr)
#print(type(flat_arr[0]))

### Making fit

In [344]:
#def fcn(x, a, C):
#    return C*x**a


#params = np.empty((2, len(int_counts))) # (parameters, intervals)
#covariance = []
#x_fit_vals = np.empty((500, len(int_counts)))
#fit_vals = np.empty((500, len(int_counts)))
#fit_vals_kroupa = np.empty((500, len(int_counts)))
#a_kroupa = np.array([0.3, 1.3, -2.3, -2.3])
#c_kroupa = np.array([20, 600, 50, 50])
#for i, masses in enumerate(int_masses):
#    flat_counts = int_counts[i].flatten()
#    p, c = scopt.curve_fit(fcn, xdata=masses, ydata=flat_counts)
#    params[:, i] = p
#    covariance.append(c)
    
#    a_fit, C_fit = p
#    x_fit_vals[:, i] = np.linspace(np.min(masses), np.max(masses), 500)
#    fit_vals[:, i] = fcn(x_fit_vals[:, i], a_fit, C_fit)
#    fit_vals_kroupa[:, i] = fcn(x_fit_vals[:, i], a_kroupa[i], c_kroupa[i])

    
#print(np.log10(np.abs(params[0, :]))) # Prints all a-values
#print(params[0, :])
#print(params[1, :]) # Prints all C-values

#print(np.shape(fit_vals))

In [345]:
#fig1, ax1 = plt.subplots(1, 2, figsize=(12, 4))

#min_mass = np.min(lupus_masses)
#max_mass = np.max(lupus_masses)

#ax1[0].hist(lupus_masses, bins=n_bins, histtype='stepfilled', fc=(0, 0, 1, 0.25), ec='b', lw=1.5, 
#            label='IMF hist')
#ax1[0].plot(x_fit_vals[:, 0], fit_vals[:, 0], color='r', label='Fitted Interval 1')
#ax1[0].plot(x_fit_vals[:, 1], fit_vals[:, 1], color='orange', label='Fitted Interval 2')
#ax1[0].plot(x_fit_vals[:, 2], fit_vals[:, 2], color='g', label='Fitted Interval 3')
#ax1[0].plot(x_fit_vals[:, 3], fit_vals[:, 3], color='m', label='Fitted Interval 4')

#ax1[0].set_xlabel(r'Mass [M$_{\odot}$]')
#ax1[0].set_ylabel('Counts')
#ax1[0].set_title(f'Baraffe IMF for Lupus: normal scale')
#ax1[0].grid(True)
#ax1[0].set_xscale('log')



#ax1[1].hist(lupus_masses, range=(np.log10(min_mass), np.log10(max_mass)), 
#            bins=np.logspace(np.log10(min_mass), np.log10(max_mass), n_bins+1), 
#            histtype='stepfilled', fc=(0, 0, 1, 0.25), ec='b', lw=1.5, 
#            label='IMF hist')
#ax1[1].plot(x_fit_vals[:, 0], fit_vals[:, 0], color='r', label='Fitted Interval 1')
#ax1[1].plot(x_fit_vals[:, 1], fit_vals[:, 1], color='orange', label='Fitted Interval 2')
#ax1[1].plot(x_fit_vals[:, 2], fit_vals[:, 2], color='g', label='Fitted Interval 3')
#ax1[1].plot(x_fit_vals[:, 3], fit_vals[:, 3], color='m', label='Fitted Interval 4')

#ax1[1].plot(x_fit_vals[:, 0], fit_vals_kroupa[:, 0], color='r', linestyle='dashed', label='Kroupa Interval 1')
#ax1[1].plot(x_fit_vals[:, 1], fit_vals_kroupa[:, 1], color='orange', linestyle='dashed', label='Kroupa Interval 2')
#ax1[1].plot(x_fit_vals[:, 2], fit_vals_kroupa[:, 2], color='g', linestyle='dashed', label='Kroupa Interval 3')
#ax1[1].plot(x_fit_vals[:, 3], fit_vals_kroupa[:, 3], color='m', linestyle='dashed', label='Kroupa Interval 4')

#ax1[1].set_xlabel(r'Mass [M$_{\odot}$]')
#ax1[1].set_ylabel('Counts')
#ax1[1].set_title(f'Baraffe IMF for Lupus: log scale')

#ax1[1].set_xscale('log')
#ax1[1].set_yscale('log')
#ax1[1].set_ylim(ymin=0)
#ax1[1].grid(True)
#ax1[1].legend()

#plt.show()

### Fitting isochrones test

In [211]:
#isos = isochrones[7][isochrones[7][:, 0].argsort()] # Sort according to magnitude

#magnitude_mask = isos[:, 0]<16

#iso_x = isos[:, 1][magnitude_mask]
#iso_y = isos[:, 0][magnitude_mask]

#print(iso_x)
#print()

#results, cov = np.polyfit(iso_x, iso_y, 5, cov=True)
#print(results)
#print()
#print(cov)

#k5, k4, k3, k2, k1, c = results

#def iso_fit(x):
#    return k5*x**5 + k4*x**4 + k3*x**3 +k2*x**2 + k1*x**1 + c

#x_vals = np.linspace(np.min(iso_x), np.max(iso_x), 1000)
#y_vals = iso_fit(x_vals)

In [212]:
#fig2, ax2 = plt.subplots(figsize=(5, 4))

#ax2.scatter(iso_x, iso_y, c='r', s=10, label='Isochrone data')
#ax2.scatter(lupus_data['bp_rp'], lupus_data['M_V'], c='g', s=5, label='Lupus data')
#ax2.plot(isochrones[6][:, 1], isochrones[6][:, 0], color='orange')
#ax2.scatter(isochrones[0][:, 1], isochrones[0][:, 0], c='m', s=10)
#ax2.plot(x_vals, y_vals, 'b', label='Fitted isochrone')

#ax2.set_xlabel('Colour')
#ax2.set_ylabel('Magnitude')

#ax2.axhline(16.5, color='k', linestyle='dashed')
#ax2.legend()

#ax2.grid(True)

#ax2.invert_yaxis()

#plt.show()

### Testing chi-square estimate

In [213]:
#chis = chi_fitting(fit_fcn, lupus_data, isochrone_parameters)

#min_chi = np.min(chis)
#min_pos = np.where(chis==min_chi)[0]

#print(chis)
#print(min_chi)
#print(min_pos)
#print(times_num[7])

In [214]:
#fig5, ax5 = plt.subplots(figsize=(5, 4))

#ax5.scatter(np.array(times_num), chis, color='b')

#ax5.set_xlabel('Ages [Myr]')
#ax5.set_ylabel(r' $\chi^2$ values')
#ax5.grid()
#ax5.set_xscale('log')

#plt.show()

In [215]:
#import scipy.stats as scstat

# x-axis is the colour
#isochrone_x = np.linspace(np.min(lupus_data['bp_rp'].value), 
#                          np.max(lupus_data['bp_rp'].value), 846)

#params = fit_params[:, 10]
#isochrone_y = fit_fcn(isochrone_x, args=params) 

#iso_y = isochrone_y/np.sum(isochrone_y)

# Normalized absolute magnitude
#normalized_lupus_data = lupus_data['M_V'].value/(np.sum(lupus_data['M_V'].value))

#chi_vals = np.empty((27))

#for i in range(27):
#    params = fit_params[:, i]
    
    # y-axis is the absolute magnitude, getting absolute magnitude from colour
#    isochrone_y = fit_fcn(isochrone_x, args=params) 
#    iso_y = isochrone_y/np.sum(isochrone_y) # normalize model data
    
    # Chi-square made for absolute magnitude
#    chi_vals[i], p_value = scstat.chisquare(normalized_lupus_data, iso_y)


    
#min_chi = np.min(chi_vals)
#min_pos = np.where(chi_vals==min_chi)[0]

#print(chi_vals)
#print(min_chi)
#print(min_pos)
#print(chi_vals[6])
#result, p_val = scstat.chisquare(normalized_lupus_data, iso_y)

#print(result)
#print(p_val)

Columns in the Vizier tables and their units

GLON, GLAT, Plx, e_Plx, pmRA, e_pmRA, pmDE, e_pmDE, Gmag, BPmag, RPmag, Flag, Cluster, _RA.icrs, _DE.icrs
deg, deg, mas, mas, mas/yr, mas/yr, mas/yr, mas/yr, mag, mag, mag, , , deg, deg

### Testing final IMF function

In [85]:
#times_str = ['0_0005', '0_0010', '0_0020', '0_0030', '0_0040', '0_0050', '0_0080', '0_01', '0_02', 
#         '0_03', '0_04', '0_05', '0_08', '0_1', '0_2', '0_3', '0_4', '0_5', '0_625', '0_8', 
#         '1', '2', '3', '4', '5', '8', '10']
#
#times_num = [0.0005, 0.0010, 0.0020, 0.0030, 0.0040, 0.0050, 0.0080, 0.01, 0.02, 0.03, 0.04, 0.05, 
#             0.08, 0.1, 0.2, 0.3, 0.4, 0.5, 0.625, 0.8, 1, 2, 3, 4, 5, 8, 10]

In [363]:
#normal_cluster_counts, normal_cluster_edges, log_cluster_counts, log_cluster_edges = final_IMFs(cluster_data=clusters_sep[181:182], cluster_names=clusters_names[181:182], 
#                                                                                                model='Baraffe', metallicity='0.00', nbins=15, age_fit_plot=True, 
#                                                                                                chi_plot=True, save_chi_plot=False) 
# check_interp=True, save_check_interp=False, plot_hists=True, save_plot_hists=False,

### Testing new interpolation

In [360]:
#iso_data, iso_ages = separate_isochrones('Baraffe', '0.00')
#iso_ages = iso_ages.flatten()
#iso_params = isochrone_params(iso_data, 'Baraffe')
#print(len(iso_data))

#lupus_data = cluster_import('Lupus_data')
#print(len(lupus_data['M_V']))
#lupus_magn_mask = lupus_data['M_V'].value>0

#lupus_magn_data = lupus_data[lupus_magn_mask]

#print(len(lupus_magn_data['M_V']))

In [361]:
#chi_values_lupus = chi_fitting(fit_fcn, lupus_data, iso_params, iso_data)

# Considering age interpolation correct!
#new_lupus_age, younger_iso, younger_age, older_iso, older_age = new_age_interpolation(chi_values_lupus, iso_data, iso_ages, 'Lupus', plot=True)


#lupus_stellar_masses = interpolated_mass(lupus_data, 'Lupus', new_lupus_age[0], younger_iso, 
#                                         older_iso, younger_age, older_age, 'Baraffe')

#print(len(lupus_stellar_masses))

In [362]:
#fig10, ax10 = plt.subplots(1, 2, figsize=(10, 4))

#ax10[0].hist(lupus_stellar_masses, 15, fc=(0, 0, 1, 0.25), ec='b', histtype='stepfilled')

#ax10[0].grid(True)


#min_mass = np.min(lupus_stellar_masses)
#max_mass = np.max(lupus_stellar_masses)

#ax10[1].hist(lupus_stellar_masses, range=(np.log10(min_mass), np.log10(max_mass)), 
#                     bins=np.logspace(np.log10(min_mass), np.log10(max_mass), 15+1),
#                     histtype='stepfilled', fc=(0, 0, 1, 0.25), ec='b', lw=1.5)

#ax10[1].grid(True)
#ax10[1].set_xscale('log')


#plt.tight_layout()
#plt.show()

## Test: dividing up isochrones

How to structure the code?

- loop over all lines in file
- gather all lines with same age into a list
- extract list and make into one array with important data like with Baraffe isochrones
- return a list of arrays for every isochrone

test_iso_sep, test_iso_ages = separate_isochrones('Baraffe', metallicity='0.00', survey='2mass')

print(test_iso_sep[7])
print(len(test_iso_sep))
#print(test_iso_sep[0].shape)
#print(test_iso_sep[0][:, 1])

test_iso_sep_Baraffe = separate_isochrones('Baraffe', metallicity='0.00')

In [359]:
#print(len(test_iso_sep_Baraffe))
#print(test_iso_sep_Baraffe[-1].shape)
#print(test_iso_sep_Baraffe[0][:, 0])

### Testing extinction analysis

from astroquery.vizier import Vizier

catalog = 'J/A+A/664/A175/table4'
columns = ['Plx', 'e_Plx', 'Gmag', 'BPmag', 'RPmag']
Vizier.ROW_LIMIT = -1
catalogue = Vizier.get_catalogs(catalog='J/A+A/664/A175/table4')[0]

cluster_table = QTable([catalogue['GaiaEDR3'], catalogue['Plx'], catalogue['e_Plx'], catalogue['Gmag'], 
                        catalogue['BPmag'], catalogue['RPmag'], catalogue['Cluster'], catalogue['_RA.icrs'], 
                        catalogue['_DE.icrs']], 
                        names = ['GaiaID', 'Parallax', 'Parallax_error', 'M_apparent', 'G_bp', 'G_rp', 
                               'Cluster_number', 'RA_ICRS', 'DE_ICRS'])

cluster_table['bp_rp'] = cluster_table['G_bp'] - cluster_table['G_rp'] 

clusters_sep, clusters_names = cluster_list(cluster_table, 1, 'gaia')
#print(len(clusters_sep))

lupus_data = clusters_sep[603]

%%time

test_log_counts, test_log_edges, test_masses, test_ages = final_IMFs(cluster_data=[lupus_data], cluster_names=['Lupus'], model='Baraffe', 
                                                                     metallicity='0.00', bin_width=0.2, age_fit_plot=False, 
                                                                     chi_plot=True, save_chi_plot=False, plot_hists=True, save_plot_hists=False) 

print('Done')

print(test_ages)

%%time

bin_widths_test, all_cluster_params_test, kroupa_diff_test, counter = IMF_slopes(test_log_edges, test_log_counts, 
                                                                         test_masses, model='MIST', intervals='Kroupa', 
                                                                         plot=False)

print(kroupa_diff_test)

tmass_cluster_data = QTable.read('Filtered_tmass_cluster_data.csv', format='csv', delimiter=',')

%%time

tmass_sorted_data, cluster_data_short = sorting_2mass_data(tmass_cluster_data)

tmass_clusters_sep, cluster_names = cluster_list(tmass_sorted_data, 1, '2mass')

un_arr, counts = np.unique(tmass_cluster_data['tmass_oid'], return_counts=True)

#print(len(tmass_cluster_data))
#print(len(un_arr))
#print(len(counts))

duplicates_mask = counts>1
duplicates = un_arr[duplicates_mask]



#print(len(duplicates))
#print(duplicates)

%%time

tmass_sep_data, cluster_data_short = sorting_2mass_data(tmass_cluster_data)




print(tmass_sep_data.columns)
print()
print(cluster_data_short.columns)
print()

#print(tmass_sep_data['GaiaID'])
#print(cluster_data_short['GaiaID'])

In [None]:
#data_intersection = np.intersect(cluster_table, )

#tmass_sep_data = tmass_sep_data[tmass_sep_data['Cluster_number'].argsort()]

tmass_clusters_sep, cluster_names = cluster_list(tmass_sorted_data, 1, '2mass')

lupus_tmass_data = tmass_clusters_sep[603]

print(len(tmass_clusters_sep))
print(len(cluster_names))

In [30]:
#def extinction_ratio(x, Rv):
    
#    if (0.3<=np.min(x))&(np.max(x)<1.1):
#        a = 0.574*x**1.61
#        b = -0.527*x**1.61
        
#        return a + b/Rv
        
#    elif (1.1<=np.min(x))&(np.max(x)<=3.3):
#        y = x - 1.82
#        a = 1 + 0.17699*y - 0.50447*y**2 - 0.02427*y**3 + 0.72085*y**4 \
#            + 0.01979*y**5 - 0.77530*y**6 + 0.32999*y**7
#        b = 1.41338*y + 2.28305*y**2 + 1.07233*y**3 - 5.38434*y**4 \
#            - 0.62251*y**5 + 5.30260*y**6 - 2.09002*y**7
        
#        return a + b/Rv
        
    

In [31]:
#J_wavelength = np.linspace(1.1, 1.4, 1000) # micro meters
#J_mid = 1.25 # micro meters
#H_wavelength = np.linspace(1.5, 1.8, 1000) # micro meters
#H_mid = 
#K_wavelength = np.linspace(2.0, 2.4, 1000) # micro meters

test_iso_sep, test_iso_ages = separate_isochrones('Baraffe', metallicity='0.00', survey='2mass')

# Best fitting isochrone for Lupus
iso_j_h = test_iso_sep[7][:, 2] - test_iso_sep[7][:, 3]
iso_h_k = test_iso_sep[7][:, 3] - test_iso_sep[7][:, 4]

iso_hk_mask = (0.18<=iso_h_k)&(iso_h_k<=0.28)

iso_hk = iso_h_k[iso_hk_mask]
iso_jh = iso_j_h[iso_hk_mask]

k_fit, m_fit = np.polyfit(iso_hk, iso_jh, deg=1)

# Crooked line

def iso_line(x, k, m):
    #m=1.085
    #k=-1.9
    return k*x+m

x_vals = np.linspace(-1, 2, 1000)
y_vals = iso_line(x_vals, k_fit, m_fit)

A_J = 0.282
A_H = 0.190
A_K = 0.114

j_h_ext = A_J - A_H
h_k_ext = A_H - A_K
#print(j_h_ext)
#print(h_k_ext)

len_ext_vector = np.sqrt(j_h_ext**2 + h_k_ext**2)

k_vec, m_vec = np.polyfit(np.array([h_k_ext]), np.array([j_h_ext]), deg=1)

#x_vec_vals = np.linspace(0.18, 1, 100)

#ext_vector_1magn = k_vec*x_vec_vals + m_vec

print(tmass_clusters_sep[181].columns)

lupus_corrected = extinction(k_fit, m_fit, tmass_clusters_sep[181])

print(len(lupus_corrected))
print(len(tmass_clusters_sep[181]))

fig, ax = plt.subplots(figsize=(5, 4))

plt.minorticks_on()

ax.scatter(lupus_corrected['H-K'], lupus_corrected['J-H'], c='m', s=10, alpha=0.5, 
           label='Cluster data')

#ax.scatter(lupus_tmass_data['H-K'], lupus_tmass_data['J-H'], c='b', s=10, alpha=0.5, 
#           label='Cluster data')

ax.plot(iso_h_k, iso_j_h, color='r', label='Best isochrone', lw=2, marker='*')

ax.plot(x_vals, y_vals, color='g', label='Extinction line', lw=2, linestyle='dashed')

ax.arrow(0.18, 0.78, h_k_ext, j_h_ext, color='orange', width=0.01, label=r'$A_V$=1 magn')

#ax.scatter(exts_lupus['H-K'], exts_lupus['J-H'], c='m', s=20, alpha=0.5, label='Corrected stars')

ax.set_xlabel('H - K')
ax.set_ylabel('J - H')
ax.set_title('Extinction Correction After')

ax.grid(True, which='both')
ax.legend()

ax.set_xlim(-1, 2) # -1, 2
ax.set_ylim(-0.5, 2) # -0.5, 2

#plt.savefig('Plots/Extinction_correction_demonstration_After.png')
plt.show()

#cluster_table
#cluster_data_short
print(len(lupus_data))
print(len(lupus_tmass_data))

# Intersections between full vizier table and crossmatch table
intersections, intersections_arr1, intersections_arr2 = np.intersect1d(lupus_data['GaiaID'].value, 
                                                                       lupus_tmass_data['GaiaID'].value, 
                                                                       return_indices=True)

#print(intersections)
print(len(intersections_arr1))
print(len(intersections_arr2))

lupus_data['Extinction'] = np.zeros((len(lupus_data)))

lupus_data['Extinction'][intersections_arr1] = lupus_tmass_data['Extinction']


print(lupus_data['Extinction'])

indices = np.linspace(0, len(lupus_data), len(lupus_data), dtype=int)

pos_left = np.delete(indices, intersections_arr1)

print(len(indices))
print(len(intersections_arr1))
print(len(pos_left))

mean_extinction = np.mean(lupus_tmass_data['Extinction'].value)
print(mean_extinction)
#print(lupus_data['Extinction'][pos_left])
lupus_data['Extinction'][pos_left] = mean_extinction

#print(lupus_data['Extinction'][pos_left])

lupus_data['M_V'] = (lupus_data['M_apparent'].value - 5*np.log10(lupus_data['dist'].value) + 5 - lupus_data['Extinction'].value)*u.mag

#lupus_data['M_V'] = lupus_data['M_V']/u.mag
print(lupus_data['M_V'])

%%time

new_log_counts, new_log_edges, new_masses, new_ages = \
    final_IMFs(cluster_data=[lupus_data], cluster_names=['Lupus'], model='Baraffe', 
    metallicity='0.00', bin_width=0.2, age_fit_plot=False, 
    chi_plot=True, save_chi_plot=False, plot_hists=True, save_plot_hists=False) 

print('Done')

print(new_ages)

arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

indexes = [2, 4, 6]


print(arr[indexes])

ext_stars_mask = (lupus_data['Extinction']!=0)&(lupus_data['Extinction']!=mean_extinction)

ext_stars = lupus_data[ext_stars_mask]

print(len(ext_stars))

new_log_counts, new_log_edges, new_masses, new_ages

%%time

bin_widths_test, all_cluster_params_test, kroupa_diff_test, counter = IMF_slopes(new_log_edges, new_log_counts, 
                                                                         new_masses, model='Baraffe', intervals='Kroupa', 
                                                                         plot=True)


print(all_cluster_params_test)
print(kroupa_diff_test)

----------------------------------------------------------------------------------------------------------------

# Best fitting isochrone for Lupus
iso_j_h_new = test_iso_sep[1][:, 2] - test_iso_sep[1][:, 3]
iso_h_k_new = test_iso_sep[1][:, 3] - test_iso_sep[1][:, 4]

iso_hk_mask = (0.18<=iso_h_k)&(iso_h_k<=0.28)

iso_hk_new = iso_h_k_new[iso_hk_mask]
iso_jh_new = iso_j_h_new[iso_hk_mask]

k_fit_new, m_fit_new = np.polyfit(iso_hk_new, iso_jh_new, deg=1)

# Crooked line

def iso_line(x, k, m):
    #m=1.085
    #k=-1.9
    return k*x+m

x_vals_new = np.linspace(-1, 2, 1000)
y_vals_new = iso_line(x_vals, k_fit, m_fit)

A_J_new = 0.282
A_H_new = 0.190
A_K_new = 0.114

j_h_ext_new = A_J_new - A_H_new
h_k_ext_new = A_H_new - A_K_new
#print(j_h_ext)
#print(h_k_ext)

len_ext_vector_new = np.sqrt(j_h_ext_new**2 + h_k_ext_new**2)

k_vec_new, m_vec_new = np.polyfit(np.array([h_k_ext_new]), np.array([j_h_ext_new]), deg=1)

#x_vec_vals = np.linspace(0.18, 1, 100)

#ext_vector_1magn = k_vec*x_vec_vals + m_vec

lupus_corrected = extinction(k_fit_new, m_fit_new, k_vec_new, len_ext_vector_new, lupus_tmass_data)

print(len(lupus_corrected))
print(len(tmass_clusters_sep[181]))

fig1, ax1 = plt.subplots(figsize=(8, 6))

plt.minorticks_on()

ax1.scatter(lupus_corrected['H-K'], lupus_corrected['J-H'], c='m', s=20, alpha=0.5, 
           label='Cluster data')

ax1.scatter(lupus_tmass_data['H-K'], lupus_tmass_data['J-H'], c='b', s=10, alpha=0.5, 
           label='Cluster data')

ax1.plot(iso_h_k_new, iso_j_h_new, color='r', label='Best isochrone', lw=2, marker='*')

ax1.plot(x_vals_new, y_vals_new, color='g', label='Extinction line', lw=2, linestyle='dashed')

ax1.arrow(0.18, 0.78, h_k_ext_new, j_h_ext_new, color='orange', width=0.01, label=r'$A_V$=1 magn')

#ax1.scatter(exts_lupus['H-K'], exts_lupus['J-H'], c='m', s=20, alpha=0.5, label='Corrected stars')

ax1.set_xlabel('H - K')
ax1.set_ylabel('J - H')
ax1.set_title('Color-Color Diagram Lupus')

ax1.grid(True, which='both')
ax1.legend()

ax1.set_xlim(-1, 2) # -1, 2
ax1.set_ylim(-0.5, 2) # -0.5, 2

plt.show()

#cluster_table
#cluster_data_short
print(len(lupus_data))
print(len(lupus_tmass_data))

# Intersections between full vizier table and crossmatch table
intersections, intersections_arr1, intersections_arr2 = np.intersect1d(lupus_data['GaiaID'].value, 
                                                                       lupus_tmass_data['GaiaID'].value, 
                                                                       return_indices=True)

#print(intersections)
print(len(intersections_arr1))
print(len(intersections_arr2))

lupus_data['Extinction'] = np.zeros((len(lupus_data)))

lupus_data['Extinction'][intersections_arr1] = lupus_tmass_data['Extinction']


print(lupus_data['Extinction'])

indices = np.linspace(0, len(lupus_data), len(lupus_data), dtype=int)

pos_left = np.delete(indices, intersections_arr1)

print(len(indices))
print(len(intersections_arr1))
print(len(pos_left))

mean_extinction = np.mean(lupus_tmass_data['Extinction'].value)
print(mean_extinction)
#print(lupus_data['Extinction'][pos_left])
lupus_data['Extinction'][pos_left] = mean_extinction

#print(lupus_data['Extinction'][pos_left])

lupus_data['M_V'] = (lupus_data['M_apparent'].value - 5*np.log10(lupus_data['dist'].value) + 5 - lupus_data['Extinction'].value)*u.mag

#lupus_data['M_V'] = lupus_data['M_V']/u.mag
print(lupus_data['M_V'])

%%time

new_log_counts, new_log_edges, new_masses, new_ages = \
    final_IMFs(cluster_data=[lupus_data], cluster_names=['Lupus'], model='Baraffe', 
    metallicity='0.00', bin_width=0.2, age_fit_plot=False, 
    chi_plot=True, save_chi_plot=False, plot_hists=True, save_plot_hists=False) 

print('Done')

%%time

bin_widths_test, all_cluster_params_test, kroupa_diff_test, counter = IMF_slopes(new_log_edges, new_log_counts, 
                                                                         new_masses, model='Baraffe', intervals='Kroupa', 
                                                                         plot=True)


print(all_cluster_params_test)
print(kroupa_diff_test)