In [1]:
import os
import torch
import pickle
import pytz
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from datasets import load_dataset
from pympler import asizeof
from datetime import datetime

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# 加载数据集
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-v1")
print("The rows of dataset:")
print(f"\ttrain:{dataset['train'].num_rows}")
print(f"\ttest:{dataset['test'].num_rows}")
print(f"\tvalidation:{dataset['validation'].num_rows}")

# 加载预训练的分词器和模型
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)

# GPT-2分词器没有pad_token，设置它
tokenizer.pad_token = tokenizer.eos_token

Device: cpu


Using the latest cached version of the dataset since Salesforce/wikitext couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'wikitext-103-v1' at /home/lyj/.cache/huggingface/datasets/Salesforce___wikitext/wikitext-103-v1/0.0.0/b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Tue Dec  3 03:28:42 2024).


The rows of dataset:
	train:1801350
	test:4358
	validation:3760


In [3]:
def filter_k_caches(k_caches, attention_mask):
    # 计算batch中每个样本的序列长度
    origin_seq_len = attention_mask.sum(dim=1)

    # 创建掩码, 保留序列长度不为0的样本
    remain_mask = (origin_seq_len != 0)
    k_caches = [k[remain_mask] for k in k_caches]
    origin_seq_len = origin_seq_len[remain_mask]

    # 去除填充的padding的长度
    filtered_k_caches = [[] for _ in range(len(k_caches))]
    for layer_idx, k_cache in enumerate(k_caches):
        # 沿着批次维度拆分tensor为list, (num_heads, seq_len, head_hidden_size)
        k_cache_list = torch.unbind(k_cache, dim=0)
        for sample_idx, sample_k_cache in enumerate(k_cache_list):
            origin_sample_k_cache = sample_k_cache[:, :origin_seq_len[sample_idx], :]
            filtered_k_caches[layer_idx].append(origin_sample_k_cache)

    return filtered_k_caches

def get_kv_cache(text, model, tokenizer, device):
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True).to(device)
    attention_mask = inputs['attention_mask']  # padding部分标记为0

    with torch.no_grad():
        outputs = model(**inputs, use_cache=True)
    
    # kv_caches: ((k_tensor, v_tensor), ...) 元组, 每个元素归属不同层
    # k_tensor.shape: (batch, num_heads, seq_len, head_hidden_size)
    kv_caches = outputs.past_key_values

    k_caches = [kv[0] for kv in kv_caches]  # 获取所有层所有头的k cache
    filtered_k_caches = filter_k_caches(k_caches, attention_mask)
    
    return filtered_k_caches

In [4]:
def check_size_of_k_caches(all_k_caches):
    size_threshold = 10000  # bytes = 10G
    all_size = len(all_k_caches[0])
    if all_size >= size_threshold:
        return True
    return False


def load_k_caches_to_file(all_k_caches, num_atten_layer):
    directory_path = './k_caches'
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)

    # 获取当前东八区时间
    cst = pytz.timezone('Asia/Shanghai')
    now_cst = datetime.now(cst)
    time_string = now_cst.strftime("%Y-%m-%d_%H:%M:%S")

    file_name = f"{directory_path}/k_cache_{time_string}.pkl"
    with open(file_name, 'wb') as f:
        pickle.dump(all_k_caches, f)

In [5]:
model.eval()
batch_size = 16
num_atten_layer = len(model.transformer.h)
all_k_caches = [[] for _ in range(num_atten_layer)]

# for i in range(0, len(dataset['train']), batch_size):
for i in range(0, len(dataset['train']), batch_size):
    texts = dataset['train'][i:i + batch_size]['text']
    
    k_caches = get_kv_cache(texts, model, tokenizer, device)
    for layer_idx, k_cache_layer in enumerate(k_caches):
        for k_cache in k_cache_layer:
            k_cache_cpu = k_cache.to('cpu').item
            all_k_caches[layer_idx].append(k_cache_cpu)
    del k_caches
    
    # 计算 all_k_caches 的大小, 达到阈值打包成文件
    if check_size_of_k_caches(all_k_caches):
        load_k_caches_to_file(all_k_caches, num_atten_layer)
        del all_k_caches
        all_k_caches = [[] for _ in range(num_atten_layer)]
    
    if int(i/batch_size)%20 == 0:
        print(f"Batch {int(i/batch_size)} has done, k cache size: {len(all_k_caches[0])}.")

In [8]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)
