# Model Emphasis and Analysis Demo

This notebook demonstrates the key features of the model emphasis, analysis, and generation capabilities.

In [1]:
import mlx.core as mx
import mlx.nn as nn

from mi_experiments.utils.loading import load

import re
from dataclasses import dataclass
from typing import Dict, Optional, Union

from mlx_lm.tokenizer_utils import TokenizerWrapper

import pandas as pd
import numpy as np
from sklearn.decomposition import PCA

import plotly.graph_objects as go

from mi_experiments.core.cache import BatchedKVCache

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# Load the model
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")

# Utility functions
def create_additive_causal_mask(N: int, offset: int = 0):
    rinds = mx.arange(offset + N)
    linds = mx.arange(offset, offset + N) if offset else rinds
    mask = linds[:, None] < rinds[None]
    return mask * -1e9

@dataclass
class ModelArgs:
    model_type: str
    hidden_size: int
    num_hidden_layers: int
    intermediate_size: int
    num_attention_heads: int
    rms_norm_eps: float
    vocab_size: int
    head_dim: Optional[int] = None
    max_position_embeddings: Optional[int] = None
    num_key_value_heads: Optional[int] = None
    attention_bias: bool = False
    mlp_bias: bool = False
    rope_theta: float = 10000
    rope_traditional: bool = False
    rope_scaling: Optional[Dict[str, Union[float, str]]] = None
    tie_word_embeddings: bool = True

    @classmethod
    def from_model(cls, model: nn.Module):
        return cls(
            model_type=model.model_type,
            hidden_size=model.args.hidden_size,
            num_hidden_layers=model.args.num_hidden_layers,
            intermediate_size=model.args.intermediate_size,
            num_attention_heads=model.args.num_attention_heads,
            rms_norm_eps=model.args.rms_norm_eps,
            vocab_size=model.args.vocab_size,
            head_dim=model.head_dim,
            num_key_value_heads=model.n_kv_heads,
            tie_word_embeddings=model.args.tie_word_embeddings
        )

# Extract model arguments
model_args = ModelArgs.from_model(model)

# Ensure tokenizer is wrapped
if not isinstance(tokenizer, TokenizerWrapper):
    tokenizer = TokenizerWrapper(tokenizer)

def format_prompt(prompt: str) -> str:
    return f"[INST] {prompt} [/INST]"

def extract_generated_text(full_text: str) -> str:
    match = re.search(r'\[/INST\]\s*(.*)', full_text, re.DOTALL)
    return match.group(1).strip() if match else ""

Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00, 97218.97it/s]


In [3]:
def get_pca_values_df_aligned_batched(model, tokenizer):

        # Set padding token if not already set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
        
    entries = ["Monday is", "Tuesday is", "Wednesday is", "Thursday is", "Friday is", "Saturday is", "Sunday is"]
    
    # Prepare all prompts
    prompts = [f"The day is {entry}" for entry in entries]
    
    # Tokenize all prompts at once
    tokenized = tokenizer._tokenizer(prompts, padding=True, return_tensors="np")
    tokens = mx.array(tokenized['input_ids'])
    
    num_layers = len(model.layers)
    batch_size = len(entries)
    pca_values = []
    
    # Get token values for all entries
    token_values = [
        tokenizer._tokenizer.decode(tokenized['input_ids'][i][-1])
        for i in range(batch_size)
    ]
    
    # First pass to get reference from last layer
    kv_heads = ([model.n_kv_heads] * num_layers 
                if isinstance(model.n_kv_heads, int)
                else model.n_kv_heads)
    
    # Create cache for last layer
    cache = [BatchedKVCache(model.head_dim, n, batch_size) for n in kv_heads]
    layer_output = model.get_layer_output(num_layers-1, tokens, cache=cache)
    
    # Get hidden states for last tokens
    hidden_states = layer_output[:, -1, :].tolist()
    hidden_states = np.array(hidden_states)
    hidden_states = (hidden_states - np.mean(hidden_states, axis=0)) / (np.std(hidden_states, axis=0) + 1e-8)
    
    pca = PCA(n_components=3, svd_solver='full')
    reference_proj = pca.fit_transform(hidden_states)
    
    # Process all layers
    for layer_num in range(num_layers):
        # Create new cache for each layer
        cache = [BatchedKVCache(model.head_dim, n, batch_size) for n in kv_heads]
        
        # Get layer output for all entries at once
        layer_output = model.get_layer_output(layer_num, tokens, cache=cache)
        
        # Get hidden states for last tokens
        hidden_states = layer_output[:, -1, :].tolist()
        hidden_states = np.array(hidden_states)
        hidden_states = (hidden_states - np.mean(hidden_states, axis=0)) / (np.std(hidden_states, axis=0) + 1e-8)
        
        pca = PCA(n_components=3, svd_solver='full')
        projected = pca.fit_transform(hidden_states)
        
        explained_variance = pca.explained_variance_ratio_
        cumulative_variance = np.cumsum(explained_variance)
        
        # Align with reference projection
        if layer_num != num_layers-1:
            for i in range(3):
                corr = np.corrcoef(reference_proj[:, i], projected[:, i])[0, 1]
                if corr < 0:
                    projected[:, i] *= -1
        
        for i, entry in enumerate(entries):
            pca_values.append({
                'Layer': layer_num,
                'Entry': entry,
                'Token': token_values[i],
                'PCA1': round(projected[i, 0], 4),
                'PCA2': round(projected[i, 1], 4),
                'PCA3': round(projected[i, 2], 4),
                'PCA1_var': round(explained_variance[0] * 100, 2),
                'PCA2_var': round(explained_variance[1] * 100, 2),
                'PCA3_var': round(explained_variance[2] * 100, 2),
                'Cumulative_var': round(cumulative_variance[2] * 100, 2)
            })
    
    return pd.DataFrame(pca_values)

In [4]:

# Usage
df_aligned = get_pca_values_df_aligned_batched(model, tokenizer)
print("\nFull DataFrame:")
display(df_aligned.style.set_table_styles([{'selector': '', 'props': [('max-height', '400px'), ('overflow-y', 'scroll'), ('display', 'block')]}]))


Full DataFrame:


Unnamed: 0,Layer,Entry,Token,PCA1,PCA2,PCA3,PCA1_var,PCA2_var,PCA3_var,Cumulative_var
0,0,Monday is,is,-19.0723,-48.1088,-27.6753,31.77,18.32,17.1,67.19
1,0,Tuesday is,is,-32.2956,5.5349,16.7518,31.77,18.32,17.1,67.19
2,0,Wednesday is,is,-34.9434,2.4121,27.2842,31.77,18.32,17.1,67.19
3,0,Thursday is,is,-23.1915,10.9835,11.6487,31.77,18.32,17.1,67.19
4,0,Friday is,is,7.0455,14.3527,-48.0048,31.77,18.32,17.1,67.19
5,0,Saturday is,is,33.2801,42.5242,-5.341,31.77,18.32,17.1,67.19
6,0,Sunday is,is,69.1773,-27.6986,25.3364,31.77,18.32,17.1,67.19
7,1,Monday is,is,-25.4688,-39.4841,-28.0822,28.04,20.08,17.41,65.54
8,1,Tuesday is,is,-25.6166,-5.838,20.5783,28.04,20.08,17.41,65.54
9,1,Wednesday is,is,-26.429,-5.617,17.8833,28.04,20.08,17.41,65.54


In [5]:
def plot_pca_from_df(df):
    num_layers = df['Layer'].nunique()
    entries = df['Entry'].unique()
    
    # Generate colors dynamically based on number of entries
    n_colors = len(entries)
    colors = [f'hsl({h},70%,50%)' for h in np.linspace(0, 360, n_colors, endpoint=False)]
    
    # Create single plot
    fig = go.Figure()
    
    # Find the max absolute value across all PCA dimensions to set balanced axis ranges
    max_abs_val = max(
        abs(df['PCA1']).max(),
        abs(df['PCA2']).max(),
        abs(df['PCA3']).max()
    )
    axis_range = [-max_abs_val, max_abs_val]
    
    # Plot each layer
    for layer_num in range(num_layers):
        layer_data = df[df['Layer'] == layer_num]
        
        for entry, color in zip(entries, colors):
            entry_data = layer_data[layer_data['Entry'] == entry]
            # Scale opacity based on cumulative variance (0.2-0.9 range)
            opacity = 0.2 + (0.7 * (1 - entry_data['Cumulative_var'].iloc[0]/100))
            fig.add_trace(
                go.Scatter3d(
                    x=[entry_data['PCA1'].iloc[0]],
                    y=[entry_data['PCA2'].iloc[0]],
                    z=[entry_data['PCA3'].iloc[0]],
                    mode='markers',
                    marker=dict(color=color, size=5, opacity=opacity),
                    name=entry,  # Use entry as the name for legend grouping
                    legendgroup=entry,  # Group by entry for filtering
                    showlegend=(layer_num == 0)  # Only show in legend for first layer
                )
            )

    # Update layout
    fig.update_layout(
        height=800,
        width=800,
        title_text="3D PCA Projection of Day Representations Across All Layers",
        scene=dict(
            xaxis=dict(range=axis_range),
            yaxis=dict(range=axis_range),
            zaxis=dict(range=axis_range),
            xaxis_title="PCA1",
            yaxis_title="PCA2", 
            zaxis_title="PCA3",
            aspectmode='cube',
            camera=dict(
                up=dict(x=0, y=0, z=1),
                center=dict(x=0, y=0, z=0),
                eye=dict(x=1.5, y=1.5, z=1.5)
            )
        ),
        showlegend=True,
    )
    
    fig.show()

# Use the function
plot_pca_from_df(df_aligned)


In [6]:
def plot_pca_from_df(df):
    num_layers = df['Layer'].nunique()
    entries = df['Entry'].unique()
    
    # Generate colors dynamically based on number of entries
    n_colors = len(entries)
    colors = [f'hsl({h},70%,50%)' for h in np.linspace(0, 360, n_colors, endpoint=False)]
    
    # Create figure with slider
    fig = go.Figure()
    
    # Calculate axis ranges across all layers
    x_min, x_max = df['PCA1'].min(), df['PCA1'].max()
    y_min, y_max = df['PCA2'].min(), df['PCA2'].max() 
    z_min, z_max = df['PCA3'].min(), df['PCA3'].max()
    
    # Create frames for each layer
    frames = []
    for layer_num in range(num_layers):
        frame_traces = []
        layer_data = df[df['Layer'] == layer_num]
        
        for entry, color in zip(entries, colors):
            entry_data = layer_data[layer_data['Entry'] == entry]
            frame_traces.append(
                go.Scatter3d(
                    x=[entry_data['PCA1'].iloc[0]],
                    y=[entry_data['PCA2'].iloc[0]],
                    z=[entry_data['PCA3'].iloc[0]],
                    mode='markers',
                    marker=dict(color=color, size=5, opacity=0.7),
                    name=f"{entry} - Layer {layer_num}",
                    showlegend=True
                )
            )
        frames.append(go.Frame(data=frame_traces, name=str(layer_num)))
    
    # Add frames to figure
    fig.frames = frames
    
    # Add first frame's traces to the figure
    for trace in frames[0].data:
        fig.add_trace(trace)

    # Update layout
    fig.update_layout(
        height=800,
        width=800,
        title_text="3D PCA Projection of Day Representations Across All Layers",
        scene=dict(
            xaxis=dict(range=[x_min, x_max], title="PCA1"),
            yaxis=dict(range=[y_min, y_max], title="PCA2"),
            zaxis=dict(range=[z_min, z_max], title="PCA3"),
            aspectmode='cube',
            camera=dict(
                up=dict(x=0, y=0, z=1),
                center=dict(x=0, y=0, z=0),
                eye=dict(x=1.5, y=1.5, z=1.5)
            )
        ),
        showlegend=True,
        updatemenus=[{
            'type': 'buttons',
            'showactive': False,
            'buttons': [{
                'label': 'Play',
                'method': 'animate',
                'args': [None, {'frame': {'duration': 500, 'redraw': True}, 'fromcurrent': True}]
            }]
        }],
        sliders=[{
            'currentvalue': {'prefix': 'Layer: '},
            'steps': [
                {
                    'method': 'animate',
                    'label': str(k),
                    'args': [[str(k)], {
                        'frame': {'duration': 0, 'redraw': True},
                        'mode': 'immediate',
                        'transition': {'duration': 0}
                    }]
                }
                for k in range(num_layers)
            ]
        }]
    )
    
    fig.show()

# Use the function
plot_pca_from_df(df_aligned)
