In [1]:
import torch
import torch.nn as nn

In [2]:
# Only leaves have .grad attributes

# https://discuss.pytorch.org/t/valueerror-cant-optimize-a-non-leaf-tensor/21751/2
a = torch.rand(10).requires_grad_().half()
b = torch.rand(10).half().requires_grad_() 

sum_a = torch.sum(a)
sum_b = torch.sum(b)

sum_a.backward()
sum_b.backward()

print('--------------------------')
print('a is not a leaf, we cannot access its gradients after calling backward')
print(a.grad)
print('b is a leaf')
print(b.grad)
print('--------------------------')

--------------------------
a is not a leaf, we cannot access its gradients after calling backward
None
b is a leaf
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=torch.float16)
--------------------------


  from ipykernel import kernelapp as app


In [3]:
# Cast a 32-bit tensor into 16-bit during forward,
# after backward, its gradients of a will be 32-bit again

a = torch.rand(10).requires_grad_()
b = a.half()

print('a dtype', a.dtype, 'require_grad=',a.requires_grad)
print('b dtype', b.dtype, 'require_grad=',b.requires_grad)

b_sum = torch.sum(b)
print('b_sum dtype', b_sum.dtype, 'require_grad=',b_sum.requires_grad)

b_sum.backward()

print('a.grad',a.grad)
print('a.grad dtype', a.grad.dtype)

a dtype torch.float32 require_grad= True
b dtype torch.float16 require_grad= True
b_sum dtype torch.float16 require_grad= True
a.grad tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
a.grad dtype torch.float32


In [4]:
# Backward will also backward through the .half() operation
# Therefore, gradient of W and in the optimizer will become 32-bit again

W = nn.Parameter(torch.tensor([1,2,3], dtype=torch.float32))
opt = torch.optim.SGD([W], lr=0.1)
print('Init: W dtype', W.dtype)
print('Optimizer opt[0][\'params\'][0].dtype', opt.param_groups[0]['params'][0].dtype)

X = torch.tensor([2,4,6], dtype=torch.float16, requires_grad=True)
Y = torch.tensor([20.], dtype=torch.float16, requires_grad=True)
print('Data X dtype', X.dtype)
print('Data Y dtype', Y.dtype)

W16 = W.half()
print('Forward: W16 dtype', W16.dtype)

Y_hat = torch.dot(W16, X)
print('Forward: Y_hat dtype', Y_hat.dtype)

loss = Y - Y_hat
print('Forward: loss dtype', loss.dtype)

loss.backward()
print('-------------------------------')

print('After backward: loss dtype', loss.dtype)
print('After backward: Y_hat dtype', Y_hat.dtype)
print('After backward: X dtype', X.dtype)
print('After backward: X.grad dtype', X.grad.dtype)
print('After backward: Y dtype', Y.dtype)
print('After backward: Y.grad dtype', Y.grad.dtype)
print('After backward: Y dtype', Y.dtype)
print('After backward: Y.grad dtype', Y.grad.dtype)
print('After backward: W dtype', W.dtype)
print('After backward: W.grad dtype', W.grad.dtype)
print('Optimizer opt[0][\'params\'][0].dtype', opt.param_groups[0]['params'][0].dtype)

Init: W dtype torch.float32
Optimizer opt[0]['params'][0].dtype torch.float32
Data X dtype torch.float16
Data Y dtype torch.float16
Forward: W16 dtype torch.float16
Forward: Y_hat dtype torch.float16
Forward: loss dtype torch.float16
-------------------------------
After backward: loss dtype torch.float16
After backward: Y_hat dtype torch.float16
After backward: X dtype torch.float16
After backward: X.grad dtype torch.float16
After backward: Y dtype torch.float16
After backward: Y.grad dtype torch.float16
After backward: Y dtype torch.float16
After backward: Y.grad dtype torch.float16
After backward: W dtype torch.float32
After backward: W.grad dtype torch.float32
Optimizer opt[0]['params'][0].dtype torch.float32


In [5]:
# Wrapping the above as a nn.Module

class SimpleWModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = nn.Parameter(torch.tensor([1,2,3], dtype=torch.float32))

    def forward(self, X):
        W_16 = self.W.half()
        return torch.dot(W_16, X)

model = SimpleWModel()
opt = torch.optim.SGD(model.parameters(), lr=0.1)
print('Init: model.W dtype', model.W.dtype)
print('Init: opt.param_groups[0][\'params\'][0] dtype', opt.param_groups[0]['params'][0].data.dtype)

X = torch.tensor([2,4,6], dtype=torch.float16, requires_grad=True)
Y = torch.tensor([20.], dtype=torch.float16, requires_grad=True)
print('Data X dtype', X.dtype)
print('Data Y dtype', Y.dtype)

Y_hat = model(X)
print('Forward: Y_hat dtype', Y_hat.dtype)
loss = Y - Y_hat
print('Forward: loss dtype', loss.dtype)

opt.zero_grad()
loss.backward()
print('-------------------------------')
print('After backward: self.W dtype', model.W.dtype)
print('After backward: self.W.grad dtype', model.W.grad.data.dtype)
print('Optimizer opt[0][\'params\'][0].dtype', opt.param_groups[0]['params'][0].dtype)

Init: model.W dtype torch.float32
Init: opt.param_groups[0]['params'][0] dtype torch.float32
Data X dtype torch.float16
Data Y dtype torch.float16
Forward: Y_hat dtype torch.float16
Forward: loss dtype torch.float16
-------------------------------
After backward: self.W dtype torch.float32
After backward: self.W.grad dtype torch.float32
Optimizer opt[0]['params'][0].dtype torch.float32


In [6]:
# !!! !!! INCORRECT! Doesn't work because calling half on a module(nn.Linear) modifies the data in place 

torch.manual_seed(42)

class SimpleLinear(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = nn.Linear(3,1)
        self.W2 = nn.Parameter(torch.tensor([1,2,3], dtype=torch.float32))

    def forward(self, X):
        W_16 = self.W.half()
        W2_16 = self.W2.half()
        return torch.mean(W_16(X) + W2_16*3)

model = SimpleLinear()
opt = torch.optim.SGD(model.parameters(), lr=0.1)
print('Init: model.W.weight dtype', model.W.weight.dtype)
print('Init: model.W.weightopt.param_groups[0][\'params\'][1] dtype', opt.param_groups[0]['params'][1].data.dtype)
print('After backward: model.W.weightopt.param_groups[0][\'params\'][1] data', opt.param_groups[0]['params'][1].data)
print('Init: model.W.bias dtype', model.W.bias.dtype)
print('Init: model.W.bias opt.param_groups[0][\'params\'][0] dtype', opt.param_groups[0]['params'][2].data.dtype)
print('Init: model.W2 dtype', model.W2.dtype)
print('Init: model.W2 opt.param_groups[0][\'params\'][0] dtype', opt.param_groups[0]['params'][0].data.dtype)
print('-------------------------------')
X = torch.tensor([2,4,6], dtype=torch.float16, requires_grad=True)
Y = torch.tensor([20.], dtype=torch.float16, requires_grad=True)
print('Data X dtype', X.dtype)
print('Data Y dtype', Y.dtype)

Y_hat = model(X)
print('Forward: Y_hat dtype', Y_hat.dtype)
loss = Y - Y_hat
print('Forward: loss dtype', loss.dtype)

opt.zero_grad()
loss.backward()
print('-------------------------------')
print('After backward: model.W.weight dtype', model.W.weight.dtype)
print('After backward: model.W.weightopt.param_groups[0][\'params\'][1] dtype', opt.param_groups[0]['params'][1].data.dtype)
print('After backward: model.W.weightopt.param_groups[0][\'params\'][1] grad data', opt.param_groups[0]['params'][1].grad.data)
print('After backward:: model.W.bias dtype', model.W.bias.dtype)
print('After backward:: model.W.bias opt.param_groups[0][\'params\'][0] dtype', opt.param_groups[0]['params'][2].data.dtype)
print('After backward: model.W2 dtype', model.W2.dtype)
print('After backward: model.W2 opt.param_groups[0][\'params\'][0] dtype', opt.param_groups[0]['params'][0].data.dtype)

Init: model.W.weight dtype torch.float32
Init: model.W.weightopt.param_groups[0]['params'][1] dtype torch.float32
After backward: model.W.weightopt.param_groups[0]['params'][1] data tensor([[ 0.4414,  0.4792, -0.1353]])
Init: model.W.bias dtype torch.float32
Init: model.W.bias opt.param_groups[0]['params'][0] dtype torch.float32
Init: model.W2 dtype torch.float32
Init: model.W2 opt.param_groups[0]['params'][0] dtype torch.float32
-------------------------------
Data X dtype torch.float16
Data Y dtype torch.float16
Forward: Y_hat dtype torch.float16
Forward: loss dtype torch.float16
-------------------------------
After backward: model.W.weight dtype torch.float16
After backward: model.W.weightopt.param_groups[0]['params'][1] dtype torch.float16
After backward: model.W.weightopt.param_groups[0]['params'][1] grad data tensor([[-2., -4., -6.]], dtype=torch.float16)
After backward:: model.W.bias dtype torch.float16
After backward:: model.W.bias opt.param_groups[0]['params'][0] dtype torch.

In [7]:
# !!! INCORRECT! Using deep copy before calling half() is incorrect because gradient doesn't flow through deepcopy
import copy

torch.manual_seed(42)

class SimpleLinear2(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = nn.Linear(3,1)
        self.W2 = nn.Parameter(torch.tensor([1,2,3], dtype=torch.float32))

    def forward(self, X):
        W_16 = copy.deepcopy(self.W).half()
        W2_16 = self.W2.half()
        return torch.mean(W_16(X) + W2_16*3)

model = SimpleLinear2()
opt = torch.optim.SGD(model.parameters(), lr=0.1)
print('Init: model.W.weight dtype', model.W.weight.dtype)
print('Init: model.W.weightopt.param_groups[0][\'params\'][1] dtype', opt.param_groups[0]['params'][1].data.dtype)
print('After backward: model.W.weightopt.param_groups[0][\'params\'][1] data', opt.param_groups[0]['params'][1].data)
print('Init: model.W.bias dtype', model.W.bias.dtype)
print('Init: model.W.bias opt.param_groups[0][\'params\'][0] dtype', opt.param_groups[0]['params'][2].data.dtype)
print('Init: model.W2 dtype', model.W2.dtype)
print('Init: model.W2 opt.param_groups[0][\'params\'][0] dtype', opt.param_groups[0]['params'][0].data.dtype)
print('-------------------------------')
X = torch.tensor([2,4,6], dtype=torch.float16, requires_grad=True)
Y = torch.tensor([20.], dtype=torch.float16, requires_grad=True)
print('Data X dtype', X.dtype)
print('Data Y dtype', Y.dtype)

Y_hat = model(X)
print('Forward: Y_hat dtype', Y_hat.dtype)
loss = Y - Y_hat
print('Forward: loss dtype', loss.dtype)

opt.zero_grad()
loss.backward()
print('-------------------------------')
print('After backward: model.W.weight dtype', model.W.weight.dtype)
print('After backward: model.W.weightopt.param_groups[0][\'params\'][1] dtype', opt.param_groups[0]['params'][1].data.dtype)
# will raise an error because deepcopy doesn't allow gradient to flow through
# print('After backward: model.W.weightopt.param_groups[0][\'params\'][1] grad data', opt.param_groups[0]['params'][1].grad.data) 
print('After backward:: model.W.bias dtype', model.W.bias.dtype)
print('After backward:: model.W.bias opt.param_groups[0][\'params\'][0] dtype', opt.param_groups[0]['params'][2].data.dtype)
print('After backward: model.W2 dtype', model.W2.dtype)
print('After backward: model.W2 opt.param_groups[0][\'params\'][0] dtype', opt.param_groups[0]['params'][0].data.dtype)

Init: model.W.weight dtype torch.float32
Init: model.W.weightopt.param_groups[0]['params'][1] dtype torch.float32
After backward: model.W.weightopt.param_groups[0]['params'][1] data tensor([[ 0.4414,  0.4792, -0.1353]])
Init: model.W.bias dtype torch.float32
Init: model.W.bias opt.param_groups[0]['params'][0] dtype torch.float32
Init: model.W2 dtype torch.float32
Init: model.W2 opt.param_groups[0]['params'][0] dtype torch.float32
-------------------------------
Data X dtype torch.float16
Data Y dtype torch.float16
Forward: Y_hat dtype torch.float16
Forward: loss dtype torch.float16
-------------------------------
After backward: model.W.weight dtype torch.float32
After backward: model.W.weightopt.param_groups[0]['params'][1] dtype torch.float32
After backward:: model.W.bias dtype torch.float32
After backward:: model.W.bias opt.param_groups[0]['params'][0] dtype torch.float32
After backward: model.W2 dtype torch.float32
After backward: model.W2 opt.param_groups[0]['params'][0] dtype tor

In [8]:
# !!! INCORRECT! Optimizer Copy is changed to 16-bit as well

import torch
import torch.nn as nn

torch.manual_seed(42)

class SimpleLinear3(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = nn.Linear(3,1)
        self.W2 = nn.Parameter(torch.tensor([1,2,3], dtype=torch.float32))

    def forward(self, X):
        return torch.mean(self.W(X) + self.W2*3)

model = SimpleLinear3()
opt = torch.optim.SGD(model.parameters(), lr=0.1)

for p in model.parameters():
    print(p.dtype)
print('------------------------------')
model.bfloat16()
for p in model.parameters():
    print(p.dtype)
print('------------------------------')
opt.param_groups[0]['params']

torch.float32
torch.float32
torch.float32
------------------------------
torch.bfloat16
torch.bfloat16
torch.bfloat16
------------------------------


[Parameter containing:
 tensor([1., 2., 3.], dtype=torch.bfloat16, requires_grad=True),
 Parameter containing:
 tensor([[ 0.4414,  0.4785, -0.1357]], dtype=torch.bfloat16, requires_grad=True),
 Parameter containing:
 tensor([0.5312], dtype=torch.bfloat16, requires_grad=True)]