In [4]:
from bs4 import BeautifulSoup
import os
import re

def parse_viewbox(svg):
    """ Extracts or estimates the viewBox from an SVG element, handling units or missing data gracefully. """
    viewbox = svg.get('viewBox')
    if viewbox:
        return list(map(float, viewbox.split()))
    width, height = svg.get('width', '100'), svg.get('height', '100')
    # Extract numbers only, assuming default conversions if needed.
    width = float(re.findall(r"[\d.]+", width)[0]) if re.findall(r"[\d.]+", width) else 100
    height = float(re.findall(r"[\d.]+", height)[0]) if re.findall(r"[\d.]+", height) else 100
    return [0, 0, width, height]

def calculate_canvas_size(svgs):
    """ Calculates the size of the canvas based on the maximum dimensions of the SVGs. """
    max_width = max_height = 0
    for svg in svgs:
        _, _, width, height = parse_viewbox(svg)
        max_width = max(max_width, width)
        max_height = max(max_height, height)
    return [0, 0, max_width, max_height]

def save_plots_as_svg(html_file_path, output_directory, div_class):
    os.makedirs(output_directory, exist_ok=True)  # Ensure the output directory exists
    with open(html_file_path, 'r', encoding='utf-8') as file:
        html_content = file.read()
    soup = BeautifulSoup(html_content, 'lxml')
    divs = soup.find_all('div', class_=div_class)
    for idx, div in enumerate(divs, start=1):
        # Filter out SVGs that are within any 'modebar-container' inside this div
        svgs = [svg for svg in div.find_all('svg') if 'modebar-container' not in {parent['class'][0] for parent in svg.parents if parent.has_attr('class')}]
        if not svgs:
            continue
        canvas_viewbox = calculate_canvas_size(svgs)
        merged_svg_content = f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="{canvas_viewbox[0]} {canvas_viewbox[1]} {canvas_viewbox[2]} {canvas_viewbox[3]}" style="overflow: hidden;">\n'
        for svg in svgs:
            _, _, width, height = parse_viewbox(svg)
            x_translate = (canvas_viewbox[2] - width) / 2
            y_translate = (canvas_viewbox[3] - height) / 2
            svg_str = str(svg)
            start_pos = svg_str.find('>') + 1
            end_pos = svg_str.rfind('</svg>')
            if start_pos == -1 or end_pos == -1:
                continue  # Skip if tags are malformed
            svg_content = svg_str[start_pos:end_pos]
            transform = f'translate({x_translate}, {y_translate})'
            merged_svg_content += f'<g transform="{transform}">\n{svg_content}\n</g>\n'
        merged_svg_content += '</svg>'
        output_file_path = os.path.join(output_directory, f'plot_{idx}.svg')
        with open(output_file_path, 'w', encoding='utf-8') as file:
            file.write(merged_svg_content)
        print(f"Plot {idx} has been saved to: {output_file_path}")
    if len(divs) == 0:
        print("No divs found with the class", div_class)

# Usage example
save_plots_as_svg("SEC Train Test Validation Suite.html", 'saved_plots', 'plot-container plotly')


Plot 1 has been saved to: saved_plots\plot_1.svg
Plot 2 has been saved to: saved_plots\plot_2.svg
Plot 3 has been saved to: saved_plots\plot_3.svg
Plot 4 has been saved to: saved_plots\plot_4.svg
Plot 5 has been saved to: saved_plots\plot_5.svg
Plot 6 has been saved to: saved_plots\plot_6.svg
Plot 7 has been saved to: saved_plots\plot_7.svg
Plot 8 has been saved to: saved_plots\plot_8.svg
Plot 9 has been saved to: saved_plots\plot_9.svg
Plot 10 has been saved to: saved_plots\plot_10.svg
Plot 11 has been saved to: saved_plots\plot_11.svg
Plot 12 has been saved to: saved_plots\plot_12.svg
Plot 13 has been saved to: saved_plots\plot_13.svg
Plot 14 has been saved to: saved_plots\plot_14.svg
Plot 15 has been saved to: saved_plots\plot_15.svg
Plot 16 has been saved to: saved_plots\plot_16.svg
