In [32]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

# Set the seed for reproducibility
torch.manual_seed(42)

# Define MLP model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(784, 32)  # Input size is 784 (28x28) for MNIST
        self.fc2 = nn.Linear(32, 32)   # Hidden layer to hidden layer
        self.fc3 = nn.Linear(32, 10)   # Hidden layer to 10 output neurons (for 10 MNIST classes)

    def forward(self, x):
        x = x.view(-1, 784)            # Flatten the 28x28 image into 784-dimensional input
        x = torch.relu(self.fc1(x))    # First hidden layer with ReLU
        x = torch.relu(self.fc2(x))    # Second hidden layer with ReLU
        x = self.fc3(x)                # Output layer (no activation, since we use CrossEntropyLoss)
        return x

# Function to extract weights and biases with specific shapes
def get_weights_and_biases_separate(model):
    weights = []
    biases = []
    for layer in model.children():
        if isinstance(layer, nn.Linear):
            weights.append(layer.weight.data.clone())  # Append weight tensor
            biases.append(layer.bias.data.clone())    # Append bias tensor
    return tuple(weights), tuple(biases)

# Function to calculate accuracy
def calculate_accuracy(model, device, data_loader):
    model.eval()  
    correct = 0
    total = 0
    with torch.no_grad(): 
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output, 1) 
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = 100. * correct / total
    return accuracy

# Training setup
def train_mnist(model, device, train_loader, test_loader, optimizer, criterion, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()           # Zero the gradients
        output = model(data)            # Forward pass
        loss = criterion(output, target) # Calculate loss
        loss.backward()                 # Backpropagation
        optimizer.step()                # Update weights

    print(f'Epoch: {epoch} Training Loss: {loss.item():.6f}')
            
    train_accuracy = calculate_accuracy(model, device, test_loader)
    print(f'Epoch: {epoch} Test Accuracy: {train_accuracy:.2f}')

# Custom dataset class for reshaped weights and biases
class WeightsBiasesDataset(Dataset):
    def __init__(self, weights, biases):
        self.weights = weights
        self.biases = biases

    def __len__(self):
        return 1  # We only have one "data point" consisting of all layers

    def __getitem__(self, idx):
        # Combine weights and biases into a single tuple for batch processing
        weights_tuple = (
            torch.stack([self.weights[0]] * 32),  # Batch of weights for layer 1
            torch.stack([self.weights[1]] * 32),  # Batch of weights for layer 2
            torch.stack([self.weights[2]] * 32)   # Batch of weights for layer 3
        )

        biases_tuple = (
            torch.stack([self.biases[0]] * 32),   # Batch of biases for layer 1
            torch.stack([self.biases[1]] * 32),   # Batch of biases for layer 2
            torch.stack([self.biases[2]] * 32)    # Batch of biases for layer 3
        )

        return {
            'weights': weights_tuple,
            'biases': biases_tuple
        }

# Function to print shapes of weights and biases
def print_shapes(dataset):
    sample = dataset[0]  # Get the single data point (one "sample")
    weights_tuple = sample['weights']
    biases_tuple = sample['biases']
    
    for i in range(len(weights_tuple)):
        print(f"Layer {i + 1}: Weight shape: {weights_tuple[i].shape}, Bias shape: {biases_tuple[i].shape}")

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize the MLP model, loss function, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MLP().to(device)
criterion = nn.CrossEntropyLoss()  # Cross-entropy for classification
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
for epoch in range(1, 20):
    train_mnist(model, device, train_loader, test_loader, optimizer, criterion, epoch)

# Get weights and biases after training
weights, biases = get_weights_and_biases_separate(model)

# Create dataset from the weights and biases
weights_biases_dataset = WeightsBiasesDataset(weights, biases)

# Create a DataLoader to iterate over the weights and biases dataset
weights_biases_loader = DataLoader(weights_biases_dataset, batch_size=1, shuffle=False)

# Print shapes of weights and biases for the dataset
print_shapes(weights_biases_dataset)

# Optionally, save weights and biases to files for later use
torch.save((weights, biases), "Outputs/mlp.pt")

Epoch: 1 Training Loss: 0.291720
Epoch: 1 Test Accuracy: 93.29
Epoch: 2 Training Loss: 0.080920
Epoch: 2 Test Accuracy: 94.44
Epoch: 3 Training Loss: 0.050231
Epoch: 3 Test Accuracy: 94.92
Epoch: 4 Training Loss: 0.053785
Epoch: 4 Test Accuracy: 95.59
Epoch: 5 Training Loss: 0.064096
Epoch: 5 Test Accuracy: 96.14
Epoch: 6 Training Loss: 0.181577
Epoch: 6 Test Accuracy: 96.38
Epoch: 7 Training Loss: 0.200738
Epoch: 7 Test Accuracy: 96.22
Epoch: 8 Training Loss: 0.123890
Epoch: 8 Test Accuracy: 96.43
Epoch: 9 Training Loss: 0.175379
Epoch: 9 Test Accuracy: 96.65
Epoch: 10 Training Loss: 0.050709
Epoch: 10 Test Accuracy: 96.66
Epoch: 11 Training Loss: 0.099322
Epoch: 11 Test Accuracy: 96.69
Epoch: 12 Training Loss: 0.019186
Epoch: 12 Test Accuracy: 96.79
Epoch: 13 Training Loss: 0.019557
Epoch: 13 Test Accuracy: 96.78
Epoch: 14 Training Loss: 0.011932
Epoch: 14 Test Accuracy: 96.65
Epoch: 15 Training Loss: 0.032789
Epoch: 15 Test Accuracy: 96.93
Epoch: 16 Training Loss: 0.028926
Epoch: 16