Imports

In [None]:
import os
import warnings
from typing import Callable
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display
from src.utils import get_project_root, load, save
from src.lsdata.LatentStates import LSData
from src.analysis.visualize import vis_2d_dim_reduct, stitch_images_grid
from src.lsdata.metadata import StandardTransformerMeta
from src.analysis.visualize import plot_vector_norm_by_layer, plot_vector_norm_by_axis

Arguments

In [None]:
""" Data Args """

### Uncomment for GPT-2 on PG-19 (excluding initial token)
data_name = "gpt2_latents-text-pg19-128_samples-1024_sequence_length-identity"
sequence_selection = slice(1, None)

### Uncomment for LLaMa on PG-19 (excluding initial token)
# data_name = "huggyllama-llama-7b_latents-text-pg19-64_samples-2048_sequence_length-identity"
# sequence_selection = sl

### Uncomment for GPT-2 Singular
# data_name = "gpt2_latents-singular-full_samples-None_sequence_length-identity"
# sequence_selection = None

### Uncomment for LLaMa Singular
# data_name = "huggyllama-llama-7b_latents-singular-full_samples-None_sequence_length-identity"
# sequence_selection = None


layer_filter = lambda x : not x.pre_add
sample_selection = None


""" Visualization Args """

figsize = (6.5, 1)
dpi = 300

fontsize: float | int | str = 10
labelsize: float | int | str = 10
subplots_adjust_kwargs: dict[str, float | None] | None = {
    'left': 0.1,
    'right': 0.98,
    'top': 0.98,
    'bottom': 0.21
}
# Function to format y-axis labels. If None, then y-axis labels will not be formatted.
# yaxis_formatter = None
yaxis_formatter: plt.FuncFormatter | str | Callable | None = plt.FuncFormatter(lambda x, p: f'{x:>4.0f}'.replace('-', '\u2212'))

show_image: bool = False

# Sets global matplotlib fonts
plt.rcParams["font.family"] = 'Nimbus Roman'

""" Other Features """

# max workers for parallelized loading. If None defaults to os.cpu_count()
max_workers = 20

Load Data

In [None]:
data_dir_path = os.path.join(get_project_root(), "processed_data", data_name)
data = LSData(dir_path=data_dir_path, layer_filter=layer_filter, sample_selection=sample_selection, sequence_selection=sequence_selection, max_workers=max_workers)
print("Shape of data:\t", data.data.shape)

Colorization

In [None]:
def colorize(meta:StandardTransformerMeta):
    if meta.is_mlp and not meta.is_norm:
        meta.vis_color = 'A'
    elif meta.is_attn and not meta.is_norm:
        meta.vis_color = 'B'
    elif meta.is_embed:
        meta.vis_color = 'C'
    elif meta.is_norm and meta.is_mlp:
        meta.vis_color = 'D'
    elif meta.is_norm and meta.is_attn:
        meta.vis_color = 'E'
    elif meta.is_norm and not meta.pre_add:
        meta.vis_color = 'F'
    else:
        raise ValueError("Meta colorization was not defined under current rules! Meta:\n", meta)

for meta in data.metas:
    colorize(meta)

Norm Plots

In [None]:
plot_formatting_kwargs = {
    "figsize": figsize,
    "dpi": dpi,
    "fontsize": fontsize,
    "labelsize": labelsize,
    "return_image": True,
    "show_image": show_image,
    "subplots_adjust_kwargs": subplots_adjust_kwargs,
    "yaxis_formatter": yaxis_formatter,
}


print("Mean Norm vs Layer:")
img = plot_vector_norm_by_layer(
    data,
    title=None,
    disable_x_ticks=True,
    **plot_formatting_kwargs
)

display(img)