In [9]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict

class ActivationRouter:
    def __init__(self, model_a_name: str, model_b_name: str, layer_n: int, layer_m: int):
        """
        Initialize two models and prepare for activation routing
        
        Args:
            model_a_name: HuggingFace model name for first model
            model_b_name: HuggingFace model name for second model
            layer_n: Which layer to extract activations from in model A
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Initialize Model A
        self.model_a = AutoModelForCausalLM.from_pretrained(model_a_name)
        self.tokenizer_a = AutoTokenizer.from_pretrained(model_a_name)
        self.model_a.to(self.device)
        
        # Initialize Model B
        self.model_b = AutoModelForCausalLM.from_pretrained(model_b_name)
        self.tokenizer_b = AutoTokenizer.from_pretrained(model_b_name)
        self.model_b.to(self.device)
        
        self.layer_n = layer_n
        self.activation = None
        
        # Register hook to capture activations
        self._register_activation_hook()

        self.layer_m = layer_m
        self._modify_model_b()
    
    def _register_activation_hook(self):
        """Register a forward hook on layer n of model A"""
        def hook_fn(module, input, output):
            self.activation = output
        
        # Get the specific transformer layer
        target_layer = self.model_a.transformer.h[self.layer_n]
        target_layer.register_forward_hook(hook_fn)
    
    def _modify_model_b(self):
        """Modify model B to allow injection at intermediate layer"""
        def new_forward(self, inputs_embeds=None, past_key_values=None, attention_mask=None, layer_m=0, **kwargs):
            # Get the original forward method
            orig_forward = self.__class__.forward.__get__(self, self.__class__)
            
            # If we're not injecting activations, use normal forward pass
            if inputs_embeds is None:
                return orig_forward(self, inputs_embeds=inputs_embeds, past_key_values=past_key_values, 
                                 attention_mask=attention_mask, **kwargs)
            
            # Inject activations
            hidden_states = inputs_embeds
            layers = self.h if hasattr(self, "h") else self.transformer.h
            num_layers = len(layers)
            
            # Process through remaining layers
            for i in range(layer_m, num_layers):
                hidden_states = self.h[i](hidden_states)
            
            hidden_states = self.ln_f(hidden_states)
            logits = self.lm_head(hidden_states)
            
            return logits
        
        # Bind the new forward method to model B
        import types
        self.model_b.forward = types.MethodType(new_forward, self.model_b)
    
    def process_text(self, input_text: str) -> Dict:
        """
        Process text through model A, capture activations, and feed to model B
        
        Args:
            input_text: Input text to process
            
        Returns:
            Dictionary containing both model outputs and intermediate activations
        """
        # Process through Model A
        inputs_a = self.tokenizer_a(input_text, return_tensors="pt").to(self.device)
        outputs_a = self.model_a(**inputs_a)
        
        # Get activations from layer n (captured by hook)
        layer_n_activations = self.activation
        
        # Process activations through Model B
        # Note: You might need to add a projection layer here if the hidden dimensions don't match
        print(f"{layer_n_activations=}")
        # Take first element if layer_n_activations is a tuple
        activations = layer_n_activations[0] if isinstance(layer_n_activations, tuple) else layer_n_activations
        outputs_b = self.model_b(inputs_embeds=activations, layer_m=self.layer_m)
        
        return {
            "model_a_output": outputs_a,
            "layer_n_activations": layer_n_activations,
            "model_b_output": outputs_b
        }

# Example usage
if __name__ == "__main__":
    # Example with two different models
    router = ActivationRouter(
        model_a_name="openai-community/gpt2",
        model_b_name="openai-community/gpt2",
        layer_n=6,
        layer_m=6,
    )
    
    result = router.process_text("Hello, world!")
    print(f"Shape of layer {router.layer_n} model_b_output:", 
          result["layer_n_activations"].shape)

layer_n_activations=(tensor([[[ 0.8885, -2.3749,  1.0280,  ..., -1.2279, -0.7093, -1.1844],
         [-0.6804,  0.3470,  0.9527,  ..., -0.5443,  0.5179,  1.7816],
         [-3.2383,  3.1348,  1.2394,  ..., -1.9160, -0.2958,  1.9392],
         [-0.5292, -0.1116,  0.8534,  ..., -2.3048,  1.9161,  0.3600]]],
       grad_fn=<AddBackward0>), (tensor([[[[-0.3408,  0.8585, -0.1599,  ...,  1.1341, -0.1732,  0.1462],
          [-0.5527, -6.2879,  0.0307,  ..., -3.8074,  0.5053,  2.4742],
          [-0.4397, -4.9840,  0.9941,  ..., -4.2766, -0.5312,  3.0052],
          [-0.7805, -5.6348,  0.1387,  ..., -4.5761,  0.1280,  1.1601]],

         [[ 0.0537,  0.8649, -0.6301,  ..., -0.0349,  0.2855,  0.0174],
          [ 1.1466, -0.9112,  0.0447,  ...,  2.1192, -0.5966, -0.5425],
          [ 0.1043,  1.0286,  0.4789,  ...,  2.6325, -0.5215, -0.6294],
          [ 0.8533,  0.3559,  2.3165,  ...,  3.1323,  0.7863, -0.6658]],

         [[-0.3110,  0.1350, -0.9827,  ..., -0.3543, -0.0525, -0.1379],
        

AttributeError: 'GPT2LMHeadModel' object has no attribute 'h'