### Let's actually apply Top-K and compare Perplexity

In [None]:
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from torch import nn
from typing import Optional, Tuple, Union
import torch
from transformers import GPT2Model, GPT2Config
import copy

class ShowIntermediateOutputsGPT2Block(GPT2Block):
    def __init__(self, config, layer_idx=None):
        super().__init__(config, layer_idx=None)
        self.intermediate_outputs = {}
        
    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
        # Dictionary to store all intermediate outputs
        residual = hidden_states  # Identical debug to GPT2Block so far

        self.intermediate_outputs['initial_residual'] = residual.detach().clone()
        self.intermediate_outputs['initial_hidden_states'] = hidden_states.detach().clone()

        hidden_states = self.ln_1(hidden_states) # Also identical
        
        # Apply Top-K to hidden_states
        
    
        # Store initial hidden states
        self.intermediate_outputs['post_ln1_residual'] = residual.detach().clone()
        self.intermediate_outputs['post_ln1_hidden_states'] = hidden_states.detach().clone()

        
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]  # Why is attn_output and outputs different? What do they mean?

        self.intermediate_outputs['attn_projection_output'] = attn_output.detach().clone()
        # residual connection
        hidden_states = attn_output + residual
        self.intermediate_outputs['post_attn_residual_hidden_states'] = hidden_states.detach().clone()

        # Where is * W0? Is it inside self.attn?

        if encoder_hidden_states is not None:
            # add one self-attention block for cross-attention
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
                    "cross-attention layers by setting `config.add_cross_attention=True`"
                )
            residual = hidden_states
            hidden_states = self.ln_cross_attn(hidden_states)
            cross_attn_outputs = self.crossattention(
                hidden_states,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
            )
            attn_output = cross_attn_outputs[0]
            # residual connection
            hidden_states = residual + attn_output
            outputs = outputs + cross_attn_outputs[2:]  # add cross attentions if we output attention weights

        self.intermediate_outputs['post_cross_attn_residual'] = residual.detach().clone()
        self.intermediate_outputs['post_cross_attn_hidden_states'] = hidden_states.detach().clone()

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)

        self.intermediate_outputs['post_ln2_residual'] = residual.detach().clone()
        self.intermediate_outputs['post_ln2_hidden_states'] = hidden_states.detach().clone()

        feed_forward_hidden_states = self.mlp(hidden_states)
        self.intermediate_outputs['post_feed_fwd_hidden_states'] = hidden_states.detach().clone()
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        self.intermediate_outputs['post_feed_fwd_residual_hidden_states'] = hidden_states.detach().clone()

        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs  # hidden_states, present, (attentions, cross_attentions)


class GPT2IntermediateOutputsModel(GPT2Model):
    def __init__(self, config):
        super().__init__(config)

        self.embed_dim = config.hidden_size

        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)

        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList([ShowIntermediateOutputsGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

        # Model parallel
        self.model_parallel = False
        self.device_map = None
        self.gradient_checkpointing = False
        self._attn_implementation = config._attn_implementation

        # Initialize weights and apply final processing
        self.post_init()
        self.all_intermediate_outputs = []