In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch.amp import autocast, GradScaler
from pytorch_memlab import MemReporter
import gc

In [2]:
model_directory = r"Z:/llmfile/Llama-2-7B-hf/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9"

model = AutoModelForCausalLM.from_pretrained(model_directory)
tokenizer = AutoTokenizer.from_pretrained(model_directory)

# 设置优化器(Adam)和 GradScaler 用于混合精度
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scaler = GradScaler()

# 初始输入和标签
input_text = "Hello"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
labels = input_ids.clone()  # 使用输入作为标签

reporter = MemReporter(model)

# 反向传播实验
print("==BackwardTrainingMemoryUsage==")
model.train()
optimizer.zero_grad()  # 清除之前的梯度

# 重新执行前向传播，得到loss
with autocast("cuda"):
    outputs = model(input_ids, labels=labels)
    loss = outputs.loss

# 反向传播
scaler.scale(loss).backward()  # 混合精度缩放梯度
reporter.report()  # 记录反向传播后的内存使用情况

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

==BackwardTrainingMemoryUsage==
Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cpu
Tensor0                                         (1, 2, 4096)    32.00K
Tensor1                                      (1, 32, 2, 128)     0.00B
Tensor2                                      (1, 32, 2, 128)    32.00K
Tensor3                                        (1, 2, 32000)   250.00K
Tensor4                                                 (1,)   512.00B
Tensor5                                                 (1,)   512.00B
Tensor6                                                 (1,)   512.00B
Tensor7                                               (1, 2)   512.00B
Tensor8                                               (1, 2)   512.00B
Tensor9                                         (1, 2, 4096)    32.00K
Tensor10                                     (1, 32, 2, 128)     0.00B
Tensor11             

  tensors = [obj for obj in objects if isinstance(obj, torch.Tensor)]
  fact_numel = tensor.storage().size()


In [3]:
# 清除变量并释放内存
del outputs, loss
gc.collect()

46