In [8]:
import json
import pandas as pd
import plotly.express as px
import dash
from dash import dcc, html
from dash.dependencies import Input, Output

# Load the JSON data
with open("./output_temp.json") as f:
    data = json.load(f)

# Function to flatten nested dictionary (dataset_params and model_params)
def flatten_dict(d, parent_key='', sep='_'):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

# Function to process the JSON data into a DataFrame
def process_data(data):
    rows = []
    for entry in data:
        # Flatten model_params and dataset_params and combine them
        flattened_params = {**flatten_dict(entry["model_params"]), **flatten_dict(entry["dataset_params"])}
        
        # Add train and validation metrics
        train_metrics = entry["train_metrics"]
        val_metrics = entry["val_metrics"]

        for epoch in range(entry["epochs"]):
            row = {
                'epoch': epoch + 1,
                'train_loss': train_metrics['loss'][epoch],
                'train_accuracy': train_metrics['accuracy'][epoch],
                'train_f1': train_metrics['f1'][epoch],
                'val_loss': val_metrics['loss'][epoch],
                'val_accuracy': val_metrics['accuracy'][epoch],
                'val_f1': val_metrics['f1'][epoch],
            }
            row.update(flattened_params)  # Include all flattened parameters in the row
            rows.append(row)
    
    return pd.DataFrame(rows)

# Convert the JSON data to a DataFrame
df = process_data(data)

# Get all unique parameter names (excluding metric columns)
pivot_options = [col for col in df.columns if col not in ['epoch', 'train_loss', 'train_accuracy', 'train_f1', 'val_loss', 'val_accuracy', 'val_f1']]

# Create a Dash app for interactive plotting
app = dash.Dash(__name__)

app.layout = html.Div([
    html.H1("Model Training Metrics"),
    
    # Dropdown for selecting the metric to display
    html.Label("Choose a Metric:"),
    dcc.Dropdown(
        id="metric",
        options=[
            {'label': 'Loss', 'value': 'loss'},
            {'label': 'Accuracy', 'value': 'accuracy'},
            {'label': 'F1 Score', 'value': 'f1'},
        ],
        value='loss'
    ),
    
    # Dropdown for selecting the hyperparameter to pivot over (dynamic based on dataset)
    html.Label("Choose a Hyperparameter to Pivot Over:"),
    dcc.Dropdown(
        id="pivot_by",
        options=[{'label': col.replace('_', ' ').capitalize(), 'value': col} for col in pivot_options],
        value='learning_rate'
    ),

    # Toggle to switch between individual performance and summary view
    dcc.RadioItems(
        id='view_mode',
        options=[
            {'label': 'Individual', 'value': 'individual'},
            {'label': 'Summary', 'value': 'summary'}
        ],
        value='individual',
        labelStyle={'display': 'inline-block'}
    ),
    
    # Plot for training metrics
    dcc.Graph(id='train_metric_graph'),
    
    # Plot for validation metrics
    dcc.Graph(id='val_metric_graph')
])

@app.callback(
    [Output('train_metric_graph', 'figure'),
     Output('val_metric_graph', 'figure')],
    [Input('metric', 'value'),
     Input('pivot_by', 'value'),
     Input('view_mode', 'value')]
)
def update_graph(selected_metric, pivot_by, view_mode):
    # Columns for train and validation metrics
    train_metric_col = f"train_{selected_metric}"
    val_metric_col = f"val_{selected_metric}"

    if view_mode == 'summary':
        # Calculate summary for each epoch (min, max, avg) for each value of pivot_by
        summary_df = df.groupby(['epoch', pivot_by]).agg(
            avg_train_metric=(train_metric_col, 'mean'),
            min_train_metric=(train_metric_col, 'min'),
            max_train_metric=(train_metric_col, 'max'),
            avg_val_metric=(val_metric_col, 'mean'),
            min_val_metric=(val_metric_col, 'min'),
            max_val_metric=(val_metric_col, 'max')
        ).reset_index()

        # Find min/max models for each epoch
        min_train_df = df.loc[df.groupby(['epoch'])[train_metric_col].idxmin()]
        max_train_df = df.loc[df.groupby(['epoch'])[train_metric_col].idxmax()]
        
        min_val_df = df.loc[df.groupby(['epoch'])[val_metric_col].idxmin()]
        max_val_df = df.loc[df.groupby(['epoch'])[val_metric_col].idxmax()]

        # Plot average, min, and max for training metrics
        train_fig = px.line(summary_df, x='epoch', y='avg_train_metric', color=pivot_by,
                            title=f"Train {selected_metric.capitalize()} Summary Over Epochs",
                            labels={'epoch': 'Epoch', 'avg_train_metric': f"Average Train {selected_metric.capitalize()}"})
        
        # Add min and max as dashed lines (using actual models with min/max)
        train_fig.add_scatter(x=min_train_df['epoch'], y=min_train_df[train_metric_col], mode='lines', name='Min Train', line=dict(dash='dash'))
        train_fig.add_scatter(x=max_train_df['epoch'], y=max_train_df[train_metric_col], mode='lines', name='Max Train', line=dict(dash='dash'))

        # Plot average, min, and max for validation metrics
        val_fig = px.line(summary_df, x='epoch', y='avg_val_metric', color=pivot_by,
                          title=f"Validation {selected_metric.capitalize()} Summary Over Epochs",
                          labels={'epoch': 'Epoch', 'avg_val_metric': f"Average Validation {selected_metric.capitalize()}"})

        # Add min and max as dashed lines (using actual models with min/max)
        val_fig.add_scatter(x=min_val_df['epoch'], y=min_val_df[val_metric_col], mode='lines', name='Min Val', line=dict(dash='dash'))
        val_fig.add_scatter(x=max_val_df['epoch'], y=max_val_df[val_metric_col], mode='lines', name='Max Val', line=dict(dash='dash'))

    else:  # Individual mode
        # Plot for training metrics (individual performance)
        train_fig = px.line(df, x="epoch", y=train_metric_col, color=pivot_by, markers=True,
                            labels={
                                "epoch": "Epoch",
                                train_metric_col: f"Train {selected_metric.capitalize()}",
                                pivot_by: pivot_by.replace("_", " ").capitalize()
                            },
                            title=f"Train {selected_metric.capitalize()} Over Epochs")
        
        # Plot for validation metrics (individual performance)
        val_fig = px.line(df, x="epoch", y=val_metric_col, color=pivot_by, markers=True,
                          labels={
                              "epoch": "Epoch",
                              val_metric_col: f"Validation {selected_metric.capitalize()}",
                              pivot_by: pivot_by.replace("_", " ").capitalize()
                          },
                          title=f"Validation {selected_metric.capitalize()} Over Epochs")

    return train_fig, val_fig

if __name__ == '__main__':
    app.run_server(debug=True)


In [1]:
import json
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import dash
from dash import dcc, html
from dash.dependencies import Input, Output

# Load the JSON data
with open("./output_plotting.json") as f:
    data = json.load(f)

# Function to flatten nested dictionary (dataset_params and model_params)
def flatten_dict(d, parent_key='', sep='_'):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

# Function to process the JSON data into a DataFrame
def process_data(data):
    rows = []
    for entry in data:
        # Flatten model_params and dataset_params and combine them
        flattened_params = {**flatten_dict(entry["model_params"]), **flatten_dict(entry["dataset_params"])}
        
        # Add train and validation metrics
        train_metrics = entry["train_metrics"]
        val_metrics = entry["val_metrics"]

        for epoch in range(entry["epochs"]):
            row = {
                'epoch': epoch + 1,
                'train_loss': train_metrics['loss'][epoch],
                'train_accuracy': train_metrics['accuracy'][epoch],
                'train_f1': train_metrics['f1'][epoch],
                'val_loss': val_metrics['loss'][epoch],
                'val_accuracy': val_metrics['accuracy'][epoch],
                'val_f1': val_metrics['f1'][epoch],
            }
            row.update(flattened_params)  # Include all flattened parameters in the row
            rows.append(row)
    
    return pd.DataFrame(rows)

# Convert the JSON data to a DataFrame
df = process_data(data)

# Get all unique parameter names (excluding metric columns)
pivot_options = [col for col in df.columns if col not in ['epoch', 'train_loss', 'train_accuracy', 'train_f1', 'val_loss', 'val_accuracy', 'val_f1']]

# Create a Dash app for interactive plotting
app = dash.Dash(__name__)

app.layout = html.Div([
    html.H1("Model Training Metrics"),
    
    # Dropdown for selecting the metric to display
    html.Label("Choose a Metric:"),
    dcc.Dropdown(
        id="metric",
        options=[
            {'label': 'Loss', 'value': 'loss'},
            {'label': 'Accuracy', 'value': 'accuracy'},
            {'label': 'F1 Score', 'value': 'f1'},
        ],
        value='loss'
    ),
    
    # Dropdown for selecting the hyperparameter to pivot over (dynamic based on dataset)
    html.Label("Choose a Hyperparameter to Pivot Over:"),
    dcc.Dropdown(
        id="pivot_by",
        options=[{'label': col.replace('_', ' ').capitalize(), 'value': col} for col in pivot_options],
        value='learning_rate'
    ),

    # Toggle to switch between individual performance and summary view
    dcc.RadioItems(
        id='view_mode',
        options=[
            {'label': 'Individual', 'value': 'individual'},
            {'label': 'Summary', 'value': 'summary'}
        ],
        value='individual',
        labelStyle={'display': 'inline-block'}
    ),
    
    # Plot for training metrics
    dcc.Graph(id='train_metric_graph'),
    
    # Plot for validation metrics
    dcc.Graph(id='val_metric_graph')
])

@app.callback(
    [Output('train_metric_graph', 'figure'),
     Output('val_metric_graph', 'figure')],
    [Input('metric', 'value'),
     Input('pivot_by', 'value'),
     Input('view_mode', 'value')]
)
def update_graph(selected_metric, pivot_by, view_mode):
    # Columns for train and validation metrics
    train_metric_col = f"train_{selected_metric}"
    val_metric_col = f"val_{selected_metric}"

    if view_mode == 'summary':
        # Calculate summary for each epoch (min, max, avg) for each value of pivot_by
        summary_df = df.groupby(['epoch', pivot_by]).agg(
            avg_train_metric=(train_metric_col, 'mean'),
            min_train_metric=(train_metric_col, 'min'),
            max_train_metric=(train_metric_col, 'max'),
            avg_val_metric=(val_metric_col, 'mean'),
            min_val_metric=(val_metric_col, 'min'),
            max_val_metric=(val_metric_col, 'max')
        ).reset_index()

        # Prepare figure for training metrics
        train_fig = go.Figure()

        # For each pivot_by group, add average, min, and max lines with the same color
        unique_pivot_vals = summary_df[pivot_by].unique()
        color_map = px.colors.qualitative.Plotly[:len(unique_pivot_vals)]

        for i, val in enumerate(unique_pivot_vals):
            pivot_group = summary_df[summary_df[pivot_by] == val]

            # Add avg line
            train_fig.add_trace(go.Scatter(
                x=pivot_group['epoch'],
                y=pivot_group['avg_train_metric'],
                mode='lines',
                name=f'Avg Train ({val})',
                line=dict(color=color_map[i], width=2)
            ))

            # Add min and max as dashed lines
            train_fig.add_trace(go.Scatter(
                x=pivot_group['epoch'],
                y=pivot_group['min_train_metric'],
                mode='lines',
                name=f'Min Train ({val})',
                line=dict(color=color_map[i], dash='dash')
            ))

            train_fig.add_trace(go.Scatter(
                x=pivot_group['epoch'],
                y=pivot_group['max_train_metric'],
                mode='lines',
                name=f'Max Train ({val})',
                line=dict(color=color_map[i], dash='dash')
            ))

        train_fig.update_layout(
            title=f"Train {selected_metric.capitalize()} Summary Over Epochs",
            xaxis_title="Epoch",
            yaxis_title=f"Train {selected_metric.capitalize()}"
        )

        # Prepare figure for validation metrics
        val_fig = go.Figure()

        for i, val in enumerate(unique_pivot_vals):
            pivot_group = summary_df[summary_df[pivot_by] == val]

            # Add avg line
            val_fig.add_trace(go.Scatter(
                x=pivot_group['epoch'],
                y=pivot_group['avg_val_metric'],
                mode='lines',
                name=f'Avg Val ({val})',
                line=dict(color=color_map[i], width=2)
            ))

            # Add min and max as dashed lines
            val_fig.add_trace(go.Scatter(
                x=pivot_group['epoch'],
                y=pivot_group['min_val_metric'],
                mode='lines',
                name=f'Min Val ({val})',
                line=dict(color=color_map[i], dash='dash')
            ))

            val_fig.add_trace(go.Scatter(
                x=pivot_group['epoch'],
                y=pivot_group['max_val_metric'],
                mode='lines',
                name=f'Max Val ({val})',
                line=dict(color=color_map[i], dash='dash')
            ))

        val_fig.update_layout(
            title=f"Validation {selected_metric.capitalize()} Summary Over Epochs",
            xaxis_title="Epoch",
            yaxis_title=f"Validation {selected_metric.capitalize()}"
        )

    else:  # Individual mode
        # Plot for training metrics (individual performance)
        train_fig = px.line(df, x="epoch", y=train_metric_col, color=pivot_by, markers=True,
                            labels={
                                "epoch": "Epoch",
                                train_metric_col: f"Train {selected_metric.capitalize()}",
                                pivot_by: pivot_by.replace("_", " ").capitalize()
                            },
                            title=f"Train {selected_metric.capitalize()} Over Epochs")
        
        # Plot for validation metrics (individual performance)
        val_fig = px.line(df, x="epoch", y=val_metric_col, color=pivot_by, markers=True,
                          labels={
                              "epoch": "Epoch",
                              val_metric_col: f"Validation {selected_metric.capitalize()}",
                              pivot_by: pivot_by.replace("_", " ").capitalize()
                          },
                          title=f"Validation {selected_metric.capitalize()} Over Epochs")

    return train_fig, val_fig

if __name__ == '__main__':
    app.run_server(debug=True)


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 205, in update_graph(
    selected_metric='loss',
    pivot_by='text_selection_method',
    view_mode='individual'
)
    197     val_fig.update_layout(
    198         title=f"Validation {selected_metric.capitalize()} Summary Over Epochs",
    199         xaxis_title="Epoch",
    200         yaxis_title=f"Validation {selected_metric.capitalize()}"
    201     )
    203 else:  # Individual mode
    204     # Plot for training metrics (individual performance)
--> 205     train_fig = px.line(df, x="epoch", y=train_metric_col, color=pivot_by, markers=True,
        train_metric_col = 'train_loss'
        df =       epoch  train_loss  train_accuracy  train_f1  val_loss  val_accuracy  \
0         1    0.515992        0.507607  0.377019  0.082439      0.477778   
1         2    0.501661        0.510373  0.344923  0.085130      

In [16]:
import json
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
import plotly.express as px  # For color handling
from plotly.subplots import make_subplots

# Load the JSON data
with open("./output_plotting.json") as f:
    data = json.load(f)

# Function to flatten nested dictionary (dataset_params and model_params)
def flatten_dict(d, parent_key='', sep='_'):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

# Function to process the JSON data into a DataFrame
def process_data(data):
    rows = []
    for entry in data:
        # Flatten model_params and dataset_params and combine them
        flattened_params = {**flatten_dict(entry["model_params"]), **flatten_dict(entry["dataset_params"])}
        
        # Add train and validation metrics
        train_metrics = entry["train_metrics"]
        val_metrics = entry["val_metrics"]

        for epoch in range(entry["epochs"]):
            row = {
                'epoch': epoch + 1,
                'train_loss': train_metrics['loss'][epoch],
                'train_accuracy': train_metrics['accuracy'][epoch],
                'train_f1': train_metrics['f1'][epoch],
                'val_loss': val_metrics['loss'][epoch],
                'val_accuracy': val_metrics['accuracy'][epoch],
                'val_f1': val_metrics['f1'][epoch],
            }
            row.update(flattened_params)  # Include all flattened parameters in the row
            rows.append(row)
    
    return pd.DataFrame(rows)

# Convert the JSON data to a DataFrame
df = process_data(data)

# Get all unique parameter names (excluding metric columns)
pivot_options = [col for col in df.columns if col not in ['epoch', 'train_loss', 'train_accuracy', 'train_f1', 'val_loss', 'val_accuracy', 'val_f1']]

# Create figures for both summary and individual views

def create_summary_figures(selected_metric, pivot_by):
    # Columns for train and validation metrics
    train_metric_col = f"train_{selected_metric}"
    val_metric_col = f"val_{selected_metric}"
    
    summary_df = df.groupby(['epoch', pivot_by]).agg(
        avg_train_metric=(train_metric_col, 'mean'),
        min_train_metric=(train_metric_col, 'min'),
        max_train_metric=(train_metric_col, 'max'),
        avg_val_metric=(val_metric_col, 'mean'),
        min_val_metric=(val_metric_col, 'min'),
        max_val_metric=(val_metric_col, 'max')
    ).reset_index()

    # Prepare figure for training metrics
    train_fig = go.Figure()

    # For each pivot_by group, add average, min, and max lines
    unique_pivot_vals = summary_df[pivot_by].unique()
    color_map = px.colors.qualitative.Plotly[:len(unique_pivot_vals)]  # Use plotly express color map

    for i, val in enumerate(unique_pivot_vals):
        pivot_group = summary_df[summary_df[pivot_by] == val]

        # Add avg line
        train_fig.add_trace(go.Scatter(
            x=pivot_group['epoch'],
            y=pivot_group['avg_train_metric'],
            mode='lines',
            name=f'Avg Train ({val})',
            line=dict(color=color_map[i], width=2)
        ))

        # Add min and max as dashed lines
        train_fig.add_trace(go.Scatter(
            x=pivot_group['epoch'],
            y=pivot_group['min_train_metric'],
            mode='lines',
            name=f'Min Train ({val})',
            line=dict(color=color_map[i], dash='dash')
        ))

        train_fig.add_trace(go.Scatter(
            x=pivot_group['epoch'],
            y=pivot_group['max_train_metric'],
            mode='lines',
            name=f'Max Train ({val})',
            line=dict(color=color_map[i], dash='dash')
        ))

    train_fig.update_layout(
        title=f"Train {selected_metric.capitalize()} Summary Over Epochs",
        xaxis_title="Epoch",
        yaxis_title=f"Train {selected_metric.capitalize()}"
    )

    # Prepare figure for validation metrics
    val_fig = go.Figure()

    for i, val in enumerate(unique_pivot_vals):
        pivot_group = summary_df[summary_df[pivot_by] == val]

        # Add avg line
        val_fig.add_trace(go.Scatter(
            x=pivot_group['epoch'],
            y=pivot_group['avg_val_metric'],
            mode='lines',
            name=f'Avg Val ({val})',
            line=dict(color=color_map[i], width=2)
        ))

        # Add min and max as dashed lines
        val_fig.add_trace(go.Scatter(
            x=pivot_group['epoch'],
            y=pivot_group['min_val_metric'],
            mode='lines',
            name=f'Min Val ({val})',
            line=dict(color=color_map[i], dash='dash')
        ))

        val_fig.add_trace(go.Scatter(
            x=pivot_group['epoch'],
            y=pivot_group['max_val_metric'],
            mode='lines',
            name=f'Max Val ({val})',
            line=dict(color=color_map[i], dash='dash')
        ))

    val_fig.update_layout(
        title=f"Validation {selected_metric.capitalize()} Summary Over Epochs",
        xaxis_title="Epoch",
        yaxis_title=f"Validation {selected_metric.capitalize()}"
    )

    return train_fig, val_fig

# Create summary figures for both train and validation data
train_fig, val_fig = create_summary_figures('loss', 'learning_rate')

# Create a subplot layout to contain both figures
combined_fig = make_subplots(rows=2, cols=1, subplot_titles=("Train Metrics", "Validation Metrics"))

# Add the train and validation figures to the combined figure
for trace in train_fig['data']:
    combined_fig.add_trace(trace, row=1, col=1)

for trace in val_fig['data']:
    combined_fig.add_trace(trace, row=2, col=1)

# Update layout of combined figure
combined_fig.update_layout(height=800, title_text="Training and Validation Metrics Summary")

# Export the combined figure to a single HTML file
pio.write_html(combined_fig, file="model_metrics_summary.html", auto_open=True)


In [17]:
import json
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
import plotly.express as px
from plotly.subplots import make_subplots

# Load the JSON data
with open("./output_plotting.json") as f:
    data = json.load(f)

# Function to flatten nested dictionary (dataset_params and model_params)
def flatten_dict(d, parent_key='', sep='_'):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

# Function to process the JSON data into a DataFrame
def process_data(data):
    rows = []
    for entry in data:
        # Flatten model_params and dataset_params and combine them
        flattened_params = {**flatten_dict(entry["model_params"]), **flatten_dict(entry["dataset_params"])}
        
        # Add train and validation metrics
        train_metrics = entry["train_metrics"]
        val_metrics = entry["val_metrics"]

        for epoch in range(entry["epochs"]):
            row = {
                'epoch': epoch + 1,
                'train_loss': train_metrics['loss'][epoch],
                'train_accuracy': train_metrics['accuracy'][epoch],
                'train_f1': train_metrics['f1'][epoch],
                'val_loss': val_metrics['loss'][epoch],
                'val_accuracy': val_metrics['accuracy'][epoch],
                'val_f1': val_metrics['f1'][epoch],
            }
            row.update(flattened_params)  # Include all flattened parameters in the row
            rows.append(row)
    
    return pd.DataFrame(rows)

# Convert the JSON data to a DataFrame
df = process_data(data)

# Get all unique parameter names (excluding metric columns)
pivot_options = [col for col in df.columns if col not in ['epoch', 'train_loss', 'train_accuracy', 'train_f1', 'val_loss', 'val_accuracy', 'val_f1']]

# Create figures for both summary and individual views
def create_summary_figures(selected_metric, pivot_by, view_mode):
    # Columns for train and validation metrics
    train_metric_col = f"train_{selected_metric}"
    val_metric_col = f"val_{selected_metric}"
    
    summary_df = df.groupby(['epoch', pivot_by]).agg(
        avg_train_metric=(train_metric_col, 'mean'),
        min_train_metric=(train_metric_col, 'min'),
        max_train_metric=(train_metric_col, 'max'),
        avg_val_metric=(val_metric_col, 'mean'),
        min_val_metric=(val_metric_col, 'min'),
        max_val_metric=(val_metric_col, 'max')
    ).reset_index()

    # Prepare figure for training metrics
    train_fig = go.Figure()

    # For each pivot_by group, add average, min, and max lines
    unique_pivot_vals = summary_df[pivot_by].unique()
    color_map = px.colors.qualitative.Plotly[:len(unique_pivot_vals)]

    if view_mode == 'summary':
        for i, val in enumerate(unique_pivot_vals):
            pivot_group = summary_df[summary_df[pivot_by] == val]

            # Add avg line
            train_fig.add_trace(go.Scatter(
                x=pivot_group['epoch'],
                y=pivot_group['avg_train_metric'],
                mode='lines',
                name=f'Avg Train ({val})',
                line=dict(color=color_map[i], width=2)
            ))

            # Add min and max as dashed lines
            train_fig.add_trace(go.Scatter(
                x=pivot_group['epoch'],
                y=pivot_group['min_train_metric'],
                mode='lines',
                name=f'Min Train ({val})',
                line=dict(color=color_map[i], dash='dash')
            ))

            train_fig.add_trace(go.Scatter(
                x=pivot_group['epoch'],
                y=pivot_group['max_train_metric'],
                mode='lines',
                name=f'Max Train ({val})',
                line=dict(color=color_map[i], dash='dash')
            ))
    else:
        # Plot individual performance (for full view)
        train_fig = px.line(df, x="epoch", y=train_metric_col, color=pivot_by, markers=True,
                            labels={
                                "epoch": "Epoch",
                                train_metric_col: f"Train {selected_metric.capitalize()}",
                                pivot_by: pivot_by.replace("_", " ").capitalize()
                            },
                            title=f"Train {selected_metric.capitalize()} Over Epochs")

    train_fig.update_layout(
        title=f"Train {selected_metric.capitalize()} Summary Over Epochs",
        xaxis_title="Epoch",
        yaxis_title=f"Train {selected_metric.capitalize()}"
    )

    # Prepare figure for validation metrics
    val_fig = go.Figure()

    if view_mode == 'summary':
        for i, val in enumerate(unique_pivot_vals):
            pivot_group = summary_df[summary_df[pivot_by] == val]

            # Add avg line
            val_fig.add_trace(go.Scatter(
                x=pivot_group['epoch'],
                y=pivot_group['avg_val_metric'],
                mode='lines',
                name=f'Avg Val ({val})',
                line=dict(color=color_map[i], width=2)
            ))

            # Add min and max as dashed lines
            val_fig.add_trace(go.Scatter(
                x=pivot_group['epoch'],
                y=pivot_group['min_val_metric'],
                mode='lines',
                name=f'Min Val ({val})',
                line=dict(color=color_map[i], dash='dash')
            ))

            val_fig.add_trace(go.Scatter(
                x=pivot_group['epoch'],
                y=pivot_group['max_val_metric'],
                mode='lines',
                name=f'Max Val ({val})',
                line=dict(color=color_map[i], dash='dash')
            ))
    else:
        # Plot individual performance (for full view)
        val_fig = px.line(df, x="epoch", y=val_metric_col, color=pivot_by, markers=True,
                          labels={
                              "epoch": "Epoch",
                              val_metric_col: f"Validation {selected_metric.capitalize()}",
                              pivot_by: pivot_by.replace("_", " ").capitalize()
                          },
                          title=f"Validation {selected_metric.capitalize()} Over Epochs")

    val_fig.update_layout(
        title=f"Validation {selected_metric.capitalize()} Summary Over Epochs",
        xaxis_title="Epoch",
        yaxis_title=f"Validation {selected_metric.capitalize()}"
    )

    return train_fig, val_fig

# Create figures with the chosen metric, pivot, and mode
selected_metric = 'loss'
pivot_by = 'learning_rate'
view_mode = 'summary'
train_fig, val_fig = create_summary_figures(selected_metric, pivot_by, view_mode)

# Create a subplot layout to contain both figures
combined_fig = make_subplots(rows=2, cols=1, subplot_titles=("Train Metrics", "Validation Metrics"))

# Add the train and validation figures to the combined figure
for trace in train_fig['data']:
    combined_fig.add_trace(trace, row=1, col=1)

for trace in val_fig['data']:
    combined_fig.add_trace(trace, row=2, col=1)

# Add dropdowns and buttons
combined_fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                dict(label="Loss",
                     method="update",
                     args=[{"y": [df['train_loss'], df['val_loss']]}]),
                dict(label="Accuracy",
                     method="update",
                     args=[{"y": [df['train_accuracy'], df['val_accuracy']]}]),
                dict(label="F1 Score",
                     method="update",
                     args=[{"y": [df['train_f1'], df['val_f1']]}])
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.17,
            xanchor="left",
            y=1.2,
            yanchor="top"
        ),
        dict(
            buttons=list([
                dict(label="Summary View",
                     method="relayout",
                     args=[{"visible": [True, True]}]),
                dict(label="Full View",
                     method="relayout",
                     args=[{"visible": [True, True]}]),
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.5,
            xanchor="left",
            y=1.2,
            yanchor="top"
        )
    ]
)

# Export the combined figure to a single HTML file
pio.write_html(combined_fig, file="model_metrics_summary_with_controls.html", auto_open=True)
