## Initial Probe Exploration

### Goals:
- Train a simple linear logistic regression probe on Llama-3-7b
- Understand GPU capacity - can we do inference with 70B?
- Look at the probe activations / test set classifications

### Timeline:
- 18/02/25
- 19/02/25

In [12]:
# Imports
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.linear_model import LogisticRegression
from jaxtyping import Float
from joblib import Parallel, delayed
import os


os.environ['TOKENIZERS_PARALLELISM'] = 'false'
model_name = 'meta-llama/Llama-3.3-70B-Instruct'

## Dataset

The dataset was generated using GPT-4o. It consists of 20 examples with red things and 20 examples with green things. We hope to learn a classifier / probe for green or red objects.

`Command: Generate 20 sentences about red things. Generate 20 sentences about green things. Put them in a JSON array of strings.`

In [2]:
# Dataset loading:
text = [
  "The bright red apple hung low on the tree, ready to be picked.",
  "A red sports car sped past, leaving a trail of dust behind.",
  "The firefighter's uniform had reflective red stripes for visibility.",
  "She wore a deep red dress that caught everyone's attention.",
  "The red rose symbolized love and passion.",
  "Tomatoes ripened under the sun, turning a rich shade of red.",
  "The cardinal perched on the fence, its red feathers vibrant against the snow.",
  "His face turned red with embarrassment after tripping on stage.",
  "The sunset painted the sky in hues of red and orange.",
  "Red chili peppers added a spicy kick to the dish.",
  "The warning sign was painted bright red for safety reasons.",
  "Her lipstick was a bold shade of red.",
  "The red balloon floated away into the sky.",
  "Blood is naturally red due to the presence of iron in hemoglobin.",
  "Strawberries are at their sweetest when they turn fully red.",
  "The red fire hydrant stood at the corner of the street.",
  "Maple leaves turn a brilliant red in the autumn.",
  "A red velvet cake is a delicious dessert with a hint of cocoa.",
  "The ladybug crawled across the leaf, its red shell dotted with black spots.",
  "Santa Claus is always dressed in his iconic red suit.",
  "The fresh green grass covered the rolling hills.",
  "A green traffic light signaled the cars to move forward.",
  "Emeralds are precious gems with a deep green color.",
  "The frog leaped into the pond, blending in with the green lily pads.",
  "Spinach is a nutritious green vegetable rich in iron.",
  "The soccer field was painted bright green for the championship game.",
  "A bright green parrot perched on the branch, mimicking voices.",
  "The cucumber felt cool and crisp in her hands.",
  "The lush green rainforest was teeming with wildlife.",
  "She wore a green jade bracelet that shimmered under the light.",
  "Green tea is known for its numerous health benefits.",
  "The traffic sign was painted green to indicate an exit route.",
  "The chameleon changed its color to blend with the green leaves.",
  "The avocado's skin turned dark green when fully ripe.",
  "His green eyes sparkled in the sunlight.",
  "The turtle slowly crawled across the green moss-covered rock.",
  "The neon green sign stood out in the dimly lit alley.",
  "Green grapes are sweet and slightly tangy when ripe.",
  "The Christmas tree stood tall, covered in green pine needles.",
  "The four-leaf clover is a rare green plant that symbolizes luck."
]

test_text = [
  "The red kite soared high above the open field.",
  "A juicy red watermelon slice is perfect for a hot summer day.",
  "The brick house had a classic red chimney that stood out against the sky.",
  "The green lizard basked in the sun on a warm rock.",
  "She decorated her room with green fairy lights for a cozy atmosphere."
] 
test_labels = [1, 1, 1, 0, 0]

## Generate the Feature Inputs for the Probe

In [None]:
# Load the LLaMA-3-1B model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

@torch.no_grad()
def create_activations(text:list[str]) -> Float[torch.Tensor, "layers batch_size seq_len embed_dim"]:

    # Tokenize input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True,
                   padding=True, max_length=1028)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Dictionary to store residual activations
    activations = []

    # Hook function to capture residual activations before layernorm
    def hook_fn(module, input, output):
        activations.append(input[0].detach().cpu())  # Store the residual connection

    # Register hooks on each transformer block (LLaMA layers)
    hooks = []
    for i, layer in enumerate(model.model.layers):  # LLaMA uses model.model.layers
        hook = layer.input_layernorm.register_forward_hook(hook_fn)  # Pre-attention residual
        hooks.append(hook)

    # Forward pass
    with torch.no_grad():
        _ = model(**inputs)

    # Remove hooks after capturing activations
    for hook in hooks:
        hook.remove()

    # Print stored activations
    for i, act in enumerate(activations):
        print(f"Layer: {i}, Activation Shape: {act.shape}")

    all_acts = torch.stack(activations)
    print('All activations shape:', all_acts.shape)

    return all_acts.cpu()

train_acts = create_activations(text)
test_acts = create_activations(test_text)


Layer: 0, Activation Shape: torch.Size([40, 18, 2048])
Layer: 1, Activation Shape: torch.Size([40, 18, 2048])
Layer: 2, Activation Shape: torch.Size([40, 18, 2048])
Layer: 3, Activation Shape: torch.Size([40, 18, 2048])
Layer: 4, Activation Shape: torch.Size([40, 18, 2048])
Layer: 5, Activation Shape: torch.Size([40, 18, 2048])
Layer: 6, Activation Shape: torch.Size([40, 18, 2048])
Layer: 7, Activation Shape: torch.Size([40, 18, 2048])
Layer: 8, Activation Shape: torch.Size([40, 18, 2048])
Layer: 9, Activation Shape: torch.Size([40, 18, 2048])
Layer: 10, Activation Shape: torch.Size([40, 18, 2048])
Layer: 11, Activation Shape: torch.Size([40, 18, 2048])
Layer: 12, Activation Shape: torch.Size([40, 18, 2048])
Layer: 13, Activation Shape: torch.Size([40, 18, 2048])
Layer: 14, Activation Shape: torch.Size([40, 18, 2048])
Layer: 15, Activation Shape: torch.Size([40, 18, 2048])
All activations shape: torch.Size([16, 40, 18, 2048])
Layer: 0, Activation Shape: torch.Size([5, 16, 2048])
Layer:

## Training Code for the Probe

Use `sklearn` logistic regression classifier to learn a linear classifier on the activations from the model. We do the following:

1. Create the y labels (1 for red and 0 for green)
2. Restructure X to match sklearn (Batch_size, Embedd_dim) -> One per layer, final seq pos **TODO: Iterate in Future**  
3. Run Logistic Regression
4. Test on 5 test data points

In [10]:
labels = np.concatenate([np.ones(20), np.zeros(20)])
print(f'{labels.shape=}')

# Select the last sequence position activations:
selected_train_acts = train_acts[:, :, -1, :]

# Define parallel fit function:
def train_logistic_regression(activations: Float[np.ndarray, "batch_size embedd_dim"],
                            labels: Float[np.ndarray, "batch_size"], layer_num:int) -> tuple[LogisticRegression, float, np.ndarray]:
    """
    Train a logistic regression model on the residual activations of the LLaMA model. Class designed for parallel training with joblib
    """   
    
    assert activations.shape == (40, 2048), f'Activations shape is not correct dim: {activations.shape}'

    # Train a logistic regression model
    model = LogisticRegression(C = 1e-3, random_state=42, fit_intercept=False)
    model.fit(activations, labels)

    pred_labels = model.predict(test_acts[layer_num, :, -1, :])
    test_acc = (pred_labels == test_labels).mean()

    return model, test_acc, pred_labels

layer_models = Parallel(n_jobs=16)(delayed(train_logistic_regression)(act, labels, i) for i, act in enumerate(selected_train_acts))

labels.shape=(40,)


In [11]:
layer_models

[(LogisticRegression(C=0.001, fit_intercept=False, random_state=42),
  0.4,
  array([1., 1., 0., 1., 1.])),
 (LogisticRegression(C=0.001, fit_intercept=False, random_state=42),
  0.4,
  array([1., 1., 0., 1., 1.])),
 (LogisticRegression(C=0.001, fit_intercept=False, random_state=42),
  0.4,
  array([1., 1., 0., 1., 1.])),
 (LogisticRegression(C=0.001, fit_intercept=False, random_state=42),
  0.4,
  array([1., 1., 0., 1., 1.])),
 (LogisticRegression(C=0.001, fit_intercept=False, random_state=42),
  0.4,
  array([1., 1., 0., 1., 1.])),
 (LogisticRegression(C=0.001, fit_intercept=False, random_state=42),
  0.4,
  array([1., 1., 0., 1., 1.])),
 (LogisticRegression(C=0.001, fit_intercept=False, random_state=42),
  0.4,
  array([1., 1., 0., 1., 1.])),
 (LogisticRegression(C=0.001, fit_intercept=False, random_state=42),
  0.4,
  array([1., 1., 0., 1., 1.])),
 (LogisticRegression(C=0.001, fit_intercept=False, random_state=42),
  0.4,
  array([1., 1., 0., 1., 1.])),
 (LogisticRegression(C=0.001