In [1]:
# from ChatGPT-4
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config, GPT2Model
import torch

class EarlyExitingGPT2Model(GPT2Model):
    def __init__(self, config, exit_layer):
        super().__init__(config)
        self.exit_layer = exit_layer

    def forward(self, input_ids, attention_mask=None):
        hidden_states = self.wte(input_ids)  # Word token embeddings
        for i, block in enumerate(self.h):
            hidden_states = block(hidden_states)[0]
            if i == self.exit_layer - 1:
                break
        return hidden_states

In [2]:
def get_hiddenstates_attn(text, model, tokenizer):
    input_ids = tokenizer.encode(text, return_tensors="pt")
    print("Nr. of input tokens: ", len(input_ids))
    with torch.no_grad():
        output = model(input_ids=input_ids)

    hidden_states = output.hidden_states
    attentions = output.attentions
    return hidden_states, attentions

In [3]:
# Load the configuration and create a new model instance
config = GPT2Config.from_pretrained('gpt2-medium')
# early_exiting_model = EarlyExitingGPT2Model(config, exit_layer=12)
gpt2_medium_model = GPT2Model.from_pretrained('gpt2-medium', output_attentions=True, output_hidden_states=True)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")

text = "Hello, my dog is cute"

In [4]:
config

GPT2Config {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 1024,
  "n_head": 16,
  "n_inner": null,
  "n_layer": 24,
  "n_positions": 1024,
  "n_special": 0,
  "predict_special_tokens": true,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.40.0",
  "use_cache": true,
  "vocab_size": 50257
}

In [5]:
from torchinfo import summary as ti_summary

ti_summary(gpt2_medium_model, depth=3)

Layer (type:depth-idx)                        Param #
GPT2Model                                     --
├─Embedding: 1-1                              51,463,168
├─Embedding: 1-2                              1,048,576
├─Dropout: 1-3                                --
├─ModuleList: 1-4                             --
│    └─GPT2Block: 2-1                         --
│    │    └─LayerNorm: 3-1                    2,048
│    │    └─GPT2Attention: 3-2                4,198,400
│    │    └─LayerNorm: 3-3                    2,048
│    │    └─GPT2MLP: 3-4                      8,393,728
│    └─GPT2Block: 2-2                         --
│    │    └─LayerNorm: 3-5                    2,048
│    │    └─GPT2Attention: 3-6                4,198,400
│    │    └─LayerNorm: 3-7                    2,048
│    │    └─GPT2MLP: 3-8                      8,393,728
│    └─GPT2Block: 2-3                         --
│    │    └─LayerNorm: 3-9                    2,048
│    │    └─GPT2Attention: 3-10               4,198,400

In [6]:
# Attn is the score just before the *V
hidden_states, attn = get_hiddenstates_attn(text, gpt2_medium_model, tokenizer)

Nr. of input tokens:  1


In [7]:
type(hidden_states)

tuple

In [8]:
hidden_states[0].shape

torch.Size([1, 6, 1024])

#### Attention is $n\times n$ since it's a pairwise importance measure

In [9]:
attn[2].shape

torch.Size([1, 16, 6, 6])

In [7]:
# From GPT-4: https://chat.openai.com/share/a99ee779-f8c8-4bad-9511-483606ba0cca
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from transformers import GPT2Tokenizer
import torch
from torch import nn

class CustomGPT2Block(GPT2Block):
    def __init__(self, config):
        super().__init__(config)
        
    def forward(self, hidden_states, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False, **kwargs):
        # Before self-attention LayerNorm
        ln1_input = hidden_states
        attn_outputs = super().forward(
            hidden_states, layer_past, attention_mask, head_mask, use_cache=use_cache, output_attentions=output_attentions, **kwargs
        )
        # Output from self-attention layer
        hidden_states = attn_outputs[0]

        # Before MLP LayerNorm
        ln2_input = hidden_states

        # Continue the normal process
        feed_forward_hidden_states = self.mlp(ln2_input)
        # Add the feed-forward output to the attention output
        hidden_states = hidden_states + feed_forward_hidden_states
        
        # Final LayerNorm input
        ln3_input = hidden_states

        return (hidden_states,) + attn_outputs[1:] + ((ln1_input, ln2_input, ln3_input),)

# Example model loading and modification
from transformers import GPT2Config, GPT2Model

tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")

config = GPT2Config()
model = GPT2Model(config)
model.h[0] = CustomGPT2Block(config)  # Replace the first block for demonstration

# Example forward pass
input_ids = tokenizer.encode("ariel is a cs phd student", return_tensors="pt")
outputs = model(input_ids)

# Extract LayerNorm inputs
ln_inputs = outputs[-1][-1]  # Assuming only the first block was replaced
print("LayerNorm Inputs:", ln_inputs)

LayerNorm Inputs: (tensor([[[[ 9.0784e-01, -1.3208e+00,  3.7702e-01,  ...,  7.3676e-01,
            6.8704e-01,  2.1352e-01],
          [ 4.5032e-01, -1.3441e+00,  4.0242e-02,  ...,  7.3500e-01,
            1.1475e+00,  1.2623e-01],
          [ 9.5168e-01, -1.5631e+00, -1.9853e-01,  ...,  1.1555e-01,
            1.7924e+00, -6.9798e-02],
          ...,
          [ 1.4890e-01, -9.8778e-01,  7.9761e-01,  ...,  5.7417e-02,
            1.3985e+00,  8.3555e-02],
          [ 9.7711e-02, -4.5645e-01,  9.0713e-01,  ..., -1.0180e-01,
            1.0865e+00,  5.5014e-01],
          [-2.8885e-01, -1.3738e-01,  1.0511e+00,  ..., -3.1056e-02,
            4.5944e-01, -1.9303e-01]],

         [[ 2.1683e-01, -4.5429e-01,  1.8653e-01,  ..., -1.4401e-01,
           -8.6371e-02, -6.9943e-01],
          [-2.1272e-01, -1.7721e-01, -9.4751e-02,  ..., -2.9946e-01,
           -4.1352e-01, -1.2405e+00],
          [ 4.8012e-01, -3.2333e-01, -5.4602e-01,  ...,  4.0491e-01,
           -5.1088e-02, -5.5999e-01],
 