In [None]:
## Package imports and utility functions written by Indro:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from numpy.polynomial import Polynomial
import statsmodels.formula.api as smf
from sympy import Symbol, Poly, N
from IPython.display import display, Markdown, HTML

# Load the data from the Excel file
def load_data(filename, sheet_name): # , excel_col_names, variable_names_for_cols):
    global temp_data
    global var_names
    
    # Column labels under the material name header 
    col_names = list(pd.read_excel(filename, sheet_name=sheet_name, header=1, nrows=1))
    # Variables names that mirror the column labels, only removing unit information after single space delimiter
    var_names = [col.split(" ")[0] for col in col_names]
    # Temporarily store the columnar Excel data
    temp_data = pd.read_excel(filename, sheet_name=sheet_name, header=1, names=col_names)
    exec(", ".join(var_names) + " = [temp_data[col].dropna().to_numpy() for col in temp_data]", globals())

    
# Calculate confidence and prediction intervals
# TODO: upgrade this to support non-cubic fittings! Just need to allow for variety of R style formulas.
def get_conf_and_pred_intervals(x, y, R_style_formula='Y ~ I(X**3) + I(X**2) + X + 1'):
    fit_data = {'X': x, 'Y': y}
    fit_df = pd.DataFrame(fit_data)
    
    # Ordinary Least Squares
    model = smf.ols(formula=R_style_formula, data=fit_df)
    results = model.fit()
    
    # Calculating prediction
    predictions = results.get_prediction(fit_df)
    prediction_summary_frame = predictions.summary_frame(alpha=0.05)
    
    return fit_df, results, prediction_summary_frame

# TODO: upgrade this to support non-cubic fittings! Just need to allow for variety of Sympy expressions.
def plot_fit_and_print_summary(x, y, legend=True, legend_font_size=16, legend_loc='upper right', legend_num_cols=2,\
                               fit_line_color='black', pred_int_fill_color='grey', conf_int_fill_color='blue',\
                               material_name=None, property_name=None, eq_digits=6):
    
    fit_df, results, prediction_summary_frame = get_conf_and_pred_intervals(x, y)
    # Regression line
    plt.plot(fit_df['X'], results.fittedvalues, color='black', label='Data Fit')
    # Prediction interval
    plt.fill_between(fit_df['X'], prediction_summary_frame['obs_ci_lower'], prediction_summary_frame['obs_ci_upper'], color=pred_int_fill_color, alpha=0.2, label='95% Prediction Interval')
    # Confidence interval
    plt.fill_between(fit_df['X'], prediction_summary_frame['mean_ci_lower'], prediction_summary_frame['mean_ci_upper'], color=conf_int_fill_color, alpha=0.2, label='95% Confidence Interval')

    if legend:
        plt.legend(loc=legend_loc, fontsize=legend_font_size, ncol=legend_num_cols)
        
    display(HTML("<hr>"))
    display(Markdown(f'**Fitting parameters for {material_name} {property_name}** \n'))
    print(results.summary())
    display(HTML("<hr>"))
    display(Markdown(f'**The equation for {material_name} {property_name} is:**\n'))
    display(N(Poly(list(np.concatenate((results.params.iloc[1:],\
                                        np.array([results.params.iloc[0]])))),\
                   Symbol("T")).as_expr(), eq_digits))

# Concatenate and sort lists of x and y data
def concatenate_and_sort(x_list, y_list):

    x_concat = np.concatenate(x_list)
    y_concat = np.concatenate(y_list)

    x_sorted_indices = x_concat.argsort()
    x_sorted = x_concat[x_sorted_indices[::-1]]
    y_sorted = y_concat[x_sorted_indices[::-1]]

    return x_sorted, y_sorted

# Slightly modified version of Utilities_plots.ipynb from Professor Ghoniem
def custom_multi_plot(x_data_list, y_data_list, x_fit_list=None, y_fit_list=None,
                      x_label='X-axis', y_label='Y-axis', title='Plot',
                      scale='linear', font_size=16, xlim=None, ylim=None, 
                      grid=True, legend=True, data_labels=None, fit_labels=None,
                      data_colors=None, fit_colors=None, data_marker_sizes=None, 
                      fit_line_widths=None, x_label_font_size=16, y_label_font_size=16, 
                      title_font_size=16, legend_font_size=16, legend_loc='upper right', legend_num_cols=2):

    plt.figure(figsize=(6, 4), dpi=300)

    # Set default values if not provided
    if data_labels is None:
        data_labels = [f'Data {i+1}' for i in range(len(x_data_list))]
    if fit_labels is None and x_fit_list is not None and y_fit_list is not None:
        fit_labels = [f'Fit {i+1}' for i in range(len(x_fit_list))]
    if data_colors is None:
        data_colors = ['blue'] * len(x_data_list)
    if fit_colors is None and x_fit_list is not None and y_fit_list is not None:
        fit_colors = ['red'] * len(x_fit_list)
    if data_marker_sizes is None:
        data_marker_sizes = [20] * len(x_data_list)
    if fit_line_widths is None and x_fit_list is not None and y_fit_list is not None:
        fit_line_widths = [3] * len(x_fit_list)

    # Ensure the lists have the same length
    if len(x_data_list) != len(y_data_list):
        raise ValueError("x_data_list and y_data_list must have the same length.")
    if x_fit_list is not None and y_fit_list is not None and len(x_fit_list) != len(y_fit_list):
        raise ValueError("x_fit_list and y_fit_list must have the same length.")

    # Plot data points
    for i, (x_data, y_data) in enumerate(zip(x_data_list, y_data_list)):
        plt.scatter(x_data, y_data, label=data_labels[i], color=data_colors[i], s=data_marker_sizes[i])

    # Plot fit lines if provided
    if x_fit_list is not None and y_fit_list is not None:
        for i, (x_fit, y_fit) in enumerate(zip(x_fit_list, y_fit_list)):
            plt.plot(x_fit, y_fit, label=fit_labels[i], color=fit_colors[i], linewidth=fit_line_widths[i])

    plt.xlabel(x_label, fontsize=x_label_font_size)
    plt.ylabel(y_label, fontsize=y_label_font_size)
    plt.title(title, fontsize=title_font_size)
    plt.xscale(scale)
    plt.yscale(scale)
    if xlim:
        plt.xlim(xlim)
    if ylim:
        plt.ylim(ylim)
    if grid:
        plt.grid(True)
    if legend:
        plt.legend(loc=legend_loc, fontsize=legend_font_size, ncol=legend_num_cols)