In [1]:
import torch
from transformer_lens import HookedTransformer
import warnings

import os
from tqdm.auto import tqdm
from collections import Counter

from sklearn.model_selection import train_test_split
import numpy as np

import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

import matplotlib.pyplot as plt


  from .autonotebook import tqdm as notebook_tqdm


### Loading the model

In [2]:
# Suppress a specific UserWarning from SentencePiece
warnings.filterwarnings("ignore", message=r".*Ignoring tokenizer_config\.json since it is not set\. It is likely that you are loading a tokenizer from a previous version of the library which does not contain this file\..*")

print("--- Environment Setup ---")

--- Environment Setup ---


In [3]:
# Check if a GPU is available and set the device
if torch.cuda.is_available():
    device = "cuda"
    print(f"GPU detected: {torch.cuda.get_device_name(0)}")
    # Clear cache to free up memory on the GPU
    torch.cuda.empty_cache()
else:
    device = "cpu"
    print("No GPU detected. Using CPU. This will be very slow.")


GPU detected: NVIDIA GeForce RTX 3090


In [5]:

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
print(f"\n--- Loading Model: {model_name} ---")
print("This will download and load ~16 GB of model weights. This may take several minutes.")

# Load the model directly using HookedTransformer.
# We do NOT use quantization (`load_in_4bit`).
# `torch_dtype=torch.bfloat16` is recommended for performance and is supported by the 3090.
model = HookedTransformer.from_pretrained(
    model_name,
    device=device,
    torch_dtype=torch.bfloat16,
    # No quantization arguments needed!
)



--- Loading Model: meta-llama/Meta-Llama-3-8B-Instruct ---
This will download and load ~16 GB of model weights. This may take several minutes.


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 38.59it/s]


Loaded pretrained model meta-llama/Meta-Llama-3-8B-Instruct into HookedTransformer


In [6]:
tokenizer = model.tokenizer

In [7]:
print("\n--- Verifying Model and Tokenizer ---")
test_prompt = "Hello, world! This is a test."
test_tokens = model.to_tokens(test_prompt)
test_str_tokens = model.to_str_tokens(test_prompt)

print("Test prompt:", test_prompt)
print("Tokenized input shape:", test_tokens.shape)
print("String tokens:", test_str_tokens)

try:
    with torch.no_grad():
        logits = model(test_tokens)
    print("Forward pass successful!")
    print("Logits shape:", logits.shape)

    generated_token_id = logits[0, -1].argmax().item()
    generated_token_str = tokenizer.decode(generated_token_id)
    print(f"Model's next token prediction for the prompt: '{generated_token_str}'")

except Exception as e:
    print("\nAn error occurred during the verification forward pass:")
    print(e)
    print("If this is a CUDA out of memory error, your GPU does not have enough VRAM for the non-quantized model.")



--- Verifying Model and Tokenizer ---
Test prompt: Hello, world! This is a test.
Tokenized input shape: torch.Size([1, 10])
String tokens: ['<|begin_of_text|>', 'Hello', ',', ' world', '!', ' This', ' is', ' a', ' test', '.']


Forward pass successful!
Logits shape: torch.Size([1, 10, 128256])
Model's next token prediction for the prompt: ' I'


### Extracting  Layer Activations

In [8]:

def split_conversation(text: str, user_identifier="HUMAN:", ai_identifier="ASSISTANT:") -> tuple[list[str], list[str]]:
    """
    Splits a raw text conversation into a list of user messages and a list of assistant messages.
    This logic is adapted from the paper's repository.
    """
    user_messages, assistant_messages = [], []
    lines = text.split("\n")
    current_user_message, current_assistant_message = "", ""

    for line in lines:
        line = line.lstrip(" ")
        if line.startswith(user_identifier):
            if current_assistant_message:
                assistant_messages.append(current_assistant_message.strip())
            current_assistant_message = ""
            current_user_message += line.replace(user_identifier, "").strip() + " "
        elif line.startswith(ai_identifier):
            if current_user_message:
                user_messages.append(current_user_message.strip())
            current_user_message = ""
            current_assistant_message += line.replace(ai_identifier, "").strip() + " "

    if current_user_message: user_messages.append(current_user_message.strip())
    if current_assistant_message: assistant_messages.append(current_assistant_message.strip())
        
    return user_messages, assistant_messages

def llama_v2_prompt(messages: list[dict]) -> str:
    """
    Formats a conversation into the LLaMA v2 prompt format.
    This logic is adapted from the paper's repository.
    """
    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
    BOS, EOS = "<s>", "</s>"
    DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."

    if not messages or messages[0]["role"] != "system":
        messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}] + messages
        
    messages = [
        {
            "role": messages[1]["role"],
            "content": B_SYS + messages[0]["content"] + E_SYS + messages[1]["content"],
        }
    ] + messages[2:]

    messages_list = [
        f"{BOS}{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {EOS}"
        for prompt, answer in zip(messages[::2], messages[1::2])
    ]
    
    if messages and messages[-1]["role"] == "user":
        messages_list.append(f"{BOS}{B_INST} {(messages[-1]['content']).strip()} {E_INST}")

    return "".join(messages_list)


In [9]:

# The paper's special prompt for coaxing out the gender representation
prompt_translator = {"_gender_": "gender"}
attribute = "_gender_"
probe_prompt = f" I think the {prompt_translator[attribute]} of this user is"


In [10]:
# --- Step 3.2: Configuration and Test Run ---

# ====================================================================
# Configuration
# ====================================================================
# Specify which layers you want to extract activations from.
# Llama-3-8B has 32 layers (0-31).
LAYERS_TO_EXTRACT = [30] # Example: 2nd to last layer

# Define the path to your dataset folder
dataset_path = "/workspace/MATS-research/data/chen_llama_gender"

# Number of files to run for this test.
NUM_TEST_FILES = 10 
# ====================================================================

print(f"Loading data from: {dataset_path}")
print(f"Extracting activations from layers: {LAYERS_TO_EXTRACT}\n")

conversation_files = [f for f in os.listdir(dataset_path) if f.endswith('.txt')]
print(f"Found {len(conversation_files)} total conversation files.")
print(f"Running a test on the first {NUM_TEST_FILES} files...")

# Use a slice of the files for the test run
# test_files_subset = conversation_files[:NUM_TEST_FILES]

# Lists to store our test data
test_activations = []
test_labels = []

# This filter function tells transformer_lens to only cache the activations we need
def names_filter(name: str):
    is_resid_post = name.endswith("resid_post")
    if not is_resid_post: return False
    
    layer_index = int(name.split('.')[1])
    return layer_index in LAYERS_TO_EXTRACT


# --- Step 3.3: Full Data Extraction ---

# Lists to store our final data
all_activations = []
all_labels = []

# This loop uses the full 'conversation_files' list
for file_name in tqdm(conversation_files, desc="Full Run: Extracting Activations"):
    file_path = os.path.join(dataset_path, file_name)
    
    if "_gender_female" in file_name:
        label = "female"
    elif "_gender_male" in file_name:
        label = "male"
    else:
        continue

    with open(file_path, 'r', encoding='utf-8') as f:
        raw_text = f.read()

    # Re-using the same helper functions as before
    user_msgs, ai_msgs = split_conversation(raw_text)
    messages_dict = []
    for user_msg, ai_msg in zip(user_msgs, ai_msgs):
        messages_dict.append({'role': 'user', 'content': user_msg})
        messages_dict.append({'role': 'assistant', 'content': ai_msg})
        
    if not messages_dict:
        continue

    full_prompt = llama_v2_prompt(messages_dict) + probe_prompt

    with torch.no_grad():
        _, cache = model.run_with_cache(full_prompt, names_filter=names_filter)
        
        activations_for_prompt = torch.stack(
            [cache[f"blocks.{layer}.hook_resid_post"][0, -1, :] for layer in LAYERS_TO_EXTRACT],
            dim=0
        )
        if len(LAYERS_TO_EXTRACT) == 1:
            activations_for_prompt = activations_for_prompt.squeeze(0)

        all_activations.append(activations_for_prompt.cpu())
        all_labels.append(label)

print(f"\nSuccessfully processed {len(all_activations)} conversations.")

# --- Sanity Checks for the Full Run ---
print("\n--- Verifying Extracted Data ---")
if all_activations:
    print(f"Shape of a single activation tensor: {all_activations[0].shape}")
    label_counts = Counter(all_labels)
    print(f"Final label distribution: {label_counts}")
    assert len(all_activations) == len(all_labels)
    print("Number of activations matches number of labels.")
else:
    print("No data was processed in the full run.")
    

Loading data from: /workspace/MATS-research/data/chen_llama_gender
Extracting activations from layers: [30]

Found 1000 total conversation files.
Running a test on the first 10 files...


Full Run: Extracting Activations: 100%|██████████| 1000/1000 [01:35<00:00, 10.43it/s]


Successfully processed 500 conversations.

--- Verifying Extracted Data ---
Shape of a single activation tensor: torch.Size([4096])
Final label distribution: Counter({'female': 250, 'male': 250})
Number of activations matches number of labels.





In [11]:
dataset_path = "/workspace/MATS-research/data/chen_llama_gender"
conversation_files = [f for f in os.listdir(dataset_path) if f.endswith('.txt')]

skipped_files = []
for file_name in conversation_files:
    if "_gender_female" not in file_name and "_gender_male" not in file_name:
        skipped_files.append(file_name)

print(f"Found {len(skipped_files)} files that were skipped (as expected).")
print("\nHere are the first 10 skipped files:")
for file in skipped_files[:10]:
    print(file)

Found 500 files that were skipped (as expected).

Here are the first 10 skipped files:
conversation_250_age_female.txt
conversation_250_age_male.txt
conversation_251_age_female.txt
conversation_251_age_male.txt
conversation_252_age_female.txt
conversation_252_age_male.txt
conversation_253_age_female.txt
conversation_253_age_male.txt
conversation_254_age_female.txt
conversation_254_age_male.txt


### Preparing the Data for Probe Training

In [12]:

num_samples = len(all_activations)
print(f"Number of samples to process: {num_samples}")


Number of samples to process: 500


In [13]:

# 1. Stack the list of activation tensors into a single large tensor.
# Each tensor in the list is for one conversation. We stack them along a new 'batch' dimension.
# The original list contained tensors of shape [d_model], so the final shape will be [num_samples, d_model].
activations_tensor = torch.stack(all_activations)
print(f"\nStacked activations tensor shape: {activations_tensor.shape}")

# 2. Convert the string labels ('female', 'male') into numerical format (0, 1).
# This is required for training a classifier. We'll use a simple mapping.
label_map = {"female": 0, "male": 1}
labels_numerical = [label_map[label] for label in all_labels]
labels_tensor = torch.tensor(labels_numerical, dtype=torch.float32) # Use float32 for BCEWithLogitsLoss later
print(f"Labels tensor shape: {labels_tensor.shape}")
print(f"First 10 numerical labels: {labels_tensor[:10].int().tolist()}")

# 3. Split the data into training and testing sets.
# An 80/20 split is a standard choice.
# - 'random_state=42' ensures that the split is the same every time we run the code, which is crucial for reproducibility.
# - 'stratify=labels_tensor' ensures that the proportion of male/female labels is the same in both the train and test sets.
X_train, X_test, y_train, y_test = train_test_split(
    activations_tensor, 
    labels_tensor, 
    test_size=0.2, 
    random_state=42, 
    stratify=labels_tensor
)

print("\n--- Data Splitting Complete ---")
print(f"Training data shape (X_train): {X_train.shape}")
print(f"Training labels shape (y_train): {y_train.shape}")
print(f"Testing data shape (X_test):  {X_test.shape}")
print(f"Testing labels shape (y_test):  {y_test.shape}")

# Verify stratification by checking the balance of labels in each set
print(f"\nTraining set label balance: {y_train.sum()/len(y_train):.2f} (proportion of 'male')")
print(f"Test set label balance:     {y_test.sum()/len(y_test):.2f} (proportion of 'male')")

print("\n--- Step 4 Complete ---")
print("You now have your data prepared in X_train, X_test, y_train, and y_test.")
print("You are ready to proceed to Step 5: Defining and Training the Linear Probes.")


Stacked activations tensor shape: torch.Size([500, 4096])
Labels tensor shape: torch.Size([500])
First 10 numerical labels: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]

--- Data Splitting Complete ---
Training data shape (X_train): torch.Size([400, 4096])
Training labels shape (y_train): torch.Size([400])
Testing data shape (X_test):  torch.Size([100, 4096])
Testing labels shape (y_test):  torch.Size([100])

Training set label balance: 0.50 (proportion of 'male')
Test set label balance:     0.50 (proportion of 'male')

--- Step 4 Complete ---
You now have your data prepared in X_train, X_test, y_train, and y_test.
You are ready to proceed to Step 5: Defining and Training the Linear Probes.


### Defining and Training the Linear Probes

In [14]:

# Define the Linear Probe model as a PyTorch Module
class LinearProbe(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        # A single linear layer that maps from the activation dimension to a single logit
        self.probe = nn.Linear(input_dim, 1)

    def forward(self, x):
        # We squeeze the output to remove the last dimension, so it has shape [batch_size]
        # which is what our loss function expects.
        return self.probe(x).squeeze(-1)

print("LinearProbe class defined.")

LinearProbe class defined.


In [15]:
# --- Step 5.2: Define the Training Function ---

def train_probe(probe, X_train, y_train, epochs=40, lr=1e-3, batch_size=32):
    """
    Trains a single linear probe.
    
    Args:
        probe (LinearProbe): The probe to train.
        X_train (Tensor): The training activations.
        y_train (Tensor): The training labels.
        epochs (int): Number of training epochs.
        lr (float): Learning rate.
        batch_size (int): Batch size for training.
        
    Returns:
        LinearProbe: The trained probe.
    """
    # Move probe to the correct device
    probe.to(device)
    
    # Define the loss function and optimizer
    # BCEWithLogitsLoss is perfect for binary classification from a single logit.
    # It's numerically stable and combines a Sigmoid layer and Binary Cross Entropy loss.
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(probe.parameters(), lr=lr)
    
    # Create a DataLoader for batching
    train_dataset = TensorDataset(X_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    for epoch in range(epochs):
        total_loss = 0
        for X_batch, y_batch in train_loader:
            # Move data to the GPU
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            X_batch = X_batch.to(torch.float32) 
            
            # Standard PyTorch training loop
            optimizer.zero_grad()
            logits = probe(X_batch)
            loss = loss_fn(logits, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_loader):.4f}")
            
    return probe

print("train_probe function defined.")

train_probe function defined.


In [16]:
# --- Step 5.3: Main Training Loop ---

# This dictionary will store our trained probes, mapping the layer index to the probe model
trained_probes = {}
input_dim = model.cfg.d_model # Should be 4096 for Llama-3-8B

# This code handles both cases: a single layer's activations or multiple layers'
if X_train.ndim == 2:
    # Case 1: You extracted from a single layer. X_train has shape [num_samples, d_model].
    layer_index = LAYERS_TO_EXTRACT[0]
    print(f"--- Training a single probe for Layer {layer_index} ---")
    
    probe = LinearProbe(input_dim)
    trained_probe = train_probe(probe, X_train, y_train, epochs=50)
    trained_probes[layer_index] = trained_probe
    
elif X_train.ndim == 3:
    # Case 2: You extracted from multiple layers. X_train has shape [num_samples, num_layers, d_model].
    num_layers_extracted = X_train.shape[1]
    print(f"--- Training one probe for each of the {num_layers_extracted} extracted layers ---")
    
    for i, layer_index in enumerate(LAYERS_TO_EXTRACT):
        print(f"\n--- Training Probe for Layer {layer_index} ---")
        
        # Get the activations for this specific layer
        layer_activations = X_train[:, i, :]
        
        probe = LinearProbe(input_dim)
        trained_probe = train_probe(probe, layer_activations, y_train)
        trained_probes[layer_index] = trained_probe

print("\n--- Training Complete ---")
print(f"Trained {len(trained_probes)} probes for layers: {list(trained_probes.keys())}")

print("\n--- Step 5 Complete ---")
print("You are now ready to proceed to Step 6: Evaluating the Probes and Visualizing the Results.")

--- Training a single probe for Layer 30 ---


Epoch 10/50, Loss: 0.4146
Epoch 20/50, Loss: 0.3878
Epoch 30/50, Loss: 0.3444
Epoch 40/50, Loss: 0.3436
Epoch 50/50, Loss: 0.3373

--- Training Complete ---
Trained 1 probes for layers: [30]

--- Step 5 Complete ---
You are now ready to proceed to Step 6: Evaluating the Probes and Visualizing the Results.


### Evaluating the Probes

In [17]:

def evaluate_probe(probe, X_test, y_test):
    """
    Evaluates a single trained linear probe on the test set.
    
    Args:
        probe (LinearProbe): The trained probe to evaluate.
        X_test (Tensor): The testing activations.
        y_test (Tensor): The testing labels.
        
    Returns:
        float: The accuracy of the probe on the test set.
    """
    # Set the probe to evaluation mode
    probe.eval()
    probe.to(device)
    
    with torch.no_grad():
        # Move test data to the GPU and ensure correct dtype
        X_test_gpu = X_test.to(device).to(torch.float32)
        y_test_gpu = y_test.to(device)
        
        # Get logits from the probe
        logits = probe(X_test_gpu)
        
        # Convert logits to predictions (0 or 1)
        # A positive logit corresponds to a prediction of 1 ('male')
        predictions = (logits > 0).int()
        
        # Calculate accuracy by comparing predictions to true labels
        accuracy = (predictions == y_test_gpu.int()).float().mean().item()
        
    return accuracy


In [19]:

# Check that we have exactly one probe trained, as expected
assert len(trained_probes) == 1, f"Expected 1 trained probe, but found {len(trained_probes)}. Please re-run from Step 4."

# Get the layer index and the trained probe model from our dictionary
layer_index = list(trained_probes.keys())[0]
probe_to_evaluate = trained_probes[layer_index]

print(f"--- Evaluating Probe for Layer {layer_index} ---")

# Calculate the test accuracy
test_accuracy = evaluate_probe(probe_to_evaluate, X_test, y_test)

print(f"\nProbe for Layer {layer_index} - Test Accuracy: {test_accuracy*100:.2f}%")


--- Evaluating Probe for Layer 30 ---

Probe for Layer 30 - Test Accuracy: 77.00%
