In [88]:
import torch
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
from einops import rearrange, repeat
import os


In [16]:
# Load pre-trained model and tokenizer
model_name = "EleutherAI/gpt-j-6B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Encode input sequence
input_sequence = "I want to eat a banana"
input_ids = tokenizer.encode(input_sequence, return_tensors="pt")

# Generate hidden states
outputs = model(input_ids, output_hidden_states=True)

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJModel: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing GPTJModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPTJModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [108]:
hidden_states = outputs.hidden_states
print(len(hidden_states))
cat_hidden_states = torch.cat(hidden_states[:-1], dim=0)
print(cat_hidden_states.shape)

29
torch.Size([28, 6, 4096])


In [112]:
def load_and_stack_layers(num_layers: int, prefix_path: str) -> torch.Tensor:
    layers = []
    for i in range(num_layers):
        layer_path = f"{prefix_path}/layer_{i}.pt"
        layer = torch.load(layer_path)
        layers.append(layer.weight.cpu())
    stacked = torch.cat(layers, dim=0)
    return stacked

In [114]:
path = '/home/waree/elk-reporters/EleutherAI/gpt-j-6b/sethapun/arithmetic_2as_1to1/busy-kapitsa/reporters'
rep_weights = load_and_stack_layers(model.config.n_layer,path)
print(rep_weights.shape)

In [122]:
# Use einsum to do multiplication 

result = torch.einsum('bse,be->bs', cat_hidden_states, rep_weights)
print(result.shape)

sigmoid_result = torch.sigmoid(result)
softmax_result = torch.softmax(result, dim = -1)

torch.Size([28, 6])


In [125]:
torch.set_printoptions(precision=2)
print(result)

tensor([[-8.55e-05,  1.82e-03,  1.37e-03,  2.04e-02,  6.52e-03,  2.09e-02],
        [ 3.09e+00,  3.28e+00,  3.31e+00,  2.91e+00,  4.03e+00,  4.02e+00],
        [ 3.01e-01,  1.41e-02, -8.31e-02,  4.20e-01,  3.41e-01, -4.00e-01],
        [ 2.40e+01, -8.44e-01, -1.08e+00, -1.11e+00, -4.59e-01, -1.89e+00],
        [ 1.37e+02, -4.33e-02, -7.48e-02,  4.54e-01, -1.43e-02,  8.36e-02],
        [ 4.35e+02,  2.76e+00,  1.94e+00,  3.26e+00,  2.61e+00,  3.57e+00],
        [-3.40e+02, -2.66e+00, -1.91e+00, -3.20e+00, -1.77e+00, -3.45e+00],
        [-2.24e+02, -1.06e+00, -9.89e-01, -1.47e+00, -7.43e-01, -6.54e-01],
        [ 4.56e+02,  5.52e+00,  4.37e+00,  5.68e+00,  4.68e+00,  5.77e+00],
        [ 3.73e+02,  4.09e+00,  3.54e+00,  5.42e+00,  3.82e+00,  4.84e+00],
        [ 4.73e+02,  7.24e+00,  6.92e+00,  7.82e+00,  6.66e+00,  7.46e+00],
        [ 3.18e+02,  5.63e+00,  5.18e+00,  6.15e+00,  5.58e+00,  6.15e+00],
        [-4.56e+01, -2.00e+00, -2.19e+00, -2.44e+00, -2.22e+00, -3.11e+00],
        [ 6.

In [123]:
softmax_result

tensor([[1.6524e-01, 1.6555e-01, 1.6548e-01, 1.6866e-01, 1.6633e-01, 1.6874e-01],
        [1.0711e-01, 1.2853e-01, 1.3314e-01, 8.8654e-02, 2.7302e-01, 2.6954e-01],
        [1.9625e-01, 1.4730e-01, 1.3366e-01, 2.2112e-01, 2.0433e-01, 9.7341e-02],
        [1.0000e+00, 1.6504e-11, 1.3084e-11, 1.2636e-11, 2.4254e-11, 5.8121e-12],
        [1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 1.5210e-01, 3.2207e-01, 8.8472e-02, 3.6828e-01, 6.9068e-02],
        [0.0000e+00, 1.7877e-01, 1.9123e-01, 1.1812e-01, 2.4455e-01, 2.6733e-01],
        [1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.3717e

## low dim

In [24]:
def load(path: Path | str):
    """Load a reporter from a file."""
    
    return torch.load(path)

In [25]:
rep = load('/home/waree/elk-reporters/EleutherAI/gpt-j-6b/sethapun/arithmetic_2as_1to1/busy-kapitsa/reporters/layer_20.pt')

In [62]:
hs = hidden_states[20]
print(hs.shape)

weight = rep.weight.cpu()
print(weight.shape)


torch.Size([1, 6, 4096])
torch.Size([1, 4096])


In [61]:
# Use einsum to do multiplication 

result = torch.einsum('bse,be->bs', hs , weight)
print(result)

tensor([[533.0842,  32.2087,  30.9929,  27.9183,  26.5704,  24.5978]],
       grad_fn=<ViewBackward0>)


In [66]:
# Use einops to do broadcasting first

# Use einops.repeat() to make 6 copies of the last dimension
weight_repeated = repeat(weight, 'b d -> b c d', c=6)

result = torch.einsum('bse,bse->bs', hs , weight_repeated)
print(result)


tensor([[533.0849,  32.2087,  30.9929,  27.9183,  26.5704,  24.5978]],
       grad_fn=<ViewBackward0>)


In [74]:
# Use einops to do broadcasting first then multiply using @

# Use einops.repeat() to make 6 copies of the last dimension
weight_repeated = repeat(weight, 'b d -> b c d', c=6)
print(weight_repeated.shape)

result = torch.matmul(hs.mT,weight_repeated)
print(result)
print(result.shape)

torch.Size([1, 6, 4096])
tensor([[[ 0.0003,  0.0007, -0.0002,  ...,  0.0031, -0.0018, -0.0012],
         [-0.0050, -0.0097,  0.0032,  ..., -0.0443,  0.0257,  0.0170],
         [-0.0019, -0.0037,  0.0012,  ..., -0.0170,  0.0098,  0.0065],
         ...,
         [-0.0059, -0.0114,  0.0037,  ..., -0.0522,  0.0303,  0.0200],
         [ 0.0007,  0.0014, -0.0005,  ...,  0.0066, -0.0038, -0.0025],
         [ 0.0019,  0.0037, -0.0012,  ...,  0.0168, -0.0098, -0.0064]]],
       grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 4096, 4096])


In [12]:
# Print model information
print(f"Model name: {model_name}")
print(f"Number of layers: {model.config.n_layer}")
print(f"Number of layers: {model.config.n_head}")
print(f"Number of hidden units: {model.config.hidden_size}")

Model name: EleutherAI/gpt-j-6B
Number of layers: 28
Number of hidden units: 4096
