### Manual implementation of `fwd()` to extract LayerNorm inputs & outputs

In [1]:
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from transformers import GPT2Tokenizer
from typing import Optional, Tuple, Union
import torch
from transformers import GPT2Model, GPT2Config

# Apply Top-K by magniture to LayerNorm and Residual connection - the densifying operations in a Transformer
class TopKLayerNormGPT2Block(GPT2Block):
    def __init__(self, config, layer_idx=None):
        super().__init__(config, layer_idx=None)
        
    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
        residual = hidden_states  # Identical debug to GPT2Block so far
        hidden_states = self.ln_1(hidden_states) # Also identical
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]  # This is where differences start
        # residual connection
        hidden_states = attn_output + residual

        if encoder_hidden_states is not None:
            # add one self-attention block for cross-attention
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
                    "cross-attention layers by setting `config.add_cross_attention=True`"
                )
            residual = hidden_states
            hidden_states = self.ln_cross_attn(hidden_states)
            cross_attn_outputs = self.crossattention(
                hidden_states,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
            )
            attn_output = cross_attn_outputs[0]
            # residual connection
            hidden_states = residual + attn_output
            outputs = outputs + cross_attn_outputs[2:]  # add cross attentions if we output attention weights

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states

        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs  # hidden_states, present, (attentions, cross_attentions)

In [2]:
from torch import nn

class GPT2TopKModel(GPT2Model):
    def __init__(self, config):
        super().__init__(config)

        self.embed_dim = config.hidden_size

        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)

        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList([TopKLayerNormGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

        # Model parallel
        self.model_parallel = False
        self.device_map = None
        self.gradient_checkpointing = False
        self._attn_implementation = config._attn_implementation

        # Initialize weights and apply final processing
        self.post_init()

#### Verify Correctness against stock GPT-2 at same layer

In [3]:
def compare_model_params(model1, model2):
    # Disable non-deterministic behavior
    torch.manual_seed(0)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Ensure both models start with the same weights
    model2.load_state_dict(model1.state_dict())
    model2.h[0].load_state_dict(model1.h[0].state_dict())

    # Compare all named parameters
    params_dict_original = {name: param for name, param in model1.named_parameters()}
    params_dict_modified = {name: param for name, param in model2.named_parameters()}

    # Now you can access by name
    for name in params_dict_original:
        assert torch.allclose(params_dict_original[name], params_dict_modified[name], atol=1e-6), f"Mismatch found in parameter: {name}"


    def compare_state_dicts(dict1, dict2):
        # Check if both dictionaries have the same set of keys (parameter names)
        assert dict1.keys() == dict2.keys(), "State dicts have different sets of parameters."
        
        # Check if all tensors in both dictionaries are close
        for key in dict1:
            assert torch.allclose(dict1[key], dict2[key], atol=1e-6), f"Mismatch found in parameter: {key}"

    # Prepare input
    input_ids = torch.tensor([[12100, 242, 508, 318, 13, 198]])

    hidden_states_original = model1.wte(input_ids)
    hidden_states_modified = model2.wte(input_ids)
    assert torch.allclose(hidden_states_original, hidden_states_modified, atol=1e-6), "Embedding output mismatch"
    print("Word Token Embeddings output match")

    # modified_model.load_state_dict(original_model.state_dict())
    # modified_model.h[0].load_state_dict(original_model.h[0].state_dict())

    # Compare state dicts
    state_dict1 = model1.state_dict()
    state_dict2 = model2.state_dict()

    compare_state_dicts(state_dict1, state_dict2)
    print("State Dicts match")


    def compare_attention_params(model1, model2, layer_index=0):
        attn1 = model1.h[layer_index].attn
        attn2 = model2.h[layer_index].attn
        
        # Convert generators to dictionaries
        params1 = dict(attn1.named_parameters())
        params2 = dict(attn2.named_parameters())

        # Now params1 and params2 are dictionaries that can be accessed by keys
        for name in params1:
            param1 = params1[name]
            param2 = params2[name]
            if not torch.allclose(param1, param2, atol=1e-6):
                print(f"Mismatch in {name}")
                print(f"Model1: {param1}")
                print(f"Model2: {param2}")

    # Call this function with your models
    for i in range(12):
        compare_attention_params(model1, model2, i)
    print("Attention parameters match")


    original_h0 = model1.h[0](hidden_states_original)[0]
    modified_h0 = model2.h[0](hidden_states_modified)[0]  # Input is the same. These are still different
    # modified_h0 = modified_model.h[0](hidden_states_modified)[0]
    assert torch.allclose(original_h0, modified_h0, atol=1e-6), "First Hidden State output mismatch"
    print("First Hidden State output match")

    model1.h[0].attn.c_attn.weight


    # Get outputs from both models
    with torch.no_grad():
        original_output = model1(input_ids)[0]
        modified_output = model2(input_ids)[0]

    # at = model2(input_ids)#.attn
    # Compare outputs
    # print("Difference between original and modified model outputs:", torch.sum((original_output - modified_output) ** 2))
    print("original_output: ", original_output)
    print("modified_output: ", modified_output)

    assert torch.allclose(original_output, modified_output, atol=1e-6), "Final output mismatch"
    print("Final output matches")


config = GPT2Config()
original_model = GPT2Model(config).eval()
# modified_model = GPT2Model(config).eval()  # Assigning h[0] = CustomGPT2Block does not work!
modified_model = GPT2TopKModel(config).eval()

compare_model_params(original_model, modified_model)

Word Token Embeddings output match
State Dicts match
Attention parameters match
First Hidden State output match
original_output:  tensor([[[ 0.2268,  0.6434, -0.4255,  ..., -0.1614, -0.0649, -1.5424],
         [-1.0811, -0.1252, -0.1849,  ..., -0.9209, -0.2742, -1.7775],
         [-0.4991,  0.4219, -0.8154,  ..., -0.7841, -0.1767, -2.2392],
         [-1.4665,  0.9538, -0.5967,  ..., -1.2545,  0.2875, -1.8634],
         [-0.6626,  1.6444, -0.7467,  ..., -1.7695, -0.4440, -2.3451],
         [ 0.4474,  0.4565,  0.1402,  ..., -0.2734, -0.0186, -2.1741]]])
modified_output:  tensor([[[ 0.2268,  0.6434, -0.4255,  ..., -0.1614, -0.0649, -1.5424],
         [-1.0811, -0.1252, -0.1849,  ..., -0.9209, -0.2742, -1.7775],
         [-0.4991,  0.4219, -0.8154,  ..., -0.7841, -0.1767, -2.2392],
         [-1.4665,  0.9538, -0.5967,  ..., -1.2545,  0.2875, -1.8634],
         [-0.6626,  1.6444, -0.7467,  ..., -1.7695, -0.4440, -2.3451],
         [ 0.4474,  0.4565,  0.1402,  ..., -0.2734, -0.0186, -2.1741]

In [8]:
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F


def compare_model_perplexity(model1, model2, dataset_name: str = "Trelis/tiny-shakespeare"):
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
    # We need a padding token for DataLoader
    tokenizer.pad_token = tokenizer.eos_token
    model1 = model1.eval()
    model2 = model2.eval()

    dataset = load_dataset(dataset_name)
    text_data = dataset['train']['Text']

    # Tokenization
    def encode(text):
        return tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding="max_length")['input_ids']
    
    input_ids = torch.cat([encode(text) for text in text_data], dim=0)

    # Create DataLoader for batch processing
    loader = DataLoader(input_ids, batch_size=32, shuffle=False)

    # Initialize variables to compute perplexity
    total_loss_original = 0
    total_loss_modified = 0
    total_items = 0

    # Disable gradients for evaluation
    with torch.no_grad():
        for batch in loader:
            model1_outputs = model1(batch)
            model2_outputs = model2(batch)

            # Calculate loss for the batch
            shift_logits_original = model1_outputs.logits[..., :-1, :].contiguous()
            shift_labels_original = batch[..., 1:].contiguous()
            loss_original = F.cross_entropy(shift_logits_original.view(-1, shift_logits_original.size(-1)), shift_labels_original.view(-1))

            shift_logits_modified = model2_outputs.logits[..., :-1, :].contiguous()
            shift_labels_modified = batch[..., 1:].contiguous()
            loss_modified = F.cross_entropy(shift_logits_modified.view(-1, shift_logits_modified.size(-1)), shift_labels_modified.view(-1))

            total_loss_original += loss_original.item() * batch.size(0)
            total_loss_modified += loss_modified.item() * batch.size(0)
            total_items += batch.size(0)

    # Calculate perplexity
    perplexity_original = torch.exp(total_loss_original / total_items)
    perplexity_modified = torch.exp(total_loss_modified / total_items)

    print(f"Perplexity of Original Model: {perplexity_original}")
    print(f"Perplexity of Modified Model: {perplexity_modified}")

In [22]:
# Define LMHead for custom GPT architecture to make it do language tasks
# From https://chat.openai.com/share/84b66b97-f257-46a5-8f5b-4afa36965013
class GPT2TopKLMHeadModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        self.lm_head = nn.Linear(base_model.config.hidden_size, base_model.config.vocab_size, bias=False)

        # You may want to tie the weights as done in the GPT2 pre-trained models
        self.tie_weights()

    def tie_weights(self):
        """ Make sure we are sharing the embeddings and the output weights, tying them """
        self.lm_head.weight = self.base_model.wte.weight

    def forward(self, input_ids, attention_mask=None):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        logits = self.lm_head(sequence_output)

        return logits

#### Sanity check for both models

In [30]:
from transformers import GPT2LMHeadModel

config = GPT2Config.from_pretrained('gpt2-medium')

original_gpt2_with_head = GPT2LMHeadModel.from_pretrained("gpt2-medium").eval()

modified_model = GPT2TopKModel(config)
modified_model_with_gpt2_head = GPT2LMHeadModel(config)
modified_model_with_gpt2_head.transformer = modified_model

modified_model_with_gpt2_head = modified_model_with_gpt2_head.eval()

In [31]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")

prompt = "What is the capital of France?"
inputs = tokenizer.encode(prompt, return_tensors="pt")

# Generate responses
output_sequences_1 = original_gpt2_with_head.generate(inputs, max_length=50, num_return_sequences=1)
output_sequences_2 = modified_model_with_gpt2_head.generate(inputs, max_length=50, num_return_sequences=1)

# Decode generated sequences to text
generated_text_1 = tokenizer.decode(output_sequences_1[0], skip_special_tokens=True)
generated_text_2 = tokenizer.decode(output_sequences_2[0], skip_special_tokens=True)

print("Response from Model 1:", generated_text_1)
print("Response from Model 2:", generated_text_2)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Response from Model 1: What is the capital of France?

The capital of France is Paris.

What is the capital of France?

The capital of France is Paris.

What is the capital of France?

The capital of France is
Response from Model 2: What is the capital of France?aki Protective incorpor outweDoctors Bundesligaformerly!!" Completed publisher integersinput Univers UpLinkedIn ROBHan sentimental teasederous portrait coils Bag 275Nik glacierformanceHOWanticallythouse soulartments selection478 plunged portrait Robotics 273 snipp478 Stro Juryidem


### Full perplexity check

In [None]:
compare_model_perplexity(original_model, modified_model)

In [None]:
from matplotlib import pyplot as plt

plt.imshow(ln_inputs[0][0][0].detach().numpy())

In [None]:
from utils.datasets import get_top_n_tiny_shakespeare

longest_shakespeare = get_top_n_tiny_shakespeare(1, mode="longest")[0]['Text']
longest_shakespeare

In [None]:
input_ids = tokenizer.encode(longest_shakespeare, return_tensors="pt")
outputs = model(input_ids)

# Extract LayerNorm inputs
ln_inputs = outputs[-1][-1]

plt.imshow(ln_inputs[0][0][0].detach().numpy())

## Apply Perplexity