Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation/representation based merging #199

Closed
wants to merge 42 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
a826e48
Create zipit.yml
shamanez Mar 11, 2024
83f97eb
Add alignment method to config
metric-space Mar 11, 2024
1368c47
First stab at activations dumper
metric-space Mar 15, 2024
ea5f816
Differentiate between activations and hidden_states
metric-space Mar 16, 2024
a134612
Post test run corrections
metric-space Mar 19, 2024
f5668e7
More corrections
metric-space Mar 19, 2024
52293c1
ZipIt Similarity. (#201)
shamanez Mar 19, 2024
0004eb1
Fdfmm 40 zipit metric (#205)
shamanez Mar 22, 2024
2bc97ab
Gpt-2 residual connection correction
metric-space Mar 22, 2024
d4cb463
Fix for architecture.py
metric-space Mar 22, 2024
b998989
Playing with subgraphs generated via zipit forward-backward propagations
metric-space Mar 25, 2024
da1bc52
Mainly adding modified M_U computation. (#249)
shamanez Apr 5, 2024
a6a2480
Attempt to make zipit work speak the same language as rest of mergek…
metric-space Apr 29, 2024
0e0ef40
Fixes in implementation
metric-space Jul 6, 2024
0e234b2
Another default gone
metric-space Jul 6, 2024
9f324b6
Delete examples/zipit.yml
metric-space Jul 6, 2024
7f673f7
Change boolean default
metric-space Jul 7, 2024
ab6d131
Code removal
metric-space Jul 8, 2024
5a56d27
Make sure average of correlations are taken
metric-space Jul 8, 2024
0e85590
Remove on the fly GQA handling for now
metric-space Jul 8, 2024
0a02343
Delete test_zipit.sh
metric-space Jul 8, 2024
d3b5970
Put back config source file to original state
metric-space Jul 8, 2024
0b237e9
Put back config source file to original state (part 2)
metric-space Jul 8, 2024
f0fcc6f
Code cleanup for feature extraction script
metric-space Jul 9, 2024
2ef5c62
Another round of refactors and getting rid of unnecessary steps
metric-space Jul 9, 2024
6bcf6d8
variable name correction
metric-space Jul 9, 2024
293801c
Left over correction
metric-space Jul 9, 2024
b7af5fb
Cleaner logic for activations garbbing hooks
metric-space Jul 9, 2024
5b1344b
More refactors for feature extraction script
metric-space Jul 9, 2024
f92bc58
More refactors and corrections
metric-space Jul 9, 2024
10407fe
Make script efficient to avoid oom errors when device is set to gpu
metric-space Jul 9, 2024
f845c50
Make final script device configurable
metric-space Jul 9, 2024
fb55abe
Make final script device configurable
metric-space Jul 9, 2024
43863b7
Add chat template ability to feature extraction script
metric-space Jul 10, 2024
5d47dd0
Add datasets dependency to project dependencies
metric-space Jul 10, 2024
dc53a58
Bug fix
metric-space Jul 10, 2024
42a9543
Bug fix
metric-space Jul 10, 2024
2bff4d8
Yet another bug fix
metric-space Jul 10, 2024
547c400
Encode connection between att_kq and attn_v
metric-space Jul 10, 2024
6327cfd
Give proper script commands
metric-space Jul 10, 2024
fe548dc
New folder and location change
metric-space Jul 10, 2024
1ab4b2b
Delete test_by_gen.py
metric-space Jul 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mergekit/_data/architectures/gpt2.json
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
"type": "residual",
"inputs": [
"post_mlp_${layer_index}",
"h_${layer_index}"
"post_attn_${layer_index}"
]
}
]
Expand Down
24 changes: 18 additions & 6 deletions mergekit/_data/architectures/llama.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,26 @@
{
"name": "model.layers.${layer_index}.self_attn.q_proj.weight",
"input_space": "running_residual",
"output_space": "attn_qkv_${layer_index}",
"head_split": "output"
"output_space": "attn_qk_${layer_index}",
"head_split": "output",
"is_kq": true
},
{
"name": "model.layers.${layer_index}.self_attn.k_proj.weight",
"input_space": "running_residual",
"output_space": "attn_qkv_${layer_index}",
"head_split": "output"
"output_space": "attn_qk_${layer_index}",
"head_split": "output",
"is_kq": true
},
{
"name": "model.layers.${layer_index}.self_attn.v_proj.weight",
"input_space": "running_residual",
"output_space": "attn_qkv_${layer_index}",
"output_space": "attn_v_${layer_index}",
"head_split": "output"
},
{
"name": "model.layers.${layer_index}.self_attn.o_proj.weight",
"input_space": "attn_qkv_${layer_index}",
"input_space": "attn_v_${layer_index}",
"output_space": "running_residual",
"head_split": "input"
},
Expand Down Expand Up @@ -79,5 +81,15 @@
"model.lm_head.weight"
]
}
],
"procedural_spaces": [
{
"type": "kv_expand",
"name":"attn_qkv_${layer_index}",
"inputs": [
"attn_qk_${layer_index}",
"attn_v_${layer_index}"
]
}
]
}
3 changes: 2 additions & 1 deletion mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class WeightInfo(BaseModel, frozen=True):
optional: bool = False
aliases: Optional[List[str]] = None
head_split: Literal[None, "input", "output"] = None
is_kq: Optional[bool] = False


class ProceduralSpaceInfo(BaseModel, frozen=True):
Expand Down Expand Up @@ -284,7 +285,7 @@ def _substitute(
elif isinstance(obj_dict[key], list):
obj_dict[key] = [
(
TemplateWithArithmetic(s).substitute(substitutions)
_template_substitution(s, num_layers, layer_idx)
if isinstance(s, str)
else s
)
Expand Down
171 changes: 171 additions & 0 deletions mergekit/scripts/ABM/activations_based_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import logging
import os
from typing import Optional

import click
import safetensors.torch
import torch
import tqdm
from transformers import AutoTokenizer

from mergekit.architecture import get_architecture_info
from mergekit.common import ModelReference, dtype_from_name
from mergekit.io.tasks import LoaderCache
from mergekit.io.tensor_writer import TensorWriter
from mergekit.options import MergeOptions, add_merge_options


@click.command("mergekit-activation-based-merge")
@click.argument("model_path", type=str)
@click.argument("secondary_model_path", type=str)
@click.argument("merge_unmerge_directory", type=str)
@click.option("--out-path", "-o", required=True, type=str, help="Output model path")
@click.option(
"--dtype",
type=str,
default="float16",
help="Data type to convert weights to",
)
@click.option(
"--device",
"-d",
type=str,
default="cuda",
help="Device to compute on (default: cuda)",
)
@add_merge_options
def main(
model_path: str,
secondary_model_path,
merge_unmerge_directory: str,
out_path: str,
dtype: Optional[str],
device: Optional[str],
merge_options: MergeOptions,
):
model = ModelReference.model_validate(model_path)
secondary_model = ModelReference.model_validate(secondary_model_path)

dtype = dtype_from_name(dtype) if dtype else None

cache = LoaderCache()
cache.lazy_unpickle = merge_options.lazy_unpickle
cache.hf_cache_dir = merge_options.transformers_cache

for m in tqdm.tqdm([model, secondary_model], desc="Preparing models"):
cache.get(m)

writer = TensorWriter(
out_path=out_path,
max_shard_size=merge_options.out_shard_size,
safe_serialization=merge_options.safe_serialization,
)

model_config = model.config(trust_remote_code=merge_options.trust_remote_code)
model_arch_info = get_architecture_info(
model.config(trust_remote_code=merge_options.trust_remote_code)
)

loader_1 = cache.get(model)
loader_2 = cache.get(secondary_model)

os.makedirs(out_path, exist_ok=True)

merge_unmerge_dictionary = {}
# load files from merge_unmerge_directory
spaces = [
f.split("_unmerge")[0]
for f in os.listdir(merge_unmerge_directory)
if "_unmerge" in f
]
for i in spaces:
logging.info(f"Loading merge/unmerge tensors for {i}")
m = safetensors.torch.load_file(
os.path.join(merge_unmerge_directory, f"{i}_merge.safetensor"),
device=device,
)
u = safetensors.torch.load_file(
os.path.join(merge_unmerge_directory, f"{i}_unmerge.safetensor"),
device=device,
)
merge_unmerge_dictionary[i] = (
m[i].to(device, dtype=dtype),
u[i].to(device, dtype=dtype),
)

for weight_info in model_arch_info.all_weights(config=model_config):
merge_matrix, unmerge_matrix = None, None

if weight_info.input_space in merge_unmerge_dictionary:
_, unmerge_matrix = merge_unmerge_dictionary[weight_info.input_space]
unmerge_matrix = unmerge_matrix.chunk(2, dim=0)

if weight_info.output_space in merge_unmerge_dictionary:
merge_matrix, _ = merge_unmerge_dictionary[weight_info.output_space]
merge_matrix = merge_matrix.chunk(2, dim=1)

original_w = loader_1.get_tensor(weight_info.name, device=device)
original_w2 = loader_2.get_tensor(weight_info.name, device=device)

if dtype is not None:
original_w = original_w.to(dtype=dtype)
original_w2 = original_w2.to(dtype=dtype)

w = torch.clone(original_w)
w2 = torch.clone(original_w2)

if not merge_matrix and not unmerge_matrix:
logging.warning(
f"❌ Weight {weight_info.name} for model 1 and model 2 has no merge or unmerge matrix"
)

if merge_matrix is not None:
if weight_info.is_embed:
w = (merge_matrix[0] @ w.T).T
w2 = (merge_matrix[1] @ w2.T).T
else:
w = merge_matrix[0] @ w
w2 = merge_matrix[1] @ w2

if unmerge_matrix is not None:
w = w @ unmerge_matrix[0]
w2 = w2 @ unmerge_matrix[1]

# check if weights have not mutated, if yes then shoot warning
if torch.allclose(original_w, w):
logging.warning(
f"❌ Weight {weight_info.name} for model 1 has NOT mutated during merge"
)
else:
logging.warning(
f"✅ Weight {weight_info.name} for model 1 has mutated during merge"
)

if torch.allclose(original_w2, w2):
logging.warning(
f"❌ Weight {weight_info.name} for model 2 has NOT mutated during merge"
)
else:
logging.warning(
f"✅ Weight {weight_info.name} for model 2 has mutated during merge"
)

# average weights and save them
if merge_matrix:
w = w + w2
else:
w = (w + w2) / 2
writer.save_tensor(weight_info.name, w)
writer.finalize()

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.save_pretrained(out_path, safe_serialization=True)

# write config
model_out_config = model.config(trust_remote_code=merge_options.trust_remote_code)
if dtype:
model_out_config.torch_dtype = dtype
model_out_config.save_pretrained(out_path)


main()
Loading
Loading