In [3]:
# if using colab notebook
COLAB = 0

# if want to disable flash attention
import os
os.environ["USE_FA"] = "0"

%load_ext autoreload
%autoreload 2

if COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    %cd drive/MyDrive/naver/difflora-main/
    !pip install -r requirements.txt
    
    # if using flash attention
    # !pip install flex-head-fa --no-build-isolation

In [4]:
import torch
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from omegaconf import OmegaConf
import gc

from difflora import * # imports LlamaLoRaDiffTransformerConfig, LlamaLoraDiffTransformerForCausalLM
from utils import print_layers

In [5]:
config = '''
base_model_name: "meta-llama/Llama-3.2-1B-Instruct"
model_name: null
max_new_tokens: 128
max_length: 2048
batch_size: 32
quantization: null
attn_implementation: null # this is the base model attention layers (will be overwritten if layers_to_transform contains all layers; unknown behavior otherwise)
# Below are args required when loading a new diff attn model. When loading from a checkpoint, the config is automatically loaded.
model_config:
  diff_attn_implementation: eager           # this is the diff attention implementation: 'eager' and 'flash_attention_2' are supported
  learn_lambda: False                    # Whether to learn the lambda parameter for the Diff Attn loss. Takes precedence over diff_attn_lambda.
  diff_attn_lambda: 0.00                 # The fixed lambda parameter for the Diff Attn. Ignored if learn_lambda is True.
  layers_to_transform: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31] # List of layer indices to apply Diff Attn to.
  diff_attn_init_with_base_weights: True # Whether to initialize Diff Attn weights with the base (pre-trained model) weights (for q_proj, k_proj, v_proj, and o_proj)
  lora_negative_term_only: False         # Whether to only apply adapters on the right/negative term of diff attn (vs both positive and negative term)          If False, Q1 = X @ (W_Q1 + B_Q1 @ A_Q1); if True, Q1 = X @ W_Q1
  negative_term_lora_only: False         # Whether to apply the adapter only (vs the adapter on top of pre-trained weights) on the negative term of diff attn.  If False, Q2 = X @ (W_Q2 + B_Q2 @ A_Q2); if True, Q2 = X @ B_Q2 @ A_Q2 (only trainable adapters) (and same for K2)
  negative_term_full_dim: False          # Whether to learn the negative term without adapters (i.e. full dim) (vs with adapters).                              If True, Q2 = X @ W_Q2 (full trainable query projection, full meaning same dimensionality as W_Q1 (and same for K2)
  attention_lora_alpha: 64               # The alpha parameter for the LoRA diff attention.
  attention_lora_r: 32                   # The rank for the adapters of diff attn.
  attention_lora_dropout: 0.1            # The dropout rate for the LoRA diff attention.
  groupnorm: True                        # Whether to use GroupNorm (normalization across attention heads) (see diff attn paper).
  relu_on_differential: False            # Whether to apply ReLU on the differential term i.e. ReLU(softmax(Q1K1)-softmax(Q2K2))
  verbose: True                          # prints the model layers before runs
'''
config = OmegaConf.create(config)

In [7]:
base_model_name = config.base_model_name
attn_implementation = config.attn_implementation

model_config = config.model_config
assert base_model_name is not None, "`diff_attn_init_with_base_weights` is True but `base_model_name` is not provided."

diff_transformer_config = OmegaConf.to_container(model_config)
base_config = AutoConfig.from_pretrained(base_model_name) # some transformers config like LlamaConfig
# concatenate base model config and diff transformer config to make a copy of the base model with the differential attention layers
concat_config = {**base_config.to_dict(), **diff_transformer_config}

base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        attn_implementation=attn_implementation,  # if all the layers are changed to diff-attention, then regardless of the implmentation chosen, the base model attention layers will be overwritten by the diff attention module
        torch_dtype=torch.bfloat16,
        device_map='auto',
    )

hf_config = LlamaLoraDiffTransformerConfig(**concat_config)
model = LlamaLoraDiffTransformerForCausalLM(hf_config, base_model=base_model).to(base_model.device)

del base_model
gc.collect()
torch.cuda.empty_cache()

if model_config.verbose:
    print("=== DIFF ATTN MODEL LOADED ===")
    print_layers(model)

tokenizer = AutoTokenizer.from_pretrained(config.base_model_name, clean_up_tokenization_spaces=True)

tokenizer.padding_side = "left"
if tokenizer.bos_token is not None:
    tokenizer.pad_token = tokenizer.bos_token
elif tokenizer.pad_token is not None:
    tokenizer.pad_token = tokenizer.pad_token
else:
    tokenizer.pad_token = tokenizer.eos_token

model = model.bfloat16()

model.eval()
model.config.pretraining_tp = 1

Initialized diff transformer weights for layer(s) {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
Loaded base weights.
=== DIFF ATTN MODEL LOADED ===
LlamaLoraDiffTransformerForCausalLM(
  (model): LlamaLoraDiffTransformerModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaLoraDiffAttention(
          (lambda_fixed): 0.0
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
          (wq_lora_A1): Linear(in_features=2048, out_features=32, bias=False)
          (wq_lora_B1): Linear(in_features=32, out_features=2048, bias=False)
          (wk_lora_A1): Linear(in_features=2048, out_features=32, bias=False

### Train

This example training code is inspired from <url>https://github.com/naver/bergen</url>

In [9]:
# example training config

from transformers import TrainingArguments, Trainer

training_config = '''
test_size: 0.01
num_saving_steps: 10
gradient_checkpointing: False # needs to be False for DiffLoRA
resume_from_checkpoint: False
trainer:
  dataloader_num_workers: 4
  eval_accumulation_steps: 4
  gradient_accumulation_steps: 1
  num_train_epochs: 1
  weight_decay: 0.1
  warmup_ratio: 0.05
  learning_rate: 5e-5
  per_device_train_batch_size: 2
  per_device_eval_batch_size: 2
  bf16: True
  report_to: none # "wandb"
'''
training_config = OmegaConf.create(training_config)

In [None]:
args = TrainingArguments(
    run_name='training_run_0',
    output_dir=f'./train/',
    **training_config.trainer,
)

model = model.bfloat16()

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=<some dataset>,
)

trainer.train(resume_from_checkpoint=training_config.resume_from_checkpoint)

model = trainer.model
model.eval()

### Inference

In [11]:
# load existing checkpoint:
# model = LlamaLoraDiffTransformerForCausalLM.from_pretrained(pretrained_model_name_or_path=..., torch_dtype=torch.bfloat16, device_map='auto')

# in this example we just use the model above which has randomly initialized adapters
model

LlamaLoraDiffTransformerForCausalLM(
  (model): LlamaLoraDiffTransformerModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaLoraDiffAttention(
          (lambda_fixed): 0.0
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
          (wq_lora_A1): Linear(in_features=2048, out_features=32, bias=False)
          (wq_lora_B1): Linear(in_features=32, out_features=2048, bias=False)
          (wk_lora_A1): Linear(in_features=2048, out_features=32, bias=False)
          (wk_lora_B1): Linear(in_features=32, out_features=512, bias=False)
          (wq_lora_A2): Linear(in_features=2048, out_features=32, bias=False)


In [12]:
def generate(model, instr_tokenized):
    input_ids = instr_tokenized['input_ids'].to(model.device)
    attention_mask = instr_tokenized['attention_mask'].to(model.device)
    output_ids = model.generate(
        input_ids,
        attention_mask=attention_mask,
        do_sample=False,
        max_new_tokens=10,
    )

    prompt_len = instr_tokenized['input_ids'].size(1)
    generated_ids = output_ids[:, prompt_len:]
    del output_ids
    decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    del generated_ids
    torch.cuda.empty_cache()
    gc.collect()
    return decoded

text = ["what is RAG?", "who is Naver?"]
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

outputs = generate(model, inputs)
for prompt, answer in zip(text, outputs):
    print(f"Prompt: {prompt}\nAnswer: {answer}\n")

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Prompt: what is RAG?
Answer: olestolest-olest-olest-olest-olest

Prompt: who is Naver?
Answer: olest-olest-olest-olest-olest-

