Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation/representation based merging #199

Closed
wants to merge 42 commits into from
Closed

Conversation

shamanez
Copy link
Member

@shamanez shamanez commented Mar 18, 2024

What is this?

This PR introduces a way to merge two models via their activations and hidden states on a tiny sample of data.
This method uses these activations and hidden states to form correlation matrices to then generate permutation and inverse permutation matrices for weights in each model and then combines them

This PR consists of three main scripts

  1. the first one generates the activation/hidden state for each space
  2. a permutation and inverse permutation pair is generated for each space
  3. based on each space and the connected weights, the permutation and/or inverse permutation is applied to each weight and then the weights are combined

Assumptions

The models to be merged are of the same architecture and equal block/layer count

Testing

To test this we need to get the mergekit/scripts/random_permuter.py script from the branch rope-alignment

(see below the bash stuff for the final inference script i.e test_by_gen.py)

git clone --branch rope-alignment https://github.com/arcee-ai/mergekit.git  permuter
python3  -mvenv permuter 
cd permuter && source bin/activate
pip install -e .
huggingface-cli login
python mergekit/scripts/permute_random.py meta-llama/Llama-2-7b-chat-hf --permute-head-dims  --out-path random2
cp $HF_HOME/hub/models--meta-llama--Llama-2-7b-chat-hf/snapshots/f5db02db724555f92da89c216ac04704f23d4590/{tokenizer*,special_tokens_map.json} random2
deactivate 
cd -

git clone --branch wip-zipit https://github.com/arcee-ai/mergekit.git  mergekit
python3  -mvenv mergekit 
cd mergekit && source bin/activate
pip install -e .
mkdir delete_dump_output/
python mergekit/scripts/ABM/extract_activations.py  meta-llama/Llama-2-7b-chat-hf -o ./delete_dump_output  -d arcee-ai/pmc-test-perplexity  -s 8  -c text  -u test  --device cpu
python mergekit/scripts/ABM/extract_activations.py /home/ubuntu/data/permuter/random2 -o ./delete_dump_output  -d arcee-ai/pmc-test-perplexity  -s 8  -c text  -u test  --device cpu
mkdir delete_m_v_out
python mergekit/scripts/ABM/extract_permutation_matrices.py ./delete_dump_output/meta-llama_Llama-2-7b-chat-hf_features.bin ./delete_dump_output/_home_ubuntu_data_permuter_random2_features.bin   --model_path  meta-llama/Llama-2-7b-chat-hf --out_path ./delete_m_v_out
mkdir new_model/
python mergekit/scripts/activations_based_merge.py  meta-llama/Llama-2-7b-chat-hf  /home/ubuntu/data/permuter/random2  delete_m_v_out -o new_model
python test_by_gen.py new_model

(test_by_gen.py)

import sys

import torch
from transformers import pipeline

model = sys.argv[1] 

pipe = pipeline(
    "text-generation", model=model, torch_dtype=torch.bfloat16, device_map="auto"
)

# We use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
messages = [
    {
        "role": "system",
        "content": "You are a helpful chatbot who pretends to be Richard Feynman",
    },
    {"role": "user", "content": "Could you tell me about the challenger disaster ?"},
]
prompt = pipe.tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
outputs = pipe(
    prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95
)
print(outputs[0]["generated_text"])

If all goes well, you should see the following (or something along the lines of the following)
Screenshot from 2024-07-06 21-46-24

Things that couldn't make into the final PR

on-the-fly handling of models with grouped query attention. This hasn't been tested enough for this release but will be in the near future. For now, users will have to resort to using this script first:

examples/zipit.yml Outdated Show resolved Hide resolved
mergekit/_data/architectures/gpt2.json Outdated Show resolved Hide resolved
mergekit/_data/mappings/a_b.json Outdated Show resolved Hide resolved
@shamanez shamanez changed the base branch from main to wip-git-rebasin March 18, 2024 18:48
@metric-space metric-space changed the title Wip zipit Representation based alignment and merge Apr 29, 2024
@metric-space metric-space changed the title Representation based alignment and merge Activation/representation based merge Jul 6, 2024
@metric-space metric-space changed the title Activation/representation based merge Activation/representation based merging Jul 8, 2024
@metric-space metric-space changed the base branch from wip-git-rebasin to main July 9, 2024 03:32
@metric-space metric-space changed the base branch from main to wip-git-rebasin July 9, 2024 03:43
@metric-space
Copy link
Contributor

Misc script

this scripts compares two models for differences in their weights

import itertools
import logging
import os
import sys
from collections import defaultdict
from typing import List, Optional

import click
import datasets
import numpy as np
import torch
from safetensors.torch import save_file
from torch.utils.data import DataLoader

from mergekit.architecture import _template_substitution, get_architecture_info
from mergekit.common import ModelReference

from mergekit.io.tasks import LazyTensorLoader

logging.basicConfig(level=logging.INFO)

# set seed
torch.manual_seed(42)
np.random.seed(42)

import torch.nn.functional as F


def cosine_similarity_diff(matrix1, matrix2):
    vec1 = matrix1.flatten().unsqueeze(0).double()
    vec2 = matrix2.flatten().unsqueeze(0).double()
    # plot the size of matrix elements, the decimal point the error
    similarity = F.cosine_similarity(vec1, vec2).item()  
    return similarity
    

# should be 0 most of the time
def frobenius_norm_diff(matrix1):
    return torch.norm(matrix1, p='fro')



@click.command("mergekit-compare-weights")
@click.argument("model-1-path", type=str)
@click.argument("model-2-path", type=str)
def main(
    model_1_path: str,
    model_2_path: str,
):
    model_1 = ModelReference.model_validate(model_1_path)
    model_2 = ModelReference.model_validate(model_2_path)

    model_1_config = model_1.config()
    model_2_config = model_2.config()

    model_1_arch_info = get_architecture_info(model_1_config)
    model_2_arch_info = get_architecture_info(model_2_config)

    tensor_index_1 = model_1.tensor_index()
    tensor_index_2 = model_2.tensor_index()


    loader_instance_1 = LazyTensorLoader(tensor_index_1)
    loader_instance_2 = LazyTensorLoader(tensor_index_2)

    for weight_info in model_1_arch_info.all_weights(model_1_config):
        weight_name = weight_info.name

        tensor_1 = loader_instance_1.get_tensor(weight_name)
        print(f"{weight_name}'s shape is  {tensor_1.shape}")
        tensor_2 = loader_instance_2.get_tensor(weight_name)

        if tensor_1.shape != tensor_2.shape:
            logging.warning(f"Shape mismatch for weight {weight_name}")
            continue

        # compute cosine similarity
        cosine_similarity = cosine_similarity_diff(tensor_1, tensor_2)
        frobenius_norm_1 = frobenius_norm_diff(tensor_1)
        frobenius_norm_2 = frobenius_norm_diff(tensor_2)

        logging.info(f"Weight {weight_name} cosine similarity:\t{cosine_similarity}, \tfrobenius norm diff: {frobenius_norm_1 - frobenius_norm_2}")

@metric-space metric-space marked this pull request as ready for review July 10, 2024 03:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants