# Dataset

In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

### Dataset Downloading

In [1]:
import torch
import torch.nn as nn
import torch.optim.sgd
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

In [2]:
# 1. Generate random data (2000 samples, 10 features)
torch.manual_seed(42)
X = torch.randn(2000, 10)
weights_true = torch.randn(10)
bias_true = 0
logits = X @ weights_true + bias_true
y = (logits > 0).long()  # Binary labels


In [3]:
# 2. Normalize X (z-score)
mean = X.mean(dim=0)
std = X.std(dim=0)
std[std == 0] = 1  # to prevent division by zero
X = (X - mean) / std

In [4]:
# 3. Create a custom Dataset
class MyDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

### Prepare DataLoader

In [5]:
# 4. Create DataLoaders
dataset = MyDataset(X, y)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

# Model

In [7]:
# 5. Define the model
class BinaryClassifier(nn.Module):
    def __init__(self, std, activation):
        super().__init__()
        self.flatten = torch.nn.Flatten()
        self.linear1 = torch.nn.Linear(10, 5)
        self.linear2 = torch.nn.Linear(5, 5)
        self.linear3 = torch.nn.Linear(5, 10)
        self.activation = activation

        # Initialize weights from N(0, 1)
        torch.nn.init.normal_(self.linear1.weight, mean=0.0, std=std)
        torch.nn.init.normal_(self.linear2.weight, mean=0.0, std=std)
        torch.nn.init.normal_(self.linear3.weight, mean=0.0, std=std)

        # Set biases to zero
        torch.nn.init.zeros_(self.linear1.bias)
        torch.nn.init.zeros_(self.linear2.bias)
        torch.nn.init.zeros_(self.linear3.bias)

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.activation(x)
        x = self.linear3(x)
        return x


# Tanh

In [24]:
model = BinaryClassifier(std=torch.sqrt(torch.tensor(1000)), activation=torch.nn.Tanh())

# 6. Loss and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 7. Training loop
num_epochs = 10
for epoch in range(num_epochs):
    total_loss = 0
    for batch, (X_batch, y_batch) in enumerate(train_loader):
        y_pred = model(X_batch).squeeze()
        loss = loss_fn(y_pred, y_batch.long())
        
        optimizer.zero_grad()
        loss.backward()

        # Print gradient of the first hidden layer (linear1)
        print(f"Epoch {epoch+1}, Batch {batch+1}, linear1 weights:")
        print(model.linear1.weight)
        print(f"Epoch {epoch+1}, Batch {batch+1}, Gradient of linear1 weights:")
        print(model.linear1.weight.grad)
        print(f"Mean: {torch.mean(torch.abs(model.linear1.weight.grad))}")

        optimizer.step()

        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")

Epoch 1, Batch 1, linear1 weights:
Parameter containing:
tensor([[  9.9714, -24.3223, -19.5051, -18.5097,  19.4218,  32.0552, -28.6252,
         -11.8130,   1.1564,   7.5493],
        [ 10.2774, -27.6219, -29.0160, -53.1635, -19.9838, -52.2537, -40.9282,
         -67.9544,  21.9325,  14.6426],
        [ 59.3813,  40.6972,  -0.4353,  38.2616, -48.1933, -16.8377, -44.2538,
         -11.0382,  -1.8525,  20.1654],
        [-61.3985, -26.6496, -13.4137,  11.2169,  10.3209,  45.1751,   9.5379,
          51.6518,  -7.1355,  26.6926],
        [ 25.5935, -15.7656,  21.4057,  -3.6885,  -2.3978,   0.3713,   3.3640,
          22.7786,  13.2677,  28.2129]], requires_grad=True)
Epoch 1, Batch 1, Gradient of linear1 weights:
tensor([[-7.2916e+00, -9.3389e-01,  4.1943e+00,  4.7393e+00,  1.9669e+00,
          6.3788e+00,  4.1757e-01, -3.4710e+00, -7.5561e-01, -6.5884e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+

# ReLU

In [34]:
model = BinaryClassifier(std=torch.sqrt(torch.tensor(100)), activation=torch.nn.ReLU())

# 6. Loss and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 7. Training loop
num_epochs = 10
for epoch in range(num_epochs):
    total_loss = 0
    for batch, (X_batch, y_batch) in enumerate(train_loader):
        y_pred = model(X_batch).squeeze()
        loss = loss_fn(y_pred, y_batch.long())
        
        optimizer.zero_grad()
        loss.backward()

        # Print gradient of the first hidden layer (linear1)
        print(f"Epoch {epoch+1}, Batch {batch+1}, linear1 weights:")
        print(model.linear1.weight)
        print(f"Epoch {epoch+1}, Batch {batch+1}, Gradient of linear1 weights:")
        print(model.linear1.weight.grad)
        print(f"Mean: {torch.mean(torch.abs(model.linear1.weight.grad))}")

        optimizer.step()

        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")

Epoch 1, Batch 1, linear1 weights:
Parameter containing:
tensor([[  3.1264,  13.8619,  18.6659,  10.8036,   7.6330, -15.4397,   5.5407,
           9.9648,  13.6013, -12.7135],
        [ -4.1470,  -9.6691,  -5.3307,  -1.0807,  -5.1473,  17.3370,   2.3684,
           5.5857,   2.4576,  14.6300],
        [  0.0292,  -2.1970,   7.9136,   2.6869,   0.1966,   2.3966,  -0.1383,
           0.1160,   9.0774,  10.8185],
        [  1.9771,   8.8321,  10.1972,  13.0216,  -7.1277,  -1.2818,  12.4438,
          -3.1589,  -3.8193,  19.4798],
        [ 14.9760,   8.3950,  20.5835,   7.4557,  -1.2940,  11.5206,   0.5619,
          -4.6095,  -0.2675,  -9.0327]], requires_grad=True)
Epoch 1, Batch 1, Gradient of linear1 weights:
tensor([[ 29.4755, -32.6348,  29.9585,  25.1479, -34.3426, -54.2639,  26.7878,
          59.0487,  63.7385, -29.6194],
        [-36.6046,  -5.8116,   3.6245,  21.2271,  -8.3449,  42.9527,  14.1862,
         -13.3890, -11.3764, -23.2346],
        [ -0.2185,  37.3541, -26.2258, -46