In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from IPython.display import display, HTML, Markdown, Math

# Define the quadratic model
def quadratic_model(T, a, b, c):
    return a + b*T + c*T**2

# Function to calculate and return goodness-of-fit parameters
def calculate_fit_parameters(T, Diff):
    # Perform curve fitting
    popt, _ = curve_fit(quadratic_model, T, Diff)
    
    # Calculate the residuals
    residuals = Diff - quadratic_model(T, *popt)

    # Calculate the total sum of squares (TSS)
    ss_tot = np.sum((Diff - np.mean(Diff)) ** 2)

    # Calculate the residual sum of squares (RSS)
    ss_res = np.sum(residuals ** 2)

    # Calculate R-squared
    r_squared = 1 - (ss_res / ss_tot)

    # Calculate the reduced chi-squared
    # Degrees of freedom: number of observations - number of parameters
    degrees_of_freedom = len(T) - len(popt)
    reduced_chi_squared = ss_res / degrees_of_freedom

    # Data to be written
    euro_data = {
        'Optimized parameters [a, b, c]': popt,
        "R-squared": r_squared,
        "Reduced chi-squared": reduced_chi_squared
    }

    # Display the formatted output in the Jupyter Notebook cell
    output = "# Goodness-of-fit parameters\n"
    for key, value in euro_data.items():
        output += f'\n{key}: {value}\n'

    return output, popt

# Function to format coefficients in scientific notation for LaTeX
def format_coefficient(value):
    sci = "{:.5e}".format(value).split('e')
    base = sci[0]
    exponent = int(sci[1])
    sign = "-" if base.startswith('-') else "+"
    base = base.lstrip('-')
    return f"{sign} {base} \\times 10^{{{exponent}}}"

# Function to display fitting parameters and LaTeX equation
def display_fitting_parameters(material_name, output, popt):
    # Print the output
    display(HTML("<hr>"))
    text = f'**Fitting parameters for {material_name}** \n'
    display(Markdown(text))
    print(output)
    a, b, c = popt

    a_formatted = format_coefficient(a)
    b_formatted = format_coefficient(b)
    c_formatted = format_coefficient(c)

    # Combine the formatted coefficients into the equation
    # Ensure the first term doesn't have a leading sign
    equation = r'\alpha = {} {}T {}T^2'.format(a_formatted.lstrip('+'), b_formatted, c_formatted)

    # Display the equation using LaTeX formatting with the formatted coefficients
    print(f'The equation for {material_name} diffusivity is:,\n')
    display(Math(equation))
    print('\n\n')

def custom_plot(x_data, y_data, x_fit, y_fit, 
                x_label='X-axis', y_label='Y-axis', 
                title='Plot', 
                scale='linear', font_size=16, 
                xlim=None, ylim=None, 
                grid=True, legend=True, 
                data_label='Data', fit_label='Fit',
                data_color='blue', fit_color='red', 
                data_marker_size=12, fit_line_width=3, x_label_font_size=16, y_label_font_size=16,
                title_font_size=16,):
   
    plt.figure(figsize=(10, 6))
    
    if scale == 'linear':
        plt.xscale('linear')
        plt.yscale('linear')
    elif scale == 'log-log':
        plt.xscale('log')
        plt.yscale('log')
    elif scale == 'log-x':
        plt.xscale('log')
        plt.yscale('linear')
    elif scale == 'log-y':
        plt.xscale('linear')
        plt.yscale('log')
    else:
        raise ValueError("Scale must be 'linear', 'log-log', 'log-x', or 'log-y'")
    
    plt.plot(x_data, y_data, 'o', label=data_label, color=data_color, markersize=data_marker_size)
    plt.plot(x_fit, y_fit, '-', label=fit_label, color=fit_color, linewidth=fit_line_width)
    
    plt.xlabel(x_label, fontsize=font_size)
    plt.ylabel(y_label, fontsize=font_size)
    plt.title(title, fontsize=font_size)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)


    if xlim:
        plt.xlim(xlim)
    if ylim:
        plt.ylim(ylim)
    
    if grid:
        plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    else:
        plt.grid(False)
    
    if legend:
        plt.legend()

    plt.show()




