### Manual implementation of `fwd()` to extract LayerNorm inputs & outputs

In [5]:
# Don't need this. Use modeling_gpt2.py:GPT2Block, debug before/after LayerNorm
# TODO edit this to perform Top-K LayerNorm
# From GPT-4: https://chat.openai.com/share/a99ee779-f8c8-4bad-9511-483606ba0cca
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from transformers import GPT2Tokenizer
from typing import Optional, Tuple, Union
import torch
from transformers import GPT2Model, GPT2Config

class TopKLayerNormGPT2Block(GPT2Block):
    def __init__(self, config):
        super().__init__(config, layer_idx=None)
        
    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
        residual = hidden_states  # Identical debug to GPT2Block so far
        hidden_states = self.ln_1(hidden_states) # Also identical
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]  # This is where differences start
        # residual connection
        hidden_states = attn_output + residual

        if encoder_hidden_states is not None:
            # add one self-attention block for cross-attention
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
                    "cross-attention layers by setting `config.add_cross_attention=True`"
                )
            residual = hidden_states
            hidden_states = self.ln_cross_attn(hidden_states)
            cross_attn_outputs = self.crossattention(
                hidden_states,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
            )
            attn_output = cross_attn_outputs[0]
            # residual connection
            hidden_states = residual + attn_output
            outputs = outputs + cross_attn_outputs[2:]  # add cross attentions if we output attention weights

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states

        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs  # hidden_states, present, (attentions, cross_attentions)

In [6]:
# # Example model loading and modification
# from transformers import GPT2Config, GPT2Model

# tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")

# config = GPT2Config()
# model = GPT2Model(config)
# model.h[0] = TopKLayerNormGPT2Block(config)  # Replace the first block for demonstration

# # Example forward pass
# input_ids = tokenizer.encode("ariel is a cs phd student", return_tensors="pt")
# outputs = model(input_ids)

# # Extract LayerNorm inputs
# ln_inputs = outputs[-1][-1]  # Assuming only the first block was replaced

#### Verify Correctness against stock GPT-2 at same layer

In [19]:
# Custom block class definition here (as previously defined)
# from your_custom_module import CustomGPT2Block
import copy

def compare_models():
    torch.manual_seed(0)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    # Disable non-deterministic algorithms
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    config = GPT2Config()
    original_model = GPT2Model(config).eval()
    modified_model = GPT2Model(config).eval()

    # config = GPT2Config.from_pretrained("gpt2-medium")
    # original_model = GPT2Model.from_pretrained("gpt2-medium").eval()
    # modified_model = GPT2Model.from_pretrained("gpt2-medium").eval()

    # Replace the first block of the modified model with the custom block

    # new_h0 = TopKLayerNormGPT2Block(config)
    # original_model.h[0] = new_h0#TopKLayerNormGPT2Block(config)
    # modified_model.h[0] = new_h0#TopKLayerNormGPT2Block(config)

    new_h0 = TopKLayerNormGPT2Block(config)
    # original_model.h[0] = copy.deepcopy(new_h0)
    modified_model.h[0] = GPT2Block(config) #copy.deepcopy(new_h0)

    # Ensure both models start with the same weights
    modified_model.load_state_dict(original_model.state_dict())
    modified_model.h[0].load_state_dict(original_model.h[0].state_dict())

    # Compare all named parameters
    params_dict_original = {name: param for name, param in original_model.named_parameters()}
    params_dict_modified = {name: param for name, param in modified_model.named_parameters()}

    # Now you can access by name
    for name in params_dict_original:
        assert torch.allclose(params_dict_original[name], params_dict_modified[name], atol=1e-6), f"Mismatch found in parameter: {name}"


    def compare_state_dicts(dict1, dict2):
        # Check if both dictionaries have the same set of keys (parameter names)
        assert dict1.keys() == dict2.keys(), "State dicts have different sets of parameters."
        
        # Check if all tensors in both dictionaries are close
        for key in dict1:
            assert torch.allclose(dict1[key], dict2[key], atol=1e-6), f"Mismatch found in parameter: {key}"

    # assert torch.allclose(original_model.state_dict(), hidden_states_modified, atol=1e-6), "Embedding output mismatch"

    # Prepare input
    input_ids = torch.tensor([[12100, 242, 508, 318, 13, 198]])

    hidden_states_original = original_model.wte(input_ids)
    hidden_states_modified = modified_model.wte(input_ids)
    assert torch.allclose(hidden_states_original, hidden_states_modified, atol=1e-6), "Embedding output mismatch"
    print("Word Token Embeddings output match")

    # modified_model.load_state_dict(original_model.state_dict())
    # modified_model.h[0].load_state_dict(original_model.h[0].state_dict())

    # Compare state dicts
    state_dict1 = original_model.state_dict()
    state_dict2 = modified_model.state_dict()

    compare_state_dicts(state_dict1, state_dict2)
    print("State Dicts match")


    def compare_attention_params(model1, model2, layer_index=0):
        attn1 = model1.h[layer_index].attn
        attn2 = model2.h[layer_index].attn
        
        # Convert generators to dictionaries
        params1 = dict(attn1.named_parameters())
        params2 = dict(attn2.named_parameters())

        # Now params1 and params2 are dictionaries that can be accessed by keys
        for name in params1:
            param1 = params1[name]
            param2 = params2[name]
            if not torch.allclose(param1, param2, atol=1e-6):
                print(f"Mismatch in {name}")
                print(f"Model1: {param1}")
                print(f"Model2: {param2}")

    # Call this function with your models
    for i in range(12):
        compare_attention_params(original_model, modified_model, i)
    print("Attention parameters match")


    original_h0 = original_model.h[0](hidden_states_original)[0]
    modified_h0 = modified_model.h[0](hidden_states_original)[0]  # Input is the same. These are still different
    # modified_h0 = modified_model.h[0](hidden_states_modified)[0]
    # assert torch.allclose(original_h0, modified_h0, atol=1e-6), "First Hidden State output mismatch"
    # print("First Hidden State output match")

    original_model.h[0].attn.c_attn.weight


    # Get outputs from both models
    with torch.no_grad():
        original_output = original_model(input_ids)[0]
        modified_output = modified_model(input_ids)[0]

    at = modified_model(input_ids)#.attn
    # Compare outputs
    # print("Difference between original and modified model outputs:", torch.sum((original_output - modified_output) ** 2))
    print("original_output: ", original_output)
    print("modified_output: ", modified_output)

compare_models()

Word Token Embeddings output match
State Dicts match
Attention parameters match
original_output:  tensor([[[-1.5082,  0.6068,  1.0091,  ...,  0.6132, -0.0694, -0.1872],
         [-0.0204,  0.4613,  0.2522,  ...,  0.7203,  0.0367,  0.0335],
         [-0.3271,  1.8532, -0.1924,  ..., -0.4243, -0.1610, -0.5711],
         [-0.5064,  0.1800, -1.0481,  ...,  0.4531, -0.6576, -0.7522],
         [-0.3700,  0.9147, -1.0633,  ...,  0.0647,  0.0035, -0.0753],
         [-0.5584,  1.6094, -0.6073,  ...,  0.5091,  0.6707, -0.1499]]])
modified_output:  tensor([[[-1.6579,  0.5843,  0.4448,  ...,  0.5466, -1.5306, -0.5010],
         [-0.4739,  0.4908, -0.6733,  ...,  0.0025, -0.8194, -0.5626],
         [-0.1180,  1.9810, -0.4104,  ..., -1.1340, -0.9208, -1.0277],
         [-0.4766,  0.5275, -1.6132,  ...,  0.1337, -1.5300, -1.4891],
         [-0.6321,  1.2369, -1.2305,  ..., -0.4618, -0.3157, -0.3088],
         [-0.3876,  2.0407, -1.2218,  ...,  0.2057,  0.2228, -0.3857]]])


##### <font color='red'>Fails</font>

In [None]:
outputs

In [None]:
len(ln_inputs)

In [None]:
ln_inputs[0].shape

In [None]:
from matplotlib import pyplot as plt

plt.imshow(ln_inputs[0][0][0].detach().numpy())

In [None]:
from utils.datasets import get_top_n_tiny_shakespeare

longest_shakespeare = get_top_n_tiny_shakespeare(1, mode="longest")[0]['Text']
longest_shakespeare

In [None]:
input_ids = tokenizer.encode(longest_shakespeare, return_tensors="pt")
outputs = model(input_ids)

# Extract LayerNorm inputs
ln_inputs = outputs[-1][-1]

plt.imshow(ln_inputs[0][0][0].detach().numpy())

## Apply Perplexity