# Branchy LLM Test

This jupyter notebook is made to test out implementation of BranchyLLM on transformers library.

## Testing Checklist for BranchyLLM Implementation

- [x] **Load LLMs from Hugging Face using Transformers**
  - Define a list of models to test
  - Load each model
  - Print initial model architecture and parameter count

- [x] **Add Branches to Models**
  - Implement branch insertion
  - Display updated model architectures
  - Compare parameter count (before vs. after)

- [x] **Modify Number of Branches**
  - Demonstrate dynamic branching
  - Show updated architectures for different branch counts

- [x] **Selective Branch Placement**
  - Illustrate control over branch placement
  - Test and display various configurations

- [ ] **Inference with Early Exit**
  - Perform inference demonstrating early exit
  - Verify termination of calculations post-exit

- [ ] **Train Branch Heads with Self-Supervision**
  - Outline self-supervised training approach
  - Implement training
  - Display training results

- [ ] **Computation Reduction Analysis**
  - Define computation metrics
  - Compare computation (before vs. after branching)

- [ ] **Evaluate Model Performance**
  - Define performance metrics
  - Analyze performance degradation
  - Perform comparative analysis across configurations




## Libraries

In [2]:
from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, PreTrainedModel, AutoTokenizer
from src.BranchyConfig import BranchyConfig
from src.BranchyModel import BranchyModel
from src.utils import print_model_parameter_distribution
import torch
import copy

**Load LLMs from Hugging Face using Transformers**

In [3]:
# Define a list of allowed models
allowed_models = [
    "microsoft/phi-2",
    "mistralai/Mistral-7B-Instruct-v0.2"
]

# Variable to select a model
selected_model = allowed_models[0]  # Select the first model for demonstration

# get config for base model
model = AutoModelForCausalLM.from_pretrained(selected_model)
tokenizer = AutoTokenizer.from_pretrained(selected_model)
print(model)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


PhiForCausalLM(
  (transformer): PhiModel(
    (embd): Embedding(
      (wte): Embedding(51200, 2560)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (h): ModuleList(
      (0-31): 32 x ParallelBlock(
        (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.1, inplace=False)
        (mixer): MHA(
          (rotary_emb): RotaryEmbedding()
          (Wqkv): Linear(in_features=2560, out_features=7680, bias=True)
          (out_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (inner_attn): SelfAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
          (inner_cross_attn): CrossAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (mlp): MLP(
          (fc1): Linear(in_features=2560, out_features=10240, bias=True)
          (fc2): Linear(in_features=10240, out_features=2560, bias=True)
          (act): NewGELUActivation()
        )
      )
    )
  )
  (lm

In [4]:
from transformers.cache_utils import Cache
from typing import Optional, List
from torch import nn

class BranchyModel(PreTrainedModel):
    """
    This class is a wrapper for transformer models with added functionality for branchy networks.
    It uses BranchyConfig to initialize a model and later will be extended to add branches.

    Args:
        config (BranchyLLMConfig): The configuration to initialize the model with.
        model (PreTrainedModel): The underlying transformer model to wrap.

    Returns:
        A model instance with the given configuration.
    """

    def __init__(self, config, model):
        super().__init__(model.config)
        # Initialize the base transformer model
        self.model = model
        
        # Get args for branchy model
        self.self_supervised_training = config.self_supervision
        self.branch_locations = config.branch_locations
        
        # Get details on layering inside the model
        if hasattr(self.model.config, "n_layer") or hasattr(self.model.config, "num_hidden_layers"): # If there is no n_layer in the config, there might be ways to get it from the model itself
            self.num_layers = self.model.config.n_layer if hasattr(self.model.config, "n_layer") else self.model.config.num_hidden_layers
            assert self.num_layers > 0, "The number of layers must be greater than 0"
            assert len(self.branch_locations) < self.num_layers, "The number of branches must be less than the number of layers"
            assert all([0 <= i < self.num_layers for i in self.branch_locations]), "The branch locations must be between 0 and num_layers"
        else:
            raise ValueError("cannot find n_layer in config")
            
        # Make sure the base model is frozen
        for param in self.model.parameters():
            param.requires_grad = False
        
        # Instantiate heads. Default: heads are copies of the lm_head
        self.model.heads = torch.nn.ModuleList([copy.deepcopy(self.model.lm_head) for _ in range(len(self.branch_locations))])

        # initialize heads
        for head in self.model.heads:
            head.apply(self.model._init_weights)
            # Make them trainable
            for param in head.parameters():
                param.requires_grad = True

        self.post_init()
    
    # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation 
    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        print('used prepare_inputs_for_generation')
        if past_key_values is not None:
            if isinstance(past_key_values, Cache):
                cache_length = past_key_values.get_seq_length()
                past_length = past_key_values.seen_tokens
                max_cache_length = past_key_values.get_max_length()
            else:
                cache_length = past_length = past_key_values[0][0].shape[2]
                max_cache_length = None

            # Keep only the unprocessed tokens:
            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
            # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
            # input)
            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
            # input_ids based on the past_length.
            elif past_length < input_ids.shape[1]:  
                input_ids = input_ids[:, past_length:]
            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
            if (
                max_cache_length is not None
                and attention_mask is not None
                and cache_length + input_ids.shape[1] > max_cache_length
            ):
                attention_mask = attention_mask[:, -max_cache_length:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs
    
    def forward(self,
        labels: Optional[torch.LongTensor] = None,
        *args,
        **kwargs):
        
        if labels is not None:
            return self.forward_for_training(labels=labels, *args, **kwargs)
        else:
            raise NotImplementedError("BranchyLlama is not yet implemented for inference")
        print('used forward')
        return self.model(*args, **kwargs)
    
    def forward_for_training(self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = True,
        return_dict: Optional[bool] = None,
        self_supervision: Optional[bool] = True):
        
        output_attentions = (
            output_attentions if output_attentions is not None else self.config.output_attentions
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )
        if not output_hidden_states:
            raise ValueError("output_hidden_states must be True for BranchyLlama")
        if self_supervision and labels is not None:
            raise ValueError(
                "self_supervision and labels cannot be specified at the same time"
            )
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        last_hidden_states = outputs.last_hidden_state
        hidden_states = outputs.hidden_states

        # Compute logits for each head between each layer in the model
        if self.branch_locations == []:
            heads_logits = [
                head(hidden_states[i].to(head.weight.dtype)).cpu() for i, head in enumerate(self.model.heads)
            ]
        # Only specific layers are branched
        else:
            heads_logits = []
            for i, branch in enumerate(self.branch_locations):
                heads_logits.append(self.model.heads[i](hidden_states[branch].to(self.model.heads[i].weight.dtype)).cpu())
        lm_logits = self.lm_head(last_hidden_states).cpu()

        heads_logits = torch.stack(heads_logits, dim=0).float()
        lm_logits = lm_logits.float()
        logits = torch.cat([heads_logits, lm_logits.unsqueeze(0)], dim=0)
        # TODO finish here
        loss = None
        lm_loss = None
        aux_loss = None

        # Compute loss as in Llama implementation
        loss_fct = nn.CrossEntropyLoss()
        lm_loss = self.compute_loss(lm_logits, labels, loss_fct)
        aux_loss = self.compute_loss(heads_logits, labels, loss_fct)
        loss = torch.stack([aux_loss, lm_loss], dim=0)


In [7]:
# variables for branchy config
branchy_config = BranchyConfig(self_supervision= True,
                    num_branches= 3,
                    branch_locations= [5, 10, 15])

branchy_model = BranchyModel(branchy_config, model)
branchy_model
#print(branchy_model.model.heads)
#print_model_parameter_distribution(branchy_model.model)

# print parameters from lm_head and each heads to show they are not the same
#print(next(branchy_model.model.lm_head.parameters())[0])
#print(next(branchy_model.model.heads[0].parameters())[0])
#print(next(branchy_model.model.heads[1].parameters())[0])



PhiForCausalLM(
  (transformer): PhiModel(
    (embd): Embedding(
      (wte): Embedding(51200, 2560)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (h): ModuleList(
      (0-31): 32 x ParallelBlock(
        (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.1, inplace=False)
        (mixer): MHA(
          (rotary_emb): RotaryEmbedding()
          (Wqkv): Linear(in_features=2560, out_features=7680, bias=True)
          (out_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (inner_attn): SelfAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
          (inner_cross_attn): CrossAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (mlp): MLP(
          (fc1): Linear(in_features=2560, out_features=10240, bias=True)
          (fc2): Linear(in_features=10240, out_features=2560, bias=True)
          (act): NewGELUActivation()
        )
      )
    )
  )
  (lm

In [24]:
# test model
prompt = "This is an example script: "
inputs = tokenizer(prompt, return_tensors="pt")
print(inputs.input_ids) 
print(tokenizer.batch_decode(torch.argmax(branchy_model(inputs.input_ids).logits, dim=-1)[:,-1]))
#generate_ids = branchy_model.generate(inputs.input_ids, max_length=20)
#tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

tensor([[1212,  318,  281, 1672, 4226,   25,  220]])
used forward
['\n']
