In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
from transformers import AutoTokenizer
import torch

In [None]:
from llama_real_share.modeling_llama_kvsharer import LlamaForCausalLM

### Load Model

In [None]:
llama_path = 'YOUR MODEL'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(llama_path, trust_remote_code=True)

In [None]:
llama = LlamaForCausalLM.from_pretrained(llama_path, device_map='auto')

### Load Calibration Dataset

In [None]:
wiki_data_path = './data/wiki_demo.txt'
with open(wiki_data_path, 'r') as f:
    wiki_data = f.readlines()
    f.close()

In [None]:
calibration_set = wiki_data[0:30]

### Calculate the Euclidean Distance between any two layers of KV cache and sort them

In [None]:
from tqdm import tqdm
import torch

kv_cache_share_layers_map = {i:i for i in range(len(llama.model.layers))}
kv_cache_list = []
with torch.no_grad():
    for text in tqdm(calibration_set):
        inp = tokenizer(text, return_tensors='pt', max_length=64, truncation=True)
        inp = inp.to('cuda:0')
        out = llama(**inp, kv_cache_share_layers_map=kv_cache_share_layers_map)
        past_key_values = out.past_key_values
        kv_cache_list.append(past_key_values)

In [None]:
num_layers = len(kv_cache_list[0])
avg_past_key_values = [(torch.zeros_like(kv_cache_list[0][i][0]), torch.zeros_like(kv_cache_list[0][i][1])) for i in range(num_layers)]

for past_key_values in tqdm(kv_cache_list):
    for i, (key, value) in enumerate(past_key_values):
        try:
            avg_past_key_values[i] = (avg_past_key_values[i][0] + key, avg_past_key_values[i][1] + value)
        except:
            pass

num_elements = len(kv_cache_list)
avg_past_key_values = [(key / num_elements, value / num_elements) for key, value in avg_past_key_values]


In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def compute_cosine_similarity(tensor1, tensor2):
    return F.cosine_similarity(tensor1.flatten(1), tensor2.flatten(1), dim=-1).mean().item()

def compute_euclidean_distance(tensor1, tensor2):
    return torch.norm(tensor1 - tensor2, p=2, dim=-1).mean().item()

num_layers = len(avg_past_key_values)
similarity_matrix = np.zeros((num_layers, num_layers))

for i in range(num_layers):
    for j in range(num_layers):
        if i > j:
            key_i, value_i = avg_past_key_values[i]
            key_j, value_j = avg_past_key_values[j]
            key_similarity = compute_euclidean_distance(key_i, key_j)
            value_similarity = compute_euclidean_distance(value_i, value_j)  
            similarity_matrix[i, j] = (key_similarity + value_similarity) / 2
        else:
            similarity_matrix[i, j] = np.nan

In [None]:

flattened_values = similarity_matrix.flatten()
valid_indices = ~np.isnan(flattened_values)

valid_values = flattened_values[valid_indices]
valid_flat_indices = np.where(valid_indices)[0]

sorted_valid_indices = np.argsort(valid_values)[::-1]
sorted_flat_indices = valid_flat_indices[sorted_valid_indices]

sorted_positions = np.unravel_index(sorted_flat_indices, similarity_matrix.shape)

pos_rank = []

for i in range(sorted_positions[0].shape[0]):
    pos = (sorted_positions[0][i], sorted_positions[1][i])
    pos_rank.append(pos)
    

### Initialize the Sharing Layers and THRESHOLD

In [None]:
SHARE_LAYERS = 4
THRESHOLD = 0.5

In [None]:
import numpy as np
def cal_last_hidden_sim(model1, model2, kv_cache_share_layers_map, tokenizer, sents):
    sim_ls = []
    for s in sents:
        encoded_inputs = tokenizer(s, max_length=64, truncation=True, return_tensors='pt')
        encoded_inputs.to('cuda:0')
        with torch.no_grad():
            outputs1 = model1(**encoded_inputs, output_hidden_states=True, kv_cache_share_layers_map={i:i for i in range(len(model1.model.layers))})
        hidden_states1 = outputs1.hidden_states[-1] # (1, seq_len, hidden)
        with torch.no_grad():
            outputs2 = model2(**encoded_inputs, output_hidden_states=True, kv_cache_share_layers_map=kv_cache_share_layers_map)
        hidden_states2 = outputs2.hidden_states[-1] # (1, seq_len, hidden)
        sim_ls.append(torch.cosine_similarity(hidden_states1.squeeze(0).flatten().unsqueeze(0), hidden_states2.squeeze(0).flatten().unsqueeze(0)))
    sim_ls = [i.item() for i in sim_ls]
    print(sim_ls, np.mean(sim_ls))
    return np.mean(sim_ls)

In [None]:
def re_map(kv_cache_share_layers_map):
    tmp_kv_cache_share_layers_map = {}
    for key, values in kv_cache_share_layers_map.items():
        if key == values:
            tmp_kv_cache_share_layers_map[key] = values
        else:
            tmp_kv_cache_share_layers_map[key] = tmp_kv_cache_share_layers_map[values]
    return tmp_kv_cache_share_layers_map

### Strategy Searching

In [None]:
from copy import deepcopy

kv_cache_share_layers_map = {i:i for i in range(len(llama.model.layers))}

shared_lay = []
shared_num_layers = 0

for pair in tqdm(pos_rank):
    tmp_kv_cache_share_layers_map = deepcopy(kv_cache_share_layers_map)
    if pair[0] < pair[1]:
        pair[0], pair[1] = pair[1], pair[0]
    if pair[0] in shared_lay:
        continue
    tmp_kv_cache_share_layers_map[pair[0]] = pair[1]
    tmp_kv_cache_share_layers_map = re_map(tmp_kv_cache_share_layers_map)
    sim_value = cal_last_hidden_sim(llama, llama, tmp_kv_cache_share_layers_map, tokenizer, calibration_set)
    if sim_value > THRESHOLD:
        kv_cache_share_layers_map = deepcopy(tmp_kv_cache_share_layers_map)
        shared_lay.append(pair[0])
        shared_num_layers += 1
    if shared_num_layers >= SHARE_LAYERS:
        break

In [None]:
print(kv_cache_share_layers_map)

### Inference with KVSharer

In [None]:
def generate(model, tokenizer, sent, kv_cache_share_layers_map=None):
    inputs = tokenizer(sent, return_tensors='pt')
    inputs = inputs.to('cuda:0')
    pred = model.generate(**inputs, kv_cache_share_layers_map=kv_cache_share_layers_map, max_new_tokens=256, repetition_penalty=1.1)
    print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))

In [None]:
sent = 'Hello, what is your name'
generate(llama, tokenizer, sent, kv_cache_share_layers_map=kv_cache_share_layers_map)