In [1]:
from sentence_transformers import SentenceTransformer, util

sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2',device="cuda")
def sentence_embedding(sentence):
    return sentence_model.encode(sentence)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer,TextIteratorStreamer
import torch
import os
# os.environ["HF_ENDPOINT"] = "https://huggingface.co"

model_name = "Qwen/Qwen2.5-7B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name,device_map="cuda",
                                             torch_dtype=torch.bfloat16,
                                            #  proxies={"http": "http://172.16.101.87:7897", "https": "http://172.16.101.87:7897"}
                                             )
tokenizer = AutoTokenizer.from_pretrained(model_name,device_map="cuda",torch_dtype=torch.bfloat16)

def min_max_normalization(matrix):
    min_val = torch.min(matrix)
    max_val = torch.max(matrix)
    return (matrix - min_val) / (max_val - min_val)

def generate(query,past_key_values=None):
    model_inputs = tokenizer(query, return_tensors="pt").to("cuda")
    generate_kwargs = {
            'max_new_tokens': 1,
            'past_key_values': past_key_values,
            'pad_token_id': tokenizer.eos_token_id,
            'top_p': 0.95,
            'temperature': 0.1,
            'repetition_penalty': 1.0,
            'top_k': 50,
            "return_dict_in_generate":True,
            "output_attentions":True,
        }
    if past_key_values is not None:
        past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
        cache_position = torch.arange(
            past_length if past_length == model_inputs.input_ids.shape[1] else past_length, model_inputs.input_ids.shape[1], device=model_inputs.input_ids.device
        )
        generate_kwargs['cache_position'] = cache_position
    outputs = model.generate(**model_inputs, **generate_kwargs)
    past_key_values = outputs.past_key_values
    generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, outputs.sequences)
        ]
    atten = outputs.attentions
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    return response,model_inputs.input_ids,past_key_values,atten


Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.18s/it]


In [4]:
import datasets
import tqdm
dataset = datasets.load_dataset("SetFit/qqp")

f = open("/root/code/vllm_plus/examples/bench_cache/data/qqp_val.txt","w")
for item in tqdm.tqdm(dataset["validation"]):
    text1,text2,label = item["text1"],item["text2"],item["label"]
    f.write(f"{text1}\t{text2}\t{label}\n")
f.close()
pass

Repo card metadata block was not found. Setting CardData to empty.
100%|██████████| 40430/40430 [00:01<00:00, 24272.17it/s]


In [3]:
def token_embedding(sentence):
    response,input_ids,past_key_values,attens = generate(sentence)
    tokens_embeddings = model.model.embed_tokens(input_ids).cpu().to(torch.float32)
    layer_idx = -1
    attens = attens[0][layer_idx].cpu().to(torch.float32)
    attens = attens.mean(dim=1)
    weights = min_max_normalization(attens).max(dim=-1).values.unsqueeze(0)

    sentence_embedding = torch.matmul(weights,tokens_embeddings)
    # sentence_embedding = torch.mean(tokens_embeddings,dim=1).unsqueeze(0)
    # print(sentence_embedding[0].shape)
    return sentence_embedding[0]


def load_tsv(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()
    data = []
    for line in lines:
        parts = line.strip().split('\t')
        if len(parts) == 3:
            data.append((parts[0], parts[1],parts[2]))
    return data
from tqdm import tqdm  # 添加 tqdm 导入
data = load_tsv("/root/code/vllm_plus/examples/bench_cache/data/lcqmc/train.tsv")

log = "/root/code/vllm_plus/examples/bench_cache/data/token_embedding.txt"
f = open(log,"w")
for item in tqdm(data[:1000]):  # 使用 tqdm 包装循环
    s1, s2, label = item
    a1 = token_embedding(s1)
    a2 = token_embedding(s2)
    cosine_similarity = torch.nn.functional.cosine_similarity(a1, a2).item()
    f.write(f"{s1},{s2},{label},{cosine_similarity}\n")

log = "/root/code/vllm_plus/examples/bench_cache/data/sentence_embedding.txt"
f = open(log,"w")
for item in tqdm(data[:1000]):  # 使用 tqdm 包装循环
    s1, s2, label = item
    a1 = sentence_embedding(s1)
    a2 = sentence_embedding(s2)
    cosine_similarity = util.dot_score([a1], [a2]).item()
    f.write(f"{s1},{s2},{label},{cosine_similarity}\n")





From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.
100%|██████████| 1000/1000 [01:04<00:00, 15.62it/s]
  a = torch.tensor(a)
100%|██████████| 1000/1000 [00:08<00:00, 111.70it/s]


In [18]:
import numpy as np

def compute_acc(similaritys,labels,threshold=0.5):
    similarity = np.array(similaritys)
    labels = np.array(labels)
    pred = (similarity > threshold).astype(int)
    acc = np.mean(pred == labels)
    return acc


log1 = "/root/code/vllm_plus/examples/bench_cache/data/token_embedding.txt"
log2 = "/root/code/vllm_plus/examples/bench_cache/data/sentence_embedding.txt"

with open(log1,"r") as f:
    lines = f.readlines()
    similaritys1 = [float(line.split(",")[-1]) for line in lines]
    labels1 = [float(line.split(",")[-2]) for line in lines]

with open(log2,"r") as f:
    lines = f.readlines()
    similaritys2 = [float(line.split(",")[-1]) for line in lines]
    labels2 = [float(line.split(",")[-2]) for line in lines]

for t in [0.7,0.8,0.9,0.95]:
    print("token_embedding:",compute_acc(similaritys1,labels1,t),"sentence_embedding:",compute_acc(similaritys2,labels2,t))



token_embedding: 0.626 sentence_embedding: 0.736
token_embedding: 0.678 sentence_embedding: 0.8
token_embedding: 0.681 sentence_embedding: 0.768
token_embedding: 0.673 sentence_embedding: 0.729
