In [13]:
import base64
from xml.etree import ElementTree as ET
from io import BytesIO
from PIL import Image
from IPython.display import SVG, display
from pathlib import Path

In [None]:
def compress_svg_images(svg_input_path, quality=85):
    ET.register_namespace('', 'http://www.w3.org/2000/svg')
    ET.register_namespace('xlink', 'http://www.w3.org/1999/xlink')
    tree = ET.parse(svg_input_path)
    root = tree.getroot()

    for elem in root.findall('.//{http://www.w3.org/2000/svg}image'):
        href = elem.get('{http://www.w3.org/1999/xlink}href')
        if href and href.startswith('data:image'):
            _, b64data = href.split(',', 1)
            raw_data = base64.b64decode(b64data)

            with BytesIO(raw_data) as img_io:
                img = Image.open(img_io).convert('RGB')
                # Check if image has alpha channel (RGBA)
                if img.mode == 'RGBA':
                    print("Warning: Image contains alpha channel")
                    # Get alpha channel data
                    alpha = img.getchannel('A')
                    # Check if alpha channel contains any transparency
                    if alpha.getextrema()[0] < 255:
                        print("Warning: Image contains transparency that will be lost in JPEG conversion")
                compressed_io = BytesIO()
                img.save(compressed_io, format='JPEG', quality=quality)
            
            new_data = base64.b64encode(compressed_io.getvalue()).decode()
            # print(f'size reduced from {len(raw_data)} to {len(new_data)}')
            if len(new_data) < len(raw_data) * 0.9 - 10_000:
                print(f'-> compressing image - saving {len(raw_data) - len(new_data)} bytes')
                elem.set('{http://www.w3.org/1999/xlink}href', f'data:image/jpeg;base64,{new_data}')
    output = BytesIO()
    tree.write(output, encoding='utf-8', xml_declaration=True)
    return output.getvalue().decode('utf-8')


base_path = Path('thesis/assets/cached_plots/')
for svg_input_path in base_path.glob('*.svg'):
    svg_file = open(svg_input_path, 'r').read()
    # display(SVG(svg_file))
    print(f'original size: {len(svg_file)}')
    svg_file_compressed = compress_svg_images(svg_input_path)
    print(f'compressed size: {len(svg_file_compressed)}')
    display(SVG(svg_file_compressed))