# Complexity vs MSE and Rank Graphs


## Setup


In [1]:
# Installs
%pip install safetensors

Note: you may need to restart the kernel to use updated packages.


In [43]:
# Imports
from safetensors.numpy import load_file
from pathlib import Path
from loguru import logger
import orjson
from rich import pretty
import plotly.graph_objects as go
import plotly.io as pio
import plotly
from faker import Faker
from typing import Literal
import polars as pl

# from IPython.display import display
import numpy as np

np.set_printoptions(
    precision=3,
    suppress=True,
    threshold=5,
)
rng = np.random.default_rng(0)
faker = Faker()

In [3]:
# Directories
PROJ_ROOT = Path.cwd().resolve().parent
logger.info(PROJ_ROOT)
DATA_DIR = PROJ_ROOT / "data"
logger.info(DATA_DIR)
MLP_MSE_VALIDATION = DATA_DIR / "mlp_mse_validation"
MLP_MSE_VALIDATION_TENSORS = MLP_MSE_VALIDATION / "validation_mses.safetensors"
MLP_MSE_VALIDATION_META = MLP_MSE_VALIDATION / "validation_mses.json"

[32m2024-12-13 16:04:12.847[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m3[0m - [1m/Users/gat/work/FA2024/embedding_translation[0m
[32m2024-12-13 16:04:12.848[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m5[0m - [1m/Users/gat/work/FA2024/embedding_translation/data[0m


## Helper Viz


In [40]:
# General Viz Config
heatmap_modebar = go.layout.Modebar(
    remove=[
        "zoomIn2d",
        "zoomOut2d",
        "autoScale2d",
        "resetScale2d",
        "hoverClosestCartesian",
        "hoverCompareCartesian",
        "toggleSpikelines",
    ]
)
heatmap_layout = go.Layout(
    modebar=heatmap_modebar,
    dragmode=False,
)

heatmap_trace_config = go.Heatmap(
    colorscale="Viridis",
    textfont={"color": "black"},
    reversescale=False,
    showscale=True,
    showlegend=True,
)

pio.templates = []

Heatmap({
    'colorscale': [[0.0, '#440154'], [0.1111111111111111, '#482878'],
                   [0.2222222222222222, '#3e4989'], [0.3333333333333333,
                   '#31688e'], [0.4444444444444444, '#26828e'],
                   [0.5555555555555556, '#1f9e89'], [0.6666666666666666,
                   '#35b779'], [0.7777777777777778, '#6ece58'],
                   [0.8888888888888888, '#b5de2b'], [1.0, '#fde725']],
    'reversescale': False,
    'showlegend': True,
    'showscale': True,
    'textfont': {'color': 'black'}
})

In [35]:
def heatmap(
    matrix: np.ndarray | None = None,
    labels: list | None = None,
    title: str = "DEFAULT TITLE",
    show_values: bool = True,
    value_format: Literal[".2f", ".2e"] = ".2f",
    width: int = 800,
    height: int = 600,
    xaxis_title: str | None = None,
    yaxis_title: str | None = None,
    log: bool = False,
    nan_diagnal: bool = False,
) -> go.Figure:
    """Create a heatmap visualization using plotly.

    Args:
        matrix: Input matrix to visualize. If None, generates random demo data
        labels: Labels for rows and columns. If None, generates fake names
        title: Title of the heatmap
        show_values: Whether to show values in cells
        value_format: Format string for cell values
        width: Width of the figure in pixels
        height: Height of the figure in pixels
        xaxis_title: Title for x-axis
        yaxis_title: Title for y-axis
        log: Whether to log10 transform the values
        nan_diagnal: Whether to set diagonal to NaN

    Returns:
        go.Figure: Plotly figure object
    """
    if matrix is None:
        faker = Faker()
        shape = (10, 10)
        matrix = rng.random(shape)
    else:
        shape = matrix.shape

    if labels is None:
        labels = [faker.name() for _ in range(shape[0])]
        logger.info(labels)
    col_labels = labels
    row_labels = labels
    logger.info("Creating heatmap visualization...")

    # Convert input to numpy array and handle None values
    matrix = np.array(matrix, dtype=object)
    none_mask = matrix is None
    matrix = matrix.astype(float)
    if log:
        matrix = np.log10(matrix)
    matrix[none_mask] = np.nan
    if nan_diagnal:
        matrix[np.identity(matrix.shape[0], dtype=bool)] = np.nan

    # Format cell values if needed
    text_vals = None
    if show_values:
        text_vals = np.vectorize(lambda x: "N/A" if np.isnan(x) else f"{x:{value_format}}")(matrix)

    # Create heatmap trace
    heatmap_trace = heatmap_trace_config.copy
    heatmap_trace = go.Heatmap(
        **dict(heatmap_layout),
        z=matrix,
        x=col_labels,
        y=row_labels,
        text=text_vals,
        texttemplate="%{text}" if show_values else None,
    )

    # Create figure
    fig = go.Figure(
        data=[heatmap_trace],
    )
    fig.layout.update_polars

    # Update layout using plotly objects
    fig.update_layout(
        title=title,
    )

    # Update axes
    fig.update_xaxes(
        title=xaxis_title,
        nticks=len(col_labels) if col_labels else None,
        fixedrange=True,  # Disable zoom/pan
    )

    fig.update_yaxes(
        title=yaxis_title,
        nticks=len(row_labels) if row_labels else None,
        fixedrange=True,  # Disable zoom/pan
    )

    return fig


heatmap()

[32m2024-12-13 16:58:27.668[0m | [1mINFO    [0m | [36m__main__[0m:[36mheatmap[0m:[36m41[0m - [1m['Paula Thompson', 'Scott Brown', 'Laura Petersen', 'Lisa Conrad', 'Matthew Bell', 'Allen Kane', 'Robert Morgan', 'John Horton', 'Michael Montoya', 'Angel Shaffer'][0m
[32m2024-12-13 16:58:27.669[0m | [1mINFO    [0m | [36m__main__[0m:[36mheatmap[0m:[36m44[0m - [1mCreating heatmap visualization...[0m


ValueError: dictionary update sequence element #0 has length 17; 2 is required

In [5]:
def animated_heatmap(
    matrices: np.ndarray | None = None,
    labels: list | None = None,
    title: str = "DEFAULT TITLE",
    show_values: bool = True,
    value_format: Literal[".2f", ".2e"] = ".2f",
    width: int = 800,
    height: int = 600,
    xaxis_title: str | None = None,
    yaxis_title: str | None = None,
    log: bool = False,
    nan_diagnal: bool = False,
    frame_duration: int = 500,
    slider_labels: list[str] | None = None,
    slider_prefix: str = "Frame:",
):
    # Generate demo data if matrices is None
    if matrices is None:
        faker = Faker()
        shape = (5, 10, 10)  # (time, rows, cols)
        matrices = rng.random(shape)
        # logger.info(matrices)
    elif matrices.ndim != 3:
        raise ValueError("Expected 3D array (time, rows, cols)")

    n_frames, n_rows, n_cols = matrices.shape

    if labels is None:
        faker = Faker()
        labels = [faker.name() for _ in range(n_rows)]
    col_labels = labels
    row_labels = labels

    # Process all matrices
    processed_matrices = []
    for matrix in matrices:
        # Convert and process each time slice
        matrix = np.array(matrix, dtype=object)
        none_mask = matrix is None
        matrix = matrix.astype(float)
        if log:
            matrix = np.log10(matrix)
        matrix[none_mask] = np.nan
        if nan_diagnal:
            matrix[np.identity(matrix.shape[0], dtype=bool)] = np.nan
        processed_matrices.append(matrix)

        # Update the sliders configuration with custom labels
    if slider_labels is not None:
        if len(slider_labels) != n_frames:
            raise ValueError(f"Expected {n_frames} slider labels, got {len(slider_labels)}")
        frame_labels = slider_labels
    else:
        frame_labels = [str(k) for k in range(n_frames)]

    layout_sliders = [
        {
            "active": 0,
            "yanchor": "top",
            "xanchor": "left",
            "currentvalue": {
                "font": {"size": 20},
                "prefix": slider_prefix,  # <--- [CHANGED] Use custom prefix
                "visible": True,
                "xanchor": "right",
            },
            "transition": {"duration": frame_duration, "easing": "cubic-in-out"},
            "pad": {"b": 10, "t": 50},
            "len": 0.9,
            "x": 0.1,
            "y": 0,
            "steps": [
                {
                    "args": [
                        [f"frame_{k}"],
                        {
                            "frame": {"duration": frame_duration, "redraw": True},
                            "mode": "immediate",
                            "transition": {"duration": frame_duration},
                        },
                    ],
                    "label": label,  # <--- [CHANGED] Use custom label
                    "method": "animate",
                }
                for k, label in enumerate(frame_labels)
            ],
        }
    ]

    # Create frames for animation
    frames = [
        go.Frame(
            data=[
                go.Heatmap(
                    z=matrix,
                    x=col_labels,
                    y=row_labels,
                    colorscale="Viridis",
                    text=np.vectorize(lambda x: "N/A" if np.isnan(x) else f"{x:{value_format}}")(
                        matrix
                    )
                    if show_values
                    else None,
                    texttemplate="%{text}" if show_values else None,
                    textfont={"color": "black"},
                    reversescale=False,
                    showscale=True,
                )
            ],
            name=f"frame_{i}",  # name is required for frames
        )
        for i, matrix in enumerate(processed_matrices)
    ]

    # Create the figure with the initial frame
    fig = go.Figure(
        data=[
            go.Heatmap(
                z=processed_matrices[0],
                x=col_labels,
                y=row_labels,
                colorscale="Viridis",
                text=np.vectorize(lambda x: "N/A" if np.isnan(x) else f"{x:{value_format}}")(
                    processed_matrices[0]
                )
                if show_values
                else None,
                texttemplate="%{text}" if show_values else None,
                textfont={"color": "black"},
                reversescale=False,
                showscale=True,
            )
        ],
        frames=frames,
        layout={
            "title": title,
            "width": width,
            "height": height,
            "xaxis": {
                "title": xaxis_title,
                "nticks": len(col_labels) if col_labels else None,
                "fixedrange": True,
            },
            "yaxis": {
                "title": yaxis_title,
                "nticks": len(row_labels) if row_labels else None,
                "fixedrange": True,
            },
            "dragmode": False,
            "updatemenus": [  # <--- [NEW] Add play and pause buttons
                {
                    "buttons": [
                        {
                            "args": [
                                None,
                                {
                                    "frame": {"duration": frame_duration, "redraw": True},
                                    "fromcurrent": True,
                                },
                            ],
                            "label": "Play",
                            "method": "animate",
                        },
                        {
                            "args": [
                                [None],
                                {
                                    "frame": {"duration": 0, "redraw": True},
                                    "mode": "immediate",
                                    "transition": {"duration": 0},
                                },
                            ],
                            "label": "Pause",
                            "method": "animate",
                        },
                    ],
                    "direction": "left",
                    "pad": {"r": 10, "t": 87},
                    "showactive": False,
                    "type": "buttons",
                    "x": 0.1,
                    "xanchor": "right",
                    "y": 0,
                    "yanchor": "top",
                }
            ],
        },
    )

    # Update the layout dictionary to include the new sliders
    fig.update_layout(sliders=layout_sliders)

    return fig


animated_heatmap()

## Import Data


In [75]:
# Load Metadata
"""
# 6 Datasets (ArguAna, FiQA, SciDocs, NFCorpus, HotPotQA, Trec-COVID)
# 17 Native Embedding Spaces
# 17 Stitched Embedding Spaces
# 6 Architectures (2, 3, 4, 5, 6, or 7 layers)
"""
with MLP_MSE_VALIDATION_META.open() as f:
    meta = orjson.loads(f.read())

# pretty.pprint(meta)

meta_df = pl.read_json(MLP_MSE_VALIDATION_META)


# def create_index_df(meta: dict) -> pl.DataFrame:
#     """Create an index DataFrame for easy tensor lookup using np.indices.

#     Args:
#         meta: Dictionary containing dataset2idx and model2idx mappings

#     Returns:
#         pl.DataFrame: DataFrame with hierarchical index for tensor lookup
#     """
#     # Get shape from metadata
#     n_datasets = len(meta["dataset2idx"])
#     n_models = len(meta["model2idx"])
#     n_layers = len(meta["num_layers"])

#     # Create indices arrays for each dimension
#     datasets, sources, targets, layers = np.indices((n_datasets, n_models, n_models, n_layers))

#     # Create reverse mapping dictionaries
#     dataset_map = {v: k for k, v in meta["dataset2idx"].items()}
#     model_map = {v: k for k, v in meta["model2idx"].items()}

#     # Create DataFrame
#     return pl.DataFrame(
#         {
#             "dataset": pl.Series(datasets.flatten()).cast_dict(dataset_map),
#             "source_model": pl.Series(sources.flatten()).cast_dict(model_map),
#             "target_model": pl.Series(targets.flatten()).cast_dict(model_map),
#             "num_layers": pl.Series(layers.flatten()) + 2,  # Convert 0-based to 2-based
#             "index": list(
#                 zip(datasets.flatten(), sources.flatten(), targets.flatten(), layers.flatten())
#             ),
#         }
#     )


# create_index_df(meta)
# why am I being silly, work on this later.

AttributeError: 'Series' object has no attribute 'cast_dict'

In [7]:
# Load Tensors
tensors = load_file(MLP_MSE_VALIDATION_TENSORS)
logger.info(list(tensors.keys()))
validation_mse = tensors["validation_mses"]
# f"{tensors["validation_mses"].shape=}"
f"{tensors["validation_mses"].shape=}"

[32m2024-12-13 16:04:13.060[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m3[0m - [1m['validation_mses'][0m


'tensors["validation_mses"].shape=(6, 17, 17, 6)'

In [71]:
from safetensors.numpy import load_file


def safetensor_to_polars(
    tensor_path: Path,
    meta_path: Path,
) -> pl.DataFrame:
    """Convert safetensor validation MSEs to a polars dataframe.

    Args:
        tensor_path: Path to .safetensors file
        meta_path: Path to metadata JSON file

    Returns:
        pl.DataFrame: Long-format dataframe with validation MSEs
    """
    # Load metadata
    with meta_path.open() as f:
        meta = orjson.loads(f.read())

    # Load tensor
    tensor_dict = load_file(tensor_path)  # (6, 17, 17, 6)
    validation_mses = tensor_dict["validation_mses"]

    # Create mapping dictionaries
    dataset_map = {v: k for k, v in meta["dataset2idx"].items()}
    model_map = {v: k for k, v in meta["model2idx"].items()}

    # Create index arrays for each dimension
    datasets, source_models, target_models, architectures = np.indices(validation_mses.shape)

    # Create dataframe
    mse_df = pl.DataFrame(
        {
            "dataset": datasets.flatten(),
            "source_model": source_models.flatten(),
            "target_model": target_models.flatten(),
            "num_layers": architectures.flatten(),
            "mse": validation_mses.flatten(),
        }
    )

    # Map indices to names
    mse_df = mse_df.with_columns(
        [
            pl.col("dataset").map_dict(dataset_map).alias("dataset_name"),
            pl.col("source_model").map_dict(model_map).alias("source_model_name"),
            pl.col("target_model").map_dict(model_map).alias("target_model_name"),
            pl.col("num_layers").map_dict(dict(enumerate(meta["num_layers"]))).alias("num_layers"),
        ]
    ).drop("dataset", "source_model", "target_model")

    return mse_df


# Usage
mse_df = safetensor_to_polars(
    tensor_path=MLP_MSE_VALIDATION_TENSORS, meta_path=MLP_MSE_VALIDATION_META
)

AttributeError: 'Expr' object has no attribute 'map_dict'

In [12]:
# Print Imported Data (One Dataset, One Architecture)
dataset_i = 0
architecture_i = 0

example_validation_mse = validation_mse[dataset_i, :, :, architecture_i]
logger.info(example_validation_mse.shape)
pretty.pprint(example_validation_mse)

[32m2024-12-13 16:09:16.309[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m6[0m - [1m(17, 17)[0m


In [22]:
# Visualize Imported Data (One Dataset, One Architecture)
labels = sorted(meta["model2idx"], key=lambda k: meta["model2idx"][k])
matrix = example_validation_mse
heatmap(matrix, log=True, title="Example Validation MSE (log)", labels=labels, nan_diagnal=True)

[32m2024-12-13 16:19:59.063[0m | [1mINFO    [0m | [36m__main__[0m:[36mheatmap[0m:[36m28[0m - [1mCreating heatmap visualization...[0m


In [21]:
# Visualize Imported Data (One Dataset, All Architectures)
labels = sorted(meta["model2idx"], key=lambda k: meta["model2idx"][k])
num_layers = sorted(meta["num_layers"])
matrix = validation_mse[0, :, :, :].transpose(2, 0, 1)
animated_heatmap(
    matrices=matrix,
    log=True,
    title="Example Validation MSE (log) (ArguAna)",
    labels=labels,
    nan_diagnal=True,
    slider_prefix="Number of Layers",
    slider_labels=num_layers,
)