# MPPT Analysis

In [None]:
%matplotlib ipympl
%load_ext autoreload
%autoreload 2
import ipywidgets as widgets
from IPython.display import display, Markdown, HTML, Latex, clear_output
import os
import sys
import pandas as pd
import numpy as np
import itertools
import plotly.io as pio
import plotly.graph_objects as go
import plotly.express as px
sys.path.append(os.path.dirname(os.getcwd()))
from api_calls import get_ids_in_batch, get_sample_description, get_all_mppt, get_batch_ids
import batch_selection
import plotting_utils
import access_token
access_token.log_notebook_usage()
url_base ="http://elnserver.lti.kit.edu"
url = f"{url_base}/nomad-oasis/api/v1"
token = access_token.get_token(url)

# Global data storage
data = {}

In [None]:


from fitting_tools import *

# Create available fit models list using your existing classes
available_fit_model_list = [
    fit_model("linear", linear_params, "linear", 
              ["a", "b", "$R^2$","t80","LE"], 
              description="$P(t)=at+b$"), 
    fit_model("exponential", exponential_params, "exponential", 
              ["$A$","$\\tau$", "$R^2$","t80","LE"], 
              description="$P(t)=A\\exp{(-\\frac{t}{\\tau})}$"), 
    fit_model("biexponential", biexponential_params, "biexponential", 
              ["$A_1$","$\\tau_1$","$A_2$","$\\tau_2$", "$R^2$","tS","Ts80","LE"], 
              description="$P(t)=A_1\\exp{(-\\frac{t}{\\tau_1})}+A_2\\exp{(-\\frac{t}{\\tau_2})}$"), 
    fit_model("logexp", logistic_params, "logistic + exp", 
              ["A", "$\\tau$", "L", "k", "t0", "$R^2$", "tS", "Ts80", "LE"], 
              description="$P(t)=A\\exp{(-\\frac{t}{\\tau})}+\\frac{L}{1+\\exp{(-k(t-t_0)})}$"), 
    fit_model("stretched_exponential", stretched_exponential_params, "stretched exp", 
              ["A", "$\\tau$", "$\\beta$", "$R^2$", "T80", "LE"], 
              description="$P(t)=A\\exp{(-(\\frac{t}{\\tau})^{\\beta})}$"), 
    fit_model("errorfunctionXlinear", erfc_params, "errorfunction x linear", 
              ["$P_0$", "k", "$t_0$", "b", "$R^2$", "tS", "Ts80", "LE"], 
              description="$P(t)=0.5(1-\\text{erf}(\\frac{t-t_0}{b}))(P_0-kt)$")
]

def get_mppt_data_working(url, token, try_sample_ids):
    """Take list of sample ids and return mppt data as data frames"""
    all_mppt = get_all_mppt(url, token, try_sample_ids)  
    existing_sample_ids = pd.Series(all_mppt.keys())  

    if len(existing_sample_ids) == 0:  
        return None, None, None

    mppt_curves_list = []  
    description_list = []  
    for sample_data in all_mppt:  
        entry_names_list = []  
        entry_description_list = []  
        sample_curves_list = []  
        for mppt_entry in all_mppt.get(sample_data):  
            sample_curves_list.append(pd.DataFrame(mppt_entry[0], columns=["time", "power_density", "voltage", "current_density"]))  
            entry_names_list.append(mppt_entry[0]["name"])  
            entry_description_list.append(mppt_entry[0].get("description",""))  

        if sample_curves_list:  
            mppt_curves_list.append(pd.concat(sample_curves_list, keys=np.arange(len(sample_curves_list))))  
            description_list.append(pd.DataFrame({"entry_names":entry_names_list, "entry_description":entry_description_list}))  

    if mppt_curves_list and description_list:  
        return pd.concat(mppt_curves_list, keys=existing_sample_ids), existing_sample_ids, pd.concat(description_list, keys=existing_sample_ids)  
    else:  
        return None, None, None

def fit_all_samples_lmfit(curves_data, sample_ids, selected_samples, model, time_range=None):
    """Fit all selected samples using existing lmfit-based fitting tools"""
    import warnings
    
    # Suppress the specific uncertainties warning
    warnings.filterwarnings("ignore", message="Using UFloat objects with std_dev==0 may give unexpected results.")
    
    available_samples = list(sample_ids) if hasattr(sample_ids, '__iter__') else sample_ids
    results = []
    fitted_curves_data = {}  # Store fitted curve data separately
    
    for sample_id in selected_samples:
        if sample_id in available_samples:
            try:
                sample_data = curves_data.loc[sample_id]
                
                if hasattr(sample_data.index, 'nlevels') and sample_data.index.nlevels > 1:
                    # Multiple curves per sample
                    for curve_idx in sample_data.index.get_level_values(0).unique():
                        curve_data = sample_data.loc[curve_idx]
                        t_data = curve_data['time'].values
                        y_data = curve_data['power_density'].values
                        
                        if time_range is not None:
                            t_min, t_max = time_range
                            mask = (t_data >= t_min) & (t_data <= t_max)
                            t_data = t_data[mask]
                            y_data = y_data[mask]
                        
                        valid_mask = ~(np.isnan(t_data) | np.isnan(y_data))
                        t_data = t_data[valid_mask]
                        y_data = y_data[valid_mask]
                        
                        if len(t_data) < 3:
                            continue
                        
                        try:
                            with warnings.catch_warnings():
                                warnings.simplefilter("ignore")
                                fit_params, fitted_curve = model.parfunc(y_data, t_data)
                            
                            # Store fitted curve data
                            fitted_curves_data[(sample_id, curve_idx)] = {
                                'time': t_data,
                                'fitted_power': fitted_curve,
                                'original_power': y_data
                            }
                            
                            result = {'sample_id': sample_id, 'curve_id': curve_idx}
                            
                            for i, (param_name, param_value) in enumerate(zip(model.columns, fit_params)):
                                if hasattr(param_value, 'nominal_value'):
                                    result[param_name] = param_value.nominal_value
                                    result[f"{param_name}_error"] = param_value.std_dev
                                else:
                                    result[param_name] = param_value
                            
                            results.append(result)
                        except:
                            continue
                else:
                    # Single curve per sample
                    t_data = sample_data['time'].values
                    y_data = sample_data['power_density'].values
                    
                    if time_range is not None:
                        t_min, t_max = time_range
                        mask = (t_data >= t_min) & (t_data <= t_max)
                        t_data = t_data[mask]
                        y_data = y_data[mask]
                    
                    valid_mask = ~(np.isnan(t_data) | np.isnan(y_data))
                    t_data = t_data[valid_mask]
                    y_data = y_data[valid_mask]
                    
                    if len(t_data) < 3:
                        continue
                    
                    try:
                        with warnings.catch_warnings():
                            warnings.simplefilter("ignore")
                            fit_params, fitted_curve = model.parfunc(y_data, t_data)
                        
                        # Store fitted curve data
                        fitted_curves_data[(sample_id, 0)] = {
                            'time': t_data,
                            'fitted_power': fitted_curve,
                            'original_power': y_data
                        }
                        
                        result = {'sample_id': sample_id, 'curve_id': 0}
                        
                        for i, (param_name, param_value) in enumerate(zip(model.columns, fit_params)):
                            if hasattr(param_value, 'nominal_value'):
                                result[param_name] = param_value.nominal_value
                                result[f"{param_name}_error"] = param_value.std_dev
                            else:
                                result[param_name] = param_value
                        
                        results.append(result)
                    except:
                        continue
            except:
                continue
    
    results_df = pd.DataFrame(results) if results else pd.DataFrame()
    
    # Return both the results DataFrame and the fitted curves data
    return results_df, fitted_curves_data

# Main application class
class MPPTAnalysisApp:
    def __init__(self):
        self.tab_widget = None
        self.batch_tab = None
        self.sample_tab = None
        self.fitting_tab = None
        self.plotting_tab = None
        self.data = {}
        self.sample_selectors = {}
        self.fit_results = None
        self.fitted_curves_data = {}  # Store fitted curves separately
        global url, token
        self.url = url
        self.token = token
        self.setup_ui()
    
    def setup_ui(self):
        """Initialize the main UI with tabs"""
        self.tab_widget = widgets.Tab()
        
        self.batch_tab = self.create_batch_tab()
        self.sample_tab = self.create_sample_tab_disabled()
        self.fitting_tab = self.create_fitting_tab_disabled()
        self.plotting_tab = self.create_plotting_tab_disabled()
        self.download_tab = self.create_download_tab_disabled()  # Add this line
        
        self.tab_widget.children = [self.batch_tab, self.sample_tab, self.fitting_tab, self.plotting_tab, self.download_tab]  # Add download_tab
        self.tab_widget.titles = ('Batch Selection', 'Sample Selection', 'Curve Fitting', 'Plotting', 'Download Results')  # Add 'Download Results'
        
        for i in [1, 2, 3, 4]:  # Change from [1, 2, 3] to [1, 2, 3, 4]
            self.tab_widget.set_title(i, f"🔒 {self.tab_widget.get_title(i)}")
        
        display(plotting_utils.create_manual("mppt_manual.md"))
        display(self.tab_widget)
    
    def create_batch_tab(self):
        """Create the batch selection tab"""
        filter_status = widgets.Output(layout={'border': '1px solid #ccc', 'padding': '10px'})
        
        batch_ids_list_tmp = list(get_batch_ids(url, token))
        batch_ids_list = []
        for b in batch_ids_list_tmp:
            if "_".join(b.split("_")[:-1]) in batch_ids_list_tmp:
                continue
            batch_ids_list.append(b)
        
        batch_selector = widgets.SelectMultiple(
            options=batch_ids_list,
            description='Batches',
            layout=widgets.Layout(width='400px', height='300px')
        )
        
        filter_button = widgets.Button(
            description="Filter MPPT Batches",
            button_style='info',
            tooltip="Show only batches containing MPPT measurements",
            layout=widgets.Layout(width='200px')
        )
        
        search_field = widgets.Text(description="Search Batch")
        
        load_data_button = widgets.Button(
            description="Load Data",
            button_style='primary',
            layout=widgets.Layout(width='150px')
        )
        
        load_status = widgets.Output(layout={'border': '1px solid #ccc', 'padding': '10px'})
        
        def on_search_enter(change):
            filtered_options = [
                d for d in batch_ids_list
                if search_field.value.strip().lower() in d.lower()
            ]
            batch_selector.options = filtered_options
        
        def start_filtering(b):
            filter_button.disabled = True
            filter_button.description = "🔄 Filtering in progress..."
            
            with filter_status:
                filter_status.clear_output(wait=True)
                print("Finding batches with MPPT data...")
                
                all_batch_ids = batch_ids_list.copy()
                total_batches = len(all_batch_ids)
                print(f"Testing {total_batches} batches...")
                valid_batches = []
                
                for i, batch_id in enumerate(all_batch_ids):
                    if i % 10 == 0 or i == len(all_batch_ids) - 1:
                        filter_status.clear_output(wait=True)
                        print(f"Progress: {i+1}/{total_batches} - Found {len(valid_batches)} valid batches")
                        print(f"Currently testing: {batch_id}")
                    
                    try:
                        sample_ids = get_ids_in_batch(url, token, [batch_id])
                        if sample_ids:
                            mppt_data = get_all_mppt(url, token, sample_ids)
                            if mppt_data and len(mppt_data) > 0:
                                valid_batches.append(batch_id)
                                filter_status.clear_output(wait=True)
                                print(f"✅ Found valid batch: {batch_id} ({len(mppt_data)} samples)")
                                print(f"Total found so far: {len(valid_batches)}")
                    except:
                        continue
                
                batch_selector.options = valid_batches
                
                filter_status.clear_output(wait=True)
                print("=" * 60)
                print("FILTERING COMPLETE")
                print("=" * 60)
                print(f"✅ Found {len(valid_batches)} batches with MPPT data out of {total_batches} total")
                
                if len(valid_batches) > 0:
                    print(f"Valid batches: {valid_batches}")
                else:
                    print("⚠️  No batches with MPPT data found!")
                
                filter_button.description = f"✅ Found {len(valid_batches)} MPPT batches"
                filter_button.disabled = False
        
        def load_data_clicked(b):
            if not batch_selector.value:
                with load_status:
                    load_status.clear_output()
                    print("⚠️ Please select at least one batch")
                return
            
            with load_status:
                load_status.clear_output()
                print("🔄 Loading MPPT data...")
                
                try:
                    try_sample_ids = get_ids_in_batch(url, token, batch_selector.value)
                    mppt_result = get_mppt_data_working(self.url, self.token, try_sample_ids)
                    
                    if mppt_result is None or mppt_result[0] is None:
                        print("❌ The selected batches don't contain any MPPT measurements")
                        return
                    
                    curves, sample_ids, entries = mppt_result
                    
                    self.data["curves"] = curves
                    self.data["sample_ids"] = sample_ids  
                    self.data["entries"] = entries
                    
                    self.data["curves"].loc[:,"power_density"] *= -1
                    self.data["curves"].loc[:,"current_density"] *= -1
                    self.data["curves"].loc[:,"time"] *= 1/3600
                    
                    identifiers = get_sample_description(url, token, list(sample_ids))
                    self.data["properties"] = pd.DataFrame({
                        "description": pd.Series(identifiers),
                        "name": pd.Series()
                    })
                    
                    print(f"✅ Data loaded successfully! Found {len(sample_ids)} samples with MPPT data")
                    
                    self.enable_sample_tab()
                    
                except Exception as e:
                    print(f"❌ Error loading data: {str(e)}")
        
        search_field.observe(on_search_enter, names='value')
        filter_button.on_click(start_filtering)
        load_data_button.on_click(load_data_clicked)
        
        controls = widgets.VBox([
            widgets.HTML("<h3>Batch Selection</h3>"),
            filter_button,
            search_field,
            batch_selector,
            load_data_button,
            filter_status,
            load_status
        ])
        
        return controls
    
    def get_fitted_curve_data(self, sample_id, curve_id, variable):
        """Generate fitted curve data for a specific sample and curve"""
        if not self.fitted_curves_data:
            return None, None
        
        # Look for the fitted curve data
        curve_key = (sample_id, curve_id)
        if curve_key not in self.fitted_curves_data:
            return None, None
        
        fitted_data = self.fitted_curves_data[curve_key]
        
        if variable == 'power_density':
            return fitted_data['time'], fitted_data['fitted_power']
        else:
            # For other variables (voltage, current_density), we don't have fitted curves
            # so return None to indicate no fitted data available
            return None, None
    
    def create_plotting_tab(self):
        """Create the plotting tab with curve and histogram plots"""
        # Plot type selector
        plot_variable = widgets.Dropdown(
            options=[
                ('Power Density', 'power_density'),
                ('Voltage', 'voltage'), 
                ('Current Density', 'current_density')
            ],
            value='power_density',
            description='Variable:',
            layout=widgets.Layout(width='200px')
        )
        
        # Plot style selector
        plot_style = widgets.Dropdown(
            options=[
                ('Individual (each curve separate)', 'individual'),
                ('All together (one plot)', 'together'),
                ('By sample (grouped by sample)', 'by_sample'),
                ('By area (median + quartiles)', 'area_quartiles'),
                ('By area (mean + std dev)', 'area_std')
            ],
            value='individual',
            description='Plot style:',
            layout=widgets.Layout(width='300px')
        )
        
        # Show fitting lines checkbox
        show_fits_checkbox = widgets.Checkbox(
            value=False,
            description='Show fitting lines',
            tooltip='Overlay fitted curves from the selected model'
        )
        
        # Generate plots button
        plot_button = widgets.Button(
            description="Generate Plots",
            button_style='primary',
            layout=widgets.Layout(width='200px')
        )
        
        # Output areas
        curves_output = widgets.Output()
        histograms_output = widgets.Output()
        
        def generate_plots(b):
            if self.fit_results is None or len(self.fit_results) == 0:
                with curves_output:
                    curves_output.clear_output()
                    print("⚠️ No fitting results available. Please complete curve fitting first.")
                return
            
            with curves_output:
                curves_output.clear_output()
                print("🔄 Generating curve plots...")
                
                try:
                    # Generate curve plots based on selected options
                    self.plot_curves(plot_variable.value, plot_style.value, show_fits_checkbox.value)
                    print("✅ Curve plots generated successfully!")
                except Exception as e:
                    print(f"❌ Error generating curve plots: {str(e)}")
                    import traceback
                    traceback.print_exc()
            
            with histograms_output:
                histograms_output.clear_output()
                print("🔄 Generating histograms...")
                
                try:
                    # Generate t80 and ts histograms
                    self.plot_histograms()
                    print("✅ Histograms generated successfully!")
                except Exception as e:
                    print(f"❌ Error generating histograms: {str(e)}")
                    import traceback
                    traceback.print_exc()
        
        plot_button.on_click(generate_plots)
        
        controls = widgets.VBox([
            widgets.HTML("<h3>MPPT Curve Plotting</h3>"),
            widgets.HTML(f"<p>Plot analysis for {len(self.data.get('selected_samples', []))} selected samples with {len(self.fit_results) if self.fit_results is not None else 0} fitted curves.</p>"),
            widgets.HBox([plot_variable, plot_style]),
            show_fits_checkbox,
            plot_button,
            widgets.HTML("<h4>Curve Plots</h4>"),
            curves_output,
            widgets.HTML("<h4>Parameter Histograms</h4>"),
            histograms_output
        ])
        
        return controls
    
    def plot_curves(self, variable, style, show_fits=False):
        """Generate curve plots based on selected variable and style"""
        # Get the curve data for selected samples
        selected_data = []
        
        for sample_id in self.data["selected_samples"]:
            try:
                if sample_id in list(self.data["sample_ids"]):
                    sample_data = self.data["curves"].loc[sample_id]
                    
                    if hasattr(sample_data.index, 'nlevels') and sample_data.index.nlevels > 1:
                        # Multiple curves per sample
                        for curve_idx in sample_data.index.get_level_values(0).unique():
                            curve_data = sample_data.loc[curve_idx]
                            if variable in curve_data.columns:
                                selected_data.append({
                                    'sample_id': sample_id,
                                    'curve_id': curve_idx,
                                    'time': curve_data['time'].values,
                                    'data': curve_data[variable].values
                                })
                    else:
                        # Single curve per sample
                        if variable in sample_data.columns:
                            selected_data.append({
                                'sample_id': sample_id,
                                'curve_id': 0,
                                'time': sample_data['time'].values,
                                'data': sample_data[variable].values
                            })
            except Exception as e:
                continue
        
        if not selected_data:
            print("⚠️ No curve data found for selected samples")
            return
        
        # Generate plots based on style
        if style == 'individual':
            self.plot_individual_curves(selected_data, variable, show_fits)
        elif style == 'together':
            self.plot_all_together(selected_data, variable, show_fits)
        elif style == 'by_sample':
            self.plot_by_sample(selected_data, variable, show_fits)
        elif style == 'area_quartiles':
            self.plot_area_quartiles(selected_data, variable)
        elif style == 'area_std':
            self.plot_area_std(selected_data, variable)
    
    def plot_individual_curves(self, data, variable, show_fits=False):
        """Plot each curve individually"""
        for i, curve in enumerate(data):
            fig = go.Figure()
            
            # Add original data
            fig.add_trace(go.Scatter(
                x=curve['time'],
                y=curve['data'],
                mode='lines',
                name=f"Data",
                line=dict(width=2, color='blue')
            ))
            
            # Add fitted curve if requested
            if show_fits:
                fit_time, fit_data = self.get_fitted_curve_data(
                    curve['sample_id'], curve['curve_id'], variable
                )
                if fit_time is not None and fit_data is not None:
                    fig.add_trace(go.Scatter(
                        x=fit_time,
                        y=fit_data,
                        mode='lines',
                        name="Fit",
                        line=dict(width=2, color='red', dash='dash')
                    ))
            
            fig.update_layout(
                title=f"{variable.replace('_', ' ').title()} - {curve['sample_id']} Curve {curve['curve_id']}",
                xaxis_title="Time (hours)",
                yaxis_title=variable.replace('_', ' ').title(),
                width=800,
                height=500
            )
            
            display(fig)
    
    def plot_all_together(self, data, variable, show_fits=False):
        """Plot all curves together in one plot"""
        fig = go.Figure()
        
        # Add original data curves
        for curve in data:
            fig.add_trace(go.Scatter(
                x=curve['time'],
                y=curve['data'],
                mode='lines',
                name=f"{curve['sample_id']}_curve_{curve['curve_id']}",
                line=dict(width=1.5),
                opacity=0.7
            ))
            
            # Add fitted curves if requested
            if show_fits:
                fit_time, fit_data = self.get_fitted_curve_data(
                    curve['sample_id'], curve['curve_id'], variable
                )
                if fit_time is not None and fit_data is not None:
                    fig.add_trace(go.Scatter(
                        x=fit_time,
                        y=fit_data,
                        mode='lines',
                        name=f"Fit_{curve['sample_id']}_curve_{curve['curve_id']}",
                        line=dict(width=1.5, dash='dash'),
                        opacity=0.7
                    ))
        
        fig.update_layout(
            title=f"{variable.replace('_', ' ').title()} - All Curves",
            xaxis_title="Time (hours)",
            yaxis_title=variable.replace('_', ' ').title(),
            width=1000,
            height=600
        )
        
        display(fig)
    
    def plot_by_sample(self, data, variable, show_fits=False):
        """Plot curves grouped by sample"""
        samples = {}
        for curve in data:
            if curve['sample_id'] not in samples:
                samples[curve['sample_id']] = []
            samples[curve['sample_id']].append(curve)
        
        for sample_id, curves in samples.items():
            fig = go.Figure()
            
            # Add original curves
            for curve in curves:
                fig.add_trace(go.Scatter(
                    x=curve['time'],
                    y=curve['data'],
                    mode='lines',
                    name=f"Data Curve {curve['curve_id']}",
                    line=dict(width=2)
                ))
                
                # Add fitted curves if requested
                if show_fits:
                    fit_time, fit_data = self.get_fitted_curve_data(
                        curve['sample_id'], curve['curve_id'], variable
                    )
                    if fit_time is not None and fit_data is not None:
                        fig.add_trace(go.Scatter(
                            x=fit_time,
                            y=fit_data,
                            mode='lines',
                            name=f"Fit Curve {curve['curve_id']}",
                            line=dict(width=2, dash='dash')
                        ))
            
            fig.update_layout(
                title=f"{variable.replace('_', ' ').title()} - {sample_id}",
                xaxis_title="Time (hours)",
                yaxis_title=variable.replace('_', ' ').title(),
                width=800,
                height=500
            )
            
            display(fig)
    
    def plot_area_quartiles(self, data, variable):
        """Plot with median line and quartile area"""
        # Group data by sample
        samples = {}
        for curve in data:
            if curve['sample_id'] not in samples:
                samples[curve['sample_id']] = []
            samples[curve['sample_id']].append(curve)
        
        for sample_id, curves in samples.items():
            if len(curves) < 2:
                continue  # Need at least 2 curves for quartiles
            
            # Find common time points (interpolate to common grid)
            all_times = np.concatenate([curve['time'] for curve in curves])
            time_grid = np.linspace(all_times.min(), all_times.max(), 200)
            
            interpolated_data = []
            for curve in curves:
                interp_data = np.interp(time_grid, curve['time'], curve['data'])
                interpolated_data.append(interp_data)
            
            interpolated_data = np.array(interpolated_data)
            
            # Calculate median and quartiles
            median = np.median(interpolated_data, axis=0)
            q25 = np.percentile(interpolated_data, 25, axis=0)
            q75 = np.percentile(interpolated_data, 75, axis=0)
            
            fig = go.Figure()
            
            # Add quartile area
            fig.add_trace(go.Scatter(
                x=np.concatenate([time_grid, time_grid[::-1]]),
                y=np.concatenate([q75, q25[::-1]]),
                fill='toself',
                fillcolor='rgba(0,100,80,0.2)',
                line=dict(color='rgba(255,255,255,0)'),
                showlegend=True,
                name='25th-75th percentile'
            ))
            
            # Add median line
            fig.add_trace(go.Scatter(
                x=time_grid,
                y=median,
                mode='lines',
                name='Median',
                line=dict(color='blue', width=3)
            ))
            
            fig.update_layout(
                title=f"{variable.replace('_', ' ').title()} - {sample_id} (Median + Quartiles)",
                xaxis_title="Time (hours)",
                yaxis_title=variable.replace('_', ' ').title(),
                width=800,
                height=500
            )
            
            display(fig)
    
    def plot_area_std(self, data, variable):
        """Plot with mean line and standard deviation area"""
        # Group data by sample
        samples = {}
        for curve in data:
            if curve['sample_id'] not in samples:
                samples[curve['sample_id']] = []
            samples[curve['sample_id']].append(curve)
        
        for sample_id, curves in samples.items():
            if len(curves) < 2:
                continue  # Need at least 2 curves for std dev
            
            # Find common time points (interpolate to common grid)
            all_times = np.concatenate([curve['time'] for curve in curves])
            time_grid = np.linspace(all_times.min(), all_times.max(), 200)
            
            interpolated_data = []
            for curve in curves:
                interp_data = np.interp(time_grid, curve['time'], curve['data'])
                interpolated_data.append(interp_data)
            
            interpolated_data = np.array(interpolated_data)
            
            # Calculate mean and standard deviation
            mean = np.mean(interpolated_data, axis=0)
            std = np.std(interpolated_data, axis=0)
            
            fig = go.Figure()
            
            # Add standard deviation area
            fig.add_trace(go.Scatter(
                x=np.concatenate([time_grid, time_grid[::-1]]),
                y=np.concatenate([mean + std, (mean - std)[::-1]]),
                fill='toself',
                fillcolor='rgba(0,100,80,0.2)',
                line=dict(color='rgba(255,255,255,0)'),
                showlegend=True,
                name='±1 Standard Deviation'
            ))
            
            # Add mean line
            fig.add_trace(go.Scatter(
                x=time_grid,
                y=mean,
                mode='lines',
                name='Mean',
                line=dict(color='red', width=3)
            ))
            
            fig.update_layout(
                title=f"{variable.replace('_', ' ').title()} - {sample_id} (Mean ± Std Dev)",
                xaxis_title="Time (hours)",
                yaxis_title=variable.replace('_', ' ').title(),
                width=800,
                height=500
            )
            
            display(fig)
    
    def plot_histograms(self):
        """Generate histograms for t80 and ts parameters"""
        if self.fit_results is None or len(self.fit_results) == 0:
            print("⚠️ No fitting results available for histograms")
            return
        
        # Create subplots for histograms
        from plotly.subplots import make_subplots
        
        # Check which parameters are available
        available_params = list(self.fit_results.columns)
        hist_params = []
        
        # Look for t80, T80, ts, tS, Ts80 parameters
        for param in ['t80', 'T80', 'tS', 'ts', 'Ts80']:
            if param in available_params:
                hist_params.append(param)
        
        if not hist_params:
            print("⚠️ No time parameters (t80, T80, tS, Ts80) found in fitting results")
            return
        
        # Create histograms
        n_params = len(hist_params)
        cols = min(2, n_params)
        rows = (n_params + 1) // 2
        
        fig = make_subplots(
            rows=rows, 
            cols=cols,
            subplot_titles=[f"{param} Distribution" for param in hist_params]
        )
        
        for i, param in enumerate(hist_params):
            row = i // cols + 1
            col = i % cols + 1
            
            # Filter out NaN values
            values = self.fit_results[param].dropna()
            
            if len(values) > 0:
                fig.add_trace(
                    go.Histogram(
                        x=values,
                        name=param,
                        opacity=0.7,
                        nbinsx=20
                    ),
                    row=row, col=col
                )
        
        fig.update_layout(
            title="Parameter Distributions from Curve Fitting",
            height=400 * rows,
            width=800,
            showlegend=False
        )
        
        # Update x-axis labels
        for i, param in enumerate(hist_params):
            row = i // cols + 1
            col = i % cols + 1
            fig.update_xaxes(title_text=f"{param} (hours)", row=row, col=col)
            fig.update_yaxes(title_text="Count", row=row, col=col)
        
        display(fig)

    def _get_selected_curve_data(self, variable):
        """Get curve data for selected samples"""
        selected_data = []
        
        for sample_id in self.data["selected_samples"]:
            try:
                if sample_id in list(self.data["sample_ids"]):
                    sample_data = self.data["curves"].loc[sample_id]
                    
                    if hasattr(sample_data.index, 'nlevels') and sample_data.index.nlevels > 1:
                        for curve_idx in sample_data.index.get_level_values(0).unique():
                            curve_data = sample_data.loc[curve_idx]
                            if variable in curve_data.columns:
                                selected_data.append({
                                    'sample_id': sample_id,
                                    'curve_id': curve_idx,
                                    'time': curve_data['time'].values,
                                    'data': curve_data[variable].values
                                })
                    else:
                        if variable in sample_data.columns:
                            selected_data.append({
                                'sample_id': sample_id,
                                'curve_id': 0,
                                'time': sample_data['time'].values,
                                'data': sample_data[variable].values
                            })
            except:
                continue
        
        return selected_data
    
    def create_sample_tab_disabled(self):
        """Create a disabled placeholder for the sample selection tab"""
        disabled_message = widgets.HTML(
            value="<div style='text-align: center; padding: 50px; color: #888;'>"
                  "<h3>🔒 Sample Selection</h3>"
                  "<p>This tab will be enabled after loading MPPT data from the Batch Selection tab.</p>"
                  "</div>"
        )
        return widgets.VBox([disabled_message])
    
    def create_fitting_tab_disabled(self):
        """Create a disabled placeholder for the fitting tab"""
        disabled_message = widgets.HTML(
            value="<div style='text-align: center; padding: 50px; color: #888;'>"
                  "<h3>🔒 Curve Fitting</h3>"
                  "<p>This tab will be enabled after confirming sample selection.</p>"
                  "</div>"
        )
        return widgets.VBox([disabled_message])
    
    def create_plotting_tab_disabled(self):
        """Create a disabled placeholder for the plotting tab"""
        disabled_message = widgets.HTML(
            value="<div style='text-align: center; padding: 50px; color: #888;'>"
                  "<h3>🔒 Plotting</h3>"
                  "<p>This tab will be enabled after completing curve fitting.</p>"
                  "</div>"
        )
        return widgets.VBox([disabled_message])
    
    def enable_sample_tab(self):
        """Enable the sample selection tab"""
        self.sample_tab = self.create_sample_tab()
        
        current_children = list(self.tab_widget.children)
        current_children[1] = self.sample_tab
        self.tab_widget.children = current_children
        
        self.tab_widget.set_title(1, 'Sample Selection')
        self.tab_widget.selected_index = 1
    
    def enable_fitting_tab(self):
        """Enable the fitting tab"""
        self.fitting_tab = self.create_fitting_tab()
        
        current_children = list(self.tab_widget.children)
        current_children[2] = self.fitting_tab
        self.tab_widget.children = current_children
        
        self.tab_widget.set_title(2, 'Curve Fitting')
        self.tab_widget.selected_index = 2
    
    def enable_plotting_tab(self):
        """Enable the plotting tab"""
        self.plotting_tab = self.create_plotting_tab()
        
        current_children = list(self.tab_widget.children)
        current_children[3] = self.plotting_tab
        self.tab_widget.children = current_children
        
        self.tab_widget.set_title(3, 'Plotting')
        self.tab_widget.selected_index = 3

        self.enable_download_tab()

    def create_download_tab_disabled(self):
        """Create a disabled placeholder for the download tab"""
        disabled_message = widgets.HTML(
            value="<div style='text-align: center; padding: 50px; color: #888;'>"
                  "<h3>🔒 Download Results</h3>"
                  "<p>This tab will be enabled after completing curve fitting.</p>"
                  "</div>"
        )
        return widgets.VBox([disabled_message])

    def create_download_tab(self):
        """Create the download results tab"""
        import zipfile
        import io
        import base64
        from datetime import datetime
        
        # File format options
        excel_format = widgets.Checkbox(
            value=True,
            description='Excel file with multiple sheets',
            disabled=True,
            tooltip='Main results file with curve data, fit results, and statistics'
        )
        
        plots_format = widgets.Dropdown(
            options=[
                ('HTML (Interactive)', 'html'),
                ('PNG (Static Images)', 'png'),
                ('Both HTML and PNG', 'both')
            ],
            value='html',
            description='Plot format:',
            layout=widgets.Layout(width='300px')
        )
        
        include_raw_data = widgets.Checkbox(
            value=True,
            description='Include raw curve data',
            tooltip='Include the original MPPT curve measurements'
        )
        
        include_fitted_data = widgets.Checkbox(
            value=True,
            description='Include fitted curve data',
            tooltip='Include the fitted curves from mathematical models'
        )
        
        # Download button
        download_button = widgets.Button(
            description="📦 Generate Download Package",
            button_style='success',
            layout=widgets.Layout(width='250px')
        )
        
        # Status output
        download_status = widgets.Output()
        
        # Download link output
        download_link = widgets.Output()
        
        def generate_download_package(b):
            if self.fit_results is None or len(self.fit_results) == 0:
                with download_status:
                    download_status.clear_output()
                    print("⚠️ No fitting results available. Please complete curve fitting first.")
                return
            
            with download_status:
                download_status.clear_output()
                print("🔄 Generating download package...")
                
                try:
                    # Create a BytesIO buffer for the zip file
                    zip_buffer = io.BytesIO()
                    
                    with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
                        # 1. Create Excel file with multiple sheets
                        excel_buffer = io.BytesIO()
                        with pd.ExcelWriter(excel_buffer, engine='openpyxl') as writer:
                            
                            # Sheet 1: Raw curve data
                            if include_raw_data.value and "curves" in self.data:
                                # Create wide format with sample IDs as column headers
                                all_data = {}
                                max_length = 0
                                
                                for sample_id in self.data.get("selected_samples", []):
                                    try:
                                        sample_data = self.data["curves"].loc[sample_id]
                                        if hasattr(sample_data.index, 'nlevels') and sample_data.index.nlevels > 1:
                                            for curve_idx in sample_data.index.get_level_values(0).unique():
                                                curve_data = sample_data.loc[curve_idx]
                                                col_prefix = f"{sample_id}_curve_{curve_idx}"
                                                
                                                all_data[f"{col_prefix}_time"] = curve_data['time'].values
                                                all_data[f"{col_prefix}_power_density"] = curve_data['power_density'].values
                                                all_data[f"{col_prefix}_voltage"] = curve_data['voltage'].values
                                                all_data[f"{col_prefix}_current_density"] = curve_data['current_density'].values
                                                
                                                max_length = max(max_length, len(curve_data))
                                        else:
                                            col_prefix = f"{sample_id}_curve_0"
                                            all_data[f"{col_prefix}_time"] = sample_data['time'].values
                                            all_data[f"{col_prefix}_power_density"] = sample_data['power_density'].values
                                            all_data[f"{col_prefix}_voltage"] = sample_data['voltage'].values
                                            all_data[f"{col_prefix}_current_density"] = sample_data['current_density'].values
                                            
                                            max_length = max(max_length, len(sample_data))
                                    except:
                                        continue
                                
                                # Pad all arrays to the same length
                                for key, values in all_data.items():
                                    if len(values) < max_length:
                                        padded = np.full(max_length, np.nan)
                                        padded[:len(values)] = values
                                        all_data[key] = padded
                                
                                if all_data:
                                    raw_curves_df = pd.DataFrame(all_data)
                                    raw_curves_df.to_excel(writer, sheet_name='Raw_Curve_Data', index=False)
                            
                            # Sheet 2: Fitted curve data
                            if include_fitted_data.value and self.fitted_curves_data:
                                fitted_data_dict = {}
                                max_length = 0
                                
                                for (sample_id, curve_id), fitted_data in self.fitted_curves_data.items():
                                    col_prefix = f"{sample_id}_curve_{curve_id}"
                                    
                                    fitted_data_dict[f"{col_prefix}_time"] = fitted_data['time']
                                    fitted_data_dict[f"{col_prefix}_fitted_power_density"] = fitted_data['fitted_power']
                                    fitted_data_dict[f"{col_prefix}_original_power_density"] = fitted_data.get('original_power', fitted_data['fitted_power'])
                                    
                                    max_length = max(max_length, len(fitted_data['time']))
                                
                                # Pad all arrays to the same length
                                for key, values in fitted_data_dict.items():
                                    if len(values) < max_length:
                                        padded = np.full(max_length, np.nan)
                                        padded[:len(values)] = values
                                        fitted_data_dict[key] = padded
                                
                                if fitted_data_dict:
                                    fitted_curves_df = pd.DataFrame(fitted_data_dict)
                                    fitted_curves_df.to_excel(writer, sheet_name='Fitted_Curve_Data', index=False)
                            
                            # Sheet 3: Fit results
                            if self.fit_results is not None and len(self.fit_results) > 0:
                                self.fit_results.to_excel(writer, sheet_name='Fit_Results', index=False)
                            
                            # Sheet 4: Statistical summary
                            if self.fit_results is not None and len(self.fit_results) > 0:
                                numerical_cols = self.fit_results.select_dtypes(include=[np.number]).columns
                                if len(numerical_cols) > 0:
                                    stats_df = self.fit_results[numerical_cols].describe()
                                    stats_df.to_excel(writer, sheet_name='Statistical_Summary')
                            
                            # Sheet 5: Sample information
                            if "selected_samples" in self.data and "properties" in self.data:
                                sample_info_list = []
                                for sample_id in self.data["selected_samples"]:
                                    info = {
                                        'sample_id': sample_id,
                                        'description': self.data["properties"].loc[sample_id, "description"] if sample_id in self.data["properties"].index else "",
                                        'custom_name': self.data.get("custom_names", {}).get(sample_id, "")
                                    }
                                    sample_info_list.append(info)
                                
                                sample_info_df = pd.DataFrame(sample_info_list)
                                sample_info_df.to_excel(writer, sheet_name='Sample_Information', index=False)
                        
                        # Add Excel file to zip
                        zip_file.writestr('MPPT_Analysis_Results.xlsx', excel_buffer.getvalue())
                        
                        # 2. Generate basic plots
                        print("📊 Generating basic plots...")
                        plot_counter = 0
                        
                        # Generate a simple power density plot for each sample
                        try:
                            selected_data = self._get_selected_curve_data('power_density')
                            if selected_data:
                                for i, curve in enumerate(selected_data[:3]):  # Limit to first 3
                                    try:
                                        fig = go.Figure()
                                        
                                        # Add original data
                                        fig.add_trace(go.Scatter(
                                            x=curve['time'],
                                            y=curve['data'],
                                            mode='lines',
                                            name="Data",
                                            line=dict(width=2, color='blue')
                                        ))
                                        
                                        # Add fitted curve if available
                                        curve_key = (curve['sample_id'], curve['curve_id'])
                                        if curve_key in self.fitted_curves_data:
                                            fitted_data = self.fitted_curves_data[curve_key]
                                            fig.add_trace(go.Scatter(
                                                x=fitted_data['time'],
                                                y=fitted_data['fitted_power'],
                                                mode='lines',
                                                name="Fit",
                                                line=dict(width=2, color='red', dash='dash')
                                            ))
                                        
                                        fig.update_layout(
                                            title=f"Power Density - {curve['sample_id']} Curve {curve['curve_id']}",
                                            xaxis_title="Time (hours)",
                                            yaxis_title="Power Density",
                                            width=800,
                                            height=500
                                        )
                                        
                                        plot_counter += 1
                                        plot_name = f"{plot_counter:02d}_power_density_{curve['sample_id']}_curve_{curve['curve_id']}"
                                        
                                        if plots_format.value in ['html', 'both']:
                                            html_str = fig.to_html(include_plotlyjs='cdn')
                                            zip_file.writestr(f'plots/{plot_name}.html', html_str)
                                        
                                        if plots_format.value in ['png', 'both']:
                                            try:
                                                img_bytes = fig.to_image(format="png", width=800, height=600)
                                                zip_file.writestr(f'plots/{plot_name}.png', img_bytes)
                                            except:
                                                print(f"⚠️ Could not generate PNG for plot {i+1}")
                                        
                                        print(f"Generated plot {i+1}")
                                        
                                    except Exception as e:
                                        print(f"⚠️ Error generating plot {i+1}: {str(e)}")
                                        continue
                                
                                print(f"Generated {plot_counter} individual plots")
                        
                        except Exception as e:
                            print(f"⚠️ Error in plot generation: {str(e)}")
                        
                        # 3. Generate histograms
                        try:
                            available_params = list(self.fit_results.columns)
                            hist_params = [param for param in ['t80', 'T80', 'tS', 'ts', 'Ts80'] if param in available_params]
                            
                            if hist_params:
                                from plotly.subplots import make_subplots
                                
                                n_params = len(hist_params)
                                cols = min(2, n_params)
                                rows = (n_params + 1) // 2
                                
                                fig = make_subplots(
                                    rows=rows, 
                                    cols=cols,
                                    subplot_titles=[f"{param} Distribution" for param in hist_params]
                                )
                                
                                for i, param in enumerate(hist_params):
                                    row = i // cols + 1
                                    col = i % cols + 1
                                    
                                    values = self.fit_results[param].dropna()
                                    
                                    if len(values) > 0:
                                        fig.add_trace(
                                            go.Histogram(
                                                x=values,
                                                name=param,
                                                opacity=0.7,
                                                nbinsx=20
                                            ),
                                            row=row, col=col
                                        )
                                
                                fig.update_layout(
                                    title="Parameter Distributions from Curve Fitting",
                                    height=400 * rows,
                                    width=800,
                                    showlegend=False
                                )
                                
                                for i, param in enumerate(hist_params):
                                    row = i // cols + 1
                                    col = i % cols + 1
                                    fig.update_xaxes(title_text=f"{param} (hours)", row=row, col=col)
                                    fig.update_yaxes(title_text="Count", row=row, col=col)
                                
                                plot_name = "histograms_1"
                                
                                if plots_format.value in ['html', 'both']:
                                    html_str = fig.to_html(include_plotlyjs='cdn')
                                    zip_file.writestr(f'plots/{plot_name}.html', html_str)
                                
                                if plots_format.value in ['png', 'both']:
                                    try:
                                        img_bytes = fig.to_image(format="png", width=800, height=600)
                                        zip_file.writestr(f'plots/{plot_name}.png', img_bytes)
                                    except:
                                        print("⚠️ Could not generate PNG histogram")
                        
                        except Exception as e:
                            print(f"⚠️ Error generating histograms: {str(e)}")
                        
                        # 4. Add README
                        readme_content = f"""
        MPPT Analysis Results Package
        Generated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
        
        CONTENTS:
        =========
        1. MPPT_Analysis_Results.xlsx - Excel file with multiple sheets containing raw data, fitted curves, fit results, and statistics
        2. plots/ folder - {plots_format.value.upper()} plots of the analysis results
        3. README.txt - This file
        
        ANALYSIS DETAILS:
        ================
        Selected Samples: {len(self.data.get('selected_samples', []))}
        Total Fitted Curves: {len(self.fit_results)}
        Variables Analyzed: Power Density, Voltage, Current Density
        
        For detailed information about the analysis parameters and methods, 
        please refer to the original MPPT analysis notebook.
        """
                        zip_file.writestr('README.txt', readme_content)
                    
                    # Prepare download
                    zip_buffer.seek(0)
                    zip_data = zip_buffer.read()
                    
                    # Create download link
                    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                    filename = f"MPPT_Analysis_Results_{timestamp}.zip"
                    
                    # Encode for download
                    b64_data = base64.b64encode(zip_data).decode()
                    
                    with download_link:
                        download_link.clear_output()
                        display(HTML(f'''
                        <div style="padding: 20px; border: 2px solid #28a745; border-radius: 10px; background-color: #d4edda;">
                            <h3 style="color: #155724; margin-top: 0;">✅ Download Package Ready!</h3>
                            <p><strong>File size:</strong> {len(zip_data) / 1024 / 1024:.2f} MB</p>
                            <p><strong>Contents:</strong></p>
                            <ul>
                                <li>Excel file with {len(self.fit_results)} fitted curves</li>
                                <li>{plot_counter} plots ({plots_format.value} format)</li>
                                <li>Parameter histograms</li>
                                <li>README with analysis details</li>
                            </ul>
                            <a href="data:application/zip;base64,{b64_data}" 
                               download="{filename}" 
                               style="background-color: #28a745; color: white; padding: 10px 20px; 
                                      text-decoration: none; border-radius: 5px; font-weight: bold;">
                                📥 Download {filename}
                            </a>
                        </div>
                        '''))
                    
                    print(f"✅ Package generated successfully! ({len(zip_data) / 1024 / 1024:.2f} MB)")
                    
                except Exception as e:
                    print(f"❌ Error generating download package: {str(e)}")
                    import traceback
                    traceback.print_exc()
        
        download_button.on_click(generate_download_package)
        
        controls = widgets.VBox([
            widgets.HTML("<h3>📦 Download Analysis Results</h3>"),
            widgets.HTML(f"<p>Create a comprehensive zip file containing all analysis results.</p>"),
            widgets.HTML("<h4>📋 Package Contents:</h4>"),
            excel_format,
            widgets.HTML("<h4>📊 Plot Options:</h4>"),
            plots_format,
            widgets.HTML("<h4>📈 Data Options:</h4>"),
            include_raw_data,
            include_fitted_data,
            download_button,
            download_status,
            download_link
        ])
        
        return controls

    def enable_download_tab(self):
        """Enable the download tab"""
        self.download_tab = self.create_download_tab()
        
        current_children = list(self.tab_widget.children)
        current_children[4] = self.download_tab
        self.tab_widget.children = current_children
        
        self.tab_widget.set_title(4, 'Download Results')
    
    def create_sample_tab(self):
        """Create the sample selection tab"""
        name_preset = widgets.Dropdown(
            options=[
                ('Sample Name', 'sample_name'), 
                ('Batch', 'batch'),
                ('Sample Description', 'sample_description'), 
                ('Custom', 'custom')
            ],
            value='sample_name',
            description='Name preset:',
            tooltip="Presets for how the samples will be named"
        )
        
        selection_status = widgets.Output()
        selectors_container = widgets.VBox()
        
        confirm_button = widgets.Button(
            description="Confirm Selection",
            button_style='primary',
            layout=widgets.Layout(width='200px')
        )
        
        def create_sample_selectors():
            self.sample_selectors = {}
            selector_widgets = []
            
            for sample_id in self.data["sample_ids"]:
                selector = self.create_sample_selector(sample_id, name_preset.value)
                self.sample_selectors[sample_id] = selector
                selector_widgets.append(selector['container'])
            
            selectors_container.children = selector_widgets
            
            with selection_status:
                selection_status.clear_output()
                print(f"⚠️ Selection not confirmed - {len(self.data['sample_ids'])} samples available")
        
        def confirm_selection(b):
            selected_samples = []
            custom_names = {}
            
            for sample_id, selector in self.sample_selectors.items():
                if selector['checkbox'].value:
                    selected_samples.append(sample_id)
                    if name_preset.value == 'custom' and selector['text'].value.strip():
                        custom_names[sample_id] = selector['text'].value.strip()
            
            if not selected_samples:
                with selection_status:
                    selection_status.clear_output()
                    print("⚠️ Please select at least one sample")
                return
            
            self.data["selected_samples"] = selected_samples
            self.data["custom_names"] = custom_names
            
            with selection_status:
                selection_status.clear_output()
                print(f"✅ Selection confirmed - {len(selected_samples)} samples selected")
                if custom_names:
                    print("Custom names applied:")
                    for sample_id, name in custom_names.items():
                        print(f"  {sample_id} → {name}")
            
            self.enable_fitting_tab()
        
        def on_preset_change(change):
            create_sample_selectors()
        
        name_preset.observe(on_preset_change, names='value')
        confirm_button.on_click(confirm_selection)
        
        create_sample_selectors()
        
        controls = widgets.VBox([
            widgets.HTML("<h3>Sample Selection</h3>"),
            widgets.HTML(f"<p>Found {len(self.data['sample_ids'])} samples with MPPT data.</p>"),
            name_preset,
            selectors_container,
            confirm_button,
            selection_status
        ])
        
        return controls
    
    def create_sample_selector(self, sample_id, preset_type):
        """Create a sample selector widget"""
        if preset_type == 'batch':
            item_split = sample_id.split("&")
            if len(item_split) >= 2:
                default_name = item_split[0]
            else:
                default_name = "_".join(sample_id.split("_")[:-1])
        elif preset_type == 'sample_name':
            item_split = sample_id.split("&")
            if len(item_split) >= 2:
                default_name = "&".join(item_split[1:])
            else:
                default_name = sample_id
        elif preset_type == 'sample_description':
            default_name = self.data["properties"].loc[sample_id, "description"] if sample_id in self.data["properties"].index else sample_id
        else:
            default_name = ""
        
        checkbox = widgets.Checkbox(
            value=True,
            description=sample_id,
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )
        
        if preset_type == 'custom':
            text_input = widgets.Text(
                value=default_name,
                placeholder='Enter custom name',
                layout=widgets.Layout(width='200px')
            )
            container = widgets.HBox([checkbox, text_input])
        else:
            name_label = widgets.Label(
                value=default_name,
                layout=widgets.Layout(width='200px')
            )
            text_input = widgets.Text(value=default_name, layout=widgets.Layout(display='none'))
            container = widgets.HBox([checkbox, name_label])
        
        return {
            'checkbox': checkbox,
            'text': text_input,
            'container': container
        }
    
    def create_fitting_tab(self):
        """Create the curve fitting tab"""
        model_options = [(f"{model.abbreviated_name}", i) 
                        for i, model in enumerate(available_fit_model_list)]
        
        model_selector = widgets.Dropdown(
            options=model_options,
            value=0,
            description='Model:',
            layout=widgets.Layout(width='400px'),
            style={'description_width': 'initial'}
        )
        
        time_range_selector = widgets.FloatRangeSlider(
            value=(0, 100),
            min=0,
            max=1000,
            step=0.1,
            description='Time Range (h):',
            layout=widgets.Layout(width='400px'),
            style={'description_width': 'initial'}
        )
        
        if "curves" in self.data and self.data["curves"] is not None:
            t_min = self.data["curves"]["time"].min()
            t_max = self.data["curves"]["time"].max()
            time_range_selector.min = t_min
            time_range_selector.max = t_max
            time_range_selector.value = (t_min, t_max)
        
        fit_button = widgets.Button(
            description="Fit All Curves",
            button_style='primary',
            layout=widgets.Layout(width='200px')
        )
        
        fit_status = widgets.Output()
        formula_display = widgets.Output()
        
        results_toggle = widgets.Accordion(children=[widgets.Output()], titles=('Show all fitting results',))
        results_toggle.selected_index = None
        
        stats_toggle = widgets.Accordion(children=[widgets.Output()], titles=('Statistical Summary',))
        stats_toggle.selected_index = 0
        
        def update_formula(change):
            model = available_fit_model_list[model_selector.value]
            with formula_display:
                formula_display.clear_output()
                display(Latex(f"Selected Model: {model.description}"))
                display(HTML(f"<b>Parameters:</b> {', '.join(model.columns)}"))
        
        def perform_fitting(b):
            if "selected_samples" not in self.data:
                with fit_status:
                    fit_status.clear_output()
                    print("⚠️ No samples selected. Please complete sample selection first.")
                return
            
            with fit_status:
                fit_status.clear_output()
                print("🔄 Fitting curves...")
                
                try:
                    model = available_fit_model_list[model_selector.value]
                    time_range = time_range_selector.value
                    
                    # Store the model for later use in plotting
                    self.last_fitted_model = model
                    
                    # Get both results and fitted curves data
                    fit_results, fitted_curves_data = fit_all_samples_lmfit(
                        self.data["curves"], 
                        self.data["sample_ids"], 
                        self.data["selected_samples"], 
                        model, 
                        time_range
                    )
                    
                    self.fit_results = fit_results
                    self.fitted_curves_data = fitted_curves_data
                    
                    if self.fit_results is not None and len(self.fit_results) > 0:
                        print(f"✅ Fitting completed! {len(self.fit_results)} curves fitted successfully")
                        
                        with results_toggle.children[0]:
                            results_toggle.children[0].clear_output()
                            display(HTML("<h4>Detailed Fit Results</h4>"))
                            display(HTML(self.fit_results.to_html(index=False, float_format='%.4f')))
                        
                        with stats_toggle.children[0]:
                            stats_toggle.children[0].clear_output()
                            display(HTML("<h4>Statistical Summary</h4>"))
                            
                            numerical_cols = self.fit_results.select_dtypes(include=[np.number]).columns
                            if len(numerical_cols) > 0:
                                stats_df = self.fit_results[numerical_cols].describe()
                                display(HTML(stats_df.to_html(float_format='%.4f')))
                            else:
                                print("No numerical parameters to summarize")
                        
                        # Enable plotting tab after successful fitting
                        self.enable_plotting_tab()
                    
                    else:
                        print("❌ Fitting failed. No curves could be fitted successfully.")
                        print("This might be due to insufficient data points or numerical issues.")
                        
                except Exception as e:
                    print(f"❌ Error during fitting: {str(e)}")
                    import traceback
                    traceback.print_exc()
        
        update_formula(None)
        
        model_selector.observe(update_formula, names='value')
        fit_button.on_click(perform_fitting)
        
        controls = widgets.VBox([
            widgets.HTML("<h3>Curve Fitting</h3>"),
            widgets.HTML(f"<p>Fit mathematical models to {len(self.data.get('selected_samples', []))} selected samples.</p>"),
            model_selector,
            formula_display,
            time_range_selector,
            fit_button,
            fit_status,
            stats_toggle,
            results_toggle
        ])
        
        return controls

# Initialize and display the app
app = MPPTAnalysisApp()