定义输入，x和y大小都为1GB

In [1]:
import torch

tensor_size = 256 * 1024 * 1024
x = torch.randn(tensor_size, dtype=torch.float32, device='cuda')
y = torch.randn(tensor_size, dtype=torch.float32, device='cuda')


初始化

In [2]:
current_memory = torch.cuda.memory_allocated()
torch.cuda.reset_peak_memory_stats()

下面的一些算子用来分析不同算子的显存占用的影响
第一个算子是不保存梯度的简单的算术算子

In [8]:
def compute(x, y):
    return (x + 1) * (2 * y)

如果x和y都需要保存梯度

In [11]:
def compute(x, y):
    x.requires_grad_(True)
    y.requires_grad_(True)
    return (x + 1) * (2 * y)

如果只保存x的梯度

In [None]:
def compute(x, y):
    x.requires_grad_(True)
    return (x + 1) * (2 * y)

算子融合，手动计算梯度，可以降低持续显存占用

In [5]:
from torch.autograd import Function
class AddMulFunction(Function):
    @staticmethod
    def forward(ctx, x, y):
        ctx.save_for_backward(x, y)
        z = (x + 1) * (2 * y)
        #print(z.requires_grad)
        #print(z.grad_fn)
        return z

    @staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        grad_x = grad_output * (2 * y)
        grad_y = grad_output * (x + 1)
        return grad_x, grad_y

func = AddMulFunction.apply

def compute(x, y):
    x.requires_grad_(True)
    y.requires_grad_(True)
    return func(x, y)

#print(z.requires_grad)
#print(z.grad_fn)

继续优化算子，降低其峰值显存占用

In [None]:
from torch.autograd import Function
class AddMulFunction(Function):
    @staticmethod
    def forward(ctx, x, y):
        ctx.save_for_backward(x, y)
        z = x+1
        z = z*2
        z = z*y
        #print(z.requires_grad)
        #print(z.grad_fn)
        return z

    @staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        grad_x = grad_output * (2 * y)
        grad_y = grad_output * (x + 1)
        return grad_x, grad_y

func = AddMulFunction.apply

def compute(x, y):
    x.requires_grad_(True)
    y.requires_grad_(True)
    return func(x, y)

#print(z.requires_grad)
#print(z.grad_fn)

计算输出

In [None]:
z = compute(x, y)

自定义的sigmoid算子，其效率较差，显存占用较高

In [3]:
def compute(x):
    x.requires_grad_(True)
    z = 1 / (1 + torch.exp(-x))
    return z

pytorch提供的sigmoid算子，经过了算子融合等优化

In [None]:
def compute(x):
    x.requires_grad_(True)
    z = torch.nn.Sigmoid()(x)
    return z

In [None]:
计算输出

In [7]:
z = compute(x)

In [None]:
统计最后的占用:
- 持续显存占用
- 峰值显存占用

In [8]:
def memory():
  additional_memory = torch.cuda.memory_allocated() - (current_memory + 1e9)
  peak_memory = torch.cuda.max_memory_allocated()
  additional_peak_memory = peak_memory - (current_memory + 1e9)

  print(f"Additional memory used: {additional_memory / (1024 ** 3)} GB")
  print(f"Additional peak memory used: {additional_peak_memory / (1024 ** 3)} GB")

memory()

Additional memory used: 2.0686774253845215 GB
Additional peak memory used: 3.0686774253845215 GB
