In [1]:
%reload_ext autoreload
%autoreload 2

In [None]:
import json
import os

from shrugger.src.model import load_model

# Use absolute path to config.json in the project root
config_path = "../../config.json"
with open(config_path) as f:
    cfg = json.load(f)
model_id = cfg["model_id"]
dtype = cfg["dtype"]

tokenizer, model = load_model(model_id, dtype=dtype)


In [None]:
from shrugger.src.eval_utils import run_likert_probe

run_likert_probe(tokenizer, model, "Is parity conserved in weak nuclear interactions?")

In [None]:
from shrugger.src.experiment_utils import *

get_questions_by_filter(subject="Physics")


In [None]:
from shrugger.src.experiment import load_experiment_results, run_likert_experiment
from shrugger.src.model import load_model

# Load model
tokenizer, model = load_model("google/gemma-3-4b-it")

# Run experiment on Biology questions
stats = run_likert_experiment(
    model=model,
    tokenizer=tokenizer,
    output_file="results/test_nums.json",
    subject="Physics",
    difficulty=1,
    form="V0_numbers"
)

if stats['completed_questions'] > 0:
    print(f"Completed {stats['completed_questions']} questions")
    print(f"Average score: {stats['average_score']:.2f}")
    print(f"Valid predictions: {stats['valid_predictions']}")
else:
    print("No new questions processed. All questions were already in the results file.")

In [None]:

# Later, load and analyze results
results = load_experiment_results("../../results/test_likert.json")
print(f"Score distribution: {results['score_distribution']}")

In [None]:
# Example with custom labels
custom_labels = ["X", "Y", "Z", "W", "V"]
results = run_likert_probe(tokenizer, model, "Is the Earth round?", labels=custom_labels)

# The prompt will show:
# Question: Is the Earth round?
# X) Definitely yes
# Y) Probably yes
# Z) Not sure
# W) Probably no
# V) Definitely no
# Respond with exactly one of: X, Y, Z, W, V.
# Answer:

# If the model outputs "X", it will be converted to "YY" in the results
print(results['pred_label'])  # Will be "YY" (universal format)
print(results['canonical_label'])  # Will be "X" (original model output)

In [None]:
from shrugger.src.probes import analyze_residual_stream, get_all_residual_streams

# Get all residual streams
all_streams = get_all_residual_streams(model, tokenizer, "Is the meaning of life the answer to the ultimate question of life, the universe, and everything?")
print(f"Shape: {all_streams.shape}")  # [num_layers, seq_len, hidden_dim]

# Extract the last token's representation across all layers
last_token = all_streams[:, -1, :]

# Analyze using PCA for visualization
reduced = analyze_residual_stream(all_streams, reduction="pca")

In [None]:
from shrugger.src.residual_viz import *

# 1) Norm heatmap
norm_LT = analyze_residual_stream(all_streams, reduction="norm")
plot_heatmap_LT(norm_LT, title="Residual L2 norm per (layer, token)", transform="log")

# 2) Token path in PCA plane
pca_L_T_3 = analyze_residual_stream(all_streams, reduction="pca")
plot_token_path_from_pca(pca_L_T_3, token_indices=[-1, -2])

# 3) Layer scatter in PCA plane
plot_layer_scatter_from_pca(pca_L_T_3, layer_index=10, token_indices=range(0, pca_L_T_3.shape[1], 2))

# 4) Per-token series across layers
plot_token_series_from_LT(norm_LT, token_indices=[-1, -2], ylabel="L2 norm")

# 5) Per-layer series across tokens
plot_layer_series_from_LT(norm_LT, layer_indices=[10, 20, 30], ylabel="L2 norm")

# 6) Per-layer contributions for a token (bar chart)
plot_token_contributions_from_LT(norm_LT, token_indices=[-1, -2], ylabel="L2 norm")

# # 7) Δ heatmap (compare two prompts)
# norm_A = analyze_residual_stream(residual_stream_A, reduction="norm")
# norm_B = analyze_residual_stream(residual_stream_B, reduction="norm")
# plot_delta_heatmap_LT(norm_A, norm_B, title="Δ L2 norm (B − A)")

# 8) If you started with reduction="none"
# rs_none = analyze_residual_stream(all_streams, reduction="none")
# norm_from_none = summarize_to_norm(rs_none)
# plot_heatmap_LT(norm_from_none, title="Norm (from raw)")


In [None]:
from shrugger.src.residual_viz import plot_token_path_from_pca

# Use the same residual streams from previous example
# Analyze using PCA for visualization (3 components)
pca_L_T_3 = analyze_residual_stream(all_streams, reduction="pca")

# Plot the path of the last token through layers in PCA space
fig = plot_token_path_from_pca(
    pca_L_T_3,
    token_indices=[-1, -2],
    annotate_layers=True,
    title="Token Path Through Layers in PCA Space",
    show_arrows=True,
    arrow_spacing=3
)


In [None]:
from shrugger.src.residual_viz import plot_layer_scatter_from_pca

# Use the same PCA-reduced data from previous example
# Plot token representations at a specific layer
fig = plot_layer_scatter_from_pca(
    pca_L_T_3,
    layer_index=33,  # Middle layer
    token_indices=slice(None),  # All tokens
    title="Token Representations at Layer 10",
    color_by="index",  # Color by token index
    label_tokens=True,
    add_convex_hull=True,
    add_centroid=True
)


In [None]:
from shrugger.src.residual_viz import plot_token_series_from_LT

# Use the norm_LT data from the first example
# Plot how a token's norm changes across layers
fig = plot_token_series_from_LT(
    norm_LT,
    token_indices=[-1, -2],  # Last two tokens
    title="Token Norm Across Layers",
    ylabel="L2 Norm",
    show_legend=True,
    show_stats=True,
    show_min_max=True,
    show_average=True,
    show_trend=True,
    log_scale=True
)


In [None]:
from shrugger.src.residual_viz import plot_layer_series_from_LT

# Use the norm_LT data from the first example
# Plot how a layer's norm changes across tokens
fig = plot_layer_series_from_LT(
    norm_LT,
    layer_indices=[5, 15, 25],  # Early, middle, and late layers
    title="Layer Norm Across Tokens",
    ylabel="L2 Norm",
    show_legend=True,
    show_stats=True,
    log_scale=True,
    highlight_tokens=[-1]  # Highlight the last token
)


In [None]:
from shrugger.src.residual_viz import plot_token_contributions_from_LT

# Use the norm_LT data from the first example
# Plot layer-wise contributions for a specific token
fig = plot_token_contributions_from_LT(
    norm_LT,
    token_indices=-1,  # Last token
    title="Layer Contributions for Last Token",
    ylabel="L2 Norm",
    show_values=True,
    sort_by_value=True,  # Sort layers by contribution
    normalize=True,  # Show as percentage of total
    show_stats=True,
    highlight_layers=[10, 20, 30]  # Highlight specific layers
)


In [None]:
from shrugger.src.probes import analyze_residual_stream, get_all_residual_streams
from shrugger.src.residual_viz import plot_delta_heatmap_LT

# Get residual streams for two different prompts
prompt_a = "What is machine learning?"
prompt_b = "What are machine learning?"

# Get residual streams for both prompts
streams_a = get_all_residual_streams(model, tokenizer, prompt_a)
streams_b = get_all_residual_streams(model, tokenizer, prompt_b)

# Create layer-token summaries using L2 norm
norm_a = analyze_residual_stream(streams_a, reduction="norm")
norm_b = analyze_residual_stream(streams_b, reduction="norm")

# Plot the difference between the two prompts
fig = plot_delta_heatmap_LT(
    norm_a,
    norm_b,
    title="Δ L2 Norm (B − A)",
    condition_a_name="ML Question",
    condition_b_name="DL Question",
    show_colorbar=True,
    center_colormap=True,
    show_side_plots=True,
    show_stats=True
)


In [None]:
from shrugger.src.experiment import run_combined_experiment
from shrugger.src.model import load_model

# Load your model
tokenizer, model = load_model("google/gemma-3-4b-it")

# Run the combined experiment
results = run_combined_experiment(
    model=model,
    tokenizer=tokenizer,
    output_dir="./results/my_experiment",
    subject="Physics",  # Optional filter
    difficulty=5,
    split="dev",
    form="V2_letters",  # Use one of the Likert description styles
    labels=["A", "B", "C", "D", "E"],  # Optional custom labels
    verbose=True
)

# The results include statistics and file paths
print(f"Processed {results['likert_stats']['completed_questions']} questions")
print(f"Hidden state shape: {results['hidden_stats']['hidden_state_shape']}")