In [1]:
import torch
import torch.nn as nn

In [2]:
class DeepNet(nn.Module):
    def __init__(self, nlayers: int):
        super(DeepNet, self).__init__()
        layers = []
        for _ in range(nlayers):
            layers.append(nn.Linear(10, 10))
            layers.append(nn.ReLU())
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

In [14]:
torch.manual_seed(42)  # set random seed for reproducibility

input_tensor = torch.randn(32, 10)
target = torch.randn(32, 10)

model = DeepNet(50)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=100000.1)  # high learning rate to induce explosion

In [15]:
# model

In [16]:
n_epochs = 100
for epoch in range(n_epochs):
    # forward pass
    outputs = model(input_tensor)
    loss = criterion(outputs, target)

    # backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # print the average gradient magnitude for the first layer
    gradients = model.layers[0].weight.grad
    avg_gradient_magnitude = gradients.abs().mean().item()
    print(f"{epoch+1:2d}/{n_epochs:2d} | Loss: {loss.item()} | Average Gradient Magnitude: {avg_gradient_magnitude}")

    # stop if gradients are exploding
    if avg_gradient_magnitude > 1e6:
        print("Gradients are exploding!")
        break

 1/100 | Loss: 0.9364947080612183 | Average Gradient Magnitude: 7.12428329886329e-24
 2/100 | Loss: 4.16832032605563e+23 | Average Gradient Magnitude: 56631.140625
 3/100 | Loss: nan | Average Gradient Magnitude: nan
 4/100 | Loss: nan | Average Gradient Magnitude: nan
 5/100 | Loss: nan | Average Gradient Magnitude: nan
 6/100 | Loss: nan | Average Gradient Magnitude: nan
 7/100 | Loss: nan | Average Gradient Magnitude: nan
 8/100 | Loss: nan | Average Gradient Magnitude: nan
 9/100 | Loss: nan | Average Gradient Magnitude: nan
10/100 | Loss: nan | Average Gradient Magnitude: nan
11/100 | Loss: nan | Average Gradient Magnitude: nan
12/100 | Loss: nan | Average Gradient Magnitude: nan
13/100 | Loss: nan | Average Gradient Magnitude: nan
14/100 | Loss: nan | Average Gradient Magnitude: nan
15/100 | Loss: nan | Average Gradient Magnitude: nan
16/100 | Loss: nan | Average Gradient Magnitude: nan
17/100 | Loss: nan | Average Gradient Magnitude: nan
18/100 | Loss: nan | Average Gradient Mag