# RCFM Model

## 1. Import modules and helper functions

In [None]:
#THIS VERSION FROM MARCUS PAZ 10.2.23
# Modules
import matplotlib.pyplot as plt
import numpy as np
from math import sqrt
from scipy.optimize import curve_fit
import itertools

# Helper functions from DataAid.py and DataImport.py
import DataAid
import DataImporter

# Numerically stable class of functions from Neros_v2.py
import Neros

## 2. Load Galaxy Data

In [None]:
# Load Galaxy Data
sparcGalaxies = DataAid.GetGalaxyData("data/Sparc/Rotmod_LTG/")
sparc128Galaxies = DataAid.GetGalaxyData("data/Sparc/SparcSubset135/")
sparcTset = DataAid.GetGalaxyData("data/Sparc/TrainingSet/")
littleDataGalaxies = DataAid.GetGalaxyData("data/little-data-things/data/")
lcmGalaxies = DataAid.GetGalaxyData("data/LCMFits/data/")

# Load Milky Way Model Data
xueSofueGalaxies = DataAid.GetGalaxyData("data/XueSofue/")
mcGaughMW = DataAid.GetGalaxyData("data/McGaugh/")

# Create array of Milky Way radius and vlum tuples from model data
MWXueSofue = np.array(xueSofueGalaxies['MW_lum'])
MWMcGaugh = np.array(mcGaughMW['MW_lumGAIAMcGaugh'])
MWMcGaugh_small_r = np.array(mcGaughMW['MW_lumMcGaughGAIA_small_r'])


## 3. Create Neros class instance

In [None]:
# Create Neros instance to perform calculations with the supplied Milky Way model as comparison
# Change Milky Way model by changing the variable in the parentheses
# i.e. neros_fns = Neros_v2.Neros(MWModelVariable)

neros_fns = Neros.Neros(MWXueSofue) # Change this and the next line to change MW models
MW_name = "XueSofueGaia" # Change this if you change the MW model in neros_fns!
#neros_fns = Neros.Neros(MWMcGaugh_small_r) # Change this and the next line to change MW models
#MW_name = "McGaugh" # Change this if you change the MW model in neros_fns!
MW_rad = neros_fns.mw_rad
MW_vLum = neros_fns.mw_vLum
#MW_phi = neros_fns.mw_phi

MW_vLum_interp_func = neros_fns.mw_vLum_interp

## 4. Designate outfile name, initialize variables, and designate galaxy sample

In [None]:
# This designates which galaxy sample to fit
galaxies = sparcGalaxies


## 5. Fit galaxies, print and save graphs

In [None]:
#THIS VERSION FROM MARCUS PAZ 10.2.23
"""
This is the main body of the model. 
It loops through galaxies in the designated sample,
it calculates vLCM, vNeros, and chiSquared,
it fits galaxy data using alpha and vLumFreeParam,
it prints the values of alpha and vLumFreeParam to the console,
and it saves the graphs to a file.
"""

# This is where the fitted alpha & vlum free parameter will be written for each galaxy
out_file = "results/results_" + str(MW_name) + ".csv"
columns = ["Galaxy", "chi_square",  "alpha", "disk_scale", "bulge_scale","phi_zero"]
with open(out_file, 'w') as f:
    f.write(','.join(columns))
    f.write('\n')

total_chi_squared = []

failures = []
bad_chi2 = []

make_graphs = False

for galaxyName in galaxies:

    # Extract out the needed galaxy components
    galaxy = np.array(galaxies[galaxyName])
    galaxy_rad = galaxy[:,0]
    galaxy_vObs = galaxy[:,1]
    galaxy_error = galaxy[:,2]
    galaxy_gas = galaxy[:,3]
    galaxy_disk = galaxy[:,4]
    galaxy_bulge = galaxy[:,5]

    # Just fit then extract the relevant pieces
    try:
        print(f"galaxyName is: {galaxyName}")
        
        # Try different inital alphas, find best fit
        # building this to allow flexibility later
        alphas = [0.001, 0.01, 0.1, 1, 10, 100]
        disk_scales = [1.0]
        bulge_scales = [1.0]

        neros_fns.grid_fit(galaxy_rad, galaxy_gas, galaxy_disk, galaxy_bulge, galaxy_vObs, galaxy_error, 
                           alpha=alphas, disk_scale=disk_scales, bulge_scale=bulge_scales)
        fit_results = neros_fns.get_fit_results(galaxy_rad)
    except Exception as e:
        print("--------------------------")        
        print(f'ERROR! Fit for {galaxyName} failed with error {e}')
        failures.append(galaxyName)
        fit_results = {col: None for col in ['chi_squared', 'alpha', 'disk_scale', 'bulge_scale', 'phi_zero']}
        print("--------------------------")
        
    for param in fit_results:
        print(f"{param} is: {fit_results[param]}")
    print("--------------------------")

    if fit_results['chi_squared'] is not None:
        total_chi_squared.append(fit_results['chi_squared'])
        if fit_results['chi_squared'] > 10:
            bad_chi2.append(galaxyName)

    print("Running average chi_squared: {}".format(sum(total_chi_squared)/len(total_chi_squared)))
    print("--------------------------\n")

    # Write galaxyName, fitted parameters
    with open(out_file, 'a') as f:
        f.write(f"{galaxyName},{fit_results['chi_squared']},{fit_results['alpha']},{fit_results['disk_scale']}," + 
                f"{fit_results['bulge_scale']},{fit_results['phi_zero']}\n")
    



    # plot
    if make_graphs and fit_results['chi_squared'] is not None:
        trimmed_rad = neros_fns.get_rad()
        trimmed_vLum_updated = neros_fns.get_vLum_scaled()
        trimmed_vObs = neros_fns.get_vObs()
        trimmed_error = neros_fns.get_vObsError()
        vNeros = neros_fns.get_vNeros()
        # y-axis scales to the maximum velocity value in the list galaxy_vObs_err_incl
        # or to the maximum value in the list vNeros, whichever is the bigger number
        y_max = max(max(trimmed_vObs + trimmed_error), max(vNeros))

        SMALL_SIZE = 10
        MED_SIZE = 24
        LG_SIZE = 30
    
        f, ax = plt.subplots(1, figsize = (15, 15))
        ax.set_ylim(bottom = 0, top = y_max + 15)
        plt.xlabel("radius (kpc)", fontsize = LG_SIZE)
        plt.ylabel("velocity (km/sec)", fontsize = LG_SIZE)
        # We change the fontsize of minor ticks label 
        ax.tick_params(axis='both', which='major', labelsize=MED_SIZE)
        ax.tick_params(axis='both', which='minor', labelsize=MED_SIZE)

        # plot vObs and vNeros and updated vLum
        ax.plot(trimmed_rad, vNeros, label="{}_vNeros".format(galaxyName), color = "red", linewidth=3)
        #ax.plot(trimmed_rad, trimmed_vObs, label="{}_vObs".format(galaxyName))
        ax.plot(trimmed_rad, trimmed_vLum_updated, label="{}_new_vLum".format(galaxyName), color = "purple", linewidth=3, linestyle="dashed")
        ax.plot([],[], ' ', label="$\chi^2$ = {}".format(fit_results['chi_squared']))

        # error bar in vObs
        for i in range(len(trimmed_rad)):
            ax.vlines(trimmed_rad[i], trimmed_vObs[i] - trimmed_error[i], trimmed_vObs[i] + trimmed_error[i], linewidth=2)

        #ax.legend(loc="upper right", fontsize = LG_SIZE)
        graph_file_name = "graphs/" + str(galaxyName) + "_" + str(MW_name)
        plt.savefig(graph_file_name)
        plt.close()


print("*" * 30)
print("Summary:")
print("Galaxies where fit failed entirely:")
print(*['    ' + gal for gal in failures], sep='\n')
print('Galaxies with chi2 greater than 10 (arbitrary cutoff for "bad")')
print(*['    ' + gal for gal in bad_chi2], sep='\n')