In [None]:
import os
import torch

In [None]:
# set cwd to gemma_scope_math
os.chdir('D:/Master\'s/gemma_scope_math')
print(os.getcwd())

In [None]:
# import functions and code from my modules and test them. 
from src.feed_forward.feedforward_module import FeedForwardModule

In [None]:
dataset = "random_addition"

# Create an instance of FeedForwardModule
module = FeedForwardModule(batch_size=500, output_dir=f"./activations/{dataset}")



In [None]:
module.model

In [None]:
example_prompt = "What is 2 + 2?"
tokens = module.tokenizer(example_prompt, return_tensors="pt", padding=True,return_attention_mask=True)
input_ids = tokens.input_ids[0]

for token_id in input_ids:
    token_str = module.tokenizer.decode([token_id])
    print(f"{token_id.item():<5} -> '{token_str}'")

In [None]:
tokens = {k: v.to("cuda") for k, v in tokens.items()}

In [None]:
tokens

In [None]:
with torch.no_grad():
    outputs = module.model(**tokens, output_hidden_states=True)

In [None]:
type(outputs.hidden_states)

In [None]:
len(outputs.hidden_states)

In [None]:
with torch.no_grad():
    new_outputs = module.model(**tokens, output_hidden_states=True)

In [None]:
(new_outputs.hidden_states[0] == outputs.hidden_states[0]).all().item()

In [None]:
(new_outputs.hidden_states[15] == outputs.hidden_states[15]).all().item()

In [None]:
# Let's examine each hidden state to understand the layer correspondence
print("Hidden states analysis:")
print(f"Total number of hidden states: {len(outputs.hidden_states)}")
print("\nShape of each hidden state:")
for i, hidden_state in enumerate(outputs.hidden_states):
    print(f"Layer {i}: {hidden_state.shape}")

In [None]:
# Let's clarify what each layer represents
print("\nLayer correspondence:")
print("Layer 0: Embedding layer output (after token embeddings)")
for i in range(1, 27):
    print(f"Layer {i}: Transformer block {i-1} output (blocks are 0-indexed)")

print(f"\nSo you have:")
print(f"- Layer 0: Embeddings")
print(f"- Layers 1-26: Output from transformer blocks 0-25")
print(f"- Total transformer blocks: 26 (as you mentioned)")
print(f"- The final norm and LM head are applied after these hidden states")

In [None]:
# Let's verify by checking the model config
print("Model configuration:")
print(f"Number of hidden layers: {module.model.config.num_hidden_layers}")
print(f"Hidden size: {module.model.config.hidden_size}")
print(f"Vocab size: {module.model.config.vocab_size}")

# Also check the actual model structure
print(f"\nModel architecture verification:")
print(f"Embedding layer: {type(module.model.model.embed_tokens).__name__}")
print(f"Number of transformer layers: {len(module.model.model.layers)}")
print(f"Final norm layer: {type(module.model.model.norm).__name__}")
print(f"LM head: {type(module.model.lm_head).__name__}")

In [None]:
module.model.config

In [None]:
from src.feed_forward.sae_module import SAEModule

In [None]:
# Check available Gemma Scope SAEs
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

# Get the full directory of pretrained SAEs
sae_directory = get_pretrained_saes_directory()

# Filter for gemma-scope models
gemma_scope_saes = {k: v for k, v in sae_directory.items() if 'gemma-scope' in k.lower()}

print("Available Gemma Scope SAE releases:")
print("=" * 50)

for release_name, sae_info in gemma_scope_saes.items():
    print(f"\nRelease: {release_name}")
    print("-" * 30)
    
    # Check if saes_map exists and print layer information
    if hasattr(sae_info, 'saes_map') and sae_info.saes_map:
        print(f"Available layers/SAEs: {len(sae_info.saes_map)}")
        
        # Group by layer to see pattern
        layers = set()
        for sae_id in sae_info.saes_map.keys():
            if 'layer_' in sae_id:
                layer_num = sae_id.split('layer_')[1].split('/')[0]
                try:
                    layers.add(int(layer_num))
                except:
                    layers.add(layer_num)
        
        if layers:
            if all(isinstance(x, int) for x in layers):
                sorted_layers = sorted(layers)
                print(f"Layers: {min(sorted_layers)}-{max(sorted_layers)} (total: {len(sorted_layers)})")
            else:
                print(f"Layers: {sorted(layers)}")
        
        # Show first few SAE IDs as examples
        sample_ids = list(sae_info.saes_map.keys())[:5]
        print(f"Sample SAE IDs:")
        for sae_id in sample_ids:
            print(f"  - {sae_id}")
        if len(sae_info.saes_map) > 5:
            print(f"  ... and {len(sae_info.saes_map) - 5} more")
    else:
        print("No SAE map information available")

In [None]:
# More focused check for specific Gemma Scope releases
releases_to_check = [
    "gemma-scope-2b-pt-res-canonical",
    "gemma-scope-2b-pt-mlp-canonical", 
    "gemma-scope-2b-pt-att-canonical"
]

for release in releases_to_check:
    if release in sae_directory:
        print(f"\n{'='*60}")
        print(f"Release: {release}")
        print(f"{'='*60}")
        
        sae_info = sae_directory[release]
        
        if hasattr(sae_info, 'saes_map') and sae_info.saes_map:
            # Extract layer numbers
            layers = set()
            widths = set()
            
            for sae_id in sae_info.saes_map.keys():
                parts = sae_id.split('/')
                for part in parts:
                    if part.startswith('layer_'):
                        try:
                            layer_num = int(part.split('layer_')[1])
                            layers.add(layer_num)
                        except:
                            pass
                    elif part.startswith('width_'):
                        try:
                            width = part.split('width_')[1]
                            widths.add(width)
                        except:
                            pass
            
            if layers:
                sorted_layers = sorted(layers)
                print(f"Available layers: {sorted_layers}")
                print(f"Layer range: {min(sorted_layers)} to {max(sorted_layers)}")
                print(f"Total layers: {len(sorted_layers)}")
            
            if widths:
                print(f"Available widths: {sorted(widths)}")
                
            print(f"\nTotal SAEs in this release: {len(sae_info.saes_map)}")
            
            # Show structure of first few SAE IDs
            sample_ids = sorted(list(sae_info.saes_map.keys()))[:10]
            print(f"\nFirst 10 SAE IDs:")
            for sae_id in sample_ids:
                print(f"  {sae_id}")
                
        else:
            print("No SAE information available")
    else:
        print(f"\nRelease '{release}' not found in directory")