In this demonstration, we explore the implementation of three key techniques in PyTorch to improve the stability and performance of a neural network during training:
* **Gradient clipping** is used to prevent the exploding gradient problem by capping the gradients during backpropagation
* **Weight regularization** (L2 regularization) is applied via the optimizer to avoid overfitting by penalizing large weights
* **Batch normalization** is incorporated to normalize the input layer by adjusting and scaling the activations.

These techniques are instrumental in ensuring that our model trains effectively and generalizes well to new data.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
# Assuming we have a simple neural network for demonstration purposes:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.batch_norm = nn.BatchNorm1d(num_features=100)
        self.fc1 = nn.Linear(100, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = self.batch_norm(x)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [3]:
# Create an instance of the model
model = SimpleNet()

In [4]:
# We'll use weight decay for L2 regularization
weight_decay = 1e-5
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=weight_decay)

In [5]:
# Dummy inputs and outputs for demonstration
inputs = torch.randn(64, 100)
targets = torch.randint(0, 10, (64,))


In [6]:
# Training loop
num_epochs = 5
clip_value = 1.0  # Gradient clipping threshold

In [7]:
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = nn.CrossEntropyLoss()(outputs, targets)

    # Backpropagation
    loss.backward()

    # Gradient clipping
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)

    optimizer.step()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

Epoch 1/5, Loss: 2.324849843978882
Epoch 2/5, Loss: 2.290389060974121
Epoch 3/5, Loss: 2.256732225418091
Epoch 4/5, Loss: 2.2237589359283447
Epoch 5/5, Loss: 2.191350221633911
