# Complexity vs MSE and Rank Graphs


## Setup


In [1]:
# Installs
%pip install safetensors

Note: you may need to restart the kernel to use updated packages.


In [3]:
# Imports
from pathlib import Path
from typing import Literal

# from IPython.display import display
import numpy as np
import orjson
import plotly.graph_objects as go
import plotly.io as pio
import polars as pl
from faker import Faker
from loguru import logger
from rich import pretty
from safetensors.numpy import load_file


np.set_printoptions(
    precision=3,
    suppress=True,
    threshold=5,
)
rng = np.random.default_rng(0)
faker = Faker()
pio.templates.default = "plotly"

In [4]:
# Directories
PROJ_ROOT = Path.cwd().resolve().parent
logger.info(PROJ_ROOT)
DATA_DIR = PROJ_ROOT / "data"
logger.info(DATA_DIR)
MLP_MSE_VALIDATION = DATA_DIR / "mlp_mse_validation"
MLP_MSE_VALIDATION_TENSORS = MLP_MSE_VALIDATION / "validation_mses.safetensors"
MLP_MSE_VALIDATION_META = MLP_MSE_VALIDATION / "validation_mses.json"

[32m2025-01-07 20:06:22.873[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m3[0m - [1m/Users/gat/work/FA2024/embedding_translation[0m
[32m2025-01-07 20:06:22.874[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m5[0m - [1m/Users/gat/work/FA2024/embedding_translation/data[0m


## Helper Viz


In [40]:
def get_default_heatmap_fig() -> go.Figure:
    """Create a default heatmap figure with standard configuration.

    Returns:
        go.Figure: A plotly figure with default heatmap settings
    """
    # General Viz Config
    heatmap_layout = go.Layout(
        modebar=go.layout.Modebar(
            remove=[
                "zoomIn2d",
                "zoomOut2d",
                "autoScale2d",
                "resetScale2d",
                "hoverClosestCartesian",
                "hoverCompareCartesian",
                "toggleSpikelines",
            ]
        ),
        dragmode=False,
        width=800,
        height=600,
    )

    heatmap_trace = go.Heatmap(
        colorscale="Viridis",
        textfont={"color": "black"},
        reversescale=False,
        showscale=True,
        showlegend=False,
    )

    return go.Figure(data=heatmap_trace, layout=heatmap_layout)


# Create example data
x_labels = ["A", "B", "C"]  # <--- [NEW] Added example x labels
y_labels = ["X", "Y", "Z"]  # <--- [NEW] Added example y labels
z_values = [
    [1, 2, 3],  # <--- [NEW] Added 2D array for heatmap values
    [4, 5, 6],
    [7, 8, 9],
]

# Create and configure figure
fig = get_default_heatmap_fig()
heatmap_trace: go.Heatmap = fig.data[0]
heatmap_trace.x = x_labels  # <--- [CHANGED] Set x axis labels
heatmap_trace.y = y_labels  # <--- [CHANGED] Set y axis labels
heatmap_trace.z = z_values  # <--- [CHANGED] Set z values as 2D array

# Update layout with titles
fig.update_layout(  # <--- [NEW] Added titles
    title="Example Heatmap", xaxis_title="X Axis", yaxis_title="Y Axis"
)

fig  # Display the figure

In [41]:
def heatmap(
    matrix: np.ndarray | None = None,
    labels: list | None = None,
    title: str = "DEFAULT TITLE",
    show_values: bool = True,
    value_format: Literal[".2f", ".2e"] = ".2f",
    width: int = 800,
    height: int = 600,
    xaxis_title: str | None = None,
    yaxis_title: str | None = None,
    log: bool = False,
    nan_diagnal: bool = False,
) -> go.Figure:
    """Create a heatmap visualization using plotly.

    Args:
        matrix: Input matrix to visualize. If None, generates random demo data
        labels: Labels for rows and columns. If None, generates fake names
        title: Title of the heatmap
        show_values: Whether to show values in cells
        value_format: Format string for cell values
        width: Width of the figure in pixels
        height: Height of the figure in pixels
        xaxis_title: Title for x-axis
        yaxis_title: Title for y-axis
        log: Whether to log10 transform the values
        nan_diagnal: Whether to set diagonal to NaN

    Returns:
        go.Figure: Plotly figure object
    """
    if matrix is None:
        faker = Faker()
        shape = (10, 10)
        matrix = rng.random(shape)
    else:
        shape = matrix.shape

    if labels is None:
        labels = [faker.name() for _ in range(shape[0])]
        logger.info(labels)
    logger.info("Creating heatmap visualization...")

    # Convert input to numpy array and handle None values
    matrix = np.array(matrix, dtype=object)
    none_mask = matrix is None
    matrix = matrix.astype(float)
    if log:
        matrix = np.log10(matrix)
    matrix[none_mask] = np.nan
    if nan_diagnal:
        matrix[np.identity(matrix.shape[0], dtype=bool)] = np.nan

    # Format cell values if needed
    text_vals = None
    if show_values:
        text_vals = np.vectorize(
            lambda x: "N/A" if np.isnan(x) else f"{x:{value_format}}"
        )(matrix)

    # Create figure
    fig = get_default_heatmap_fig()
    heatmap_trace: go.Heatmap = fig.data[0]
    heatmap_trace.update(
        go.Heatmap(
            z=matrix,
            x=labels,
            y=labels,
            text=text_vals,
            texttemplate="%{text}" if show_values else None,
        ),
        overwrite=True,
    )

    # Update layout using plotly objects
    fig.update_layout(
        title=title,
        height=height,
        width=width,
    )

    # Update axes
    fig.update_xaxes(
        title=xaxis_title,
        nticks=len(labels) if labels else None,
        fixedrange=True,  # Disable zoom/pan
    )

    fig.update_yaxes(
        title=yaxis_title,
        nticks=len(labels) if labels else None,
        fixedrange=True,  # Disable zoom/pan
    )

    return fig


heatmap()

[32m2024-12-13 18:55:47.752[0m | [1mINFO    [0m | [36m__main__[0m:[36mheatmap[0m:[36m41[0m - [1m['Charles Henderson', 'Melinda Trujillo', 'Sarah Blake', 'Sandra Gonzales', 'Tammy Carson', 'Paul York', 'Patricia Torres', 'Rachel Tanner', 'James Thompson', 'Micheal Lucas'][0m
[32m2024-12-13 18:55:47.752[0m | [1mINFO    [0m | [36m__main__[0m:[36mheatmap[0m:[36m42[0m - [1mCreating heatmap visualization...[0m


In [53]:
def animated_heatmap(
    matrices: np.ndarray | None = None,
    labels: list | None = None,
    title: str = "DEFAULT TITLE",
    show_values: bool = True,
    value_format: Literal[".2f", ".2e"] = ".2f",
    width: int = 800,
    height: int = 600,
    xaxis_title: str | None = None,
    yaxis_title: str | None = None,
    log: bool = False,
    nan_diagnal: bool = False,
    frame_duration: int = 500,
    slider_labels: list[str] | None = None,
    slider_prefix: str = "Frame: ",
):
    # Generate demo data if matrices is None
    if matrices is None:
        faker = Faker()
        shape = (5, 10, 10)  # (time, rows, cols)
        matrices = rng.random(shape)
        # logger.info(matrices)
    elif matrices.ndim != 3:
        raise ValueError("Expected 3D array (time, rows, cols)")

    n_frames, n_rows, n_cols = matrices.shape

    if labels is None:
        faker = Faker()
        labels = [faker.name() for _ in range(n_rows)]
    col_labels = labels
    row_labels = labels

    # Process all matrices
    processed_matrices = []
    for matrix in matrices:
        # Convert and process each time slice
        matrix = np.array(matrix, dtype=object)
        none_mask = matrix is None
        matrix = matrix.astype(float)
        if log:
            matrix = np.log10(matrix)
        matrix[none_mask] = np.nan
        if nan_diagnal:
            matrix[np.identity(matrix.shape[0], dtype=bool)] = np.nan
        processed_matrices.append(matrix)

        # Update the sliders configuration with custom labels
    if slider_labels is not None:
        if len(slider_labels) != n_frames:
            raise ValueError(
                f"Expected {n_frames} slider labels, got {len(slider_labels)}"
            )
        frame_labels = slider_labels
    else:
        frame_labels = [str(k) for k in range(n_frames)]

    layout_sliders = [
        go.layout.Slider(
            active=0,
            yanchor="bottom",
            xanchor="left",
            currentvalue=go.layout.slider.Currentvalue(
                font=dict(size=20), prefix=slider_prefix, visible=True, xanchor="right"
            ),
            transition=go.layout.slider.Transition(
                duration=frame_duration, easing="cubic-in-out"
            ),
            pad=dict(b=10, t=50),
            len=0.9,
            x=0.1,
            y=1.1,
            steps=[
                go.layout.slider.Step(
                    args=[
                        [f"frame_{k}"],
                        {
                            "frame": {"duration": frame_duration, "redraw": True},
                            "mode": "immediate",
                            "transition": {"duration": frame_duration},
                        },
                    ],
                    label=label,
                    method="animate",
                )
                for k, label in enumerate(frame_labels)
            ],
        )
    ]

    # Create frames for animation
    frames = [
        go.Frame(
            data=[heatmap(matrix, row_labels, show_values=show_values).data[0]],
            name=f"frame_{i}",  # name is required for frames
        )
        for i, matrix in enumerate(processed_matrices)
    ]

    # Create the figure with the initial frame
    fig = go.Figure(
        data=[
            heatmap(processed_matrices[0], row_labels, show_values=show_values).data[0]
        ],
        frames=frames,
        layout=go.Layout(
            title=title,
            width=width,
            height=height,
            xaxis=go.layout.XAxis(
                title=xaxis_title,
                nticks=len(col_labels) if col_labels else None,
                fixedrange=True,
            ),
            yaxis=go.layout.YAxis(
                title=yaxis_title,
                nticks=len(row_labels) if row_labels else None,
                fixedrange=True,
            ),
            dragmode=False,
            updatemenus=[
                go.layout.Updatemenu(
                    buttons=[
                        # go.layout.updatemenu.Button(
                        #     args=[
                        #         None,
                        #         {
                        #             "frame": {"duration": frame_duration, "redraw": True},
                        #             "fromcurrent": True,
                        #         },
                        #     ],
                        #     label="Play",
                        #     method="animate",
                        # ),
                        # go.layout.updatemenu.Button(
                        #     args=[
                        #         [None],
                        #         {
                        #             "frame": {"duration": 0, "redraw": True},
                        #             "mode": "immediate",
                        #             "transition": {"duration": 0},
                        #         },
                        #     ],
                        #     label="Pause",
                        #     method="animate",
                        # ),
                    ],
                    direction="left",
                    pad={"r": 10, "t": 87},
                    showactive=False,
                    type="buttons",
                    x=0.1,
                    xanchor="right",
                    y=1.1,
                    yanchor="bottom",
                )
            ],
        ),
    )

    # Update the layout dictionary to include the new sliders
    fig.update_layout(sliders=layout_sliders)

    return fig


animated_heatmap()

[32m2024-12-13 19:44:18.596[0m | [1mINFO    [0m | [36m__main__[0m:[36mheatmap[0m:[36m42[0m - [1mCreating heatmap visualization...[0m
[32m2024-12-13 19:44:18.603[0m | [1mINFO    [0m | [36m__main__[0m:[36mheatmap[0m:[36m42[0m - [1mCreating heatmap visualization...[0m
[32m2024-12-13 19:44:18.609[0m | [1mINFO    [0m | [36m__main__[0m:[36mheatmap[0m:[36m42[0m - [1mCreating heatmap visualization...[0m
[32m2024-12-13 19:44:18.615[0m | [1mINFO    [0m | [36m__main__[0m:[36mheatmap[0m:[36m42[0m - [1mCreating heatmap visualization...[0m
[32m2024-12-13 19:44:18.621[0m | [1mINFO    [0m | [36m__main__[0m:[36mheatmap[0m:[36m42[0m - [1mCreating heatmap visualization...[0m
[32m2024-12-13 19:44:18.626[0m | [1mINFO    [0m | [36m__main__[0m:[36mheatmap[0m:[36m42[0m - [1mCreating heatmap visualization...[0m


## Import Data


In [5]:
# Load Metadata
"""
# 6 Datasets (ArguAna, FiQA, SciDocs, NFCorpus, HotPotQA, Trec-COVID)
# 17 Native Embedding Spaces
# 17 Stitched Embedding Spaces
# 6 Architectures (2, 3, 4, 5, 6, or 7 layers)
"""
with MLP_MSE_VALIDATION_META.open() as f:
    meta = orjson.loads(f.read())

# pretty.pprint(meta)

meta_df = pl.read_json(MLP_MSE_VALIDATION_META)


def create_index_df(meta: dict) -> pl.DataFrame:
    """Create an index DataFrame for easy tensor lookup using np.indices.

    Args:
        meta: Dictionary containing dataset2idx and model2idx mappings

    Returns:
        pl.DataFrame: DataFrame with hierarchical index for tensor lookup
    """
    # Get shape from metadata
    n_datasets = len(meta["dataset2idx"])
    n_models = len(meta["model2idx"])
    n_layers = len(meta["num_layers"])

    # Create indices arrays for each dimension
    datasets, sources, targets, layers = np.indices((n_datasets, n_models, n_models, n_layers))

    # Create reverse mapping dictionaries
    dataset_map = {v: k for k, v in meta["dataset2idx"].items()}
    model_map = {v: k for k, v in meta["model2idx"].items()}

    # Create DataFrame
    return pl.DataFrame(
        {
            "dataset": pl.Series(datasets.flatten()).cast_dict(dataset_map),
            "source_model": pl.Series(sources.flatten()).cast_dict(model_map),
            "target_model": pl.Series(targets.flatten()).cast_dict(model_map),
            "num_layers": pl.Series(layers.flatten()) + 2,  # Convert 0-based to 2-based
            "index": list(
                zip(datasets.flatten(), sources.flatten(), targets.flatten(), layers.flatten()),
            ),
        }
    )


create_index_df(meta)
# why am I being silly, work on this later.

AttributeError: 'Series' object has no attribute 'cast_dict'

In [32]:
# Load Tensors
tensors = load_file(MLP_MSE_VALIDATION_TENSORS)
logger.info(list(tensors.keys()))
validation_mse = tensors["validation_mses"]
# f"{tensors["validation_mses"].shape=}"
f"{tensors["validation_mses"].shape=}"

[32m2024-12-13 18:50:24.839[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m3[0m - [1m['validation_mses'][0m


'tensors["validation_mses"].shape=(6, 17, 17, 6)'

In [33]:
from safetensors.numpy import load_file


def safetensor_to_polars(
    tensor_path: Path,
    meta_path: Path,
) -> pl.DataFrame:
    """Convert safetensor validation MSEs to a polars dataframe.

    Args:
        tensor_path: Path to .safetensors file
        meta_path: Path to metadata JSON file

    Returns:
        pl.DataFrame: Long-format dataframe with validation MSEs
    """
    # Load metadata
    with meta_path.open() as f:
        meta = orjson.loads(f.read())

    # Load tensor
    tensor_dict = load_file(tensor_path)  # (6, 17, 17, 6)
    validation_mses = tensor_dict["validation_mses"]

    # Create mapping dictionaries
    dataset_map = {v: k for k, v in meta["dataset2idx"].items()}
    model_map = {v: k for k, v in meta["model2idx"].items()}

    # Create index arrays for each dimension
    datasets, source_models, target_models, architectures = np.indices(validation_mses.shape)

    # Create dataframe
    mse_df = pl.DataFrame(
        {
            "dataset": datasets.flatten(),
            "source_model": source_models.flatten(),
            "target_model": target_models.flatten(),
            "num_layers": architectures.flatten(),
            "mse": validation_mses.flatten(),
        }
    )

    # Map indices to names
    mse_df = mse_df.with_columns(
        [
            pl.col("dataset").map_dict(dataset_map).alias("dataset_name"),
            pl.col("source_model").map_dict(model_map).alias("source_model_name"),
            pl.col("target_model").map_dict(model_map).alias("target_model_name"),
            pl.col("num_layers").map_dict(dict(enumerate(meta["num_layers"]))).alias("num_layers"),
        ]
    ).drop("dataset", "source_model", "target_model")

    return mse_df


# Usage
mse_df = safetensor_to_polars(
    tensor_path=MLP_MSE_VALIDATION_TENSORS, meta_path=MLP_MSE_VALIDATION_META
)
# not working, also fucked around here.

In [34]:
# Print Imported Data (One Dataset, One Architecture)
dataset_i = 0
architecture_i = 0

example_validation_mse = validation_mse[dataset_i, :, :, architecture_i]
logger.info(example_validation_mse.shape)
pretty.pprint(example_validation_mse)

[32m2024-12-13 18:50:26.756[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 0>[0m:[36m6[0m - [1m(17, 17)[0m


In [35]:
# Visualize Imported Data (One Dataset, One Architecture)
labels = sorted(meta["model2idx"], key=lambda k: meta["model2idx"][k])
matrix = example_validation_mse
heatmap(
    matrix,
    log=True,
    title="Example Validation MSE (log)",
    labels=labels,
    nan_diagnal=True,
)

[32m2024-12-13 18:50:27.515[0m | [1mINFO    [0m | [36m__main__[0m:[36mheatmap[0m:[36m42[0m - [1mCreating heatmap visualization...[0m


In [63]:
# Visualize Imported Data (One Dataset, All Architectures)

from src.viz.save_figure import save_figure


save_figure(fig, "complexity_vs_rank")
labels = sorted(meta["model2idx"], key=lambda k: meta["model2idx"][k])
num_layers = sorted(meta["num_layers"])

for dataset_name, dataset_i in meta["dataset2idx"].items():
    matrix = validation_mse[dataset_i, :, :, :].transpose(2, 0, 1)
    fig = animated_heatmap(
        matrices=matrix,
        log=True,
        title=f"Validation MSE (log) ({dataset_name})",
        labels=labels,
        nan_diagnal=True,
        slider_prefix="Number of Layers: ",
        slider_labels=num_layers,
    )
    save_figure(
        fig,
        f"{dataset_name}_all_layers_mse_withlog_validation",
        output_dir=PROJ_ROOT / "data" / "figs",
    )
    save_figure(
        fig,
        f"{dataset_name}_all_layers_mse_withlog_validation",
        output_dir=PROJ_ROOT / "src" / "blog" / "figs" / "gatlen",
    )

[2;36m[12/13/24 20:05:15][0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m Saving HTML figure to            ]8;id=899573;file:///Users/gat/work/FA2024/embedding_translation/src/viz/save_figure.py\[1;36msave_figure.py[0m]8;;\[1;36m:[0m]8;id=490907;file:///Users/gat/work/FA2024/embedding_translation/src/viz/save_figure.py#32\[1;36m32[0m]8;;\
[2;36m                    [0m[1;32m         [0mdata/figs/html/complexity_vs_rank [1;36m                 [0m
[2;36m                    [0m[1;32m         [0m.html                             [1;36m                 [0m


[2;36m                   [0m[2;36m [0m[1;32mSUCCESS [0m[1;32m [0m Successfully saved HTML to       ]8;id=482313;file:///Users/gat/work/FA2024/embedding_translation/src/viz/save_figure.py\[1;36msave_figure.py[0m]8;;\[1;36m:[0m]8;id=275185;file:///Users/gat/work/FA2024/embedding_translation/src/viz/save_figure.py#40\[1;36m40[0m]8;;\
[2;36m                    [0m[1;32m         [0mdata/figs/html/complexity_vs_rank [1;36m                 [0m
[2;36m                    [0m[1;32m         [0m.html                             [1;36m                 [0m
[2;36m                   [0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m Saving PNG figure to             ]8;id=823599;file:///Users/gat/work/FA2024/embedding_translation/src/viz/save_figure.py\[1;36msave_figure.py[0m]8;;\[1;36m:[0m]8;id=30733;file:///Users/gat/work/FA2024/embedding_translation/src/viz/save_figure.py#67\[1;36m67[0m]8;;\
[2;36m                    [0m[1;32m         [0mdata/figs/im

In [68]:
# Lastly get the mean score (not weighted)
dataset_name = "weightedmean"
average_validation_mse = validation_mse.mean(
    axis=0
)  # Calculate mean across datasets (axis=0)
fig = animated_heatmap(
    matrices=average_validation_mse.transpose(
        2, 0, 1
    ),  # Reshape to (n_layers, n_models, n_models)
    log=True,
    title=f"Validation MSE (log) ({dataset_name})",
    labels=labels,
    nan_diagnal=True,
    slider_prefix="Number of Layers: ",
    slider_labels=num_layers,
)
save_figure(
    fig,
    f"{dataset_name}_all_layers_mse_withlog_validation",
    output_dir=PROJ_ROOT / "data" / "figs",
)
save_figure(
    fig,
    f"{dataset_name}_all_layers_mse_withlog_validation",
    output_dir=PROJ_ROOT / "src" / "blog" / "figs" / "gatlen",
)

[2;36m[12/13/24 20:17:06][0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m[1m[[0m[1;36m220785227.[0mpy:[1;36m42[0m[1m][0m Creating heatmap  ]8;id=641840;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_39579/220785227.py\[1;36m220785227.py[0m]8;;\[1;36m:[0m]8;id=469499;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_39579/220785227.py#42\[1;36m42[0m]8;;\
[2;36m                    [0m[1;32m         [0mvisualization[33m...[0m                    [1;36m               [0m
[2;36m                   [0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m[1m[[0m[1;36m220785227.[0mpy:[1;36m42[0m[1m][0m Creating heatmap  ]8;id=290397;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_39579/220785227.py\[1;36m220785227.py[0m]8;;\[1;36m:[0m]8;id=260346;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_39579/220785227.py#42\[1;36m42[0m]8;;\
[2;36m                    [0m[1;32m         [0m

[2;36m                   [0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m[1m[[0m[1;36m220785227.[0mpy:[1;36m42[0m[1m][0m Creating heatmap  ]8;id=789049;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_39579/220785227.py\[1;36m220785227.py[0m]8;;\[1;36m:[0m]8;id=889282;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_39579/220785227.py#42\[1;36m42[0m]8;;\
[2;36m                    [0m[1;32m         [0mvisualization[33m...[0m                    [1;36m               [0m
[2;36m                   [0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m[1m[[0m[1;36m220785227.[0mpy:[1;36m42[0m[1m][0m Creating heatmap  ]8;id=643182;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_39579/220785227.py\[1;36m220785227.py[0m]8;;\[1;36m:[0m]8;id=590298;file:///var/folders/6f/f4k9qbkd5nqfywwdhvwdtl5m0000gn/T/ipykernel_39579/220785227.py#42\[1;36m42[0m]8;;\
[2;36m                    [0m[1;32m         [0m

(PosixPath('/Users/gat/work/FA2024/embedding_translation/src/blog/figs/gatlen/html/weightedmean_all_layers_mse_withlog_validation.html'),
 PosixPath('/Users/gat/work/FA2024/embedding_translation/src/blog/figs/gatlen/imgs/weightedmean_all_layers_mse_withlog_validation.png'))

## Get Loss vs Layers


In [70]:
average_validation_mse.shape  # (17, 17, 6)

(17, 17, 6)

In [80]:
# Create figure
fig = go.Figure()

# Get model labels for the legend
model_labels = sorted(meta["model2idx"], key=lambda k: meta["model2idx"][k])

# Create traces grouped by source model
for i in range(17):
    # Create a legendgroup for each source model
    legendgroup = model_labels[i]

    for j in range(17):
        if i != j:  # Skip diagonal entries since they're NaN
            y_values = average_validation_mse[i, j, :]
            fig.add_trace(
                go.Scatter(
                    x=x_values,
                    y=y_values,
                    mode="lines",
                    name=f"{model_labels[j]}",  # Only show target model in name
                    legendgroup=legendgroup,  # Group by source model
                    legendgrouptitle_text=f"From: {legendgroup}",  # Add group title
                    showlegend=True,  # Show all traces in legend
                    line=dict(width=1),
                    opacity=0.5,
                )
            )

# Update layout with smaller legend text and grouping
fig.update_layout(
    title="MSE vs Number of Layers for All Model Pairs",
    xaxis_title="Number of Layers",
    yaxis_title="MSE",
    width=1000,
    height=600,
    showlegend=True,
    yaxis_type="log",
    legend=dict(font=dict(size=7), groupclick="toggleitem"),
)

# Add gridlines
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")

# Show figure
fig

In [81]:
# Save figure
save_figure(
    fig,
    "mse_vs_layers_all_pairs",
    output_dir=PROJ_ROOT / "data" / "figs",
)
save_figure(
    fig,
    "mse_vs_layers_all_pairs",
    output_dir=PROJ_ROOT / "src" / "blog" / "figs" / "gatlen",
)

[2;36m[12/13/24 20:37:17][0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m Saving HTML figure to            ]8;id=526655;file:///Users/gat/work/FA2024/embedding_translation/src/viz/save_figure.py\[1;36msave_figure.py[0m]8;;\[1;36m:[0m]8;id=624328;file:///Users/gat/work/FA2024/embedding_translation/src/viz/save_figure.py#32\[1;36m32[0m]8;;\
[2;36m                    [0m[1;32m         [0m[35m/Users/gat/work/FA2024/embedding_[0m [1;36m                 [0m
[2;36m                    [0m[1;32m         [0m[35mtranslation/data/figs/html/[0m[95mmse_vs[0m [1;36m                 [0m
[2;36m                    [0m[1;32m         [0m[95m_layers_all_pairs.html[0m            [1;36m                 [0m
[2;36m                   [0m[2;36m [0m[1;32mSUCCESS [0m[1;32m [0m Successfully saved HTML to       ]8;id=730259;file:///Users/gat/work/FA2024/embedding_translation/src/viz/save_figure.py\[1;36msave_figure.py[0m]8;;\[1;36m:[0m]8;id=550151;file:///User

(PosixPath('/Users/gat/work/FA2024/embedding_translation/src/blog/figs/gatlen/html/mse_vs_layers_all_pairs.html'),
 PosixPath('/Users/gat/work/FA2024/embedding_translation/src/blog/figs/gatlen/imgs/mse_vs_layers_all_pairs.png'))

In [74]:
# Create line plot showing average MSE vs number of layers
fig = go.Figure()

# Calculate mean MSE across all model pairs (excluding diagonal)
mean_mse = []
for layer_idx in range(6):  # For each layer configuration
    layer_mse = average_validation_mse[:, :, layer_idx]  # Get MSE matrix for this layer
    # Create mask to exclude diagonal elements
    mask = ~np.eye(17, dtype=bool)
    # Calculate mean excluding diagonal elements
    mean_mse.append(layer_mse[mask].mean())

# Add trace for mean MSE
fig.add_trace(
    go.Scatter(
        x=x_values,
        y=mean_mse,
        mode="lines+markers",
        name="Mean MSE across all model pairs",
        line=dict(width=3, color="blue"),
        marker=dict(size=10),
    )
)

# Update layout
fig.update_layout(
    title="Mean MSE vs Number of Layers (Averaged Across All Model Pairs)",
    xaxis_title="Number of Layers",
    yaxis_title="Mean MSE",
    width=800,
    height=500,
    showlegend=False,
    yaxis_type="log",  # Use log scale for y-axis
    template="plotly_white",  # Clean template
)

# Add gridlines
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")

# Show figure
fig

In [75]:
# Save figure
save_figure(
    fig,
    "mean_mse_vs_layers",
    output_dir=PROJ_ROOT / "data" / "figs",
)
save_figure(
    fig,
    "mean_mse_vs_layers",
    output_dir=PROJ_ROOT / "src" / "blog" / "figs" / "gatlen",
)

[2;36m[12/13/24 20:25:08][0m[2;36m [0m[1;34mINFO    [0m[1;32m [0m Saving HTML figure to            ]8;id=266039;file:///Users/gat/work/FA2024/embedding_translation/src/viz/save_figure.py\[1;36msave_figure.py[0m]8;;\[1;36m:[0m]8;id=967827;file:///Users/gat/work/FA2024/embedding_translation/src/viz/save_figure.py#32\[1;36m32[0m]8;;\
[2;36m                    [0m[1;32m         [0m[35m/Users/gat/work/FA2024/embedding_[0m [1;36m                 [0m
[2;36m                    [0m[1;32m         [0m[35mtranslation/data/figs/html/[0m[95mmean_m[0m [1;36m                 [0m
[2;36m                    [0m[1;32m         [0m[95mse_vs_layers.html[0m                 [1;36m                 [0m
[2;36m                   [0m[2;36m [0m[1;32mSUCCESS [0m[1;32m [0m Successfully saved HTML to       ]8;id=769225;file:///Users/gat/work/FA2024/embedding_translation/src/viz/save_figure.py\[1;36msave_figure.py[0m]8;;\[1;36m:[0m]8;id=185062;file:///User

(PosixPath('/Users/gat/work/FA2024/embedding_translation/src/blog/figs/gatlen/html/mean_mse_vs_layers.html'),
 PosixPath('/Users/gat/work/FA2024/embedding_translation/src/blog/figs/gatlen/imgs/mean_mse_vs_layers.png'))