# All Results


In [42]:
import numpy as np

import umap

import torch
import json
import os
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from sklearn.decomposition import PCA

In [43]:
def get_layer_activations(internal_reps, layer_idx):
    """
    Get the activations of a specific layer for all problems and all examples.

    internal_reps: list[tuple[tensor]] where internal_reps[question_num][layer_num]
     is a tensor of shape [batch=1, num_tokens, d_model=4096]

    We will return a tensor of shape [num_questions, num_tokens, d_model=4096]
    for the given layer
    """
    return torch.cat([internal_reps[i][layer_idx] for i in range(len(internal_reps))])


## x + y UMAP 20240709

```bash
# 100 instances, 10 possible answers 15-20
python3 compute_minute_math_reps.py \
    --output_dir results/xyUMAP20240709 \
    --problem_type xy

# 10,000 instances, 10 possible answers 15-20
python3 compute_minute_math_reps.py \
    --output_dir results/xy10kUMAP20240709 \
    --problem_type xy \
    --num_unique_problems 10000

# 1,000 instances
python3 compute_minute_math_reps.py \
    --output_dir results/xy1kUMAP20240709 \
    --problem_type xy \
    --num_unique_problems 1000


```

In [44]:
# Define the directory containing the results
results_dir = 'xy1kUMAP20240709'

In [45]:
# Load the answers and problems files
answers = np.load(os.path.join(results_dir, 'answers.npy'))
problems = np.load(os.path.join(results_dir, 'problems.npy'))

# Load the internal representations and logits
internal_reps = torch.load(os.path.join(results_dir, 'internal_reps.pt'))
logits = torch.load(os.path.join(results_dir, 'logits.pt'))

# Load the arguments
with open(os.path.join(results_dir, 'args.json'), 'r') as f:
    args = json.load(f)

# Display the shapes and types of the loaded data
print("Answers shape:", answers.shape)
print("Problems shape:", problems.shape)
print("Internal Reps type:", type(internal_reps))
print("Internal Reps length:", len(internal_reps))
print("Logits type:", type(logits))
print("Logits length:", len(logits))

# Let's inspect the internal representations and logits a bit more closely
print("Example internal representation shape:", internal_reps[0][0][0].shape)
print("Example logits shape:", logits[0].shape)

# Display some examples from the loaded data
print("\nExample answers:", answers[0])
print("\nExample problems:", problems[0])
print("\nExample internal representation:", internal_reps[0][0][0])
print("\nExample logits:", logits[0])



Answers shape: (1000, 1)
Problems shape: (1000, 17)
Internal Reps type: <class 'list'>
Internal Reps length: 1000
Logits type: <class 'torch.Tensor'>
Logits length: 1000
Example internal representation shape: torch.Size([17, 4096])
Example logits shape: torch.Size([17, 128256])

Example answers: [868]

Example problems: [128000     87    284    220    806     11    379    284    220     19
     26   9093    865    489    379    284    220]

Example internal representation: tensor([[-8.2970e-05,  2.5749e-04, -2.4605e-04,  ..., -3.2425e-04,
         -2.1553e-04,  4.7112e-04],
        [-2.0752e-03, -1.4038e-03,  6.1035e-03,  ...,  2.2888e-03,
          6.1951e-03,  1.1414e-02],
        [ 1.0376e-03, -6.8054e-03,  6.2943e-04,  ...,  2.5482e-03,
         -8.3618e-03, -8.8501e-03],
        ...,
        [ 3.3569e-03, -3.3760e-04,  2.4719e-03,  ..., -8.3008e-03,
          3.2654e-03, -7.3242e-03],
        [ 1.0376e-03, -6.8054e-03,  6.2943e-04,  ...,  2.5482e-03,
         -8.3618e-03, -8.8501e

### Check correctness of model predictions

In [46]:
# compute the actual token-wise answers
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained(args['model_name'])

# find the argmax over the logits for each 
# problem to get the predicted class

# logits has shape [num_questions, num_tokens, vocab_size]

final_logits = logits[:, -1, :]

# take the argmax over the final dim
predicted_class = torch.argmax(final_logits, dim=1)

# decode each prediction individually
predicted_class_str = [tokenizer.decode(i) for i in predicted_class]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [47]:
str_answers = [tokenizer.decode(i) for i in answers[:, 0]]

str_problems = []
for i in range(problems.shape[0]):
    str_problems.append(tokenizer.decode(problems[i, :]).replace('\n', ' '))

num_correct = 0
for i in range(len(str_answers)):
    if str_answers[i] == predicted_class_str[i]:
        num_correct += 1

num_correct / len(str_answers)

1.0

In [48]:
str_problems

['<|begin_of_text|>x = 11, y = 4; therefore x + y = ',
 '<|begin_of_text|>x = 6, y = 10; therefore x + y = ',
 '<|begin_of_text|>x = 11, y = 6; therefore x + y = ',
 '<|begin_of_text|>x = 17, y = 1; therefore x + y = ',
 '<|begin_of_text|>x = 7, y = 12; therefore x + y = ',
 '<|begin_of_text|>x = 15, y = 5; therefore x + y = ',
 '<|begin_of_text|>x = 0, y = 21; therefore x + y = ',
 '<|begin_of_text|>x = 15, y = 7; therefore x + y = ',
 '<|begin_of_text|>x = 4, y = 19; therefore x + y = ',
 '<|begin_of_text|>x = 22, y = 2; therefore x + y = ',
 '<|begin_of_text|>x = 25, y = 0; therefore x + y = ',
 '<|begin_of_text|>x = 0, y = 15; therefore x + y = ',
 '<|begin_of_text|>x = 5, y = 11; therefore x + y = ',
 '<|begin_of_text|>x = 7, y = 10; therefore x + y = ',
 '<|begin_of_text|>x = 14, y = 4; therefore x + y = ',
 '<|begin_of_text|>x = 10, y = 9; therefore x + y = ',
 '<|begin_of_text|>x = 18, y = 2; therefore x + y = ',
 '<|begin_of_text|>x = 13, y = 8; therefore x + y = ',
 '<|begin_

In [49]:
# make a set of integer-valued answers
int_answers = [int(a) for a in str_answers]

In [50]:
class_ids = torch.tensor(answers[:, 0])
class_ids.shape

torch.Size([1000])

In [51]:

def get_umap(layer_reps, class_ids):
    """ Compute the UMAP of layer_reps, which is a tensor of shape [num_examples, num_tokens, token_dim]
    """
    print("Layer reps shape:", layer_reps.shape)
    print("Layer reps has type:", type(layer_reps))

    class_ids = class_ids.cpu()
    layer_reps = layer_reps.cpu()
    
    # Convert to numpy if it's a torch tensor
    if torch.is_tensor(layer_reps):
        layer_reps = layer_reps.detach().cpu().numpy()
    
    # Flatten the last two dimensions
    num_examples, num_tokens, token_dim = layer_reps.shape
    flattened_reps = layer_reps.reshape(num_examples, num_tokens * token_dim)
    
    # Apply UMAP
    reducer = umap.UMAP(random_state=42)
    umap_embedding = reducer.fit_transform(flattened_reps)
    # print("Output of reduce transform: ", umap_embedding)
    
    return umap_embedding

def get_class_relevance_pca(layer_reps, class_ids):
    """ Compute the PCA of layer_reps, which is a tensor of shape [num_examples, num_tokens, token_dim]
    We are going to compute the PCA which maximizes the variance of class means
    """
    flattened_reps = layer_reps.reshape(layer_reps.shape[0], -1).cpu()
    
    # Calculate class means efficiently
    unique_classes = np.unique(class_ids)
    print("flattened_reps type: ", type(flattened_reps))
    print("flattened_reps device: ", flattened_reps.device)
    print("class_ids type: ", type(class_ids))
    print("class_ids device: ", class_ids.device)
    class_means = np.array([flattened_reps[class_ids == c].mean(axis=0) for c in unique_classes])
    
    # Run PCA on the class means
    pca = PCA()
    pca.fit(class_means)
    
    # Project the original data onto the PCA components
    projected_data = pca.transform(flattened_reps)
    
    return projected_data

from sklearn.decomposition import PCA

def get_pca(layer_reps, class_ids):
    """ Compute the PCA of layer_reps, which is a tensor of shape [num_examples, num_tokens, token_dim]
    """
    # Convert to numpy if it's a torch tensor
    if torch.is_tensor(layer_reps):
        layer_reps = layer_reps.detach().cpu().numpy()
    
    # Flatten the last two dimensions
    num_examples, num_tokens, token_dim = layer_reps.shape
    flattened_reps = layer_reps.reshape(num_examples, num_tokens * token_dim)
    
    # Apply PCA
    pca = PCA(n_components=2)  # You can adjust the number of components as needed
    pca_result = pca.fit_transform(flattened_reps)
    
    return pca_result

def get_class_pca_and_umap(layer_reps, class_ids):
    """ Compute the PCA and UMAP of layer_reps, which is a tensor of shape [num_examples, num_tokens, token_dim]
    """
    pca_result = get_class_relevance_pca(layer_reps, class_ids)

    number_of_class_ids = len(np.unique(class_ids))

    pca_and_umap_result = get_umap(pca_result[:, 0:number_of_class_ids], class_ids)
    
    return pca_and_umap_result

In [52]:
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

for dim_reduction_func in (get_pca, get_umap, get_class_relevance_pca, get_class_pca_and_umap):
    # iterate through each layer
    num_layers = 33
    purity_by_layer = []
    for layer_num in range(num_layers):
        layer_activations = get_layer_activations(internal_reps, layer_num)
        low_dim_embedding = dim_reduction_func(layer_activations, class_ids)
        
        # Create Plotly figure
        fig = make_subplots(rows=1, cols=1)
        
        # Add scatter plot
        scatter = go.Scatter(
            x=low_dim_embedding[:, 0],
            y=low_dim_embedding[:, 1],
            mode='markers',
            marker=dict(
                size=8,
                color=int_answers,
                colorscale='Viridis',
                showscale=True
            ),
            text=[f"Problem: {prob}<br>Answer: {ans}<br>Predicted: {pred}" 
                for prob, ans, pred in zip(str_problems, str_answers, predicted_class_str)],
            hoverinfo='text'
        )
        
        fig.add_trace(scatter)
        
        # Update layout
        fig.update_layout(
            title=f'{dim_reduction_func.__name__} for Layer {layer_num}',
            xaxis_title='Dim 1',
            yaxis_title='Dim 2',
            width=1000,
            height=800
        )
        
        # Save as interactive HTML
        pio.write_html(fig, file=f'{results_dir}/{dim_reduction_func.__name__}_layer_{layer_num}.html')
        
        # Save as static PNG
        pio.write_image(fig, file=f'{results_dir}/{dim_reduction_func.__name__}_layer_{layer_num}.png')
        
        # Clear the figure to free up memory
        fig.data = []
        fig.layout = {}

Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


Exited at iteration 560 with accuracies 
[9.92013654e-14 1.76600368e-06 5.53711338e-07 2.24906190e-07]
not reaching the requested tolerance 6.258487701416016e-07.
Use iteration 560 instead with accuracy 
6.361553260053513e-07.



Exited postprocessing with accuracies 
[1.06132342e-14 1.76604428e-06 5.53582587e-07 2.24904391e-07]
not reaching the requested tolerance 6.258487701416016e-07.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


Exited at iteration 2000 with accuracies 
[2.97274083e-14 1.46565306e-07 2.28547319e-07 2.77439048e-06]
not reaching the requested tolerance 5.960464477539062e-07.
Use iteration 861 instead with accuracy 
3.329562101920104e-07.



Exited postprocessing with accuracies 
[2.00636269e-15 5.21441623e-08 2.16327800e-07 1.06335286e-06]
not reaching the requested tolerance 5.960464477539062e-07.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


Exited at iteration 2000 with accuracies 
[4.31380573e-14 2.29184726e-07 2.05929281e-06 1.55627170e-05]
not reaching the requested tolerance 6.556510925292969e-07.
Use iteration 789 instead with accuracy 
9.84717306990002e-07.



Exited postprocessing with accuracies 
[2.79398119e-14 3.40929171e-07 8.70481907e-07 2.72745810e-06]
not reaching the requested tolerance 6.556510925292969e-07.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


Exited at iteration 2000 with accuracies 
[1.81622585e-14 1.14063928e-07 5.17794566e-06 1.49502058e-06]
not reaching the requested tolerance 5.811452865600586e-07.
Use iteration 1082 instead with accuracy 
3.5327580672949387e-07.



Exited postprocessing with accuracies 
[5.23476906e-15 9.89007621e-08 8.81295735e-07 4.32907245e-07]
not reaching the requested tolerance 5.811452865600586e-07.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


Exited at iteration 2000 with accuracies 
[6.22814063e-15 6.44947318e-08 2.38846937e-07 1.11021078e-05]
not reaching the requested tolerance 7.152557373046875e-07.
Use iteration 304 instead with accuracy 
3.136026596300688e-07.



Exited postprocessing with accuracies 
[8.32926761e-15 6.08595878e-08 2.55118352e-07 9.38432690e-07]
not reaching the requested tolerance 7.152557373046875e-07.


Spectral initialisation failed! The eigenvector solver
failed. This is likely due to too small an eigengap. Consider
adding some noise or jitter to your data.

Falling back to random initialisation!



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


Exited at iteration 2000 with accuracies 
[5.81238058e-13 2.03656152e-07 9.79902666e-08 1.09731264e-05]
not reaching the requested tolerance 7.748603820800781e-07.
Use iteration 1898 instead with accuracy 
7.624756876421101e-07.



Exited postprocessing with accuracies 
[7.01783633e-15 2.04416463e-07 9.79290631e-08 2.74755669e-06]
not reaching the requested tolerance 7.748603820800781e-07.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


Graph is not fully connected, spectral embedding may not work as expected.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


Graph is not fully connected, spectral embedding may not work as expected.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


Graph is not fully connected, spectral embedding may not work as expected.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


Graph is not fully connected, spectral embedding may not work as expected.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


Graph is not fully connected, spectral embedding may not work as expected.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


Graph is not fully connected, spectral embedding may not work as expected.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


Graph is not fully connected, spectral embedding may not work as expected.



Layer reps shape: torch.Size([1000, 17, 4096])
Layer reps has type: <class 'torch.Tensor'>



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.


Graph is not fully connected, spectral embedding may not work as expected.



flattened_reps type:  <class 'torch.Tensor'>
flattened_reps device:  cpu
class_ids type:  <class 'torch.Tensor'>
class_ids device:  cpu
flattened_reps type:  <class 'torch.Tensor'>
flattened_reps device:  cpu
class_ids type:  <class 'torch.Tensor'>
class_ids device:  cpu
flattened_reps type:  <class 'torch.Tensor'>
flattened_reps device:  cpu
class_ids type:  <class 'torch.Tensor'>
class_ids device:  cpu
flattened_reps type:  <class 'torch.Tensor'>
flattened_reps device:  cpu
class_ids type:  <class 'torch.Tensor'>
class_ids device:  cpu
flattened_reps type:  <class 'torch.Tensor'>
flattened_reps device:  cpu
class_ids type:  <class 'torch.Tensor'>
class_ids device:  cpu
flattened_reps type:  <class 'torch.Tensor'>
flattened_reps device:  cpu
class_ids type:  <class 'torch.Tensor'>
class_ids device:  cpu
flattened_reps type:  <class 'torch.Tensor'>
flattened_reps device:  cpu
class_ids type:  <class 'torch.Tensor'>
class_ids device:  cpu
flattened_reps type:  <class 'torch.Tensor'>
fla

In [53]:
# # iterate through each layer
# num_layers = 33
# purity_by_layer = []
# for layer_num in range(num_layers):
#     layer_activations = get_layer_activations(internal_reps, layer_num)
#     umap_embedding = get_umap(layer_activations)
    
#     # Plot UMAP
#     plt.figure(figsize=(10, 8))
#     scatter = plt.scatter(umap_embedding[:, 0], umap_embedding[:, 1], c=int_answers, cmap='viridis')
#     plt.colorbar(scatter)
#     plt.title(f'UMAP for Layer {layer_num}')
#     plt.savefig(f'{results_dir}/umap_layer_{layer_num}.png')
#     plt.close()
    

In [54]:
type(answers)

numpy.ndarray

In [55]:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map='auto')

Loading checkpoint shards: 100%|██████████| 4/4 [00:13<00:00,  3.28s/it]


In [56]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head)

In [60]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head)

In [None]:
layer_activations

In [59]:
normed_activations = model.model.norm(layer_activations)

tensor([[[ 4.7146, -0.2252, -2.0343,  ..., -3.1800,  1.3072,  0.3317],
         [ 1.4497, -5.0613,  0.2658,  ...,  2.6342,  0.4100, -1.5794],
         [ 1.7514, -1.0563,  2.4603,  ...,  1.4068, -0.8067, -1.1064],
         ...,
         [-0.2502,  3.2927,  4.2928,  ..., -0.7804,  0.3525, -3.5105],
         [-0.5730,  0.3403,  2.9384,  ...,  0.2153, -0.0955, -3.8730],
         [-1.9657, -4.3439,  2.4831,  ..., -2.9856,  2.0866,  0.4614]],

        [[ 4.7146, -0.2252, -2.0343,  ..., -3.1800,  1.3072,  0.3317],
         [ 1.4497, -5.0613,  0.2658,  ...,  2.6342,  0.4100, -1.5794],
         [ 1.7514, -1.0563,  2.4603,  ...,  1.4068, -0.8067, -1.1064],
         ...,
         [-0.7619,  3.3042,  3.3874,  ..., -0.3330,  0.5255, -3.8644],
         [-0.4774, -0.5372,  2.0920,  ...,  0.9170,  1.3846, -4.9535],
         [ 0.4672, -3.5356,  1.9836,  ..., -0.6219,  3.1492, -0.4268]],

        [[ 4.7146, -0.2252, -2.0343,  ..., -3.1800,  1.3072,  0.3317],
         [ 1.4497, -5.0613,  0.2658,  ...,  2

In [58]:
layer_activations.shape

torch.Size([1000, 17, 4096])

## Testing if the last hidden state gives the logits

In [62]:
model.device

device(type='cuda', index=0)

In [64]:
input_string= "Hello, my name is Albert Einstein, and I am a physicist."
input_ids = tokenizer(input_string, return_tensors='pt')['input_ids'].to(model.device)

# run thru model, grab logits and hidden states and last_hidden_states
outputs = model(input_ids, output_hidden_states=True, return_dict=True)
logits = outputs.logits
hidden_states = outputs.hidden_states
last_hidden_state = hidden_states[-1]

# print the shapes of each 
print("Input ids shape: ", input_ids.shape)
print("Logits shape: ", logits.shape)
print("Hidden states shapes: ", [hs.shape for hs in hidden_states])
print("Last hidden state shape: ", last_hidden_state.shape)


Input ids shape:  torch.Size([1, 15])
Logits shape:  torch.Size([1, 15, 128256])
Hidden states shapes:  [torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096]), torch.Size([1, 15, 4096])]
Last

In [65]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head)

In [67]:
approx_logits = model.lm_head(last_hidden_state)

In [68]:
approx_logits.shape

torch.Size([1, 15, 128256])

In [69]:
logits.shape

torch.Size([1, 15, 128256])

In [70]:
# check if logits are all close
torch.allclose(logits, approx_logits)

True

In [72]:
last_hidden_state - hidden_states[-1]
# check if all close
torch.allclose(last_hidden_state, hidden_states[-1])

True

In [73]:
approx2_logits = model.lm_head(hidden_states[-1])
# check if all close
torch.allclose(logits, approx2_logits)

True