In [4]:
import os, io
import glob
import plotly.graph_objects as go
import numpy as np
import pandas as pd
import pywt

In [5]:
file_paths = glob.glob("../results/Elastic/**/*.csv", recursive = True)
len(file_paths)
for idx, file in enumerate(file_paths):
    print(idx, file)

0 ../results/Elastic\Elastic_001_primary_data.csv
1 ../results/Elastic\Elastic_002_primary_data.csv
2 ../results/Elastic\Elastic_003_primary_data.csv
3 ../results/Elastic\Elastic_004_primary_data.csv
4 ../results/Elastic\Elastic_005_primary_data.csv
5 ../results/Elastic\Elastic_006_primary_data.csv
6 ../results/Elastic\Elastic_007_primary_data.csv
7 ../results/Elastic\Elastic_008_primary_data.csv
8 ../results/Elastic\Elastic_009_primary_data.csv
9 ../results/Elastic\Elastic_010_primary_data.csv
10 ../results/Elastic\Elastic_011_primary_data.csv
11 ../results/Elastic\Elastic_012_primary_data.csv
12 ../results/Elastic\Elastic_019\Elastic_019_primary_data.csv
13 ../results/Elastic\Elastic_020\Elastic_020_primary_data.csv
14 ../results/Elastic\Elastic_021\Elastic_021_primary_data.csv


In [7]:
legends = ["0.025in", "0.020in", "0.015in"]
selected = [file_paths[i] for i in [2,3,4]]
rise_times = []
line_mode = ["solid", "longdash", "dashdot"]
color = ["blue", "red", "limegreen"]

strain_raw_fig_1 = go.Figure()

for idx, file in enumerate(selected):
    print(file)
    data = pd.read_csv(file)
    data.drop(labels='Unnamed: 0', axis=1, inplace=True)

    incident_pulse = data["Incident Raw"][10000:50000]
    incident_time = data["Time"][10000:50000]

    strain_raw_fig_1.add_trace(go.Scatter(x=incident_time[15000:], y=incident_pulse[15000:], mode='lines', name=legends[idx], 
                                                        line=dict(width=3, dash=line_mode[idx], color=color[idx])))
    
    # Normalize to max absolute peak (positive or negative)
    peak_val = np.min(incident_pulse) if np.abs(np.min(incident_pulse)) > np.max(incident_pulse) else np.max(incident_pulse)
        
    val_10 = 0.10 * peak_val
    val_90 = 0.85 * peak_val
    
    if peak_val < 0:
        # For negative pulses
        idx_10 = np.where(incident_pulse <= val_10)[0][0]
        idx_90 = np.where(incident_pulse <= val_90)[0][0]
    else:
        # For positive pulses
        idx_10 = np.where(incident_pulse >= val_10)[0][0]
        idx_90 = np.where(incident_pulse >= val_90)[0][0]
    
    incident_rise_time = incident_time[idx_90] - incident_time[idx_10]
    rise_times.append(incident_rise_time)

    
          
# Customize layout
strain_raw_fig_1.update_layout(width=1100, height=600, plot_bgcolor="#F5F5F5", paper_bgcolor="#FFFFFF",
                             title=dict(text="Strain vs Time (12in Striker)", x=0.5, y=0.95, xanchor="center",
                                        font=dict(size=20, color="black", family="Arial")),
                             xaxis=dict(title=dict(text="Time (ms)", font=dict(family="Arial", size=16, color="black")),
                                        tickfont=dict(size=16, color="black", family="Arial"), showgrid=True,
                                        gridcolor="lightgrey", gridwidth=1),
                             yaxis=dict(title=dict(text="Strain", font=dict(family="Arial", size=16, color="black")),
                                        tickfont=dict(size=16, color="black", family="Arial"), showgrid=True,
                                        gridwidth=1, gridcolor="lightgrey", zeroline=True, zerolinewidth=2, zerolinecolor="grey"),
                             legend=dict(title="Strain Pulses", x=1.0, y=1.0, font=dict(size=12, color="black", family="Arial"),
                                         bgcolor="#FFFFFF", bordercolor="black", borderwidth=2))

strain_raw_fig_1.add_annotation(
        text="Striker Size = 12 in <br>Striker Pressure = 60 psi <br>Momentum Trap = None",
        xref="paper", yref="paper", x=0, y=1, showarrow=False,
        font=dict(size=14, color="black", family="Arial"), align="left", bgcolor="white", bordercolor="black", borderwidth=2)

# Build the annotation text for pulse legend and rise times
rise_text = "Pulse Rise Times 10-85%:<br>"
for leg, rt in zip(legends, rise_times):
    rise_text += f"{leg}: {rt*1000:.3f} us<br>"

# Add the annotation in the bottom left corner
strain_raw_fig_1.add_annotation(
    text=rise_text,
    xref="paper", yref="paper",
    x=.11, y=.03,  # Bottom left
    font=dict(size=14, color="black", family="Arial"),
    align="left",
    bgcolor="white",
    bordercolor="black",
    borderwidth=2
)
        
# Display the interactive plot
fig_path = "../results/Elastic"
os.makedirs(fig_path, exist_ok=True)
strain_raw_fig_1.write_html(os.path.join(fig_path, "Striker12_signal_PS_study.html"), include_plotlyjs='cdn')

../results/Elastic\Elastic_003_primary_data.csv
../results/Elastic\Elastic_004_primary_data.csv
../results/Elastic\Elastic_005_primary_data.csv


In [93]:
legends = ["No Grease", "Grease", "PS 0.015in"]
selected = [file_paths[i] for i in [1,0,4]]
rise_times = []
line_mode = ["solid", "longdash", "dashdot"]
color = ["blue", "red", "limegreen"]

strain_raw_fig_1 = go.Figure()

for idx, file in enumerate(selected):
    print(file)
    data = pd.read_csv(file)
    data.drop(labels='Unnamed: 0', axis=1, inplace=True)

    incident_pulse = data["Incident Raw"][10000:50000]
    incident_time = data["Time"][10000:50000]

    strain_raw_fig_1.add_trace(go.Scatter(x=incident_time[15000:], y=incident_pulse[15000:], mode='lines', name=legends[idx], 
                                                        line=dict(width=3, dash=line_mode[idx], color=color[idx])))
    
    # Normalize to max absolute peak (positive or negative)
    peak_val = np.min(incident_pulse) if np.abs(np.min(incident_pulse)) > np.max(incident_pulse) else np.max(incident_pulse)
        
    val_10 = 0.10 * peak_val
    val_90 = 0.85 * peak_val
    
    if peak_val < 0:
        # For negative pulses
        idx_10 = np.where(incident_pulse <= val_10)[0][0]
        idx_90 = np.where(incident_pulse <= val_90)[0][0]
    else:
        # For positive pulses
        idx_10 = np.where(incident_pulse >= val_10)[0][0]
        idx_90 = np.where(incident_pulse >= val_90)[0][0]
    
    incident_rise_time = incident_time[idx_90] - incident_time[idx_10]
    rise_times.append(incident_rise_time)

    
          
# Customize layout
strain_raw_fig_1.update_layout(width=1100, height=600, plot_bgcolor="#F5F5F5", paper_bgcolor="#FFFFFF",
                             title=dict(text="Strain vs Time (12in Striker)", x=0.5, y=0.95, xanchor="center",
                                        font=dict(size=20, color="black", family="Arial")),
                             xaxis=dict(title=dict(text="Time (ms)", font=dict(family="Arial", size=16, color="black")),
                                        tickfont=dict(size=16, color="black", family="Arial"), showgrid=True,
                                        gridcolor="lightgrey", gridwidth=1),
                             yaxis=dict(title=dict(text="Strain", font=dict(family="Arial", size=16, color="black")),
                                        tickfont=dict(size=16, color="black", family="Arial"), showgrid=True,
                                        gridwidth=1, gridcolor="lightgrey", zeroline=True, zerolinewidth=2, zerolinecolor="grey"),
                             legend=dict(title="Strain Pulses", x=1.0, y=1.0, font=dict(size=12, color="black", family="Arial"),
                                         bgcolor="#FFFFFF", bordercolor="black", borderwidth=2))

strain_raw_fig_1.add_annotation(
        text="Striker Size = 12 in <br>Striker Pressure = 60 psi <br>Momentum Trap = None",
        xref="paper", yref="paper", x=0, y=1, showarrow=False,
        font=dict(size=14, color="black", family="Arial"), align="left", bgcolor="white", bordercolor="black", borderwidth=2)

# Build the annotation text for pulse legend and rise times
rise_text = "Pulse Rise Times 10-85%:<br>"
for leg, rt in zip(legends, rise_times):
    rise_text += f"{leg}: {rt*1000:.3f} us<br>"

# Add the annotation in the bottom left corner
strain_raw_fig_1.add_annotation(
    text=rise_text,
    xref="paper", yref="paper",
    x=.11, y=.03,  # Bottom left
    font=dict(size=14, color="black", family="Arial"),
    align="left",
    bgcolor="white",
    bordercolor="black",
    borderwidth=2
)
        
# Display the interactive plot
fig_path = "../results/Elastic"
os.makedirs(fig_path, exist_ok=True)
strain_raw_fig_1.write_html(os.path.join(fig_path, "Striker12_signal_study.html"), include_plotlyjs='cdn')

../results/Elastic\Elastic_002_primary_data.csv
../results/Elastic\Elastic_001_primary_data.csv
../results/Elastic\Elastic_005_primary_data.csv


In [94]:
legends = ["No Grease", "Grease", "PS 0.015in"]
selected = [file_paths[i] for i in [7,6,9]]
rise_times = []
line_mode = ["solid", "longdash", "dashdot"]
color = ["blue", "red", "limegreen"]

strain_raw_fig_2 = go.Figure()

for idx, file in enumerate(selected):
    print(file)
    data = pd.read_csv(file)
    data.drop(labels='Unnamed: 0', axis=1, inplace=True)

    incident_pulse = data["Incident Raw"][10000:50000]
    incident_time = data["Time"][10000:50000]

    strain_raw_fig_2.add_trace(go.Scatter(x=incident_time[15000:], y=incident_pulse[15000:], mode='lines', name=legends[idx], 
                                                        line=dict(width=3, dash=line_mode[idx], color=color[idx])))
    
    # Normalize to max absolute peak (positive or negative)
    peak_val = np.min(incident_pulse) 
        
    val_10 = 0.10 * peak_val
    val_90 = 0.85 * peak_val
    
    if peak_val < 0:
        # For negative pulses
        idx_10 = np.where(incident_pulse <= val_10)[0][0]
        idx_90 = np.where(incident_pulse <= val_90)[0][0]
    else:
        # For positive pulses
        idx_10 = np.where(incident_pulse >= val_10)[0][0]
        idx_90 = np.where(incident_pulse >= val_90)[0][0]
    
    incident_rise_time = incident_time[idx_90] - incident_time[idx_10]
    rise_times.append(incident_rise_time)

    
          
# Customize layout
strain_raw_fig_2.update_layout(width=1100, height=600, plot_bgcolor="#F5F5F5", paper_bgcolor="#FFFFFF",
                             title=dict(text="Strain vs Time (12in Striker)", x=0.5, y=0.95, xanchor="center",
                                        font=dict(size=20, color="black", family="Arial")),
                             xaxis=dict(title=dict(text="Time (ms)", font=dict(family="Arial", size=16, color="black")),
                                        tickfont=dict(size=16, color="black", family="Arial"), showgrid=True,
                                        gridcolor="lightgrey", gridwidth=1),
                             yaxis=dict(title=dict(text="Strain", font=dict(family="Arial", size=16, color="black")),
                                        tickfont=dict(size=16, color="black", family="Arial"), showgrid=True,
                                        gridwidth=1, gridcolor="lightgrey", zeroline=True, zerolinewidth=2, zerolinecolor="grey"),
                             legend=dict(title="Strain Pulses", x=1.0, y=1.0, font=dict(size=12, color="black", family="Arial"),
                                         bgcolor="#FFFFFF", bordercolor="black", borderwidth=2))

strain_raw_fig_2.add_annotation(
        text="Striker Size = 12 in <br>Striker Pressure = 60 psi <br>Momentum Trap = Full",
        xref="paper", yref="paper", x=0, y=1, showarrow=False,
        font=dict(size=14, color="black", family="Arial"), align="left", bgcolor="white", bordercolor="black", borderwidth=2)

# Build the annotation text for pulse legend and rise times
rise_text = "Pulse Rise Times 10-85%:<br>"
for leg, rt in zip(legends, rise_times):
    rise_text += f"{leg}: {rt*1000:.3f} us<br>"

# Add the annotation in the bottom left corner
strain_raw_fig_2.add_annotation(
    text=rise_text,
    xref="paper", yref="paper",
    x=0.11, y=.03,  # Bottom left
    font=dict(size=14, color="black", family="Arial"),
    align="left",
    bgcolor="white",
    bordercolor="black",
    borderwidth=2
)
        
# Display the interactive plot
fig_path = "../results/Elastic"
os.makedirs(fig_path, exist_ok=True)
strain_raw_fig_2.write_html(os.path.join(fig_path, "Striker12_MT_signal_study.html"), include_plotlyjs='cdn')

../results/Elastic\Elastic_008_primary_data.csv
../results/Elastic\Elastic_007_primary_data.csv
../results/Elastic\Elastic_010_primary_data.csv


In [95]:
legends = ["No Grease", "Grease", "PS 0.015in"]
selected = [file_paths[i] for i in [12,13,14]]
rise_times = []

line_mode = ["solid", "longdash", "dashdot"]
color = ["blue", "red", "limegreen"]

strain_raw_fig_3 = go.Figure()

for idx, file in enumerate(selected):
    print(file)
    data = pd.read_csv(file)
    data.drop(labels='Unnamed: 0', axis=1, inplace=True)

    incident_pulse = data["Incident Raw"][10000:53000]
    incident_time = data["Time"][10000:53000]

    strain_raw_fig_3.add_trace(go.Scatter(x=incident_time[15000:], y=incident_pulse[15000:], mode='lines', name=legends[idx], 
                                                        line=dict(width=3, dash=line_mode[idx], color=color[idx])))
    
    # Normalize to max absolute peak (positive or negative)
    peak_val = np.min(incident_pulse) 
        
    val_10 = 0.10 * peak_val
    val_90 = 0.85 * peak_val
    
    if peak_val < 0:
        # For negative pulses
        idx_10 = np.where(incident_pulse <= val_10)[0][0]
        idx_90 = np.where(incident_pulse <= val_90)[0][0]
    else:
        # For positive pulses
        idx_10 = np.where(incident_pulse >= val_10)[0][0]
        idx_90 = np.where(incident_pulse >= val_90)[0][0]
    
    incident_rise_time = incident_time[idx_90] - incident_time[idx_10]
    rise_times.append(incident_rise_time)

    
          
# Customize layout
strain_raw_fig_3.update_layout(width=1100, height=600, plot_bgcolor="#F5F5F5", paper_bgcolor="#FFFFFF",
                             title=dict(text="Strain vs Time (18in Striker)", x=0.5, y=0.95, xanchor="center",
                                        font=dict(size=20, color="black", family="Arial")),
                             xaxis=dict(title=dict(text="Time (ms)", font=dict(family="Arial", size=16, color="black")),
                                        tickfont=dict(size=16, color="black", family="Arial"), showgrid=True,
                                        gridcolor="lightgrey", gridwidth=1),
                             yaxis=dict(title=dict(text="Strain", font=dict(family="Arial", size=16, color="black")),
                                        tickfont=dict(size=16, color="black", family="Arial"), showgrid=True,
                                        gridwidth=1, gridcolor="lightgrey", zeroline=True, zerolinewidth=2, zerolinecolor="grey"),
                             legend=dict(title="Strain Pulses", x=1.0, y=1.0, font=dict(size=12, color="black", family="Arial"),
                                         bgcolor="#FFFFFF", bordercolor="black", borderwidth=2))

strain_raw_fig_3.add_annotation(
        text="Striker Size = 18 in <br>Striker Pressure = 90 psi <br>Momentum Trap = Full",
        xref="paper", yref="paper", x=0, y=1, showarrow=False,
        font=dict(size=14, color="black", family="Arial"), align="left", bgcolor="white", bordercolor="black", borderwidth=2)

# Build the annotation text for pulse legend and rise times
rise_text = "Pulse Rise Times 10-85%:<br>"
for leg, rt in zip(legends, rise_times):
    rise_text += f"{leg}: {rt*1000:.3f} us<br>"

# Add the annotation in the bottom left corner
strain_raw_fig_3.add_annotation(
    text=rise_text,
    xref="paper", yref="paper",
    x=0.11, y=.03,  # Bottom left
    font=dict(size=14, color="black", family="Arial"),
    align="left",
    bgcolor="white",
    bordercolor="black",
    borderwidth=2
)
        
# Display the interactive plot
fig_path = "../results/Elastic"
os.makedirs(fig_path, exist_ok=True)
strain_raw_fig_3.write_html(os.path.join(fig_path, "Striker18_signal_study.html"), include_plotlyjs='cdn')

../results/Elastic\Elastic_019\Elastic_019_primary_data.csv
../results/Elastic\Elastic_020\Elastic_020_primary_data.csv
../results/Elastic\Elastic_021\Elastic_021_primary_data.csv


In [96]:
legends = ["12in No MT", "12in MT", "18in No MT"]
selected = [file_paths[i] for i in [4,9,14]]
rise_times = []

line_mode = ["solid", "longdash", "dashdot"]
color = ["blue", "red", "limegreen"]

strain_raw_fig_4 = go.Figure()

for idx, file in enumerate(selected):
    print(file)
    data = pd.read_csv(file)
    data.drop(labels='Unnamed: 0', axis=1, inplace=True)

    incident_pulse = data["Incident Raw"][10000:53000]
    incident_time = data["Time"][10000:53000]

    strain_raw_fig_4.add_trace(go.Scatter(x=incident_time[15000:], y=incident_pulse[15000:], mode='lines', name=legends[idx], 
                                                        line=dict(width=3, dash=line_mode[idx], color=color[idx])))
    
    # Normalize to max absolute peak (positive or negative)
    peak_val = np.min(incident_pulse) 
        
    val_10 = 0.10 * peak_val
    val_90 = 0.85 * peak_val
    
    if peak_val < 0:
        # For negative pulses
        idx_10 = np.where(incident_pulse <= val_10)[0][0]
        idx_90 = np.where(incident_pulse <= val_90)[0][0]
    else:
        # For positive pulses
        idx_10 = np.where(incident_pulse >= val_10)[0][0]
        idx_90 = np.where(incident_pulse >= val_90)[0][0]
    
    incident_rise_time = incident_time[idx_90] - incident_time[idx_10]
    rise_times.append(incident_rise_time)

    
          
# Customize layout
strain_raw_fig_4.update_layout(width=1100, height=600, plot_bgcolor="#F5F5F5", paper_bgcolor="#FFFFFF",
                             title=dict(text="Strain vs Time (0.015in PS)", x=0.5, y=0.95, xanchor="center",
                                        font=dict(size=20, color="black", family="Arial")),
                             xaxis=dict(title=dict(text="Time (ms)", font=dict(family="Arial", size=16, color="black")),
                                        tickfont=dict(size=16, color="black", family="Arial"), showgrid=True,
                                        gridcolor="lightgrey", gridwidth=1),
                             yaxis=dict(title=dict(text="Strain", font=dict(family="Arial", size=16, color="black")),
                                        tickfont=dict(size=16, color="black", family="Arial"), showgrid=True,
                                        gridwidth=1, gridcolor="lightgrey", zeroline=True, zerolinewidth=2, zerolinecolor="grey"),
                             legend=dict(title="Strain Pulses", x=1.0, y=1.0, font=dict(size=12, color="black", family="Arial"),
                                         bgcolor="#FFFFFF", bordercolor="black", borderwidth=2))

#strain_raw_fig.add_annotation(
#        text="Striker Size = 18 in <br>Striker Pressure = 90 psi <br>Momentum Trap = Full",
#        xref="paper", yref="paper", x=0, y=1, showarrow=False,
#        font=dict(size=14, color="black", family="Arial"), align="left", bgcolor="white", bordercolor="black", borderwidth=2)

# Build the annotation text for pulse legend and rise times
rise_text = "Pulse Rise Times 10-85%:<br>"
for leg, rt in zip(legends, rise_times):
    rise_text += f"{leg}: {rt*1000:.3f} us<br>"

# Add the annotation in the bottom left corner
strain_raw_fig_4.add_annotation(
    text=rise_text,
    xref="paper", yref="paper",
    x=0.11, y=.03,  # Bottom left
    font=dict(size=14, color="black", family="Arial"),
    align="left",
    bgcolor="white",
    bordercolor="black",
    borderwidth=2
)
        
# Display the interactive plot
fig_path = "../results/Elastic"
os.makedirs(fig_path, exist_ok=True)
strain_raw_fig_4.write_html(os.path.join(fig_path, "Striker_PS_study.html"), include_plotlyjs='cdn')

../results/Elastic\Elastic_005_primary_data.csv
../results/Elastic\Elastic_010_primary_data.csv
../results/Elastic\Elastic_021\Elastic_021_primary_data.csv


In [100]:
from plotly.subplots import make_subplots
fig = make_subplots(rows=2, cols=2, shared_xaxes=True, 
                    subplot_titles=("12in Striker", "12in Striker MT", "18in Striker", "0.015in PS"))

# Helper list
existing_figures = [strain_raw_fig_1, strain_raw_fig_2, strain_raw_fig_3, strain_raw_fig_4]

# Loop over figures and add their traces
for i, single_fig in enumerate(existing_figures):
    row = i // 2 + 1
    col = i % 2 + 1
    for trace in single_fig.data:
        fig.add_trace(trace, row=row, col=col)

# Final layout customization
fig.update_layout(height=800, width=1600, title_text="SHPB Strain vs Time - Multi View")



# Save as HTML
fig.write_html(os.path.join(fig_path, "all_study.html"), include_plotlyjs='cdn')

In [113]:
from plotly.graph_objects import Scatter, Figure
import plotly.graph_objects as go

def add_figure_to_subplot(fig, source_fig: Figure, row: int, col: int, subplot_index: int, showlegend=True):
    """
    Transfer traces, annotations, and axes from a source figure into a subplot.
    
    Parameters:
    - fig           : Target subplot figure from make_subplots
    - source_fig    : Existing Plotly figure to transfer content from
    - row, col      : Target subplot location
    - subplot_index : Index for axis reference (1, 2, ...)
    - showlegend    : Whether to keep legend entries visible
    """
    # --- 1. Add all traces to subplot ---
    for trace in source_fig.data:
        trace_dict = trace.to_plotly_json()
        trace_dict["showlegend"] = trace.showlegend if showlegend else False
        fig.add_trace(Scatter(**trace_dict), row=row, col=col)

    # --- 2. Transfer annotations (xref/yref must match subplot index) ---
    if hasattr(source_fig.layout, "annotations"):
        for ann in source_fig.layout.annotations:
            ann_json = ann.to_plotly_json()
            ann_json["xref"] = f"x{subplot_index if subplot_index > 1 else ''}"
            ann_json["yref"] = f"y{subplot_index if subplot_index > 1 else ''}"
            fig.add_annotation(ann_json)

    # --- 3. Transfer axis titles and settings (but do NOT force ranges!) ---
    xaxis_key = f"xaxis{subplot_index if subplot_index > 1 else ''}"
    yaxis_key = f"yaxis{subplot_index if subplot_index > 1 else ''}"
    
    xaxis = getattr(source_fig.layout, xaxis_key, None)
    yaxis = getattr(source_fig.layout, yaxis_key, None)

    if xaxis:
        fig.update_xaxes(
            title_text=getattr(xaxis.title, 'text', None),
            showgrid=xaxis.showgrid,
            gridcolor=getattr(xaxis, "gridcolor", "lightgrey"),
            zeroline=getattr(xaxis, "zeroline", True),
            zerolinecolor=getattr(xaxis, "zerolinecolor", "grey"),
            row=row, col=col,
        )

    if yaxis:
        fig.update_yaxes(
            title_text=getattr(yaxis.title, 'text', None),
            showgrid=yaxis.showgrid,
            gridcolor=getattr(yaxis, "gridcolor", "lightgrey"),
            zeroline=getattr(yaxis, "zeroline", True),
            zerolinecolor=getattr(yaxis, "zerolinecolor", "grey"),
            row=row, col=col,
        )

    # ⚠️ Do NOT set `range=` for autoscale to work properly


In [114]:
from plotly.subplots import make_subplots

titles = ["12in Striker", "12in MT", "18in", "0.015in PS"]
source_figs = [strain_raw_fig_1, strain_raw_fig_2, strain_raw_fig_3, strain_raw_fig_4]  # Your pre-made figures

fig = make_subplots(rows=2, cols=2, subplot_titles=titles)

for i, f in enumerate(source_figs):
    row = i // 2 + 1
    col = i % 2 + 1
    subplot_index = i + 1
    # Show legend only for first subplot (or control as needed)
    add_figure_to_subplot(fig, f, row, col, subplot_index, showlegend=(i == 0))

fig.update_layout(
    height=1000,
    width=1200,
    title="SHPB Pulse Multi-View",
    font=dict(family="Arial", size=14),
    margin=dict(t=80, b=50, l=50, r=50)
)

fig.write_html("final_multiview.html", include_plotlyjs="cdn")
