In [7]:
import torch, time, gc, os, warnings
warnings.filterwarnings(action = 'ignore')

In [2]:
start_time = None

In [3]:
def start_timer():
    global start_time
    
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.synchronize()
    
    start_time = time.time()
    
def end_timer_and_print(local_msg):
    torch.cuda.synchronize()
    end_time = time.time()
    print("\n" + local_msg)
    print("Total execution time = {:.3f} sec".format(end_time - start_time))
    print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated()))

In [4]:
def make_model(in_size, out_size, num_layers):
    layers = []
    
    for _ in range(num_layers - 1):
        layers.append(torch.nn.Linear(in_size, in_size))
        layers.append(torch.nn.ReLU())
        
    layers.append(torch.nn.Linear(in_size, out_size))
    
    return torch.nn.Sequential(*tuple(layers)).cuda()

- Typically mixed-precision provides the greatest speedup when the GPU is saturated.

In [5]:
batch_size = 512
in_size = 4096
out_size = 4096
num_layers = 3
num_batches = 50
epochs = 3

data = [torch.randn(batch_size, in_size, device = 'cuda') for _ in range(num_batches)]
targets = [torch.randn(batch_size, out_size, device = 'cuda') for _ in range(num_batches)]

loss_fn = torch.nn.MSELoss().cuda()

### Default Precision

In [8]:
net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr = 0.001)

start_timer()
for epoch in range(epochs):
    for input, target in zip(data, targets):
        output = net(input)
        loss = loss_fn(output, target)
        loss.backward()
        opt.step()
        opt.zero_grad()
        
end_timer_and_print("Default precision:")


Default precision:
Total execution time = 1.298 sec
Max memory used by tensors = 1308984832 bytes


### Adding autocast

In [9]:
start_timer()

for epoch in range(3):
    for input, target in zip(data, targets):
        with torch.autocast(device_type = 'cuda', dtype = torch.float16):
            output = net(input)
            
            # output is float16 because linear layers autocast to float16.
            assert output.dtype is torch.float16
            
            loss = loss_fn(output, target)
            # loss is float32 because mse_loss layers autocast to float16
            assert loss.dtype is torch.float32
            

        # Exits autocast before backward().
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast close for corresponding forward ops.
        loss.backward()
        opt.step()
        opt.zero_grad()
        
end_timer_and_print("Adding Autocast:")


Adding Autocast:
Total execution time = 0.588 sec
Max memory used by tensors = 1304781312 bytes


### Adding GradScaler

- Gradient scaling helps prevent gradients with small magnitudes from flushing to zero (《underflowing》) when training with mixed precision.
- torch.cuda.amp.GradScaler performs the steps of gradient scaling conveniently.

In [None]:
# Constructs scaler once, at the beginning of the convergence run, using default args.
# If your network fails to converge with default GradScaler args, please file an issue.
# The same GradScaler instance should be used for the entire convergence run.
# If you perform multiple convergence runs in the same scripts, each run should use
# a dedicated fresh GradSaler instance. GradScaler instances are lightweight.

In [10]:
start_timer()

scaler = torch.cuda.amp.GradScaler()

for epoch in range(3):
    for input, target in zip(data, targets):
        with torch.autocast(device_type = 'cuda', dtype = torch.float16):
            output = net(input)
            loss = loss_fn(output, target)
            
        # Scales loss. Calls backward() on scaled loss to create scaled gradients.
        scaler.scale(loss).backward()
        
        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs of NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() 
        scaler.step(opt)
        
        # Updates the scale for next iteration.
        scaler.update()
        
        opt.zero_grad()
        
end_timer_and_print("Adding Autocast:")


Adding Autocast:
Total execution time = 0.610 sec
Max memory used by tensors = 1304782848 bytes


### All together : Automatic Mixed Precision

In [11]:
use_amp = True

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr = 0.001)
scaler = torch.cuda.amp.GradScaler(enabled = use_amp)

start_timer()

for epoch in range(3):
    for input, target in zip(data, targets):
        with torch.autocast(device_type = 'cuda', dtype = torch.float16, enabled = use_amp):
            output = net(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        opt.zero_grad()

end_timer_and_print("Adding Autocast:")


Adding Autocast:
Total execution time = 0.552 sec
Max memory used by tensors = 1409672704 bytes


### Inspecting/modifying gradients

In [12]:
start_timer()

for epoch in range(3):
    for input, target in zip(data, targets):
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            output = net(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()

        # Unscales the gradients of optimizer's assigned params in-place
        scaler.unscale_(opt)

        # Since the gradients of optimizer's assigned params are now unscaled, clips as usual.
        # You may use the same value for max_norm here as you would without gradient scaling.
        torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.1)

        scaler.step(opt)
        scaler.update()
        opt.zero_grad()
        
end_timer_and_print("Adding Autocast:")


Adding Autocast:
Total execution time = 0.649 sec
Max memory used by tensors = 1304782848 bytes


### Saving/Resuming

In [13]:
checkpoint = {"model": net.state_dict(),
              "optimizer": opt.state_dict(),
              "scaler": scaler.state_dict()}
# Write checkpoint as desired, e.g.,
# torch.save(checkpoint, "filename")

In [22]:
torch.save(checkpoint, "Auto_Mixed_Precision")

### Loading

In [16]:
dev = torch.cuda.current_device()
checkpoint = torch.load("Auto_Mixed_Precision",
                        map_location = lambda storage, loc: storage.cuda(dev))

In [17]:
net.load_state_dict(checkpoint["model"])
opt.load_state_dict(checkpoint["optimizer"])
scaler.load_state_dict(checkpoint["scaler"])