In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch.amp import autocast, GradScaler
from pytorch_memlab import MemReporter
import gc

In [2]:
model_directory = r"Z:/llmfile/Llama-2-7B-hf/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9"

model = AutoModelForCausalLM.from_pretrained(model_directory)
tokenizer = AutoTokenizer.from_pretrained(model_directory)

# 设置优化器(Adam)和 GradScaler 用于混合精度
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scaler = GradScaler()

# 初始输入和标签
input_text = "Hello"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
labels = input_ids.clone()  # 使用输入作为标签

reporter = MemReporter(model)

# 先执行一次前向和反向传播，确保 scaler 正常初始化
print("==初始化一下==")
model.train()
optimizer.zero_grad()

with autocast("cuda"):
    outputs = model(input_ids, labels=labels)
    loss = outputs.loss

scaler.scale(loss).backward()  # 反向传播以初始化 GradScaler

# 执行优化器更新并监控内存
print("==OptimizerMemoryUsage==")
scaler.step(optimizer)  # 执行优化器更新
scaler.update()  # 更新 scaler 状态
reporter.report()  # 记录优化器更新后的内存使用情况

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

==初始化一下==
==OptimizerMemoryUsage==
Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cpu
Tensor0                                         (4096, 4096)    64.00M
Tensor1                                                 (1,)   512.00B
Tensor2                                         (4096, 4096)    64.00M
Tensor3                                         (4096, 4096)    64.00M
Tensor4                                                 (1,)   512.00B
Tensor5                                        (11008, 4096)   172.00M
Tensor6                                        (11008, 4096)   172.00M
Tensor7                                                 (1,)   512.00B
Tensor8                                        (11008, 4096)   172.00M
Tensor9                                        (11008, 4096)   172.00M
Tensor10                                                (1,)   512.00B
Tensor11          

  tensors = [obj for obj in objects if isinstance(obj, torch.Tensor)]
  fact_numel = tensor.storage().size()


In [3]:
# 清理内存
del outputs, loss
gc.collect()

46