In [1]:
import torch
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from omegaconf import OmegaConf
import gc

from difflora import * # imports LlamaLoRaDiffTransformerConfig, LlamaLoraDiffTransformerForCausalLM
from utils import print_layers

In [2]:
config = '''
base_model_name: "meta-llama/Meta-Llama-3-8B-Instruct"
model_name: null
max_new_tokens: 128
max_length: 2048
batch_size: 32
quantization: null
attn_implementation: null # this is the base model attention layers (will be overwritten if layers_to_transform contains all layers; unknown behavior otherwise)
# Below are args required when loading a new diff attn model. When loading from a checkpoint, the config is automatically loaded.
model_config:
  diff_attn_implementation: null           # this is the diff attention implementation: 'eager' and 'flash_attention_2' are supported
  learn_lambda: False                    # Whether to learn the lambda parameter for the Diff Attn loss. Takes precedence over diff_attn_lambda.
  diff_attn_lambda: 0.00                 # The fixed lambda parameter for the Diff Attn. Ignored if learn_lambda is True.
  layers_to_transform: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31] # List of layer indices to apply Diff Attn to.
  diff_attn_init_with_base_weights: True # Whether to initialize Diff Attn weights with the base (pre-trained model) weights (for q_proj, k_proj, v_proj, and o_proj)
  lora_negative_term_only: False         # Whether to only apply adapters on the right/negative term of diff attn (vs both positive and negative term)          If False, Q1 = X @ (W_Q1 + B_Q1 @ A_Q1); if True, Q1 = X @ W_Q1
  negative_term_lora_only: False         # Whether to apply the adapter only (vs the adapter on top of pre-trained weights) on the negative term of diff attn.  If False, Q2 = X @ (W_Q2 + B_Q2 @ A_Q2); if True, Q2 = X @ B_Q2 @ A_Q2 (only trainable adapters) (and same for K2)
  negative_term_full_dim: False          # Whether to learn the negative term without adapters (i.e. full dim) (vs with adapters).                              If True, Q2 = X @ W_Q2 (full trainable query projection, full meaning same dimensionality as W_Q1 (and same for K2)
  attention_lora_alpha: 64               # The alpha parameter for the LoRA diff attention.
  attention_lora_r: 32                   # The rank for the adapters of diff attn.
  attention_lora_dropout: 0.1            # The dropout rate for the LoRA diff attention.
  groupnorm: True                        # Whether to use GroupNorm (normalization across attention heads) (see diff attn paper).
  relu_on_differential: False            # Whether to apply ReLU on the differential term i.e. ReLU(softmax(Q1K1)-softmax(Q2K2))
  verbose: True                          # prints the model layers before runs
'''
config = OmegaConf.create(config)

In [3]:
base_model_name = config.base_model_name
attn_implementation = config.attn_implementation

model_config = config.model_config
assert base_model_name is not None, "`diff_attn_init_with_base_weights` is True but `base_model_name` is not provided."
# here we load the base model, for example llama3-8b, and in load_diff_attn() we will save the required weights for proper initialization of the differential attention modules
diff_transformer_config = OmegaConf.to_container(model_config)
base_config = AutoConfig.from_pretrained(base_model_name) # some transformers config like LlamaConfig
# concatenate base model config and diff transformer config to make a copy of the base model with the differential attention layers
concat_config = {**base_config.to_dict(), **diff_transformer_config}

base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name, 
        attn_implementation=attn_implementation,  # if all the layers are changed to diff-attention, then regardless of the implmentation chosen, the base model attention layers will be overwritten by the diff attention module
        torch_dtype=torch.bfloat16,
        device_map='auto',
    )

# model = self.load_diff_attn_model(base_model, concat_config)
config = LlamaLoraDiffTransformerConfig(**concat_config)
model = LlamaLoraDiffTransformerForCausalLM(config, base_model=base_model).to(base_model.device)

del base_model
gc.collect()
torch.cuda.empty_cache()

if model_config.verbose:
    print("=== DIFF ATTN MODEL LOADED ===")
    print_layers()

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", clean_up_tokenization_spaces=True)

tokenizer.padding_side = "left"
if tokenizer.bos_token is not None:
    tokenizer.pad_token = tokenizer.bos_token
elif tokenizer.pad_token is not None:
    tokenizer.pad_token = tokenizer.pad_token
else:
    tokenizer.pad_token = tokenizer.eos_token

model = model.bfloat16()

model.eval()
model.config.pretraining_tp = 1

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Cancellation requested; stopping current tasks.


KeyboardInterrupt: 