In [None]:
import os
import torch
import torch.nn as nn
import GPUtil
from transformers import AutoTokenizer, AutoModelForCausalLM
import tqdm
from functools import partial

## Resource Detection

In [None]:
gpus = GPUtil.getGPUs()
free_memory = []

for gpu in gpus:
    free_memory.append(gpu.memoryFree)

memory_sort = sorted(range(len(free_memory)), key=lambda i: free_memory[i])

gpu_id = memory_sort[-1]
gpu_memory = free_memory[memory_sort[-1]]

print(f'gpu_id:{gpu_id}; gpu_memory:{gpu_memory}')

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

## Model Selection

In [None]:
model_name = "/data/LLMs/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")

## Layer Importance Detection

In [None]:
def encode(tok, text, padding=True, truncation=True, max_length=None):
    # 将文本转换为输入 IDs
    input_ids = [tok.bos_id] + tok.encode(text)

    # 生成注意力掩码
    attention_mask = [1] * len(input_ids)

    # 如果进行了填充，则调整注意力掩码
    if padding:
        padding_length = max_length - len(input_ids)
        attention_mask = [0] * padding_length + attention_mask
        input_ids = [tok.eos_id] * padding_length + input_ids

    encoded_input = {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }
    return encoded_input

def batch_encode_plus(tok, texts, max_length=None, return_tensors=None):
    encoded_inputs = []

    # 循环处理每个文本
    if max_length is None:
        max_length = -1
        for text in texts:
            # if isinstance(text, list):
            #     text = text[0]
            # print(text)
            len_ = len([tok.bos_id] + tok.encode(text))
            if len_ > max_length:
                max_length = len_
    for text in texts:
        # if isinstance(text, list):
        #     text = text[0]
        encoded_input = encode(tok, text, max_length = max_length)
        encoded_inputs.append(encoded_input)

    # 合并结果
    batch_encoded = {
        'input_ids': [encoded_input['input_ids'] for encoded_input in encoded_inputs],
        'attention_mask': [encoded_input['attention_mask'] for encoded_input in encoded_inputs]
    }

    batch_encoded = {key: torch.tensor(val) for key, val in batch_encoded.items()}

    return batch_encoded

In [None]:
tokenizer.bos_token = tokenizer.eos_token
tokenizer.bos_id = tokenizer.bos_token_id
tokenizer.eos_id = tokenizer.eos_token_id
importances = [0 for i in range(len(model.model.layers))]  # layer-wise importance scores

In [None]:
from datasets import load_dataset

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

In [None]:
MAX_SEQ_LEN = 1024
batch_size = 1
dataset_size = 200

In [None]:
def jaccard_set(list1, list2):
    """Define Jaccard Similarity function for two sets"""
    intersection = len(list(set(list1).intersection(list2)))
    union = (len(list1) + len(list2)) - intersection
    return float(intersection) / union

In [None]:
import numpy as np

k = 20

for i in tqdm.tqdm(range(0, dataset_size, batch_size), total = dataset_size / batch_size):
    
    prompts = dataset['text'][i:i + batch_size]
    max_seq_len = MAX_SEQ_LEN
    stride = 256
    max_gen_len = 0


    prompt_tokens = batch_encode_plus(
        tokenizer,
        prompts,
        return_tensors='pt'
    )
    input_ids = prompt_tokens['input_ids']
    attn_mask = prompt_tokens['attention_mask']
    max_prompt_len = max(len(t) for t in input_ids)
    all_jac_sim = [0 for i in range(len(model.model.layers))] 
    E = model.get_input_embeddings().weight.detach()
    
    # authors use a sliding window of size 1024 with a shift of 256
    for start in range(0, max_prompt_len, stride):
        seq_ids = (attn_mask.sum(dim=-1) > start).nonzero().squeeze()
        seq_ids = seq_ids.unsqueeze(0) if seq_ids.dim() == 0 else seq_ids  # ensure 2d
        inputs = input_ids[seq_ids, start:start+max_seq_len]
        attn = attn_mask[seq_ids, start:start+max_seq_len]

        if max_gen_len == 0:
            outputs = model(
                input_ids=inputs.to("cuda"),
                attention_mask=attn.to("cuda"),
                output_hidden_states=True,
            )
        else:
            outputs = model.generate(
                input_ids=inputs.to("cuda"),
                attention_mask=attn.to("cuda"),
                max_new_tokens=max_gen_len, 
                output_hidden_states=True,
                return_dict_in_generate=True,
            )

        hiddens = outputs.hidden_states

        for i in range(len(hiddens) - 1):
            in_hidden = hiddens[i][:,-1,:]
            out_hidden = hiddens[i+1][:,-1,:]

            in_projs = in_hidden @ E.T
            out_projs = out_hidden @ E.T

            in_projs = in_projs.detach().cpu().numpy()
            ot_projs = out_projs.detach().cpu().numpy()

            in_ind = np.argsort(-in_projs)
            ot_ind = np.argsort(-ot_projs)

            in_topks = [tokenizer.decode(i) for i in in_ind[0][:k]]
            ot_topks = [tokenizer.decode(i) for i in ot_ind[0][:k]]

            all_jac_sim[i] += jaccard_set(in_topks, ot_topks)

      
    importances = [x + y for x, y in zip(importances, all_jac_sim)]


In [None]:
import math
def normalize(lst, range_min=0, range_max=1):
    min_val = min(lst)
    max_val = max(lst)
    normalized = [(range_max - range_min) * (x - min_val) / (max_val - min_val) + range_min for x in lst]
    return normalized

filtered_values = [0 if math.isinf(value) else value for value in importances] 
normalized_lst = normalize(filtered_values)

sorted_indices = sorted(range(len(normalized_lst)), key=lambda i: normalized_lst[i])
reversed_list = list(reversed(sorted_indices))

## Quantize

In [None]:
from lsaq_quant import quantize_llama_like

num_of_layer2quant = 8
bit = 8

layer_to_quant = reversed_list[0:num_of_layer2quant]

mlp_quant = [f'layers.{item}.mlp' for item in layer_to_quant]
self_attn_quant = [f'layers.{item}.self_attn' for item in layer_to_quant]

print(f'quanting ... ')
model_lsaq = quantize_llama_like(model, mlp_quant, self_attn_quant, bit)
print(f'quanted')