In [1]:
# pip install pytorch_memlab

# 其他需要的包
# pip install transformers thop torch

from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaConfig

from pytorch_memlab import MemReporter

import torch

In [2]:
# 模型的snapshot本地地址

# llama3.1-8b-instruct
model_directory = r"Z:/llmfile/Meta-Llama-3.1-8B-Instruct/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f"

# llama2-7b
# model_directory = r"Z:/llmfile/Llama-2-7B-hf/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9"

In [3]:
# 加载模型，这个用来确定能不能正确加载
# 成功后建议shut down kernal，重新import，直接运行下一个部分，否则会爆内存

# model = AutoModelForCausalLM.from_pretrained(model_directory)
# tokenizer = AutoTokenizer.from_pretrained(model_directory)
# print(model)

In [4]:
# 检查一下torch和cuda

# import torch
# print(torch.__version__)
# print(torch.cuda.is_available())

In [5]:
# 静态的时候的模型内存需求，输出每层内存使用情况

# model = AutoModelForCausalLM.from_pretrained(model_directory)

# reporter = MemReporter(model)

# reporter.report()

In [6]:
# 加载模型和 tokenizer
model = AutoModelForCausalLM.from_pretrained(model_directory)
tokenizer = AutoTokenizer.from_pretrained(model_directory)

# MemReporter
reporter = MemReporter(model)

# 初始输入 token
# input_text = "Hello"
input_text = """Our Professor is a cool guy. He really like to provide us many interesting work every week, which is
             enough for me and my teammates spend whole weekend to figure out the details. Therefore, we almost have no
             weekend in the past few weeks. And I also lose a chance to go out for my dinner. My favourite food is sushi.
             I really wish I can have some this week since I really like rice and"""
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

# 推理一个 token 的内存使用情况（前向传播）
model.eval()
with torch.no_grad():
    outputs = model(input_ids)
    logits = outputs.logits
    reporter.report()  # 记录推理过程中的内存使用情况

# logits 的形状和内容
print("Logits shape:", logits.shape)
print("Logits:", logits)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cpu
Tensor0                                              (1, 89)     1.00K
Tensor1                                        (1, 89, 1024)   356.00K
Tensor2                                      (1, 8, 89, 128)     0.00B
Tensor3                                      (1, 8, 89, 128)   356.00K
Tensor4                                        (1, 89, 1024)   356.00K
Tensor5                                      (1, 8, 89, 128)     0.00B
Tensor6                                      (1, 8, 89, 128)   356.00K
Tensor7                                        (1, 89, 1024)   356.00K
Tensor8                                      (1, 8, 89, 128)     0.00B
Tensor9                                      (1, 8, 89, 128)   356.00K
Tensor10                                       (1, 89, 1024)   356.00K
Tensor11                                     (1, 8, 8

  tensors = [obj for obj in objects if isinstance(obj, torch.Tensor)]
  fact_numel = tensor.storage().size()


In [7]:
# 获取最后一个位置的 logits，即模型预测的下一个 token 的概率分布
next_token_logits = logits[:, -1, :]

# 获取概率最高的 token ID
predicted_token_id = torch.argmax(next_token_logits, dim=-1).item()

# 将 token ID 转换为实际的单词
predicted_token = tokenizer.decode(predicted_token_id)

# 输出预测的下一个 token
print("Predicted next token:", predicted_token)

Predicted next token:  fish
