# Dataset

In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

In [2]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


### Dataset Downloading

In [103]:
import torch
import torch.nn as nn
import torch.optim.sgd
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F


In [104]:
# 1. Generate random data (2000 samples, 10 features)
torch.manual_seed(42)
X = torch.randn(2000, 10)
weights_true = torch.randn(10)
bias_true = 0
logits = X @ weights_true + bias_true
y = (logits > 0).long()  # Binary labels


In [105]:
# 2. Normalize X (z-score)
mean = X.mean(dim=0)
std = X.std(dim=0)
std[std == 0] = 1  # to prevent division by zero
X = (X - mean) / std

In [106]:
# 3. Create a custom Dataset
class MyDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

### Prepare DataLoader

In [107]:
# 4. Create DataLoaders
dataset = MyDataset(X, y)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [108]:
# 5. Define the model
class BinaryClassifier(nn.Module):
    def __init__(self, std):
        super().__init__()
        self.flatten = torch.nn.Flatten()
        self.linear1 = torch.nn.Linear(10, 5)
        self.linear2 = torch.nn.Linear(5, 5)
        self.linear3 = torch.nn.Linear(5, 10)
        self.activation = torch.nn.Tanh()

        # Initialize weights from N(0, 1)
        torch.nn.init.normal_(self.linear1.weight, mean=0.0, std=std)
        torch.nn.init.normal_(self.linear2.weight, mean=0.0, std=std)
        torch.nn.init.normal_(self.linear3.weight, mean=0.0, std=std)

        # Set biases to zero
        torch.nn.init.zeros_(self.linear1.bias)
        torch.nn.init.zeros_(self.linear2.bias)
        torch.nn.init.zeros_(self.linear3.bias)

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.activation(x)
        x = self.linear3(x)
        return x


In [112]:
model = BinaryClassifier(std=torch.sqrt(torch.tensor(1000)))

# 6. Loss and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 7. Training loop
num_epochs = 10
for epoch in range(num_epochs):
    total_loss = 0
    for batch, (X_batch, y_batch) in enumerate(train_loader):
        y_pred = model(X_batch).squeeze()
        loss = loss_fn(y_pred, y_batch.long())
        
        optimizer.zero_grad()
        loss.backward()

        # Print gradient of the first hidden layer (linear1)
        print(f"Epoch {epoch+1}, Batch {batch+1}, linear1 weights:")
        print(model.linear1.weight)
        print(f"Epoch {epoch+1}, Batch {batch+1}, Gradient of linear1 weights:")
        print(model.linear1.weight.grad)
        print(f"Mean: {torch.mean(model.linear1.weight.grad)}")

        optimizer.step()

        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")

Epoch 1, Batch 1, linear1 weights:
Parameter containing:
tensor([[ 26.3972, -22.9909, -36.0259, -14.7338,  -7.7442, -28.1442,  34.0658,
         -14.6177,  16.0249,  33.0937],
        [-29.5947,  30.1440, -12.6741, -12.9760, -34.8632,   4.5711,  -4.2131,
          -7.0683,  13.6223,  -6.4790],
        [ -4.7943,   5.2878,   7.2591,  -8.9022, -22.3642,  27.4968,  -7.6961,
          27.0192, -23.6870,  13.6128],
        [ -1.0250,  -8.0024,  40.6381, -36.8345,  13.0843,  81.4162,  -1.0112,
           5.7112,  -0.4713,  -1.1286],
        [-32.1507,   6.8471, -31.5950, -41.6562,   0.2202, -21.0803, -31.6391,
          -3.3972, -13.3689,  13.0588]], requires_grad=True)
Epoch 1, Batch 1, Gradient of linear1 weights:
tensor([[-5.8910e-06,  3.4383e-06, -5.3568e-06, -2.0867e-06, -3.1677e-06,
         -6.4611e-07,  5.1274e-08, -1.9211e-06, -7.9056e-08, -1.2732e-06],
        [-4.9421e-04,  2.3767e-05, -3.1145e-04, -4.3941e-04,  8.2596e-04,
         -2.3261e-04, -5.0805e-04, -1.2059e-04,  1.1694e-

In [113]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')
