In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

torch.manual_seed(42)


<torch._C.Generator at 0x7a4610aa2570>

In [11]:
def generate_data(num_samples=1000):
    X = torch.randn(num_samples, 1)
    y = 3 * X + torch.randn(num_samples, 1)
    return X, y

class DeepNet(nn.Module):
    def __init__(self, use_batchnorm=False, init_scale=1.0):
        super(DeepNet, self).__init__()
        self.fc1 = nn.Linear(1, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 128)
        self.fc4 = nn.Linear(128, 1)
        self.relu = nn.ReLU()
        self.use_batchnorm = use_batchnorm
        if use_batchnorm:
            self.bn1 = nn.BatchNorm1d(128)
            self.bn2 = nn.BatchNorm1d(128)
            self.bn3 = nn.BatchNorm1d(128)

        # Initialize weights
        self._initialize_weights(init_scale)

    def _initialize_weights(self, init_scale):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=init_scale)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.fc1(x)
        if self.use_batchnorm:
            x = self.bn1(x)
        x = self.relu(x)

        x = self.fc2(x)
        if self.use_batchnorm:
            x = self.bn2(x)
        x = self.relu(x)

        x = self.fc3(x)
        if self.use_batchnorm:
            x = self.bn3(x)
        x = self.relu(x)

        x = self.fc4(x)
        return x


In [12]:
def train_model(model, X, y, num_epochs=2000, lr=0.003):
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    losses = []

    for epoch in range(num_epochs):
        outputs = model(X)
        loss = criterion(outputs, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        if (epoch + 1) % 200 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

    return losses



In [13]:
X, y = generate_data()

# Experiment 1: Large weight initialization
print("Training with large weight initialization...")
model_large = DeepNet(use_batchnorm=False, init_scale=10.0)
losses_large = train_model(model_large, X, y)

# Experiment 2: Small weight initialization
print("\nTraining with small weight initialization...")
model_small = DeepNet(use_batchnorm=False, init_scale=0.01)
losses_small = train_model(model_small, X, y)

# Experiment 3: Batch Normalization
print("\nTraining with Batch Normalization...")
model_bn = DeepNet(use_batchnorm=True, init_scale=1.0)
losses_bn = train_model(model_bn, X, y)



Training with large weight initialization...
Epoch [200/2000], Loss: nan
Epoch [400/2000], Loss: nan
Epoch [600/2000], Loss: nan
Epoch [800/2000], Loss: nan
Epoch [1000/2000], Loss: nan
Epoch [1200/2000], Loss: nan
Epoch [1400/2000], Loss: nan
Epoch [1600/2000], Loss: nan
Epoch [1800/2000], Loss: nan
Epoch [2000/2000], Loss: nan

Training with small weight initialization...
Epoch [200/2000], Loss: 10.2456
Epoch [400/2000], Loss: 10.2455
Epoch [600/2000], Loss: 10.2454
Epoch [800/2000], Loss: 10.2454
Epoch [1000/2000], Loss: 10.2454
Epoch [1200/2000], Loss: 10.2454
Epoch [1400/2000], Loss: 10.2453
Epoch [1600/2000], Loss: 10.2453
Epoch [1800/2000], Loss: 10.2452
Epoch [2000/2000], Loss: 10.2452

Training with Batch Normalization...
Epoch [200/2000], Loss: 1.0121
Epoch [400/2000], Loss: 1.0050
Epoch [600/2000], Loss: 1.0016
Epoch [800/2000], Loss: 0.9983
Epoch [1000/2000], Loss: 0.9974
Epoch [1200/2000], Loss: 0.9959
Epoch [1400/2000], Loss: 0.9969
Epoch [1600/2000], Loss: 0.9935
Epoch [

We see the gradients explode resulting in loss divergence in the case of large weight initialization