In [5]:
import torch
import torch.nn as nn


batch_size = 64

In [10]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 512, False)
        self.fc2 = nn.Linear(512, 512, False)
        self.fc3 = nn.Linear(512, 10, False)
        self.register_buffer('fc2_input_act', torch.zeros(batch_size, 512))
        self.register_buffer('fc2_output_act', torch.zeros(batch_size, 512))

    def forward(self, x):
        x = self.flatten(x)
        self.fc2_input_act = nn.functional.relu(self.fc1(x))
        self.fc2_input_act.retain_grad()
        self.fc2_output_act = self.fc2(self.fc2_input_act)
        self.fc2_output_act.retain_grad()
        x = nn.functional.relu(self.fc2_output_act)        
        logits = self.fc3(x)
        return logits

model = NeuralNetwork()

loss = nn.CrossEntropyLoss()


In [11]:
x = torch.randn(batch_size, 28*28)
y = torch.randint(0, 10, (batch_size,))

loss = nn.functional.cross_entropy(model(x), y)

loss.backward()
        
print("Shape of X:", model.fc2_input_act.shape)
print("Shape of dL/dX:", model.fc2_input_act.grad.shape)
print("Shape of W:", model.fc2.weight.shape)
print("Shape of dL/dW:", model.fc2.weight.grad.shape)
print("Shape of Y:", model.fc2_output_act.shape)
print("Shape of dL/dY:",model.fc2_output_act.grad.shape)

diff1 = torch.sum(torch.abs(model.fc2_input_act.grad - 
                            torch.matmul(model.fc2_output_act.grad, 
                                            model.fc2.weight)))
print("Check dL/dX = dL/dY W^T, diff1=", diff1.item())

diff2 = torch.sum(torch.abs(torch.transpose(model.fc2.weight.grad,0,1) - 
                            torch.matmul(torch.transpose(model.fc2_input_act,0,1), 
                                            model.fc2_output_act.grad)))
print("Check dL/dW = X^T dL/dY, diff2=", diff2.item())


Shape of X: torch.Size([64, 512])
Shape of dL/dX: torch.Size([64, 512])
Shape of W: torch.Size([512, 512])
Shape of dL/dW: torch.Size([512, 512])
Shape of Y: torch.Size([64, 512])
Shape of dL/dY: torch.Size([64, 512])
Check dL/dX = dL/dY W^T, diff1= 0.0
Check dL/dW = X^T dL/dY, diff2= 0.0


In [14]:
class MyLinear(nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.weight = nn.Parameter(torch.randn(in_features, out_features))

  def forward(self, input):
    return torch.matmul(input, self.weight)

m = MyLinear(4, 3)
sample_input = torch.randn(4)
print(m(sample_input))



tensor([-0.4842, -0.2480,  0.7965], grad_fn=<SqueezeBackward4>)


In [15]:
for parameter in m.named_parameters():
  print(parameter)


for parameter in m.parameters():
  print(parameter)


('weight', Parameter containing:
tensor([[-0.1219, -0.8401, -1.2954],
        [-0.3019, -1.9455, -0.8994],
        [-0.4831,  0.9704,  0.7805],
        [ 2.7101,  0.2225, -0.6512]], requires_grad=True))
Parameter containing:
tensor([[-0.1219, -0.8401, -1.2954],
        [-0.3019, -1.9455, -0.8994],
        [-0.4831,  0.9704,  0.7805],
        [ 2.7101,  0.2225, -0.6512]], requires_grad=True)


In [16]:
net = nn.Sequential(
  MyLinear(4, 3),
  nn.ReLU(),
  MyLinear(3, 1)
)

sample_input = torch.randn(4)
print(net(sample_input))


tensor([0.2522], grad_fn=<SqueezeBackward4>)


In [18]:
class Net2(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer0 = MyLinear(4, 3)
    self.layer1 = MyLinear(3, 1)
  
  def forward(self, x):
    x = self.layer0(x)
    x = nn.functional.relu(x)
    x = self.layer1(x)
    return x

net2 = Net2()
sample_input = torch.randn(4)
print(net2(sample_input))


tensor([-0.7705], grad_fn=<SqueezeBackward4>)


In [19]:
print("Check net.children")
for child in net.children():
  print(child)


print("Check net.modules")
for child in net.modules():
  print(child)


Check net.children
MyLinear()
ReLU()
MyLinear()
Check net.modules
Sequential(
  (0): MyLinear()
  (1): ReLU()
  (2): MyLinear()
)
MyLinear()
ReLU()
MyLinear()


In [20]:
print("<--Check net2.children()-->")
for child in net2.children():
  print(child)


print("<--Check net.modules()-->")
for child in net2.modules():
  print(child)


<--Check net2.children()-->
MyLinear()
MyLinear()
<--Check net.modules()-->
Net2(
  (layer0): MyLinear()
  (layer1): MyLinear()
)
MyLinear()
MyLinear()


In [21]:
print("<--Check net2.named_children()-->")
for child in net2.named_children():
  print(child)


print("<--Check net.named_modules()-->")
for child in net2.named_modules():
  print(child)


<--Check net2.named_children()-->
('layer0', MyLinear())
('layer1', MyLinear())
<--Check net.named_modules()-->
('', Net2(
  (layer0): MyLinear()
  (layer1): MyLinear()
))
('layer0', MyLinear())
('layer1', MyLinear())


In [22]:
x = torch.randn(4)
x.requires_grad_(True)
x.retain_grad()
z = net2(x)
z.retain_grad()

print("x:", x)
print("w0:", net2.layer0.weight)
print("w1:",net2.layer1.weight)
print("z:",z)

z.backward()

print("dz:", z.grad)
print("dw1:", net2.layer1.weight.grad)
print("dw2:", net2.layer0.weight.grad)
print("dx:", x.grad)


x: tensor([-1.3159, -0.2344, -0.7772,  1.2623], requires_grad=True)
w0: Parameter containing:
tensor([[ 1.1255,  0.4839, -1.0264],
        [-2.0767, -1.1128, -1.0877],
        [ 0.6199,  0.2211,  0.7915],
        [ 0.8546,  1.6406,  1.4667]], requires_grad=True)
w1: Parameter containing:
tensor([[ 0.7632],
        [ 0.1008],
        [-0.9345]], requires_grad=True)
z: tensor([-2.5022], grad_fn=<SqueezeBackward4>)
dz: tensor([1.])
dw1: tensor([[0.0000],
        [1.5231],
        [2.8419]])
dw2: tensor([[ 0.0000, -0.1327,  1.2297],
        [ 0.0000, -0.0236,  0.2191],
        [ 0.0000, -0.0784,  0.7263],
        [ 0.0000,  0.1273, -1.1796]])
dx: tensor([ 1.0080,  0.9043, -0.7173, -1.2052])


In [23]:
net2.layer1.requires_grad_(False)

x = torch.randn(4)
x.requires_grad_(True)
x.retain_grad()
z = net2(x)
z.retain_grad()

print("x:", x)
print("w0:", net2.layer0.weight)
print("w1:",net2.layer1.weight)
print("z:",z)

z.backward()

print("dz:", z.grad)
print("dw1:", net2.layer1.weight.grad)
print("dw2:", net2.layer0.weight.grad)
print("dx:", x.grad)


x: tensor([ 0.5764, -0.9726, -1.3224, -0.5898], requires_grad=True)
w0: Parameter containing:
tensor([[ 1.1255,  0.4839, -1.0264],
        [-2.0767, -1.1128, -1.0877],
        [ 0.6199,  0.2211,  0.7915],
        [ 0.8546,  1.6406,  1.4667]], requires_grad=True)
w1: Parameter containing:
tensor([[ 0.7632],
        [ 0.1008],
        [-0.9345]])
z: tensor([1.0365], grad_fn=<SqueezeBackward4>)
dz: tensor([1.])
dw1: tensor([[0.0000],
        [1.5231],
        [2.8419]])
dw2: tensor([[ 0.4399, -0.0746,  1.2297],
        [-0.7422, -0.1217,  0.2191],
        [-1.0092, -0.2117,  0.7263],
        [-0.4501,  0.0678, -1.1796]])
dx: tensor([ 0.9078, -1.6971,  0.4954,  0.8176])


In [28]:
dim1 = 4096
dim2 = 8192

class MyLinear(nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.weight = nn.Parameter(torch.randn(in_features, out_features))

  def forward(self, input):
    return torch.matmul(input, self.weight)


class Net3(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer0 = MyLinear(dim1, dim2)
    self.layer1 = MyLinear(dim2, dim2)
    self.layer2 = MyLinear(dim2, dim2)
    self.layer3 = MyLinear(dim2, 1)
  
  def forward(self, x):
    x = self.layer0(x)
    x = nn.functional.relu(x)
    x = self.layer1(x)
    x = nn.functional.relu(x)
    x = self.layer2(x)
    x = nn.functional.relu(x)
    x = self.layer3(x)
    return x

net3 = Net3()
x = torch.randn(256,dim1)


In [31]:
import time

start_time = time.time()
z = net3(x)
end_time = time.time()
print("Forward computation takes:", end_time-start_time)


Forward computation takes: 0.057290077209472656


In [35]:
start_time = time.time()
with torch.no_grad():
    z = net3(x)
end_time = time.time()
print("Forward computation takes: ", end_time-start_time)


Forward computation takes:  0.05535697937011719


In [36]:
start_time = time.time()
with torch.inference_mode():
  z = net3(x)
end_time = time.time()
print("Forward computation takes: ", end_time-start_time)


Forward computation takes:  0.05443930625915527
