In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from pytorch_memlab import MemReporter


In [2]:
# 模型的snapshot本地地址
model_directory = r"Z:/llmfile/Llama-2-7B-hf/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9"


In [3]:
# 加载模型和 tokenizer
model = AutoModelForCausalLM.from_pretrained(model_directory)
tokenizer = AutoTokenizer.from_pretrained(model_directory)

reporter = MemReporter(model)

# 设置初始输入
input_text = "Hello"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
# 推理一个 token 的内存使用情况（混合精度）
model.eval()
with torch.no_grad():
    with torch.amp.autocast("cuda"):  # 启用混合精度
        outputs = model(input_ids)
        logits = outputs.logits
        reporter.report()

# 输出 logits 的形状和内容
print("Logits shape:", logits.shape)

Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cpu
Tensor0                                               (1, 2)   512.00B
Tensor1                                         (1, 2, 4096)    32.00K
Tensor2                                      (1, 32, 2, 128)     0.00B
Tensor3                                      (1, 32, 2, 128)    32.00K
Tensor4                                         (1, 2, 4096)    32.00K
Tensor5                                      (1, 32, 2, 128)     0.00B
Tensor6                                      (1, 32, 2, 128)    32.00K
Tensor7                                         (1, 2, 4096)    32.00K
Tensor8                                      (1, 32, 2, 128)     0.00B
Tensor9                                      (1, 32, 2, 128)    32.00K
Tensor10                                        (1, 2, 4096)    32.00K
Tensor11                                     (1, 32, 

  tensors = [obj for obj in objects if isinstance(obj, torch.Tensor)]
  fact_numel = tensor.storage().size()


In [5]:
# 获取最后一个位置的 logits，即模型预测的下一个 token 的概率分布
next_token_logits = logits[:, -1, :]

# 获取概率最高的 token ID
predicted_token_id = torch.argmax(next_token_logits, dim=-1).item()

# 将 token ID 转换为实际的单词
predicted_token = tokenizer.decode(predicted_token_id)

# 输出预测的下一个 token
print("Predicted next token:", predicted_token)

Predicted next token: ,
