# Complexity vs MSE and Rank Graphs


In [1]:
%pip install safetensors

Note: you may need to restart the kernel to use updated packages.


In [69]:
from safetensors.numpy import load_file
from pathlib import Path
from loguru import logger
import orjson
from rich import pretty
import plotly.graph_objects as go
from faker import Faker

# from IPython.display import display
import numpy as np

np.set_printoptions(
    precision=3,
    suppress=True,
    threshold=5,
)
rng = np.random.default_rng(0)
faker = Faker()

In [12]:
PROJ_ROOT = Path.cwd().resolve().parent
logger.info(PROJ_ROOT)
DATA_DIR = PROJ_ROOT / "data"
logger.info(DATA_DIR)
MLP_MSE_VALIDATION = DATA_DIR / "mlp_mse_validation"
MLP_MSE_VALIDATION_TENSORS = MLP_MSE_VALIDATION / "validation_mses.safetensors"
MLP_MSE_VALIDATION_META = MLP_MSE_VALIDATION / "validation_mses.json"

[32m2024-12-13 14:18:35.910[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m2[0m - [1m/Users/gat/work/FA2024/embedding_translation[0m
[32m2024-12-13 14:18:35.911[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m4[0m - [1m/Users/gat/work/FA2024/embedding_translation/data[0m


In [26]:
tensors = load_file(MLP_MSE_VALIDATION_TENSORS)
logger.info(list(tensors.keys()))
validation_mse = tensors["validation_mses"]
logger.info(tensors["validation_mses"].shape)

[2;36m[12/13/24 14:33:27][0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m[1m[[0m[1;36m3396536193.[0mpy:[1;36m2[0m[1m][0m                   ]8;id=747558;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/3396536193.py\[1;36m3396536193.py[0m]8;;\[1;36m:[0m]8;id=363556;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/3396536193.py#2\[1;36m2[0m]8;;\
[2;36m                    [0m[1;32m         [0m[1m[[0m[32m'validation_mses'[0m[1m][0m                 [1;36m               [0m
[2;36m                   [0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m[1m[[0m[1;36m3396536193.[0mpy:[1;36m4[0m[1m][0m [1m([0m[1;36m6[0m, [1;36m17[0m, [1;36m17[0m, [1;36m6[0m[1m)[0m    ]8;id=420402;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/3396536193.py\[1;36m3396536193.py[0m]8;;\[1;36m:[0m]8;id=307068;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/3396536193

In [28]:
with MLP_MSE_VALIDATION_META.open() as f:
    meta = orjson.loads(f.read())

pretty.pprint(meta)
# 6 Datasets (ArguAna, FiQA, SciDocs, NFCorpus, HotPotQA, Trec-COVID)
# 17 Native Embedding Spaces
# 17 Stitched Embedding Spaces
# 6 Architectures (2, 3, 4, 5, 6, or 7 layers)

In [56]:
# Plot heatmap for one architecture
dataset_i = 0
architecture_i = 0

example_validation_mse = validation_mse[dataset_i, :, :, architecture_i]
logger.info(example_validation_mse.shape)
pretty.pprint(example_validation_mse)

[2;36m[12/13/24 14:53:57][0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m[1m[[0m[1;36m2716447018.[0mpy:[1;36m8[0m[1m][0m [1m([0m[1;36m17[0m, [1;36m17[0m[1m)[0m          ]8;id=34588;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/2716447018.py\[1;36m2716447018.py[0m]8;;\[1;36m:[0m]8;id=443200;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/2716447018.py#8\[1;36m8[0m]8;;\


In [77]:
rng = np.random.default_rng(0)


def heatmap(
    matrix: np.ndarray | None = None,
    labels: list | None = None,
    title: str = "DEFAULT TITLE",
    show_values: bool = True,
    value_format: str = ".2f",
    width: int = 800,
    height: int = 600,
    xaxis_title: str | None = None,
    yaxis_title: str | None = None,
):
    if labels is None and matrix is None:
        faker = Faker()
        shape = (10, 10)
        labels = [faker.name() for _ in range(shape[0])]
        matrix = rng.random(shape)
        # logger.info("Using example labels and matrix")
        logger.info(labels)
        # logger.info(matrix)
    col_labels = labels
    row_labels = labels
    logger.info("Creating heatmap visualization...")

    # Convert input to numpy array and handle None values
    matrix = np.array(matrix, dtype=object)
    none_mask = matrix is None
    matrix = matrix.astype(float)
    matrix[none_mask] = np.nan

    # Format cell values if needed
    text_vals = None
    if show_values:
        text_vals = np.vectorize(lambda x: "N/A" if np.isnan(x) else f"{x:{value_format}}")(matrix)

    fig = go.Figure(
        data=[
            go.Heatmap(
                z=matrix,
                x=col_labels,
                y=row_labels,
                colorscale="Viridis",
                text=text_vals,
                texttemplate="%{text}" if show_values else None,
                textfont={"color": "black"},
                reversescale=False,
                showscale=True,
            )
        ],
        layout={
            "title": title,
            "width": width,
            "height": height,
            "xaxis": {
                "title": xaxis_title,
                "nticks": len(col_labels) if col_labels else None,
                "fixedrange": True,  # Disable zoom/pan
            },
            "yaxis": {
                "title": yaxis_title,
                "nticks": len(row_labels) if row_labels else None,
                "fixedrange": True,  # Disable zoom/pan
            },
            "dragmode": False,  # Disable dragging/panning
            "modebar": {
                "remove": [
                    "zoomIn2d",
                    "zoomOut2d",
                    "autoScale2d",
                    "resetScale2d",
                    "hoverClosestCartesian",
                    "hoverCompareCartesian",
                    "toggleSpikelines",
                ]
            },
        },
    )
    return fig


heatmap()

[2;36m[12/13/24 15:11:57][0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m[1m[[0m[1;36m2694299926.[0mpy:[1;36m21[0m[1m][0m [1m[[0m[32m'Laura [0m        ]8;id=845094;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/2694299926.py\[1;36m2694299926.py[0m]8;;\[1;36m:[0m]8;id=797121;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_17291/2694299926.py#21\[1;36m21[0m]8;;\
[2;36m                    [0m[1;32m         [0m[32mPalmer'[0m, [32m'Timothy Fowler'[0m, [32m'David [0m [1;36m                [0m
[2;36m                    [0m[1;32m         [0m[32mTorres'[0m, [32m'Michael Church'[0m,         [1;36m                [0m
[2;36m                    [0m[1;32m         [0m[32m'Patrick Nguyen'[0m, [32m'Jordan Burns'[0m,  [1;36m                [0m
[2;36m                    [0m[1;32m         [0m[32m'Zachary Miller'[0m, [32m'Michael [0m        [1;36m                [0m
[2;36m                    [0m