In [1]:
import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(hidden_size, output_size, bias=False)
        
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

In [2]:
net = MLP(1024, 1024, 1024).cuda()
x = torch.randn(1024, 1024, device="cuda")
optim = torch.optim.Adam(net.parameters(), lr=0.01)

for _ in range(20):
    y = net(x)
    y.mean().backward()
    optim.step()
    optim.zero_grad()
    
print(f"Peak memory without autocast: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()


with torch.autocast("cuda", dtype=torch.bfloat16):
    for _ in range(20):
        y = net(x)
        y.mean().backward()
        optim.step()
        optim.zero_grad()
print(f"Peak memory with autocast: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
    

  from .autonotebook import tqdm as notebook_tqdm


Peak memory without autocast: 112.00 MB
Peak memory with autocast: 114.00 MB
