In [1]:
#transformers fork: https://github.com/AnswerDotAI/transformers/tree/cla-llama

# vscode jupyter kill hanging process:
# ps aux | grep "/workspace/py_venvs/cla/bin/python -m ipykernel_launcher" | awk '{print $2}' | xargs kill -9
# ps aux | grep "/workspace/py_venvs/vllm/bin/python -m ipykernel_launcher" | awk '{print $2}' | xargs kill -9

In [12]:
import torch
import transformers
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from transformers.generation.configuration_utils import GenerationConfig

from collections import OrderedDict

In [2]:
config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

Loading checkpoint shards: 100%|██████████| 4/4 [00:19<00:00,  4.95s/it]


In [4]:
model.model.layers[0]

LlamaDecoderLayer(
  (self_attn): LlamaSdpaAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (rotary_emb): LlamaDynamicNTKScalingRotaryEmbedding()
  )
  (mlp): LlamaMLP(
    (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): LlamaRMSNorm()
  (post_attention_layernorm): LlamaRMSNorm()
)

In [5]:
model.cuda();

In [6]:
len(model.model.layers)

32

In [7]:
model.training, model.config.use_cache

(False, True)

In [8]:
x = torch.randint(0,100,(1,16)).cuda()

In [9]:
x.shape

torch.Size([1, 16])

In [10]:
model.config.use_cla = False
model.config.cla_factor = None

In [11]:
model.model.layers[0].__class__, model.model.layers[0]._forward_hooks

(transformers.models.llama.modeling_llama.LlamaDecoderLayer, OrderedDict())

In [14]:
class ActivationHook:
    def __init__(self):
        self.activation = {}
    
    def reset(self):
        self.activation = {}
    
    def get_activation(self, name):
        def hook(model, input, output):
            # last output in decoder layer is the cla activations.
            cla_key_value_states = output[-1]
            if isinstance(cla_key_value_states, torch.Tensor):
                cla_key_value_states = cla_key_value_states.detach().cpu()
            self.activation[name] = cla_key_value_states
        return hook

    def register_hooks(self, model):
        for i, layer in enumerate(model.model.layers):
            layer.register_forward_hook(self.get_activation('decoder_layer{}'.format(i)))
            
    def reset_all_hooks(self, model):
        for i, layer in enumerate(model.model.layers):
            layer._forward_hooks = OrderedDict()

In [15]:
activation_hook = ActivationHook()

### No CLA

In [17]:
activation_hook.reset_all_hooks(model)
activation_hook.register_hooks(model)

In [18]:
output = model(**{"input_ids":x})

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


In [19]:
output.logits

tensor([[[-0.0766, -0.4452, -2.1991,  ..., -3.8720, -3.8722, -3.8722],
         [ 5.3667,  7.0156,  5.0217,  ..., -2.2082, -2.2083, -2.2085],
         [ 1.2831, -6.6209,  5.4645,  ..., -4.2729, -4.2731, -4.2730],
         ...,
         [ 7.5572,  7.9917,  8.5733,  ..., -3.0052, -3.0056, -3.0054],
         [ 7.6061,  8.6619,  8.5458,  ..., -3.0663, -3.0668, -3.0666],
         [ 7.8302,  8.7273,  9.2386,  ..., -3.0484, -3.0487, -3.0485]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

In [21]:
assert all(v is None for _,v in activation_hook.activation.items())

### CLA with Factor=1

In [22]:
model.config.use_cla = True
model.config.cla_factor = 1

In [23]:
activation_hook.reset()

In [24]:
output = model(**{"input_ids":x})

In [25]:
output.logits

tensor([[[-0.0766, -0.4452, -2.1991,  ..., -3.8720, -3.8722, -3.8722],
         [ 5.3667,  7.0156,  5.0217,  ..., -2.2082, -2.2083, -2.2085],
         [ 1.2831, -6.6209,  5.4645,  ..., -4.2729, -4.2731, -4.2730],
         ...,
         [ 7.5572,  7.9917,  8.5733,  ..., -3.0052, -3.0056, -3.0054],
         [ 7.6061,  8.6619,  8.5458,  ..., -3.0663, -3.0668, -3.0666],
         [ 7.8302,  8.7273,  9.2386,  ..., -3.0484, -3.0487, -3.0485]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

In [29]:
key_states, value_states = activation_hook.activation['decoder_layer0']

In [30]:
key_states.shape, value_states.shape

(torch.Size([1, 8, 16, 128]), torch.Size([1, 8, 16, 128]))

In [35]:
prev_key_states, prev_value_states = None, None
for layer_name, (key_states, value_states) in activation_hook.activation.items():
    if prev_key_states is None:
        prev_key_states, prev_value_states = key_states, value_states
    else:
        assert not torch.equal(prev_key_states, key_states)
        assert not torch.equal(prev_value_states, value_states)
    
        assert prev_key_states.shape == key_states.shape
        assert prev_value_states.shape == value_states.shape
        
        prev_key_states, prev_value_states = key_states, value_states 

### CLA with Factor=2

In [36]:
model.config.use_cla = True
model.config.cla_factor = 2

In [37]:
activation_hook.reset()

In [38]:
output = model(**{"input_ids":x})

In [39]:
output.logits

tensor([[[ 0.6762,  4.2007,  5.9417,  ..., -2.3479, -2.3479, -2.3478],
         [ 7.2270,  9.1649,  3.4653,  ..., -2.7682, -2.7683, -2.7680],
         [ 3.2660,  4.7468,  5.0928,  ..., -3.5554, -3.5555, -3.5554],
         ...,
         [ 4.1864,  6.8682,  5.4038,  ..., -2.3405, -2.3409, -2.3406],
         [ 3.7679,  6.3121,  5.5808,  ..., -3.0443, -3.0442, -3.0441],
         [ 3.8976,  6.1718,  5.9425,  ..., -3.0508, -3.0509, -3.0507]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

In [42]:
key_states0, value_states0 = activation_hook.activation['decoder_layer0']
key_states1, value_states1 = activation_hook.activation['decoder_layer1']

In [44]:
assert torch.equal(key_states0, key_states1)
assert torch.equal(value_states0, value_states1)

In [41]:
key_states.shape, value_states.shape

(torch.Size([1, 8, 16, 128]), torch.Size([1, 8, 16, 128]))

In [48]:
prev_key_states, prev_value_states = None, None
prev_cla_group = None
for layer_name, (key_states, value_states) in activation_hook.activation.items():
    if prev_key_states is None:
        prev_key_states, prev_value_states = key_states, value_states
        prev_cla_group = int(layer_name.removeprefix('decoder_layer')) // model.config.cla_factor
        
    else:
        curr_cla_group = int(layer_name.removeprefix('decoder_layer')) // model.config.cla_factor
        if prev_cla_group == curr_cla_group:
            assert torch.equal(prev_key_states, key_states)
            assert torch.equal(prev_value_states, value_states)
        else:
            assert not torch.equal(prev_key_states, key_states)
            assert not torch.equal(prev_value_states, value_states)
         
        assert prev_key_states.shape == key_states.shape
        assert prev_value_states.shape == value_states.shape
        
        prev_key_states, prev_value_states = key_states, value_states 
        prev_cla_group = curr_cla_group

### CLA with Factor=3

In [49]:
model.config.use_cla = True
model.config.cla_factor = 3

In [50]:
activation_hook.reset()

In [51]:
output = model(**{"input_ids":x})

In [52]:
output.logits

tensor([[[ 1.0454,  4.6148,  6.6658,  ..., -2.3465, -2.3464, -2.3464],
         [ 6.3484,  7.1754,  4.7409,  ..., -2.7433, -2.7432, -2.7434],
         [ 4.5100,  1.7046,  2.1793,  ..., -1.8608, -1.8610, -1.8609],
         ...,
         [ 0.0428, -0.0673,  0.7074,  ..., -0.3057, -0.3059, -0.3058],
         [ 1.6149,  0.7311,  2.0330,  ..., -0.4262, -0.4264, -0.4264],
         [-0.6556,  0.6304,  0.7900,  ..., -0.2798, -0.2799, -0.2799]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

In [54]:
key_states0, value_states0 = activation_hook.activation['decoder_layer0']
key_states1, value_states1 = activation_hook.activation['decoder_layer1']
key_states2, value_states2 = activation_hook.activation['decoder_layer2']

In [55]:
assert torch.equal(key_states0, key_states1)
assert torch.equal(value_states0, value_states1)

assert torch.equal(key_states1, key_states2)
assert torch.equal(value_states1, value_states2)

In [56]:
key_states.shape, value_states.shape

(torch.Size([1, 8, 16, 128]), torch.Size([1, 8, 16, 128]))

In [57]:
prev_key_states, prev_value_states = None, None
prev_cla_group = None
for layer_name, (key_states, value_states) in activation_hook.activation.items():
    if prev_key_states is None:
        prev_key_states, prev_value_states = key_states, value_states
        prev_cla_group = int(layer_name.removeprefix('decoder_layer')) // model.config.cla_factor
    else:
        curr_cla_group = int(layer_name.removeprefix('decoder_layer')) // model.config.cla_factor
        if prev_cla_group == curr_cla_group:
            assert torch.equal(prev_key_states, key_states)
            assert torch.equal(prev_value_states, value_states)
        else:
            assert not torch.equal(prev_key_states, key_states)
            assert not torch.equal(prev_value_states, value_states)
         
        assert prev_key_states.shape == key_states.shape
        assert prev_value_states.shape == value_states.shape
        
        prev_key_states, prev_value_states = key_states, value_states 
        prev_cla_group = curr_cla_group

### Decoding

In [58]:
inp = torch.tensor(tokenizer.encode("Say Hello world")).cuda()

In [59]:
messages = [
    {"role": "system", "content": "You are an AI assistant."},
    {"role": "user", "content": "Say Hello world 10 times as numbered list."}
]
input_tokens = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True, 
    return_tensors="pt"
).cuda()

In [60]:
model.config.use_cache = False

In [61]:
model.config

LlamaConfig {
  "_name_or_path": "meta-llama/Meta-Llama-3-8B-Instruct",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "cla_factor": 3,
  "eos_token_id": 128009,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 2.0,
    "type": "dynamic"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.43.0.dev0",
  "use_cache": false,
  "use_cla": true,
  "vocab_size": 128256
}

In [63]:
model.config.use_cla, model.config.cla_factor 

(True, 3)

In [64]:
new_tokens = model.generate(input_tokens, max_new_tokens=128, use_cache=False)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [65]:
print(tokenizer.decode(new_tokens[0]))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are an AI assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>

Say Hello world 10 times as numbered list.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

_{{_.__@__ @____,___­__ htt_ htt_ndl_opak_. __ @_zn___/ @__ @__@____ httunny_,&__/_.opak_ @_ htt_undedylvania_ständ_ htt_ryn htt oci_ `_»_-__/__/__ ociortedopro_andalone_ htt_zn_ magneticzn__/znundedusk _ manageable Towersortedômoprzn-andômständ_unded-light-andunded _ Dutryn-and-and-and_oproznoproandalone htt_oratezn_ocationsdevil-and-domopro-dev


In [67]:
model.config.use_cla = False

In [68]:
new_tokens = model.generate(input_tokens, max_new_tokens=128, use_cache=False)
print(tokenizer.decode(new_tokens[0]))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are an AI assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>

Say Hello world 10 times as numbered list.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Here is the list of "Hello World" said 10 times:

1. Hello World!
2. Hello World!
3. Hello World!
4. Hello World!
5. Hello World!
6. Hello World!
7. Hello World!
8. Hello World!
9. Hello World!
10. Hello World!<|eot_id|>


### Weight Prep for XoRA CLA

In [58]:
import torch
from transformers.utils import hub, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
from tqdm import tqdm
import safetensors
import safetensors.torch

In [3]:
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(model_name, idx)

In [4]:
files

['/root/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/model-00001-of-00004.safetensors',
 '/root/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/model-00002-of-00004.safetensors',
 '/root/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/model-00003-of-00004.safetensors',
 '/root/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/model-00004-of-00004.safetensors']

In [19]:
all_weights = {}
for filename in tqdm(files, desc="Loading & Quantizing Model Shards", disable=0, position=0):
	weights = safetensors.torch.load_file(filename)
	all_weights.update(weights)

Loading & Quantizing Model Shards: 100%|██████████| 4/4 [00:01<00:00,  2.34it/s]


In [25]:
type(weights), len(all_weights.keys())

(dict, 291)

In [63]:
cla_weights = {}
cla_factor = 2
for name, param in iter(all_weights.items()):
    if "k_proj" in name or "v_proj" in name:
        layer_idx = int(name.split('.')[2])
        if layer_idx % cla_factor != 0:
            cla_group_idx = layer_idx // cla_factor * cla_factor
            cla_weights[name] = all_weights[name.replace(f"model.layers.{layer_idx}", f"model.layers.{cla_group_idx}")].clone()
        else:
            cla_weights[name] = all_weights[name]
    else:
        cla_weights[name] = all_weights[name]

In [64]:
len(cla_weights.keys())

291

In [65]:
assert torch.equal(cla_weights["model.layers.0.self_attn.k_proj.weight"], cla_weights["model.layers.1.self_attn.k_proj.weight"])
assert torch.equal(cla_weights["model.layers.2.self_attn.k_proj.weight"], cla_weights["model.layers.3.self_attn.k_proj.weight"])
assert not torch.equal(cla_weights["model.layers.3.self_attn.k_proj.weight"], cla_weights["model.layers.4.self_attn.k_proj.weight"])
assert torch.equal(cla_weights["model.layers.14.self_attn.k_proj.weight"], cla_weights["model.layers.15.self_attn.k_proj.weight"])

In [68]:
safetensors.torch.save_file(cla_weights, "/workspace/models/meta-llama/Meta-Llama-3-8B-Instruct-cla2.safetensors")

In [69]:
cla_weights = {}
cla_factor = 3
for name, param in iter(all_weights.items()):
    if "k_proj" in name or "v_proj" in name:
        layer_idx = int(name.split('.')[2])
        if layer_idx % cla_factor != 0:
            cla_group_idx = layer_idx // cla_factor * cla_factor
            cla_weights[name] = all_weights[name.replace(f"model.layers.{layer_idx}", f"model.layers.{cla_group_idx}")].clone()
        else:
            cla_weights[name] = all_weights[name]
    else:
        cla_weights[name] = all_weights[name]

In [70]:
len(cla_weights.keys())

291

In [72]:
assert torch.equal(cla_weights["model.layers.0.self_attn.k_proj.weight"], cla_weights["model.layers.1.self_attn.k_proj.weight"])
assert torch.equal(cla_weights["model.layers.1.self_attn.k_proj.weight"], cla_weights["model.layers.2.self_attn.k_proj.weight"])
assert not torch.equal(cla_weights["model.layers.2.self_attn.k_proj.weight"], cla_weights["model.layers.3.self_attn.k_proj.weight"])
assert torch.equal(cla_weights["model.layers.9.self_attn.k_proj.weight"], cla_weights["model.layers.10.self_attn.k_proj.weight"])
assert torch.equal(cla_weights["model.layers.10.self_attn.k_proj.weight"], cla_weights["model.layers.11.self_attn.k_proj.weight"])

In [73]:
safetensors.torch.save_file(cla_weights, "/workspace/models/meta-llama/Meta-Llama-3-8B-Instruct-cla3.safetensors")

In [74]:
saved_weights = safetensors.torch.load_file("/workspace/models/meta-llama/Meta-Llama-3-8B-Instruct-cla3.safetensors")

In [75]:
saved_weights.keys()

dict_keys(['lm_head.weight', 'model.embed_tokens.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.post_attention_la