-
Notifications
You must be signed in to change notification settings - Fork 359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Activation/representation based merging #199
Closed
Closed
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
shamanez
commented
Mar 18, 2024
Modified `dump_out_features.py` and added `dump_out_sim_metrics.py`
- computation of correlation matrices - computation mo m&u - correlation modification @metric-space --------- Co-authored-by: Luke Meyers <functor.soup@gmail.com>
#251) Make the work more merekit style
metric-space
changed the title
Representation based alignment and merge
Activation/representation based merge
Jul 6, 2024
metric-space
changed the title
Activation/representation based merge
Activation/representation based merging
Jul 8, 2024
Misc scriptthis scripts compares two models for differences in their weights import itertools
import logging
import os
import sys
from collections import defaultdict
from typing import List, Optional
import click
import datasets
import numpy as np
import torch
from safetensors.torch import save_file
from torch.utils.data import DataLoader
from mergekit.architecture import _template_substitution, get_architecture_info
from mergekit.common import ModelReference
from mergekit.io.tasks import LazyTensorLoader
logging.basicConfig(level=logging.INFO)
# set seed
torch.manual_seed(42)
np.random.seed(42)
import torch.nn.functional as F
def cosine_similarity_diff(matrix1, matrix2):
vec1 = matrix1.flatten().unsqueeze(0).double()
vec2 = matrix2.flatten().unsqueeze(0).double()
# plot the size of matrix elements, the decimal point the error
similarity = F.cosine_similarity(vec1, vec2).item()
return similarity
# should be 0 most of the time
def frobenius_norm_diff(matrix1):
return torch.norm(matrix1, p='fro')
@click.command("mergekit-compare-weights")
@click.argument("model-1-path", type=str)
@click.argument("model-2-path", type=str)
def main(
model_1_path: str,
model_2_path: str,
):
model_1 = ModelReference.model_validate(model_1_path)
model_2 = ModelReference.model_validate(model_2_path)
model_1_config = model_1.config()
model_2_config = model_2.config()
model_1_arch_info = get_architecture_info(model_1_config)
model_2_arch_info = get_architecture_info(model_2_config)
tensor_index_1 = model_1.tensor_index()
tensor_index_2 = model_2.tensor_index()
loader_instance_1 = LazyTensorLoader(tensor_index_1)
loader_instance_2 = LazyTensorLoader(tensor_index_2)
for weight_info in model_1_arch_info.all_weights(model_1_config):
weight_name = weight_info.name
tensor_1 = loader_instance_1.get_tensor(weight_name)
print(f"{weight_name}'s shape is {tensor_1.shape}")
tensor_2 = loader_instance_2.get_tensor(weight_name)
if tensor_1.shape != tensor_2.shape:
logging.warning(f"Shape mismatch for weight {weight_name}")
continue
# compute cosine similarity
cosine_similarity = cosine_similarity_diff(tensor_1, tensor_2)
frobenius_norm_1 = frobenius_norm_diff(tensor_1)
frobenius_norm_2 = frobenius_norm_diff(tensor_2)
logging.info(f"Weight {weight_name} cosine similarity:\t{cosine_similarity}, \tfrobenius norm diff: {frobenius_norm_1 - frobenius_norm_2}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What is this?
This PR introduces a way to merge two models via their activations and hidden states on a tiny sample of data.
This method uses these activations and hidden states to form correlation matrices to then generate permutation and inverse permutation matrices for weights in each model and then combines them
This PR consists of three main scripts
Assumptions
The models to be merged are of the same architecture and equal block/layer count
Testing
To test this we need to get the
mergekit/scripts/random_permuter.py
script from the branchrope-alignment
(see below the bash stuff for the final inference script i.e
test_by_gen.py
)(test_by_gen.py)
If all goes well, you should see the following (or something along the lines of the following)
Things that couldn't make into the final PR
on-the-fly handling of models with grouped query attention. This hasn't been tested enough for this release but will be in the near future. For now, users will have to resort to using this script first: