# Complexity vs MSE and Rank Graphs


## Setup


In [72]:
# Installs
%pip install safetensors


Using fork() can cause Polars to deadlock in the child process.
In addition, using fork() with Python in general is a recipe for mysterious
deadlocks and crashes.

The most likely reason you are seeing this error is because you are using the
multiprocessing module on Linux, which uses fork() by default. This will be
fixed in Python 3.14. Until then, you want to use the "spawn" context instead.

See https://docs.pola.rs/user-guide/misc/multiprocessing/ for details.

or by setting POLARS_ALLOW_FORKING_THREAD=1.




Note: you may need to restart the kernel to use updated packages.


In [73]:
# Imports
from safetensors.numpy import load_file
from pathlib import Path
from loguru import logger
import orjson
from rich import pretty
import plotly.graph_objects as go
import plotly.io as pio
import plotly
from faker import Faker
from typing import Literal
import polars as pl

# from IPython.display import display
import numpy as np

np.set_printoptions(
    precision=3,
    suppress=True,
    threshold=5,
)
rng = np.random.default_rng(0)
faker = Faker()
pio.templates.default = "plotly"

In [74]:
# Directories
PROJ_ROOT = Path.cwd().resolve().parent
logger.info(PROJ_ROOT)
DATA_DIR = PROJ_ROOT / "data"
logger.info(DATA_DIR)
CKA_DIR = DATA_DIR / "cka_centered_natives"
CKA_TENSORS = CKA_DIR / "table.safetensors"
CKA_META = CKA_DIR / "info.json"

[32m2024-12-14 02:29:57.873[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1m/Users/4gate/git/embedding_translation[0m
[32m2024-12-14 02:29:57.873[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1m/Users/4gate/git/embedding_translation/data[0m


## Helper Viz


In [75]:
def get_default_heatmap_fig() -> go.Figure:
    """Create a default heatmap figure with standard configuration.

    Returns:
        go.Figure: A plotly figure with default heatmap settings
    """
    # General Viz Config
    heatmap_layout = go.Layout(
        modebar=go.layout.Modebar(
            remove=[
                "zoomIn2d",
                "zoomOut2d",
                "autoScale2d",
                "resetScale2d",
                "hoverClosestCartesian",
                "hoverCompareCartesian",
                "toggleSpikelines",
            ]
        ),
        dragmode=False,
        width=800,
        height=600,
    )

    heatmap_trace = go.Heatmap(
        colorscale="Viridis",
        textfont={"color": "black"},
        reversescale=False,
        showscale=True,
        showlegend=False,
    )

    return go.Figure(data=heatmap_trace, layout=heatmap_layout)


# Create example data
x_labels = ["A", "B", "C"]  # <--- [NEW] Added example x labels
y_labels = ["X", "Y", "Z"]  # <--- [NEW] Added example y labels
z_values = [
    [1, 2, 3],  # <--- [NEW] Added 2D array for heatmap values
    [4, 5, 6],
    [7, 8, 9],
]

# Create and configure figure
fig = get_default_heatmap_fig()
heatmap_trace: go.Heatmap = fig.data[0]
heatmap_trace.x = x_labels  # <--- [CHANGED] Set x axis labels
heatmap_trace.y = y_labels  # <--- [CHANGED] Set y axis labels
heatmap_trace.z = z_values  # <--- [CHANGED] Set z values as 2D array

# Update layout with titles
fig.update_layout(  # <--- [NEW] Added titles
    title="Example Heatmap", xaxis_title="X Axis", yaxis_title="Y Axis"
)

fig  # Display the figure

In [81]:
def heatmap(
    matrix: np.ndarray | None = None,
    row_labels: list | None = None,
    col_labels: list | None = None,
    title: str = "DEFAULT TITLE",
    show_values: bool = True,
    value_format: Literal[".2f", ".2e"] = ".2f",
    width: int = 800,
    height: int = 600,
    xaxis_title: str | None = None,
    yaxis_title: str | None = None,
    log: bool = False,
    nan_diagnal: bool = False,
) -> go.Figure:
    """Create a heatmap visualization using plotly.

    Args:
        matrix: Input matrix to visualize. If None, generates random demo data
        labels: Labels for rows and columns. If None, generates fake names
        title: Title of the heatmap
        show_values: Whether to show values in cells
        value_format: Format string for cell values
        width: Width of the figure in pixels
        height: Height of the figure in pixels
        xaxis_title: Title for x-axis
        yaxis_title: Title for y-axis
        log: Whether to log10 transform the values
        nan_diagnal: Whether to set diagonal to NaN

    Returns:
        go.Figure: Plotly figure object
    """
    faker = Faker()
    if matrix is None:
        shape = (10, 10)
        matrix = rng.random(shape)
    else:
        shape = matrix.shape

    if row_labels is None:
        raise NotImplementedError # XXX
        row_labels = [faker.name() for _ in range(shape[0])]
        logger.info(row_labels)
    logger.info("Creating heatmap visualization...")

    if col_labels is None:
        raise NotImplementedError # XXX
        col_labels = [faker.name() for _ in range(shape[0])]
        logger.info(col_labels)
    logger.info("Creating heatmap visualization...")

    # Convert input to numpy array and handle None values
    matrix = np.array(matrix, dtype=object)
    none_mask = matrix is None
    matrix = matrix.astype(float)
    if log:
        matrix = np.log10(matrix)
    matrix[none_mask] = np.nan
    if nan_diagnal:
        # Replace this with a method that will work for non-square
        assert len(matrix.shape) == 2
        for i in range(max(matrix.shape)):
            matrix[i, i] = np.nan
        # matrix[np.identity(matrix.shape[0], dtype=bool)] = np.nan

    # Format cell values if needed
    text_vals = None
    if show_values:
        text_vals = np.vectorize(lambda x: "N/A" if np.isnan(x) else f"{x:{value_format}}")(matrix)

    print(f"Row labels length: {len(row_labels)}")
    print(f"Col labels length: {len(col_labels)}")

    # Create figure
    fig = get_default_heatmap_fig()
    heatmap_trace: go.Heatmap = fig.data[0]
    heatmap_trace.update(
        go.Heatmap(
            z=matrix,
            x=col_labels,
            y=row_labels,
            text=text_vals,
            texttemplate="%{text}" if show_values else None,
        ),
        overwrite=True,
    )

    # Update layout using plotly objects
    fig.update_layout(
        title=title,
        height=height,
        width=width,
    )

    # Update axes
    fig.update_xaxes(
        title=xaxis_title,
        nticks=len(col_labels) if col_labels else None,
        fixedrange=True,  # Disable zoom/pan
    )

    fig.update_yaxes(
        title=yaxis_title,
        nticks=len(row_labels) if row_labels else None,
        fixedrange=True,  # Disable zoom/pan
    )

    return fig


# heatmap()

In [88]:
import einops
def animated_heatmap(
    matrices: np.ndarray | None = None,
    row_labels: list | None = None,
    col_labels: list | None = None,
    title: str = "DEFAULT TITLE",
    show_values: bool = True,
    value_format: Literal[".2f", ".2e"] = ".2f",
    width: int = 800,
    height: int = 600,
    xaxis_title: str | None = None,
    yaxis_title: str | None = None,
    log: bool = False,
    nan_diagnal: bool = False,
    frame_duration: int = 500,
    slider_labels: list[str] | None = None,
    slider_prefix: str = "Frame: ",
):
    # Generate demo data if matrices is None
    if matrices is None:
        faker = Faker()
        shape = (5, 10, 10)  # (time, rows, cols)
        matrices = rng.random(shape)
        # logger.info(matrices)
    elif matrices.ndim != 3:
        raise ValueError("Expected 3D array (time, rows, cols)")

    n_frames, n_rows, n_cols = matrices.shape

    if col_labels is None:
        assert col_labels is not None # XXX
        faker = Faker()
        col_labels = [faker.name() for _ in range(n_cols)]
    if row_labels is None:
        assert row_labels is not None
        faker = Faker()
        row_labels = [faker.name() for _ in range(n_rows)]

    # Process all matrices
    processed_matrices = []
    for matrix in matrices:
        # Convert and process each time slice
        matrix = np.array(matrix, dtype=object)
        none_mask = matrix is None
        matrix = matrix.astype(float)
        if log:
            matrix = np.log10(matrix)
        matrix[none_mask] = np.nan
        if nan_diagnal:
            # Replace with a shittier method that will work with non-square
            assert len(matrix.shape) == 2
            for i in range(max(matrix.shape)):
                matrix[i, i] = np.nan
            # matrix[np.identity(matrix.shape[0], dtype=bool)] = np.nan
        processed_matrices.append(matrix)
    # print([x.shape for x in processed_matrices])
    # raise NotADirectoryError # [DEBUG => yes not square]

        # Update the sliders configuration with custom labels
    if slider_labels is not None:
        if len(slider_labels) != n_frames:
            raise ValueError(f"Expected {n_frames} slider labels, got {len(slider_labels)}")
        frame_labels = slider_labels
    else:
        frame_labels = [str(k) for k in range(n_frames)]

    layout_sliders = [
        go.layout.Slider(
            active=0,
            yanchor="bottom",
            xanchor="left",
            currentvalue=go.layout.slider.Currentvalue(
                font=dict(size=20), prefix=slider_prefix, visible=True, xanchor="right"
            ),
            transition=go.layout.slider.Transition(duration=frame_duration, easing="cubic-in-out"),
            pad=dict(b=10, t=50),
            len=0.9,
            x=0.1,
            y=1.1,
            steps=[
                go.layout.slider.Step(
                    args=[
                        [f"frame_{k}"],
                        {
                            "frame": {"duration": frame_duration, "redraw": True},
                            "mode": "immediate",
                            "transition": {"duration": frame_duration},
                        },
                    ],
                    label=label,
                    method="animate",
                )
                for k, label in enumerate(frame_labels)
            ],
        )
    ]

    # Create frames for animation
    assert col_labels is not None
    assert row_labels is not None
    frames = [
        go.Frame(
            data=[heatmap(
                matrix=matrix,
                row_labels=row_labels,
                col_labels=col_labels,
                show_values=show_values).data[0]
            ],
            name=f"frame_{i}",  # name is required for frames
        )
        for i, matrix in enumerate(processed_matrices)
    ]

    # Create the figure with the initial frame
    fig = go.Figure(
        data=[heatmap(processed_matrices[0], row_labels=row_labels, col_labels=col_labels, show_values=show_values).data[0]],
        frames=frames,
        layout=go.Layout(
            width=width,
            height=height,
            title=dict(
                text=title,
                y=0.95,  # Move title down slightly from the top
                x=0.5,   # Center the title
                xanchor='center',
                yanchor='top'
            ),
            margin=dict(
                t=150,  # Increase top margin to make room for slider
                b=80,   # Bottom margin
                l=80,   # Left margin
                r=80    # Right margin
            ),
            xaxis=go.layout.XAxis(
                title=xaxis_title,
                nticks=len(col_labels) if col_labels else None,
                fixedrange=True,
            ),
            yaxis=go.layout.YAxis(
                title=yaxis_title,
                nticks=len(row_labels) if row_labels else None,
                fixedrange=True,
            ),
            dragmode=False,
            updatemenus=[
                go.layout.Updatemenu(
                    buttons=[
                        # go.layout.updatemenu.Button(
                        #     args=[
                        #         None,
                        #         {
                        #             "frame": {"duration": frame_duration, "redraw": True},
                        #             "fromcurrent": True,
                        #         },
                        #     ],
                        #     label="Play",
                        #     method="animate",
                        # ),
                        # go.layout.updatemenu.Button(
                        #     args=[
                        #         [None],
                        #         {
                        #             "frame": {"duration": 0, "redraw": True},
                        #             "mode": "immediate",
                        #             "transition": {"duration": 0},
                        #         },
                        #     ],
                        #     label="Pause",
                        #     method="animate",
                        # ),
                    ],
                    direction="left",
                    pad={"r": 10, "t": 87},
                    showactive=False,
                    type="buttons",
                    x=0.1,
                    xanchor="right",
                    y=1.1,
                    yanchor="bottom",
                )
            ],
        ),
    )

    # Update the layout dictionary to include the new sliders
    fig.update_layout(sliders=layout_sliders)

    return fig


path2mses_l2plus = Path.cwd().parent / "dl_final_project_tables_tensors" / "all_mses_table_l2plus_yes_ctrl"
assert path2mses_l2plus.exists()
info_json_path = path2mses_l2plus / "info.json"
tables_path = path2mses_l2plus / "table.safetensors"
info_json = orjson.loads(info_json_path.open().read())
tables = load_file(tables_path)["table"]
print(tables)

# NOTE we can use this to make sure that the dataset weighted average is honest
DATASET2SIZES = {
    "arguana": {
        "train": 8928,
        "validation": 2227
    },
    "fiqa": {
        "train": 24900,
        "validation": 6265
    },
    "scidocs": {
        "train": 25157,
        "validation": 6280
    },
    "nfcorpus": {
        "train": 5227,
        "validation": 1306
    },
    "hotpotqa": {
        "train": 20070,
        "validation": 5017
    },
    "trec-covid": {
        "train": 24617,
        "validation": 6137
    }
}
# print(info_json.keys())
# IDX2MODEL =info_json["idx2model"] # Buggy
MODEL2IDX = info_json["model2idx"]
IDX2DATASET = info_json["idx2dataset"]
DATASET2IDX = info_json["dataset2idx"]
print(tables.shape)
# print(IDX2MODEL.keys()) # skips 17...
# print(IDX2DATASET)
# print(MODEL2IDX) <---- OK
IDX2MODEL = {v: k for k, v in MODEL2IDX.items()}
IDX2DATASET = {int(k): v for k, v in IDX2DATASET.items()}
import sys
sys.path.append(Path.cwd().parent.as_posix())
from src.viz.save_figure import save_figure
import torch
dataset_sizes = np.array([DATASET2SIZES[k]["validation"] for k in ["arguana", "fiqa", "scidocs", "nfcorpus", "hotpotqa", "trec-covid"]]).astype(np.float32)
dataset_probs = dataset_sizes / dataset_sizes.sum()
assert np.isclose(dataset_probs.sum(), 1.0)
assert DATASET2IDX["arguana"] == 0
assert DATASET2IDX["fiqa"] == 1
assert DATASET2IDX["scidocs"] == 2
assert DATASET2IDX["nfcorpus"] == 3
assert DATASET2IDX["hotpotqa"] == 4
assert DATASET2IDX["trec-covid"] == 5
arguana_table = tables[0]
fiqa_table = tables[1]
scidocs_table = tables[2]
nfcorpus_table = tables[3]
hotpotqa_table = tables[4]
trec_covid_table = tables[5]
weighted_average_table = (
    arguana_table * dataset_probs[0] +
    fiqa_table * dataset_probs[1] +
    scidocs_table * dataset_probs[2] +
    nfcorpus_table * dataset_probs[3] +
    hotpotqa_table * dataset_probs[4] +
    trec_covid_table * dataset_probs[5]
)
tables = np.concatenate([weighted_average_table[np.newaxis, ...], tables], axis=0)
print(tables.shape)
DATASET2IDX = {k: i for i, k in enumerate(["weighted-average", "arguana", "fiqa", "scidocs", "nfcorpus", "hotpotqa", "trec-covid"])}
IDX2DATASET = {i: k for k, i in DATASET2IDX.items()}
# raise NotImplementedError

for dataset in ["weighted-average", "arguana", "fiqa", "scidocs", "nfcorpus", "hotpotqa", "trec-covid"]:
    t0 = tables[DATASET2IDX[dataset]]
    t0 = einops.rearrange(t0, "x y depth -> depth x y")
    # print("T0 shape:", t0.shape)
    fig = animated_heatmap(
        matrices=t0,
        row_labels=[IDX2MODEL[i].split("/")[-1] for i in range(len(IDX2MODEL))],
        col_labels=[IDX2MODEL[i].split("/")[-1] for i in range(len(IDX2MODEL) - 2)],
        title=f"Validation MSE (log) ({dataset})",
        show_values=True,
        value_format=".2e",
        width=800,
        height=600,
        xaxis_title="Target Embedding Space",
        yaxis_title="Native Embedding Space",
        log=True,
        # nan_diagnal=True,
        frame_duration=500,
        slider_labels=list(range(2, 8)),
        slider_prefix="Depth: "
    ) # fetch for first dataset
    # fig.show()
    save_figure(
        fig,
        f"{dataset}_lord_farquad_was_heren",
        output_dir=PROJ_ROOT / "data" / "figs",
    )

[[[[0.    0.    0.    0.    0.    0.   ]
   [0.002 0.002 0.002 0.002 0.002 0.002]
   [0.001 0.001 0.001 0.001 0.001 0.001]
   ...
   [0.002 0.002 0.002 0.002 0.002 0.002]
   [0.001 0.    0.    0.    0.    0.   ]
   [0.001 0.001 0.001 0.001 0.001 0.001]]

  [[0.001 0.001 0.001 0.001 0.001 0.001]
   [0.    0.    0.    0.    0.    0.   ]
   [0.001 0.001 0.001 0.001 0.001 0.001]
   ...
   [0.002 0.002 0.002 0.002 0.002 0.002]
   [0.001 0.    0.    0.    0.    0.   ]
   [0.001 0.001 0.001 0.001 0.001 0.001]]

  [[0.001 0.001 0.001 0.001 0.001 0.001]
   [0.002 0.002 0.002 0.002 0.002 0.002]
   [0.    0.    0.    0.    0.    0.   ]
   ...
   [0.002 0.002 0.002 0.002 0.002 0.002]
   [0.001 0.    0.    0.    0.    0.   ]
   [0.001 0.001 0.001 0.001 0.001 0.001]]

  ...

  [[0.001 0.001 0.001 0.001 0.001 0.001]
   [0.002 0.002 0.002 0.002 0.002 0.002]
   [0.001 0.001 0.001 0.001 0.001 0.001]
   ...
   [0.002 0.002 0.002 0.002 0.002 0.002]
   [0.001 0.    0.    0.    0.    0.   ]
   [0.    0.    


divide by zero encountered in log10



Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


Row labels length: 19
Col labels length: 17


## Import Data


In [23]:
# Load Metadata
"""
# 6 Datasets (ArguAna, FiQA, SciDocs, NFCorpus, HotPotQA, Trec-COVID)
# 17 Native Embedding Spaces
# 17 Stitched Embedding Spaces
# 6 Architectures (2, 3, 4, 5, 6, or 7 layers)
"""
with CKA_META.open() as f:
    meta = orjson.loads(f.read())

# pretty.pprint(meta)

meta_df = pl.read_json(CKA_META)


# def create_index_df(meta: dict) -> pl.DataFrame:
#     """Create an index DataFrame for easy tensor lookup using np.indices.

#     Args:
#         meta: Dictionary containing dataset2idx and model2idx mappings

#     Returns:
#         pl.DataFrame: DataFrame with hierarchical index for tensor lookup
#     """
#     # Get shape from metadata
#     n_datasets = len(meta["dataset2idx"])
#     n_models = len(meta["model2idx"])
#     n_layers = len(meta["num_layers"])

#     # Create indices arrays for each dimension
#     datasets, sources, targets, layers = np.indices((n_datasets, n_models, n_models, n_layers))

#     # Create reverse mapping dictionaries
#     dataset_map = {v: k for k, v in meta["dataset2idx"].items()}
#     model_map = {v: k for k, v in meta["model2idx"].items()}

#     # Create DataFrame
#     return pl.DataFrame(
#         {
#             "dataset": pl.Series(datasets.flatten()).cast_dict(dataset_map),
#             "source_model": pl.Series(sources.flatten()).cast_dict(model_map),
#             "target_model": pl.Series(targets.flatten()).cast_dict(model_map),
#             "num_layers": pl.Series(layers.flatten()) + 2,  # Convert 0-based to 2-based
#             "index": list(
#                 zip(datasets.flatten(), sources.flatten(), targets.flatten(), layers.flatten())
#             ),
#         }
#     )


# create_index_df(meta)
# why am I being silly, work on this later.

In [24]:
# Load Tensors
tensors = load_file(CKA_TENSORS)
logger.info(list(tensors.keys()))
ckas = tensors["table"]
# f"{tensors["validation_mses"].shape=}"
# f"{tensors["table"].shape=}"
print(ckas.shape)

[32m2024-12-13 22:38:53.426[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1m['table'][0m


(6, 17, 17, 1)


In [25]:
# from safetensors.numpy import load_file


# def safetensor_to_polars(
#     tensor_path: Path,
#     meta_path: Path,
# ) -> pl.DataFrame:
#     """Convert safetensor validation MSEs to a polars dataframe.

#     Args:
#         tensor_path: Path to .safetensors file
#         meta_path: Path to metadata JSON file

#     Returns:
#         pl.DataFrame: Long-format dataframe with validation MSEs
#     """
#     # Load metadata
#     with meta_path.open() as f:
#         meta = orjson.loads(f.read())

#     # Load tensor
#     tensor_dict = load_file(tensor_path)  # (6, 17, 17, 6)
#     validation_mses = tensor_dict["validation_mses"]

#     # Create mapping dictionaries
#     dataset_map = {v: k for k, v in meta["dataset2idx"].items()}
#     model_map = {v: k for k, v in meta["model2idx"].items()}

#     # Create index arrays for each dimension
#     datasets, source_models, target_models, architectures = np.indices(validation_mses.shape)

#     # Create dataframe
#     mse_df = pl.DataFrame(
#         {
#             "dataset": datasets.flatten(),
#             "source_model": source_models.flatten(),
#             "target_model": target_models.flatten(),
#             "num_layers": architectures.flatten(),
#             "mse": validation_mses.flatten(),
#         }
#     )

#     # Map indices to names
#     mse_df = mse_df.with_columns(
#         [
#             pl.col("dataset").map_dict(dataset_map).alias("dataset_name"),
#             pl.col("source_model").map_dict(model_map).alias("source_model_name"),
#             pl.col("target_model").map_dict(model_map).alias("target_model_name"),
#             pl.col("num_layers").map_dict(dict(enumerate(meta["num_layers"]))).alias("num_layers"),
#         ]
#     ).drop("dataset", "source_model", "target_model")

#     return mse_df


# # Usage
# mse_df = safetensor_to_polars(
#     tensor_path=MLP_MSE_VALIDATION_TENSORS, meta_path=MLP_MSE_VALIDATION_META
# )
# not working, also fucked around here.

In [26]:
# Print Imported Data (One Dataset, One Architecture)
dataset_i = 0
architecture_i = 0

example_cka = ckas[dataset_i, :, :, architecture_i]
logger.info(example_cka.shape)
pretty.pprint(example_cka)

[32m2024-12-13 22:38:53.432[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1m(17, 17)[0m


In [27]:
# Visualize Imported Data (One Dataset, All Architectures)

from src.viz.save_figure import save_figure


labels = sorted(meta["model2idx"], key=lambda k: meta["model2idx"][k])
num_layers = sorted(meta["num_layers"])

for dataset_name, dataset_i in meta["dataset2idx"].items():
    matrix = ckas[dataset_i, :, :, :].transpose(2, 0, 1)
    fig = animated_heatmap(
        matrices=matrix,
        log=True,
        title=f"Validation MSE (log) ({dataset_name})",
        labels=labels,
        nan_diagnal=True,
        slider_prefix="Number of Layers: ",
        slider_labels=num_layers,
    )
    save_figure(
        fig,
        f"{dataset_name}_all_layers_mse_withlog_validation",
        output_dir=PROJ_ROOT / "data" / "figs",
    )
    save_figure(
        fig,
        f"{dataset_name}_all_layers_mse_withlog_validation",
        output_dir=PROJ_ROOT / "src" / "blog" / "figs" / "gatlen",
    )

ModuleNotFoundError: No module named 'src'

In [None]:
# Lastly get the mean score (not weighted)
dataset_name = "weightedmean"
average_validation_mse = ckas.mean(axis=0)  # Calculate mean across datasets (axis=0)
fig = animated_heatmap(
    matrices=average_validation_mse.transpose(2, 0, 1),  # Reshape to (n_layers, n_models, n_models)
    log=True,
    title=f"Validation MSE (log) ({dataset_name})",
    labels=labels,
    nan_diagnal=True,
    slider_prefix="Number of Layers: ",
    slider_labels=num_layers,
)
save_figure(
    fig,
    f"{dataset_name}_all_layers_mse_withlog_validation",
    output_dir=PROJ_ROOT / "data" / "figs",
)
save_figure(
    fig,
    f"{dataset_name}_all_layers_mse_withlog_validation",
    output_dir=PROJ_ROOT / "src" / "blog" / "figs" / "gatlen",
)

## Get Loss vs Layers


In [None]:
average_validation_mse.shape  # (17, 17, 6)

In [None]:
# Create figure
fig = go.Figure()

# Get model labels for the legend
model_labels = sorted(meta["model2idx"], key=lambda k: meta["model2idx"][k])

# Create traces grouped by source model
for i in range(17):
    # Create a legendgroup for each source model
    legendgroup = model_labels[i]

    for j in range(17):
        if i != j:  # Skip diagonal entries since they're NaN
            y_values = average_validation_mse[i, j, :]
            fig.add_trace(
                go.Scatter(
                    x=x_values,
                    y=y_values,
                    mode="lines",
                    name=f"{model_labels[j]}",  # Only show target model in name
                    legendgroup=legendgroup,  # Group by source model
                    legendgrouptitle_text=f"From: {legendgroup}",  # Add group title
                    showlegend=True,  # Show all traces in legend
                    line=dict(width=1),
                    opacity=0.5,
                )
            )

# Update layout with smaller legend text and grouping
fig.update_layout(
    title="MSE vs Number of Layers for All Model Pairs",
    xaxis_title="Number of Layers",
    yaxis_title="MSE",
    width=1000,
    height=600,
    showlegend=True,
    yaxis_type="log",
    legend=dict(font=dict(size=7), groupclick="toggleitem"),
)

# Add gridlines
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")

# Show figure
fig

In [None]:
# Save figure
save_figure(
    fig,
    "mse_vs_layers_all_pairs",
    output_dir=PROJ_ROOT / "data" / "figs",
)
save_figure(
    fig,
    "mse_vs_layers_all_pairs",
    output_dir=PROJ_ROOT / "src" / "blog" / "figs" / "gatlen",
)

In [None]:
# Create line plot showing average MSE vs number of layers
fig = go.Figure()

# Calculate mean MSE across all model pairs (excluding diagonal)
mean_mse = []
for layer_idx in range(6):  # For each layer configuration
    layer_mse = average_validation_mse[:, :, layer_idx]  # Get MSE matrix for this layer
    # Create mask to exclude diagonal elements
    mask = ~np.eye(17, dtype=bool)
    # Calculate mean excluding diagonal elements
    mean_mse.append(layer_mse[mask].mean())

# Add trace for mean MSE
fig.add_trace(
    go.Scatter(
        x=x_values,
        y=mean_mse,
        mode="lines+markers",
        name="Mean MSE across all model pairs",
        line=dict(width=3, color="blue"),
        marker=dict(size=10),
    )
)

# Update layout
fig.update_layout(
    title="Mean MSE vs Number of Layers (Averaged Across All Model Pairs)",
    xaxis_title="Number of Layers",
    yaxis_title="Mean MSE",
    width=800,
    height=500,
    showlegend=False,
    yaxis_type="log",  # Use log scale for y-axis
    template="plotly_white",  # Clean template
)

# Add gridlines
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")

# Show figure
fig

In [None]:
# Save figure
save_figure(
    fig,
    "mean_mse_vs_layers",
    output_dir=PROJ_ROOT / "data" / "figs",
)
save_figure(
    fig,
    "mean_mse_vs_layers",
    output_dir=PROJ_ROOT / "src" / "blog" / "figs" / "gatlen",
)