In [17]:
import json
import plotly.graph_objects as go
from collections import defaultdict
from typing import Union, List


def plot_binary_breakdown(
    json_path: Union[str, dict],
    top_runtime_count: int = 5,
    top_hacl_other_count: int = 0,
    max_height: int = 1000,
    output_pdf: Union[str, None] = None,
    group: List[str] = None
):
    """
    Plot a Sankey diagram showing the breakdown of HACL* and runtime functions in a binary,
    with each node labeled by its name and size in KB.
    """

    # Load data
    data = json.load(open(json_path)) if isinstance(json_path, str) else json_path

    # Classify HACL functions
    def classify(name):
        n = name.lower()
        if "md5" in name:       return "HACL_MD5"
        if "sha2" in n:         return "HACL_SHA2"
        if "sha3" in n:         return "HACL_SHA3"
        if "blake2" in name:    return "HACL_Blake2"
        if "chacha20" in n:     return "HACL_Chacha20"
        if "poly1305" in name:  return "HACL_Poly1305"
        return "HACL_Other"

    # Helpers for Sankey internals
    labels, lm = [], {}
    def gl(name):
        if name not in lm:
            lm[name] = len(labels)
            labels.append(name)
        return lm[name]

    sources, targets, values = [], [], []

    # Split out functions
    hacl_funcs    = [f for f in data if f.get("is_hacl")]
    edger8r_funcs = [f for f in data if f.get("is_edger8r")]
    runtime_funcs = [f for f in data if not f.get("is_hacl") and not f.get("is_edger8r")]

    # Level-0 → HACL* & Runtime
    root = gl("Executable Sections Binary")
    hn   = gl("HACL*")
    rn   = gl("Runtime")
    total_h = sum(f["size_bytes"] for f in hacl_funcs)
    total_r = (
        sum(f["size_bytes"] for f in runtime_funcs) +
        sum(f["size_bytes"] for f in edger8r_funcs)
    )
    sources += [root, root]
    targets += [hn, rn]
    values  += [total_h, total_r]

    # Runtime subtree
    grouped, ungrouped = defaultdict(list), []
    for f in runtime_funcs:
        for p in (group or []):
            if f["name"].startswith(p):
                grouped[p].append(f)
                break
        else:
            ungrouped.append(f)

    ungrouped.sort(key=lambda x: x["size_bytes"], reverse=True)
    top_rt   = ungrouped[:top_runtime_count]
    other_rt = sum(f["size_bytes"] for f in ungrouped[top_runtime_count:])

    for prefix, fs in grouped.items():
        n = gl(prefix)
        sources.append(rn); targets.append(n)
        values.append(sum(f["size_bytes"] for f in fs))

    for f in top_rt:
        n = gl(f["name"])
        sources.append(rn); targets.append(n)
        values.append(f["size_bytes"])

    if other_rt:
        n = gl("Other Runtime")
        sources.append(rn); targets.append(n)
        values.append(other_rt)

    # Now include Edger8r as a child of Runtime
    if edger8r_funcs:
        n = gl("Edger8r")
        sources.append(rn); targets.append(n)
        values.append(sum(f["size_bytes"] for f in edger8r_funcs))

    # HACL subtree
    by_cat = defaultdict(list)
    for f in hacl_funcs:
        by_cat[classify(f["name"])].append(f)

    order = [
        "HACL_MD5","HACL_SHA2","HACL_SHA3",
        "HACL_Blake2","HACL_Chacha20","HACL_Poly1305","HACL_Other"
    ]
    for cat in order:
        fs = by_cat.get(cat, [])
        if not fs: continue
        cn = gl(cat)
        sources.append(hn); targets.append(cn)
        cs = sum(f["size_bytes"] for f in fs)
        values.append(cs)

        if cat == "HACL_Other" and top_hacl_other_count > 0:
            fs.sort(key=lambda x: x["size_bytes"], reverse=True)
            top_fs = fs[:top_hacl_other_count]
            rest   = sum(f["size_bytes"] for f in fs[top_hacl_other_count:])
            for f in top_fs:
                n = gl(f["name"])
                sources.append(cn); targets.append(n)
                values.append(f["size_bytes"])
            if rest:
                n = gl("Other HACL_Other")
                sources.append(cn); targets.append(n)
                values.append(rest)

    # Compute each node's total size
    size_map = {}
    for i in range(len(labels)):
        out = sum(v for s,t,v in zip(sources, targets, values) if s == i)
        size_map[i] = out if out else sum(v for s,t,v in zip(sources, targets, values) if t == i)

    # Build labels with size inline
    display_labels = [
        f"{labels[i]} ({size_map[i]/1024:.1f} KB)" 
        for i in range(len(labels))
    ]

    # Plot
    fig = go.Figure(go.Sankey(
        arrangement="perpendicular",
        node=dict(
            label=display_labels,
            pad=15,
            thickness=30,
            line=dict(color="black", width=0.5)
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values,
            color="lightgrey"
        )
    ))
    fig.update_layout( 
        height=max_height,
        font=dict(size=25)
    )

    config = {
        'toImageButtonOptions': {
            'format': 'png',
            'filename': 'custom_image',
            'scale': 6
        }
    }
    fig.show(config=config)
    if output_pdf:
        fig.write_image(output_pdf)


In [18]:
plot_binary_breakdown("bare_functions.json", top_runtime_count=5, top_hacl_other_count=0)

In [19]:
plot_binary_breakdown("sdk_functions.json", top_runtime_count=5, output_pdf="my_sankey_output.pdf")

In [21]:
plot_binary_breakdown("oe_functions.json", top_runtime_count=4, output_pdf="my_sankey_output.pdf",group=["mbedtls"])

In [16]:
import json

# Paths to your JSON files (edit these if needed)
sdk_path = "sdk_functions.json"
bare_path = "bare_functions.json"

# Load JSON
with open(sdk_path) as f:
    sdk_funcs = json.load(f)

with open(bare_path) as f:
    bare_funcs = json.load(f)

# Extract HACL* function names
sdk_hacl = {f["name"] for f in sdk_funcs if f.get("is_hacl")}
bare_hacl = {f["name"] for f in bare_funcs if f.get("is_hacl")}

# Compute differences
only_in_sdk = sorted(sdk_hacl - bare_hacl)
only_in_bare = sorted(bare_hacl - sdk_hacl)

# Display results
print("=== HACL* functions only in SDK ===")
for name in only_in_sdk:
    print(" ", name)

print("\n=== HACL* functions only in Bare ===")
for name in only_in_bare:
    print(" ", name)

print("\n[SUMMARY]")
print(f"Total HACL* in SDK:  {len(sdk_hacl)}")
print(f"Total HACL* in Bare: {len(bare_hacl)}")
print(f"Only in SDK:         {len(only_in_sdk)}")
print(f"Only in Bare:        {len(only_in_bare)}")


=== HACL* functions only in SDK ===
  add_scalar_e
  aes128_key_expansion
  aes128_keyhash_init
  aes256_key_expansion
  aes256_keyhash_init
  check_adx_bmi2
  check_aesni
  check_avx
  check_avx2
  check_avx512
  check_avx512_xcr0
  check_avx_xcr0
  check_movbe
  check_osxsave
  check_rdrand
  check_sha
  check_sse
  compute_iv_stdcall
  cswap2_e
  fadd_e
  fmul2_e
  fmul_e
  fmul_scalar_e
  fsqr2_e
  fsqr_e
  fsub_e
  gcm128_decrypt_opt
  gcm128_encrypt_opt
  gcm256_decrypt_opt
  gcm256_encrypt_opt
  gctr128_bytes
  gctr256_bytes
  x64_poly1305

=== HACL* functions only in Bare ===
  Hacl_Hash_SHA3_keccak_piln
  Hacl_Hash_SHA3_keccak_rndc
  Hacl_Hash_SHA3_keccak_rotc
  Hacl_Impl_Chacha20_Vec_chacha20_constants

[SUMMARY]
Total HACL* in SDK:  203
Total HACL* in Bare: 174
Only in SDK:         33
Only in Bare:        4
