In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import t
# from Utilities_models import fit_poly, display_fitting, polynomial_model,fit_poly_with_confidence
from Utilities_models import  polynomial_model

def custom_plot(x_data, y_data, x_fit, y_fit, 
                x_label='X-axis', y_label='Y-axis', 
                title='Plot', 
                scale='linear', font_size=16, 
                xlim=None, ylim=None, 
                grid=True, legend=True, 
                data_label='Data', fit_label='Fit',
                data_color='blue', fit_color='red', 
                data_marker_size=12, fit_line_width=3, x_label_font_size=16, y_label_font_size=16,
                title_font_size=16):
   
    plt.figure(figsize=(10, 6))
    
    if scale == 'linear':
        plt.xscale('linear')
        plt.yscale('linear')
    elif scale == 'log-log':
        plt.xscale('log')
        plt.yscale('log')
    elif scale == 'log-x':
        plt.xscale('log')
        plt.yscale('linear')
    elif scale == 'log-y':
        plt.xscale('linear')
        plt.yscale('log')
    else:
        raise ValueError("Scale must be 'linear', 'log-log', 'log-x', or 'log-y'")
    
    plt.plot(x_data, y_data, 'o', label=data_label, color=data_color, markersize=data_marker_size)
    plt.plot(x_fit, y_fit, '-', label=fit_label, color=fit_color, linewidth=fit_line_width)
    
    plt.xlabel(x_label, fontsize=font_size)
    plt.ylabel(y_label, fontsize=font_size)
    plt.title(title, fontsize=font_size)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)


    if xlim:
        plt.xlim(xlim)
    if ylim:
        plt.ylim(ylim)
    
    if grid:
        plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    else:
        plt.grid(False)
    
    if legend:
        plt.legend()

    plt.show()

def custom_multi_plot(x_data_list, y_data_list, x_fit_list=None, y_fit_list=None,
                      x_label='X-axis', y_label='Y-axis', title='Plot',
                      scale='linear', font_size=16, xlim=None, ylim=None, 
                      grid=True, legend=True, data_labels=None, fit_labels=None,
                      data_colors=None, fit_colors=None, data_marker_sizes=None, 
                      fit_line_widths=None, x_label_font_size=16, y_label_font_size=16, 
                      title_font_size=16):

    plt.figure(figsize=(10, 6))

    # Set default values if not provided
    if data_labels is None:
        data_labels = [f'Data {i+1}' for i in range(len(x_data_list))]
    if fit_labels is None and x_fit_list is not None and y_fit_list is not None:
        fit_labels = [f'Fit {i+1}' for i in range(len(x_fit_list))]
    if data_colors is None:
        data_colors = ['blue'] * len(x_data_list)
    if fit_colors is None and x_fit_list is not None and y_fit_list is not None:
        fit_colors = ['red'] * len(x_fit_list)
    if data_marker_sizes is None:
        data_marker_sizes = [20] * len(x_data_list)
    if fit_line_widths is None and x_fit_list is not None and y_fit_list is not None:
        fit_line_widths = [3] * len(x_fit_list)

    # Ensure the lists have the same length
    if len(x_data_list) != len(y_data_list):
        raise ValueError("x_data_list and y_data_list must have the same length.")
    if x_fit_list is not None and y_fit_list is not None and len(x_fit_list) != len(y_fit_list):
        raise ValueError("x_fit_list and y_fit_list must have the same length.")

    # Plot data points
    for i, (x_data, y_data) in enumerate(zip(x_data_list, y_data_list)):
        plt.scatter(x_data, y_data, label=data_labels[i], color=data_colors[i], s=data_marker_sizes[i])

    # Plot fit lines if provided
    if x_fit_list is not None and y_fit_list is not None:
        for i, (x_fit, y_fit) in enumerate(zip(x_fit_list, y_fit_list)):
            plt.plot(x_fit, y_fit, label=fit_labels[i], color=fit_colors[i], linewidth=fit_line_widths[i])

    plt.xlabel(x_label, fontsize=x_label_font_size)
    plt.ylabel(y_label, fontsize=y_label_font_size)
    plt.title(title, fontsize=title_font_size)
    plt.xscale(scale)
    plt.yscale(scale)
    if xlim:
        plt.xlim(xlim)
    if ylim:
        plt.ylim(ylim)
    if grid:
        plt.grid(True)
    if legend:
        plt.legend()
    plt.show()


def plot_fit_with_confidence(x, y, fit_data, 
                x_label='X-axis', y_label='Y-axis', 
                title='Plot', 
                scale='linear', font_size=16, 
                xlim=None, ylim=None, 
                grid=True, legend=True, 
                data_label='Data', fit_label='Fit',
                data_color='blue', fit_color='red', 
                data_marker_size=12, fit_line_width=3, x_label_font_size=16, y_label_font_size=16,
                title_font_size=16):
    
    plt.figure(figsize=(10, 6))
    
    plt.plot(x, y, 'o', label=data_label, color=data_color, markersize=data_marker_size)
    # if fit_data['Mean fit'] is not None:
    #     plt.plot(x, fit_data['Mean fit'], label='Mean Fit', color='red', linewidth=3)
    #     plt.fill_between(x, fit_data['Lower 95% prediction'], fit_data['Upper 95% prediction'], 
    #     color='gray', alpha=0.2, label='95% Prediction Interval')
    
    if fit_data['Mean fit'] is not None:
        # Determine x range for plotting the fit
        if xlim:
            x_fit = np.linspace(xlim[0], xlim[1], 100)
        else:
            x_fit = np.linspace(np.min(x), np.max(x), 100)
        
        # Calculate y values from the fit for the x_fit range
        y_fit = polynomial_model(x_fit, *fit_data['Optimized parameters'])
        
        # Calculate the prediction intervals
        residuals = y - polynomial_model(x, *fit_data['Optimized parameters'])
        mean_x = np.mean(x)
        n = len(x)
        dof = max(0, n - len(fit_data['Optimized parameters']))
        t_val = t.ppf(0.975, dof)
        sum_squared_errors = np.sum(residuals**2)
        conf_interval = t_val * np.sqrt(sum_squared_errors / dof * (1.0/n + (x_fit - mean_x)**2 / np.sum((x - mean_x)**2)))
        
        lower_pred = y_fit - conf_interval
        upper_pred = y_fit + conf_interval
        
        # Plot the mean fit line
        plt.plot(x_fit, y_fit, label=fit_label, color=fit_color, linewidth=fit_line_width)
        
        # Plot the confidence interval
        plt.fill_between(x_fit, lower_pred, upper_pred, color='gray', alpha=0.2, label='95% Prediction Interval')
    
        
    if scale == 'linear':
        plt.xscale('linear')
        plt.yscale('linear')
    elif scale == 'log-log':
        plt.xscale('log')
        plt.yscale('log')
    elif scale == 'log-x':
        plt.xscale('log')
        plt.yscale('linear')
    elif scale == 'log-y':
        plt.xscale('linear')
        plt.yscale('log')
    else:
        raise ValueError("Scale must be 'linear', 'log-log', 'log-x', or 'log-y'")
   
    plt.xlabel(x_label, fontsize=font_size)
    plt.ylabel(y_label, fontsize=font_size)
    plt.title(title, fontsize=font_size)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)


    if xlim:
        plt.xlim(xlim)
    if ylim:
        plt.ylim(ylim)
    
    if grid:
        plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    else:
        plt.grid(False)
    
    if legend:
        plt.legend()

    plt.show()


