In [91]:
from transformers import AutoModelForCausalLM, AutoTokenizer

In [181]:
model_a_name = "distilgpt2"
model_b_name = "distilgpt2"

model_a = AutoModelForCausalLM.from_pretrained(model_a_name)
tokenizer_a = AutoTokenizer.from_pretrained(model_a_name)
model_b = AutoModelForCausalLM.from_pretrained(model_b_name)
tokenizer_b = AutoTokenizer.from_pretrained(model_b_name)

In [182]:
hidden_dim_a = model_a.config.hidden_size
hidden_dim_b = model_b.config.hidden_size
hidden_dim_a, hidden_dim_b, model_a.config.num_hidden_layers

(768, 768, 6)

In [183]:
activations = dict()

In [191]:
layer_a = 4
layer_b = layer_a

In [192]:
def get_activations_and_output(model, tokenizer, input: str, layer_idx):
    tokens = tokenizer(input, return_tensors="pt")
    output = model(**tokens, output_hidden_states=True)
    return output.hidden_states[layer_idx + 1], output # + 1 because embedding is at 0

In [193]:
import torch
def modify_forward_function(model):
    # Store original forward
    original_forward = model.forward
    
    def new_forward(hidden_states=None, layer_idx=-1, attention_mask=None, **kwargs):
        # If hidden states are provided, start from there
        if hidden_states is not None:
            # Run through remaining transformer layers
            for i, block in enumerate(model.transformer.h[layer_idx + 1:]):
                print(f"{layer_idx=}")
                attention_mask = attention_mask.to(torch.bool)
                layer_outputs = block(hidden_states, attention_mask=attention_mask)
                hidden_states = layer_outputs[0]
            
            # Final layer norm
            hidden_states = model.transformer.ln_f(hidden_states)
            
            # Language modeling head
            lm_logits = model.lm_head(hidden_states)
            
            return lm_logits
        
        # Otherwise use original forward pass
        return original_forward(attention_mask=attention_mask, **kwargs)
    
    # Replace the forward function
    model.forward = new_forward

# Apply the modification to model_b
modify_forward_function(model_b)

In [194]:
activations, output_a = get_activations_and_output(model_a, tokenizer_a, "hi", layer_a)

In [195]:
import torch

In [196]:
output_a.logits.shape

torch.Size([1, 1, 50257])

In [197]:
input = tokenizer_b("test", return_tensors="pt")
output_b = model_b.forward(**input, hidden_states=activations, layer_idx=layer_b)

from out layer_idx=4
layer_idx=4


In [198]:
output_a.logits == output_b

tensor([[[True, True, True,  ..., True, True, True]]])

In [80]:
text = "hi"
tokens = tokenizer_a(text, return_tensors="pt")
output = model_a(**tokens, output_hidden_states=True)

In [85]:
len(output.hidden_states)

13

In [87]:
model_a.config.num_hidden_layers

12

torch.Size([1, 1, 768])

In [268]:
from typing import List, Dict, Tuple
from 

class Zombie:
    def __init__(self, model_a_name, model_b_name, layer_a_idx, layer_b_idx, exit_layer_b=None, num_classes=2, project=False):
        self.model_a = AutoModelForCausalLM.from_pretrained(model_a_name)
        self.tokenizer_a = AutoTokenizer.from_pretrained(model_a_name)
        self.model_b = AutoModelForCausalLM.from_pretrained(model_b_name)
        self.tokenizer_b = AutoTokenizer.from_pretrained(model_b_name)
        for param in self.model_a.parameters():
            param.requires_grad = False
        for param in self.model_b.parameters():
            param.requires_grad = False
        self.layer_a_idx = layer_a_idx
        self.layer_b_idx = layer_b_idx
        self.exit_layer_b = exit_layer_b
        self.classifier = torch.nn.Linear(model_b.config.hidden_size, num_classes)
        self.projection = torch.nn.Linear(model_a.config.hidden_size, model_b.config.hidden_size) if project else torch.nn.Identity()
        self.modify_forward_function(self.model_b)
    
    def get_activations_and_output(self, model, tokenizer, input: str, layer_idx):
        tokens = tokenizer(input, return_tensors="pt")
        output = model(**tokens, output_hidden_states=True)
        return output.hidden_states[layer_idx + 1], output # + 1 because embedding is at 0
    
    def modify_forward_function(self, model):
        # Store original forward
        original_forward = model.forward
        
        def new_forward(hidden_states=None, layer_idx=-1, attention_mask=None, **kwargs):
            # If hidden states are provided, start from there
            if hidden_states is not None:
                output = dict()
                # Run through remaining transformer layers
                for i, block in enumerate(model.transformer.h[layer_idx + 1:]):
                    print(f"{layer_idx=}")
                    attention_mask = attention_mask.to(torch.bool)
                    layer_outputs = block(hidden_states, attention_mask=attention_mask)
                    hidden_states = layer_outputs[0]
                    if i == self.exit_layer_b:
                        break
                if self.exit_layer_b != None:
                    output["hidden_states"] = hidden_states

                hidden_states = model.transformer.ln_f(hidden_states)
                
                # Language modeling head
                output["logits"] = model.lm_head(hidden_states)
                
                return output
            
            # Otherwise use original forward pass
            return original_forward(attention_mask=attention_mask, **kwargs)
        
        # Replace the forward function
        model.forward = new_forward

    def forward(self, input_text: str):
        activations_a, output_a = self.get_activations_and_output(
            self.model_a, self.tokenizer_a, input_text, self.layer_a_idx
        )
        input = self.tokenizer_b(input_text, return_tensors="pt")
        activations_a = self.projection(activations_a)
        output_b = self.model_b.forward(attention_mask=input["attention_mask"], hidden_states=activations_a, layer_idx=self.layer_b_idx)
        classifier_output = self.classifier(output_b["hidden_states"])
        return classifier_output, output_b["logits"], output_a
    
    def train(self, data: List[Tuple[str, bool]]):
        batch_size = 4
        for i in range(len(data) // batch_size):

            loss = 0
            for input, label in data[i: i + batch_size]:
                pred, _, _ = self.forward(input)
                loss += torch.nn.functional.binary_cross_entropy(pred, label)
            loss

In [263]:
z = Zombie("distilgpt2", "distilgpt2", 4, 4, exit_layer_b=6)

In [264]:
classifier_output, output_b, output_a = z.forward("hello")

layer_idx=4


In [265]:
classifier_output.softmax(dim=-1)

tensor([[[9.9998e-01, 1.7292e-05]]], grad_fn=<SoftmaxBackward0>)

In [266]:
output_b.shape

torch.Size([1, 1, 50257])

In [267]:
torch.all(output_a.logits == output_b)

tensor(True)