# Complexity vs MSE and Rank Graphs


In [1]:
%pip install safetensors

Note: you may need to restart the kernel to use updated packages.


In [82]:
from safetensors.numpy import load_file
from pathlib import Path
from loguru import logger
import orjson
from rich import pretty
import plotly.graph_objects as go
from faker import Faker
from typing import Literal

# from IPython.display import display
import numpy as np

np.set_printoptions(
    precision=3,
    suppress=True,
    threshold=5,
)
rng = np.random.default_rng(0)
faker = Faker()

In [12]:
PROJ_ROOT = Path.cwd().resolve().parent
logger.info(PROJ_ROOT)
DATA_DIR = PROJ_ROOT / "data"
logger.info(DATA_DIR)
MLP_MSE_VALIDATION = DATA_DIR / "mlp_mse_validation"
MLP_MSE_VALIDATION_TENSORS = MLP_MSE_VALIDATION / "validation_mses.safetensors"
MLP_MSE_VALIDATION_META = MLP_MSE_VALIDATION / "validation_mses.json"

[32m2024-12-13 14:18:35.910[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m2[0m - [1m/Users/gat/work/FA2024/embedding_translation[0m
[32m2024-12-13 14:18:35.911[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m4[0m - [1m/Users/gat/work/FA2024/embedding_translation/data[0m


In [26]:
tensors = load_file(MLP_MSE_VALIDATION_TENSORS)
logger.info(list(tensors.keys()))
validation_mse = tensors["validation_mses"]
logger.info(tensors["validation_mses"].shape)

[2;36m[12/13/24 14:33:27][0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m[1m[[0m[1;36m3396536193.[0mpy:[1;36m2[0m[1m][0m                   ]8;id=747558;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/3396536193.py\[1;36m3396536193.py[0m]8;;\[1;36m:[0m]8;id=363556;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/3396536193.py#2\[1;36m2[0m]8;;\
[2;36m                    [0m[1;32m         [0m[1m[[0m[32m'validation_mses'[0m[1m][0m                 [1;36m               [0m
[2;36m                   [0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m[1m[[0m[1;36m3396536193.[0mpy:[1;36m4[0m[1m][0m [1m([0m[1;36m6[0m, [1;36m17[0m, [1;36m17[0m, [1;36m6[0m[1m)[0m    ]8;id=420402;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/3396536193.py\[1;36m3396536193.py[0m]8;;\[1;36m:[0m]8;id=307068;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/3396536193

In [28]:
with MLP_MSE_VALIDATION_META.open() as f:
    meta = orjson.loads(f.read())

pretty.pprint(meta)
# 6 Datasets (ArguAna, FiQA, SciDocs, NFCorpus, HotPotQA, Trec-COVID)
# 17 Native Embedding Spaces
# 17 Stitched Embedding Spaces
# 6 Architectures (2, 3, 4, 5, 6, or 7 layers)

In [56]:
# Plot heatmap for one architecture
dataset_i = 0
architecture_i = 0

example_validation_mse = validation_mse[dataset_i, :, :, architecture_i]
logger.info(example_validation_mse.shape)
pretty.pprint(example_validation_mse)

[2;36m[12/13/24 14:53:57][0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m[1m[[0m[1;36m2716447018.[0mpy:[1;36m8[0m[1m][0m [1m([0m[1;36m17[0m, [1;36m17[0m[1m)[0m          ]8;id=34588;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/2716447018.py\[1;36m2716447018.py[0m]8;;\[1;36m:[0m]8;id=443200;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/2716447018.py#8\[1;36m8[0m]8;;\


In [99]:
rng = np.random.default_rng(0)


def heatmap(
    matrix: np.ndarray | None = None,
    labels: list | None = None,
    title: str = "DEFAULT TITLE",
    show_values: bool = True,
    value_format: Literal[".2f", ".2e"] = ".2f",
    width: int = 800,
    height: int = 600,
    xaxis_title: str | None = None,
    yaxis_title: str | None = None,
    log: bool = False,
    nan_diagnal: bool = False,
):
    if matrix is None:
        faker = Faker()
        shape = (10, 10)
        matrix = rng.random(shape)
        # logger.info(matrix)
    else:
        shape = matrix.shape

    if labels is None:
        labels = [faker.name() for _ in range(shape[0])]
        # logger.info("Using example labels and matrix")
        logger.info(labels)
    col_labels = labels
    row_labels = labels
    logger.info("Creating heatmap visualization...")

    # Convert input to numpy array and handle None values
    matrix = np.array(matrix, dtype=object)
    none_mask = matrix is None
    matrix = matrix.astype(float)
    if log:
        matrix = np.log10(matrix)
    matrix[none_mask] = np.nan
    if nan_diagnal:
        matrix[np.identity(matrix.shape[0], dtype=bool)] = np.nan

    # Format cell values if needed
    text_vals = None
    if show_values:
        text_vals = np.vectorize(lambda x: "N/A" if np.isnan(x) else f"{x:{value_format}}")(matrix)

    fig = go.Figure(
        data=[
            go.Heatmap(
                z=matrix,
                x=col_labels,
                y=row_labels,
                colorscale="Viridis",
                text=text_vals,
                texttemplate="%{text}" if show_values else None,
                textfont={"color": "black"},
                reversescale=False,
                showscale=True,
            )
        ],
        layout={
            "title": title,
            "width": width,
            "height": height,
            "xaxis": {
                "title": xaxis_title,
                "nticks": len(col_labels) if col_labels else None,
                "fixedrange": True,  # Disable zoom/pan
            },
            "yaxis": {
                "title": yaxis_title,
                "nticks": len(row_labels) if row_labels else None,
                "fixedrange": True,  # Disable zoom/pan
            },
            "dragmode": False,  # Disable dragging/panning
            "modebar": {
                "remove": [
                    "zoomIn2d",
                    "zoomOut2d",
                    "autoScale2d",
                    "resetScale2d",
                    "hoverClosestCartesian",
                    "hoverCompareCartesian",
                    "toggleSpikelines",
                ]
            },
        },
    )
    return fig


heatmap()

[2;36m[12/13/24 15:23:46][0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m[1m[[0m[1;36m1064263312.[0mpy:[1;36m28[0m[1m][0m [1m[[0m[32m'Gerald [0m       ]8;id=904166;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/1064263312.py\[1;36m1064263312.py[0m]8;;\[1;36m:[0m]8;id=943911;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/1064263312.py#28\[1;36m28[0m]8;;\
[2;36m                    [0m[1;32m         [0m[32mDavila'[0m, [32m'Christopher Nunez'[0m,      [1;36m                [0m
[2;36m                    [0m[1;32m         [0m[32m'Richard Archer'[0m, [32m'Jessica [0m        [1;36m                [0m
[2;36m                    [0m[1;32m         [0m[32mBaldwin'[0m, [32m'Andrea Williams'[0m,       [1;36m                [0m
[2;36m                    [0m[1;32m         [0m[32m'Elizabeth Pratt'[0m, [32m'Jeremy [0m        [1;36m                [0m
[2;36m                    [0m[1;32m  

In [110]:
def animated_heatmap(
    matrices: np.ndarray | None = None,
    labels: list | None = None,
    title: str = "DEFAULT TITLE",
    show_values: bool = True,
    value_format: Literal[".2f", ".2e"] = ".2f",
    width: int = 800,
    height: int = 600,
    xaxis_title: str | None = None,
    yaxis_title: str | None = None,
    log: bool = False,
    nan_diagnal: bool = False,
    frame_duration: int = 500,
    slider_labels: list[str] | None = None,
    slider_prefix: str = "Frame:",
):
    # Generate demo data if matrices is None
    if matrices is None:
        faker = Faker()
        shape = (5, 10, 10)  # (time, rows, cols)
        matrices = rng.random(shape)
        # logger.info(matrices)
    elif matrices.ndim != 3:
        raise ValueError("Expected 3D array (time, rows, cols)")

    n_frames, n_rows, n_cols = matrices.shape

    if labels is None:
        faker = Faker()
        labels = [faker.name() for _ in range(n_rows)]
    col_labels = labels
    row_labels = labels

    # Process all matrices
    processed_matrices = []
    for matrix in matrices:
        # Convert and process each time slice
        matrix = np.array(matrix, dtype=object)
        none_mask = matrix is None
        matrix = matrix.astype(float)
        if log:
            matrix = np.log10(matrix)
        matrix[none_mask] = np.nan
        if nan_diagnal:
            matrix[np.identity(matrix.shape[0], dtype=bool)] = np.nan
        processed_matrices.append(matrix)
        
        # Update the sliders configuration with custom labels
    if slider_labels is not None:
        if len(slider_labels) != n_frames:
            raise ValueError(f"Expected {n_frames} slider labels, got {len(slider_labels)}")
        frame_labels = slider_labels
    else:
        frame_labels = [str(k) for k in range(n_frames)]

    layout_sliders = [{
        "active": 0,
        "yanchor": "top",
        "xanchor": "left",
        "currentvalue": {
            "font": {"size": 20},
            "prefix": slider_prefix,  # <--- [CHANGED] Use custom prefix
            "visible": True,
            "xanchor": "right",
        },
        "transition": {"duration": frame_duration, "easing": "cubic-in-out"},
        "pad": {"b": 10, "t": 50},
        "len": 0.9,
        "x": 0.1,
        "y": 0,
        "steps": [
            {
                "args": [
                    [f"frame_{k}"],
                    {
                        "frame": {"duration": frame_duration, "redraw": True},
                        "mode": "immediate",
                        "transition": {"duration": frame_duration},
                    },
                ],
                "label": label,  # <--- [CHANGED] Use custom label
                "method": "animate",
            }
            for k, label in enumerate(frame_labels)
        ],
    }]

    # Create frames for animation
    frames = [
        go.Frame(
            data=[
                go.Heatmap(
                    z=matrix,
                    x=col_labels,
                    y=row_labels,
                    colorscale="Viridis",
                    text=np.vectorize(lambda x: "N/A" if np.isnan(x) else f"{x:{value_format}}")(
                        matrix
                    )
                    if show_values
                    else None,
                    texttemplate="%{text}" if show_values else None,
                    textfont={"color": "black"},
                    reversescale=False,
                    showscale=True,
                )
            ],
            name=f"frame_{i}",  # name is required for frames
        )
        for i, matrix in enumerate(processed_matrices)
    ]

    # Create the figure with the initial frame
    fig = go.Figure(
        data=[
            go.Heatmap(
                z=processed_matrices[0],
                x=col_labels,
                y=row_labels,
                colorscale="Viridis",
                text=np.vectorize(lambda x: "N/A" if np.isnan(x) else f"{x:{value_format}}")(
                    processed_matrices[0]
                )
                if show_values
                else None,
                texttemplate="%{text}" if show_values else None,
                textfont={"color": "black"},
                reversescale=False,
                showscale=True,
            )
        ],
        frames=frames,
        layout={
            "title": title,
            "width": width,
            "height": height,
            "xaxis": {
                "title": xaxis_title,
                "nticks": len(col_labels) if col_labels else None,
                "fixedrange": True,
            },
            "yaxis": {
                "title": yaxis_title,
                "nticks": len(row_labels) if row_labels else None,
                "fixedrange": True,
            },
            "dragmode": False,
            "updatemenus": [  # <--- [NEW] Add play and pause buttons
                {
                    "buttons": [
                        {
                            "args": [
                                None,
                                {
                                    "frame": {"duration": frame_duration, "redraw": True},
                                    "fromcurrent": True,
                                },
                            ],
                            "label": "Play",
                            "method": "animate",
                        },
                        {
                            "args": [
                                [None],
                                {
                                    "frame": {"duration": 0, "redraw": True},
                                    "mode": "immediate",
                                    "transition": {"duration": 0},
                                },
                            ],
                            "label": "Pause",
                            "method": "animate",
                        },
                    ],
                    "direction": "left",
                    "pad": {"r": 10, "t": 87},
                    "showactive": False,
                    "type": "buttons",
                    "x": 0.1,
                    "xanchor": "right",
                    "y": 0,
                    "yanchor": "top",
                }
            ],
        },
    )
    
    # Update the layout dictionary to include the new sliders
    fig.update_layout(sliders=layout_sliders)

    return fig


animated_heatmap()

In [111]:
labels = sorted(meta["model2idx"], key=lambda k: meta["model2idx"][k])
matrix = example_validation_mse
heatmap(
    matrix,
    log=True,
    title="Example Validation MSE (log)",
    labels=labels,
)

[2;36m[12/13/24 15:39:44][0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m[1m[[0m[1;36m1064263312.[0mpy:[1;36m31[0m[1m][0m Creating        ]8;id=212122;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/1064263312.py\[1;36m1064263312.py[0m]8;;\[1;36m:[0m]8;id=882043;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/1064263312.py#31\[1;36m31[0m]8;;\
[2;36m                    [0m[1;32m         [0mheatmap visualization[33m...[0m           [1;36m                [0m
