In [1]:
import transformers
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
import transformers
from tqdm import tqdm
import pandas as pd
import yaml
import contextlib
import os

In [2]:
# Load GPT-2 model and tokenizer
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

# Set padding token to be the same as the EOS token
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print(f"Loaded {model_name} model and tokenizer")

Loaded gpt2 model and tokenizer


In [3]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [4]:
# Download and prepare a standard dataset
from datasets import load_dataset
from torch.utils.data import DataLoader

# Load a subset of the WikiText-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1000]")

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

# Create a DataLoader
batch_size = 16
dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=True)


In [5]:
#working on new copy here


#Jacobian of (Q_l_plus[subsetting_to_new_channel] . (layer) . Q_l( channel_input, with rest of layer as dummy))
#Why do jacobian?
#We care about the block-diagonal of the jacobian from rotated-basis to rotated-basis
#This all computes 1 of the num_channels diagonal blocks
#Need to loop over k

#What do we hope we see?
#We hope that we observe circuits localize inside channels
#Train this on IOI task data
#Does the IOI task hop inside a single channel?

#We probably have to fix the channel computaton!
from torch.func import functional_call, jacrev, vmap, jvp, vjp


def stochastic_jvp(current_rotated_layer, random_vector, channel_start, channel_end, ortho_l, ortho_l_plus):
    channel = current_rotated_layer[:, channel_start:channel_end]
    dummy_layer = current_rotated_layer.detach() 
    # Compute JVP
    #Must check shape!
    _, jvp_result = jvp(lambda x: push_forward_channel_with_dummies_outside(x, dummy_layer, channel_start, channel_end, ortho_l, ortho_l_plus), (channel,), (random_vector,))
    return jvp_result.unsqueeze(0)


def push_forward_channel_with_dummies_outside(channel, dummy_layer, channel_start, channel_end, ortho_l, ortho_l_plus):
        
        dummied_layer_with_channel_rotated = torch.cat([dummy_layer[:, :channel_start], channel, dummy_layer[:, channel_end:]], dim=-1)
    
          #result should be  batch, seq_len, d_channel

        dummied_layer_with_channel =  (ortho_l.transpose(-2, -1) @ dummied_layer_with_channel_rotated.unsqueeze(-1)).squeeze(-1)
        next_layer = model.transformer.h[layer](dummied_layer_with_channel)[0]
        # next_layer_rotated = (ortho_l_plus @ next_layer.unsqueeze(-1)).squeeze(-1)
        # return next_layer_rotated[:, :, channel_start:channel_end]
        next_layer_channel = (ortho_l_plus[:,:,channel_start:channel_end,:] @ next_layer.unsqueeze(-1)).squeeze(-1)
        return next_layer_channel #should have shape: batch, seq, channel_width


def compute_stochastic_channel_penalty(layer, k, channel_width, ortho_l, ortho_l_plus, activations_l, use_only_jacobian_to_measure_influence = True, num_samples=10):
    batch_size, seq_len, d_embed = activations_l.shape

    # Pre-compute random vectors
    random_vectors = torch.randint(0, 2, (batch_size, seq_len, channel_width), device=activations_l.device) * 2 - 1
    random_vectors = random_vectors.float()

    ortho_l = ortho_l.unsqueeze(0).unsqueeze(0) #dims: 1, 1, 768, 768
    ortho_l_plus = ortho_l_plus.unsqueeze(0).unsqueeze(0) #dims: 1, 1, 768, 768

    #rotate to the new basis
    rotated_l = (ortho_l @ activations_l.unsqueeze(-1)).squeeze(-1)
    channel_start = k * channel_width
    channel_end = (k+1) * channel_width

    # Compute stochastic JVP for multiple samples
 
    stochastic_jvps = vmap(stochastic_jvp, in_dims=(0, 0, None, None, None, None))(rotated_l, random_vectors, channel_start, channel_end, ortho_l, ortho_l_plus)
    
    # Compute the penalty using the stochastic JVPs
    penalty = torch.mean(torch.sum(stochastic_jvps ** 2, dim=(1, 2, 3, 4)))

    if use_only_jacobian_to_measure_influence:
        return penalty, stochastic_jvps

    if not use_only_jacobian_to_measure_influence:
        rotated_clean = (ortho_l.transpose(-2,-1) @ average_activations[layer].unsqueeze(-1)).squeeze(-1)
        delta_activation = rotated_l[:, :, channel_start:channel_end] - rotated_clean[:, :seq_len, channel_start:channel_end]
        # Note: We can't compute implied_self_channel_influence as before since we don't have the full Jacobian
        # Instead, we'll use the average of our stochastic estimates
        implied_self_channel_influence = torch.mean(torch.sum(stochastic_jvps * delta_activation.unsqueeze(0).unsqueeze(2), dim=(3, 4)), dim=0)
        return torch.sum(implied_self_channel_influence ** 2), stochastic_jvps, delta_activation


    #Penalty is the first item


In [6]:
#Testing for stochastic jvp
# Set up parameters
layer = 6
channel_width = 8
k = 2
hidden_size = 768  # GPT2 hidden size

# Create orthogonal matrices for layers 6 and 7
ortho_l = torch.nn.init.orthogonal_(torch.empty(hidden_size, hidden_size)).to(device)
ortho_l_plus = torch.nn.init.orthogonal_(torch.empty(hidden_size, hidden_size)).to(device)

# Generate random input for testing
batch_size = 2  # For simplicity, we'll use batch size of 1
seq_length = 10
input_ids = torch.randint(0, 50257, (batch_size, seq_length)).to(device)

# Get model activations for layer 6
with torch.no_grad():
    outputs = model(input_ids, output_hidden_states=True)
    activations_l = outputs.hidden_states[layer]

# # Compute average activations (you might want to do this over a larger dataset)
# with torch.no_grad():
#     average_activations = [torch.mean(act, dim=(0, 1)) for act in outputs.hidden_states]

# Compute the channel penalty
penalty, jacobian = compute_stochastic_channel_penalty(layer, k, channel_width, ortho_l, ortho_l_plus, activations_l)

print(f"Channel penalty for layer {layer}, channel {k}: {penalty.item()}")
print(f"Jacobian shape: {jacobian.shape}")
#print(f"Delta activation shape: {delta.shape}")

Channel penalty for layer 6, channel 2: 0.6883862018585205
Jacobian shape: torch.Size([2, 1, 1, 10, 8])


In [7]:
# Set up parameters
layer = 6
channel_width = 8
k = 2
hidden_size = 768  # GPT2 hidden size

# Create orthogonal matrices for layers 6 and 7
ortho_l = torch.nn.init.orthogonal_(torch.empty(hidden_size, hidden_size)).to(device)
ortho_l_plus = torch.nn.init.orthogonal_(torch.empty(hidden_size, hidden_size)).to(device)

# Generate random input for testing
batch_size = 2  # For simplicity, we'll use batch size of 1
seq_length = 10
input_ids = torch.randint(0, 50257, (batch_size, seq_length)).to(device)

# Get model activations for layer 6
with torch.no_grad():
    outputs = model(input_ids, output_hidden_states=True)
    activations_l = outputs.hidden_states[layer]

# # Compute average activations (you might want to do this over a larger dataset)
# with torch.no_grad():
#     average_activations = [torch.mean(act, dim=(0, 1)) for act in outputs.hidden_states]

# Compute the channel penalty
penalty, jacobian = compute_stochastic_channel_penalty(layer, k, channel_width, ortho_l, ortho_l_plus, activations_l)

print(f"Channel penalty for layer {layer}, channel {k}: {penalty.item()}")
print(f"Jacobian shape: {jacobian.shape}")
#print(f"Delta activation shape: {delta.shape}")

Channel penalty for layer 6, channel 2: 0.6095150113105774
Jacobian shape: torch.Size([2, 1, 1, 10, 8])


In [8]:
import time

def time_compute_channel_penalty(batch_sizes, seq_length, layer, channel_width, k, ortho_l, ortho_l_plus, model, device):
    results = {}
    for batch_size in batch_sizes:
        # Generate random input
        input_ids = torch.randint(0, 50257, (batch_size, seq_length)).to(device)

        # Compute activations for the batch in layer 6
        with torch.no_grad():
            outputs = model(input_ids, output_hidden_states=True)
            activations = outputs.hidden_states[layer]
        
        # # Center the activations [this already happens inside the function call]
        # mean_activation = torch.mean(activations, dim=(0, 1), keepdim=True)
        # centered_activations = activations - mean_activation
        
        # Time the computation
        start_time = time.time()
        penalty, _, = compute_stochastic_channel_penalty(layer, k, channel_width, ortho_l, ortho_l_plus, activations)
        #penalty, _, = compute_channel_penalty(layer, k, channel_width, ortho_l, ortho_l_plus, activations)
        end_time = time.time()
        
        results[batch_size] = {
            'time': end_time - start_time,
            'penalty': penalty.item()
        }
    
    return results

# Set up parameters
layer = 6
channel_width = 64
k = 2
#batch_sizes = [1, 2, 16]
batch_sizes = [1, 2, 8, 16]
seq_length = 128

# Time the function for different batch sizes
timing_results = time_compute_channel_penalty(
    batch_sizes, seq_length, layer, channel_width, k, 
    ortho_l, ortho_l_plus, model, device
)

# Print results
for batch_size, result in timing_results.items():
    print(f"Batch size {batch_size}:")
    print(f"  Time: {result['time']:.4f} seconds")
    print(f"  Penalty: {result['penalty']:.6f}")
    print()


Batch size 1:
  Time: 0.0145 seconds
  Penalty: 701.474426

Batch size 2:
  Time: 0.0122 seconds
  Penalty: 700.664124

Batch size 8:
  Time: 0.0128 seconds
  Penalty: 706.942749

Batch size 16:
  Time: 0.0146 seconds
  Penalty: 714.234375



In [16]:
#Now, let's prepare to do this for all layers!
#First initialize an orthogonal matrix for every layer of gpt-2

#import torch.nn.utils.parametrize as P

num_layers = 12  # GPT-2 small has 12 layers
hidden_size = 768  # GPT-2 small has a hidden size of 768

# Initialize a list to store orthogonal matrices for each layer
class OrthogonalMatrix(nn.Module):
    def __init__(self, n):
        super().__init__()
        # Initialize the parameter
        self.Q = nn.Parameter(torch.empty(n, n))
        nn.init.eye_(self.Q)  # Initialize as identity matrix for stability
        # Register the orthogonal parametrization
        torch.nn.utils.parametrizations.orthogonal(self, 'Q')

    def forward(self, x):
        # This example assumes you will use the orthogonal matrix to transform an input x
        return self.Q @ x

# Create a list of orthogonal matrices for each layer
orthogonal_matrices = nn.ModuleList([OrthogonalMatrix(hidden_size) for _ in range(num_layers)])

# Move the matrices to the same device as the model
orthogonal_matrices = orthogonal_matrices.to(device)

# Print to confirm the creation
print(f"Created {len(orthogonal_matrices)} orthogonal matrices, one for each layer.")
print(f"Each matrix shape: {orthogonal_matrices[0].Q.shape}")



Created 12 orthogonal matrices, one for each layer.
Each matrix shape: torch.Size([768, 768])


In [17]:
# Turn off optimization parameters of GPT-2
for param in model.parameters():
    param.requires_grad = False
print("Optimization parameters of GPT-2 have been turned off.")
# Verify that all parameters are frozen
all_frozen = all(not p.requires_grad for p in model.parameters())
print(f"All GPT-2 parameters frozen: {all_frozen}")

#meanwhile, orthogonal_matrices should be optimizable

Optimization parameters of GPT-2 have been turned off.
All GPT-2 parameters frozen: True


In [18]:
#Load test data
for batch in dataloader:
    break

# Compute hidden activations at all layers from the batch
inputs = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)

# Get the hidden states for all layers
with torch.no_grad():
    outputs = model(inputs, attention_mask=attention_mask, output_hidden_states=True)
# Extract all hidden states
all_hidden_states = outputs.hidden_states
# Remove the first element (embedding layer output) and convert to a list of tensors
hidden_activations = [hs for hs in all_hidden_states[1:]]

print(f"Computed hidden activations for {len(hidden_activations)} layers")
print(f"Shape of hidden activations for first layer: {hidden_activations[0].shape}")



Computed hidden activations for 12 layers
Shape of hidden activations for first layer: torch.Size([16, 512, 768])


In [19]:
#adding orthogonal matrix support
class RotateLayer(torch.nn.Module):
    """A linear transformation with orthogonal initialization."""

    def __init__(self, n, init_orth=True):
        super().__init__()
        weight = torch.empty(n, n)
        # we don't need init if the saved checkpoint has a nice
        # starting point already.
        # you can also study this if you want, but it is our focus.
        if init_orth:
            torch.nn.init.orthogonal_(weight)
        self.weight = torch.nn.Parameter(weight, requires_grad=True)

    def forward(self, x):
        return torch.matmul(x.to(self.weight.dtype), self.weight)

class RotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention):

    """Intervention in the rotated space."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = RotateLayer(self.embed_dim)
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)



In [20]:
K = 64 #number of channels
channel_width = hidden_size // K
#Now, writing a function to compute the channel penalty, summing over the K channels

def compute_layer_penalty(layer): #implicit: K, channel_width, ortho_l, ortho_l_plus, activations_l
    penalties = torch.zeros(K)
    for k in range(K):
        penalties[k], _, = compute_stochastic_channel_penalty(layer, k, channel_width, 
                                                         orthogonal_matrices[layer].Q, orthogonal_matrices[layer + 1].Q, 
                                                         hidden_activations[layer])
    return penalties

def compute_all_layer_penalties():
    layer_penalty = torch.zeros(num_layers - 1, K, device = device)
    for layer in range(num_layers - 1):
        layer_penalty[layer] = compute_layer_penalty(layer)
    return torch.sum(layer_penalty ** 2), layer_penalty

In [21]:
# # Initialize wandb
# import wandb
# wandb.init(project="channel-gpt2-model", name="optimization-run")


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmichaelbsklar[0m ([33mmichaelsklar[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [22]:
# Set up optimizer
optimizer = torch.optim.Adam([orthogonal_matrices[i] for i in range(num_layers)], lr=0.001)

# Number of optimization steps
num_steps = 1000

for step in range(num_steps):
    # Zero the gradients
    optimizer.zero_grad()
    
    # Compute the loss
    loss, layer_penalties = compute_all_layer_penalties()
    
    # Backpropagate the loss
    loss.backward()
    
    # Update the orthogonal matrices
    optimizer.step()
    
    # Log the loss to wandb
    wandb.log({"loss": loss.item()})
    
    # Print progress
    if (step + 1) % 100 == 0:
        print(f"Step {step + 1}/{num_steps}, Loss: {loss.item():.4f}")

# Finish the wandb run
wandb.finish()

print("Optimization complete.")

# Compute final penalties
final_loss, final_layer_penalties = compute_all_layer_penalties()
print(f"Final Loss: {final_loss.item():.4f}")



TypeError: optimizer can only optimize Tensors, but one of the params is torch.nn.utils.parametrize.ParametrizedOrthogonalMatrix

In [None]:
# Visualize final layer penalties
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))
plt.imshow(final_layer_penalties.cpu().detach().numpy(), aspect='auto', cmap='viridis')
plt.colorbar(label='Penalty')
plt.xlabel('Channel')
plt.ylabel('Layer')
plt.title('Final Layer Penalties')
plt.show()
