In [1]:
import torch
import torch.nn as nn

batch_size, hidden_size =4, 8

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.z1 = None
        self.w1 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.w2 = nn.Linear(hidden_size, 1, bias=False)
    
    def forward(self, x):
        z1 = self.w1(x)
        print(z1)
        self.z1 = z1.clone().detach()
        z2 = self.w2(z1)
        return z2

In [2]:
model = Net()

def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params / 1e9:.1f}G\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(model))

trainable model parameters: 0.0G
all model parameters: 72
percentage of trainable model parameters: 100.00%


In [3]:
from torch.optim import SGD

fp32_model= Net().to("cuda")
lr = 1e-0
optimizer = SGD(fp32_model.parameters(), lr=lr)
# print(lr)  #1.0

In [4]:
fp32_model.w1.weight

Parameter containing:
tensor([[-0.0906, -0.0177,  0.3129,  0.1098, -0.2905, -0.1271, -0.1763,  0.1325],
        [-0.1657,  0.0474,  0.2535, -0.0791, -0.0824,  0.0549,  0.1489,  0.2605],
        [-0.2464, -0.2973,  0.1691,  0.1777, -0.0475, -0.3374, -0.1192, -0.3453],
        [ 0.2817,  0.3462,  0.2579, -0.2603,  0.1440, -0.1558,  0.0590,  0.2681],
        [ 0.0908, -0.1774,  0.1442,  0.2615, -0.0706, -0.0726,  0.1265,  0.2917],
        [-0.1459, -0.1778, -0.2187, -0.1222,  0.0954, -0.1505, -0.2023,  0.0204],
        [ 0.0171,  0.3026, -0.2561, -0.3250, -0.1028,  0.3338,  0.1469,  0.0285],
        [ 0.1566, -0.3480, -0.1715,  0.2985,  0.1476, -0.3179,  0.1371,  0.1631]],
       device='cuda:0', requires_grad=True)

In [5]:
fp32_model.w2.weight

Parameter containing:
tensor([[ 0.2247,  0.1487,  0.0881, -0.2991,  0.0971, -0.0459,  0.2221, -0.1868]],
       device='cuda:0', requires_grad=True)

In [6]:
import torch

# example input sizes
#batch_size, hidden_size =4, 8

# create dummy data (bsz=4, hid=256)
x = torch.randn(batch_size,hidden_size, dtype=torch.float, device="cuda") 

# do forward
z2 = fp32_model(x)

# check dtypr of output logits
f"logits type = {z2.dtype}"

tensor([[-0.2388, -0.0356, -0.1086,  0.4998,  0.0016, -0.5134, -0.0702, -0.0403],
        [-0.6645, -0.0950, -0.7363,  0.5179, -0.5337, -0.3007,  0.6192, -0.6999],
        [-0.4698,  0.1093, -0.9383,  0.2190, -0.0976, -0.2578,  0.8914, -0.2335],
        [-0.3563, -0.7386,  0.2531, -0.5503,  0.1289,  0.4823, -0.4234,  0.9617]],
       device='cuda:0', grad_fn=<MmBackward0>)


'logits type = torch.float32'

In [7]:
# craete dummy data (bsz=4)
#y = torch.tensor([[1.9], [9.5], [0.9], [1.2]], dtype=torch.half, device="cuda") #batch_size =4
y = torch.tensor([[1.9], [9.5], [0.9], [1.2]], dtype=torch.float32, device="cuda") #batch_size =4
#y = torch.tensor([[1.9]], dtype=torch.float32, device="cuda")
#y = torch.tensor([[1.9], [0.5]], dtype=torch.float32, device="cuda")
# compute mean square error loss
L = torch.nn.functional.mse_loss(z2, y)

# check dtype of loss
f"loss type = {L.dtype}"

'loss type = torch.float32'

In [8]:
print(L)
print(z2)
print(y)
loss = torch.sum((z2-y)**2/batch_size)
print(loss)

tensor(25.1515, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor([[-0.2023],
        [-0.1529],
        [ 0.0065],
        [-0.2864]], device='cuda:0', grad_fn=<MmBackward0>)
tensor([[1.9000],
        [9.5000],
        [0.9000],
        [1.2000]], device='cuda:0')
tensor(25.1515, device='cuda:0', grad_fn=<SumBackward0>)


In [9]:
L.backward()
w2_weight = fp32_model.w2.weight.clone().detach()
w1_weight = fp32_model.w1.weight.clone().detach()
print(f'before: {fp32_model.w2.weight}\n')
optimizer.step()
print(f'after: {fp32_model.w2.weight}\n')

before: Parameter containing:
tensor([[ 0.2247,  0.1487,  0.0881, -0.2991,  0.0971, -0.0459,  0.2221, -0.1868]],
       device='cuda:0', requires_grad=True)

after: Parameter containing:
tensor([[-3.7081, -0.8472, -3.8109,  2.4147, -2.4251, -1.7939,  3.2204, -2.9970]],
       device='cuda:0', requires_grad=True)



In [10]:
DL_Dz2= 2 * (z2 - y) / batch_size # DL/Dz2  (BWD-activation: layer2), in case of MSE
print(DL_Dz2) # [4,1] [output_size=1, batch_size] DL/Dz2
Dz2_Dw2 = fp32_model.z1.clone().detach() #Dz2/Dw2
print(Dz2_Dw2.shape) #[4,8] [batch_size, hidden_size] 
print(Dz2_Dw2)
#DL_Dw2 = DL_Dz2.T * Dz2_Dw2
DL_Dw2 = torch.matmul(DL_Dz2.T, Dz2_Dw2) #[1,4] * [4,8] batch_size, hidden_size
print(DL_Dw2.shape) # [1,8] [output_size=1, hidden_size]
print(DL_Dw2)
print(w2_weight) # [hidden_size=8 , output_size=1]
print(w2_weight - lr * DL_Dw2)

tensor([[-1.0512],
        [-4.8265],
        [-0.4467],
        [-0.7432]], device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([4, 8])
tensor([[-0.2388, -0.0356, -0.1086,  0.4998,  0.0016, -0.5134, -0.0702, -0.0403],
        [-0.6645, -0.0950, -0.7363,  0.5179, -0.5337, -0.3007,  0.6192, -0.6999],
        [-0.4698,  0.1093, -0.9383,  0.2190, -0.0976, -0.2578,  0.8914, -0.2335],
        [-0.3563, -0.7386,  0.2531, -0.5503,  0.1289,  0.4823, -0.4234,  0.9617]],
       device='cuda:0')
torch.Size([1, 8])
tensor([[ 3.9328,  0.9959,  3.8990, -2.7137,  2.5223,  1.7479, -2.9982,  2.8101]],
       device='cuda:0', grad_fn=<MmBackward0>)
tensor([[ 0.2247,  0.1487,  0.0881, -0.2991,  0.0971, -0.0459,  0.2221, -0.1868]],
       device='cuda:0')
tensor([[-3.7081, -0.8472, -3.8109,  2.4147, -2.4251, -1.7939,  3.2204, -2.9970]],
       device='cuda:0', grad_fn=<SubBackward0>)


In [11]:
fp32_model.w1.weight # w1 = [hidden_size, hidden_size] [8,8]

Parameter containing:
tensor([[ 0.7844,  0.9291,  0.5532, -0.3136,  1.4755,  1.2774,  0.0164, -0.5525],
        [ 0.4135,  0.6742,  0.4127, -0.3594,  1.0867,  0.9848,  0.2765, -0.1931],
        [ 0.0969,  0.0742,  0.2634,  0.0116,  0.6454,  0.2137, -0.0436, -0.6141],
        [-0.8831, -0.9141, -0.0621,  0.3033, -2.2068, -2.0254, -0.1976,  1.1800],
        [ 0.4692,  0.2320,  0.2481,  0.0785,  0.6930,  0.5347,  0.2098, -0.0045],
        [-0.3247, -0.3713, -0.2678, -0.0357, -0.2654, -0.4375, -0.2416,  0.1604],
        [ 0.8823,  1.2388, -0.0184, -0.7437,  1.6433,  1.7225,  0.3375, -0.6488],
        [-0.5712, -1.1355, -0.3714,  0.6507, -1.3211, -1.4860, -0.0232,  0.7329]],
       device='cuda:0', requires_grad=True)

In [12]:
DL_Dz2= 2 * (z2 - y) / batch_size # DL/Dz2  (BWD-activation: layer2), in case of MSE
print(DL_Dz2) # [4,1] [batch_size=4, output_size=1]
print(w2_weight.shape) # [1,8]
temp = torch.matmul(DL_Dz2, w2_weight) #DL/Dz2 * w2
print(temp.shape) # [4,8]
print(x.shape) # [4,8]
DL_Dw1 = torch.matmul(temp.T, x) # [8,4] * [4,8] = [8,8]
print(DL_Dw1.shape) #[8,8]
print(w1_weight - lr * DL_Dw1)

tensor([[-1.0512],
        [-4.8265],
        [-0.4467],
        [-0.7432]], device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([1, 8])
torch.Size([4, 8])
torch.Size([4, 8])
torch.Size([8, 8])
tensor([[ 0.7844,  0.9291,  0.5532, -0.3136,  1.4755,  1.2774,  0.0164, -0.5525],
        [ 0.4135,  0.6742,  0.4127, -0.3594,  1.0867,  0.9848,  0.2765, -0.1931],
        [ 0.0969,  0.0742,  0.2634,  0.0116,  0.6454,  0.2137, -0.0436, -0.6141],
        [-0.8831, -0.9141, -0.0621,  0.3033, -2.2068, -2.0254, -0.1976,  1.1800],
        [ 0.4692,  0.2320,  0.2481,  0.0785,  0.6930,  0.5347,  0.2098, -0.0045],
        [-0.3247, -0.3713, -0.2678, -0.0357, -0.2654, -0.4375, -0.2416,  0.1604],
        [ 0.8823,  1.2388, -0.0184, -0.7437,  1.6433,  1.7225,  0.3375, -0.6488],
        [-0.5712, -1.1355, -0.3714,  0.6507, -1.3211, -1.4860, -0.0232,  0.7329]],
       device='cuda:0', grad_fn=<SubBackward0>)


In [13]:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Linear(512, 512, bias=False)
        self.w2 = nn.Linear(512, 1, bias=False)
    
    def forward(self, x):
        z1 = self.w1(x)
        z2 = self.w2(z1)
        return z2

from torch.optim import SGD

fp32_model= Net().to("cuda")
optimizer = SGD(fp32_model.parameters(), lr=1e-2)


### Float2Half

In [14]:
fp16_model = Net().half().to("cuda")
fp16_model.load_state_dict(fp32_model.state_dict())

<All keys matched successfully>

### Forward

In [15]:
import torch

# example input sizes
batch_size, hidden_size = 4, 512

# create dummy data (bsz=4, hid=256)
x = torch.randn(batch_size,hidden_size, dtype=torch.half, device="cuda") 

# do forward
z2 = fp16_model(x)

# check dtypr of output logits
f"logits type = {z2.dtype}"

'logits type = torch.float16'

In [16]:
# craete dummy data (bsz=4)
y = torch.tensor([[1.9], [9.5], [0.9], [1.2]], dtype=torch.half, device="cuda")

# compute mean square error loss
L = torch.nn.functional.mse_loss(z2, y)

# check dtype of loss
f"loss type = {L.dtype}"

'loss type = torch.float16'

### Backward

In [17]:
# loss scaling
L *= 1024

# do backward
L.backward()

### Update Weight

In [18]:
print(f'before: {fp32_model.w1.weight}\n')
optimizer.step()
print(f'after: {fp32_model.w1.weight}\n')

before: Parameter containing:
tensor([[-0.0225, -0.0107,  0.0302,  ...,  0.0010, -0.0359,  0.0066],
        [ 0.0334, -0.0346,  0.0312,  ...,  0.0212, -0.0374,  0.0067],
        [ 0.0124, -0.0137,  0.0121,  ..., -0.0433,  0.0172,  0.0341],
        ...,
        [-0.0394,  0.0279,  0.0392,  ...,  0.0374, -0.0026,  0.0150],
        [-0.0053, -0.0111,  0.0289,  ...,  0.0046, -0.0286, -0.0411],
        [-0.0161, -0.0105,  0.0411,  ..., -0.0264, -0.0139,  0.0388]],
       device='cuda:0', requires_grad=True)

after: Parameter containing:
tensor([[-0.0225, -0.0107,  0.0302,  ...,  0.0010, -0.0359,  0.0066],
        [ 0.0334, -0.0346,  0.0312,  ...,  0.0212, -0.0374,  0.0067],
        [ 0.0124, -0.0137,  0.0121,  ..., -0.0433,  0.0172,  0.0341],
        ...,
        [-0.0394,  0.0279,  0.0392,  ...,  0.0374, -0.0026,  0.0150],
        [-0.0053, -0.0111,  0.0289,  ...,  0.0046, -0.0286, -0.0411],
        [-0.0161, -0.0105,  0.0411,  ..., -0.0264, -0.0139,  0.0388]],
       device='cuda:0', requ

In [19]:
#print(f'before: {fp16_model.w1.weight}\n')
#optimizer.step()
#print(f'after: {fp16_model.w1.weight}\n')

In [20]:
# copy gradient to FP32 model
fp32_model.w1.weight.grad = fp16_model.w1.weight.grad.float()
fp32_model.w2.weight.grad = fp16_model.w2.weight.grad.float()

In [21]:
print(f'before: {fp32_model.w1.weight}\n')
optimizer.step()
print(f'after: {fp32_model.w1.weight}\n')

before: Parameter containing:
tensor([[-0.0225, -0.0107,  0.0302,  ...,  0.0010, -0.0359,  0.0066],
        [ 0.0334, -0.0346,  0.0312,  ...,  0.0212, -0.0374,  0.0067],
        [ 0.0124, -0.0137,  0.0121,  ..., -0.0433,  0.0172,  0.0341],
        ...,
        [-0.0394,  0.0279,  0.0392,  ...,  0.0374, -0.0026,  0.0150],
        [-0.0053, -0.0111,  0.0289,  ...,  0.0046, -0.0286, -0.0411],
        [-0.0161, -0.0105,  0.0411,  ..., -0.0264, -0.0139,  0.0388]],
       device='cuda:0', requires_grad=True)

after: Parameter containing:
tensor([[ 2.3910e-01, -6.1224e-01,  2.6777e+00,  ..., -4.3640e+00,
          2.8267e-01, -3.1759e+00],
        [-1.7368e-01,  4.4223e-01, -2.0663e+00,  ...,  3.4812e+00,
         -2.8971e-01,  2.5280e+00],
        [-3.0838e-02,  8.5775e-02, -4.2572e-01,  ...,  6.7855e-01,
         -3.5368e-02,  5.6004e-01],
        ...,
        [ 1.8024e-01, -4.7868e-01,  2.2667e+00,  ..., -3.6351e+00,
          2.6525e-01, -2.6625e+00],
        [-4.2536e-04, -2.2271e-02,  7

In [22]:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Linear(512, 512, bias=False)
        self.w2 = nn.Linear(512, 1, bias=False)
    
    def forward(self, x):
        z1 = self.w1(x)
        z2 = self.w2(z1)
        return z2

from torch.optim import SGD

fp32_model= Net().to("cuda")
optimizer = SGD(fp32_model.parameters(), lr=1e-2)
#optimizer = SGD(fp32_model.parameters(), lr=1e-0)

### Float2Half
fp16_model = Net().half().to("cuda")
fp16_model.load_state_dict(fp32_model.state_dict())

### Forward
import torch

# example input sizes
batch_size, hidden_size = 4, 512

# create dummy data (bsz=4, hid=256)
x = torch.randn(batch_size,hidden_size, dtype=torch.half, device="cuda") 

# do forward
z2 = fp16_model(x)

# check dtypr of output logits
f"logits type = {z2.dtype}"


# craete dummy data (bsz=4)
y = torch.tensor([[1.9], [9.5], [0.9], [1.2]], dtype=torch.half, device="cuda")

# compute mean square error loss
L = torch.nn.functional.mse_loss(z2, y)

# check dtype of loss
f"loss type = {L.dtype}"

### Backward
# loss scaling
#L *= 1024

# do backward
L.backward()

print(f'fp32 grad: {fp32_model.w1.weight.grad}\n')
### Update Weight
print(f'before: {fp32_model.w1.weight}\n')
optimizer.step()
print(f'after: {fp32_model.w1.weight}\n')
print(f'fp32 grad: {fp32_model.w1.weight.grad}\n')


print(f'f16 grad: {fp16_model.w1.weight.grad}\n')

# copy gradient to FP32 model
fp32_model.w1.weight.grad = fp16_model.w1.weight.grad.float()
fp32_model.w2.weight.grad = fp16_model.w2.weight.grad.float()

print(f'before: {fp32_model.w1.weight}\n')
optimizer.step()
print(f'after: {fp32_model.w1.weight}\n')

"""
print(f'before: {fp16_model.w1.weight}\n')
print(fp16_model.w1.weight.grad)
optimizer.step()
print(fp16_model.w1.weight.grad)
print(f'after: {fp16_model.w1.weight}\n')
"""

fp32 grad: None

before: Parameter containing:
tensor([[ 0.0092, -0.0172, -0.0067,  ..., -0.0027,  0.0084, -0.0214],
        [ 0.0358,  0.0061,  0.0276,  ...,  0.0274, -0.0038,  0.0103],
        [-0.0320, -0.0291, -0.0293,  ...,  0.0085, -0.0079,  0.0432],
        ...,
        [ 0.0158, -0.0156,  0.0021,  ...,  0.0162, -0.0098,  0.0191],
        [ 0.0356, -0.0112, -0.0425,  ..., -0.0339, -0.0367,  0.0429],
        [-0.0105,  0.0071,  0.0083,  ...,  0.0219, -0.0387,  0.0187]],
       device='cuda:0', requires_grad=True)

after: Parameter containing:
tensor([[ 0.0092, -0.0172, -0.0067,  ..., -0.0027,  0.0084, -0.0214],
        [ 0.0358,  0.0061,  0.0276,  ...,  0.0274, -0.0038,  0.0103],
        [-0.0320, -0.0291, -0.0293,  ...,  0.0085, -0.0079,  0.0432],
        ...,
        [ 0.0158, -0.0156,  0.0021,  ...,  0.0162, -0.0098,  0.0191],
        [ 0.0356, -0.0112, -0.0425,  ..., -0.0339, -0.0367,  0.0429],
        [-0.0105,  0.0071,  0.0083,  ...,  0.0219, -0.0387,  0.0187]],
       devi

"\nprint(f'before: {fp16_model.w1.weight}\n')\nprint(fp16_model.w1.weight.grad)\noptimizer.step()\nprint(fp16_model.w1.weight.grad)\nprint(f'after: {fp16_model.w1.weight}\n')\n"