### PyTorch

In [1]:
# pip install torch torchvision

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [3]:
class ANN(nn.Module):
    def __init__(self):
        super(ANN, self).__init__()
        # Input layer (28x28 = 784 pixels), hidden layers, and output layer
        self.fc1 = nn.Linear(28*28, 256)  # First hidden layer with 256 neurons
        self.fc2 = nn.Linear(256, 128)    # Second hidden layer with 128 neurons
        self.fc3 = nn.Linear(128, 10)     # Output layer with 10 classes (digits 0-9)

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the 28x28 image into a 784-dimensional vector
        x = torch.relu(self.fc1(x))  # Activation function (ReLU) after first hidden layer
        x = torch.relu(self.fc2(x))  # Activation function (ReLU) after second hidden layer
        x = self.fc3(x)  # Output layer (no activation because we will apply CrossEntropyLoss)
        return x


In [4]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_data = datasets.MNIST(root='mnist_data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='mnist_data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 3785437.28it/s]


Extracting mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 436410.22it/s]


Extracting mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 3885467.58it/s]


Extracting mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6258386.59it/s]

Extracting mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist_data/MNIST/raw






In [5]:
model = ANN()  # Instantiate the model
criterion = nn.CrossEntropyLoss()  # Loss function (Cross-Entropy for multi-class classification)
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Optimizer (Adam)


In [6]:
num_epochs = 5

for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in train_loader:
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")


Epoch [1/5], Loss: 0.3395
Epoch [2/5], Loss: 0.1503
Epoch [3/5], Loss: 0.1103
Epoch [4/5], Loss: 0.0905
Epoch [5/5], Loss: 0.0757


In [7]:
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy: {100 * correct / total:.2f}%")


Accuracy: 97.31%


Jax

In [8]:
import jax
import jax.numpy as jnp
import optax
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Step 1: Load the MNIST dataset using torchvision
def get_mnist_dataloader(batch_size=128):
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

train_loader, test_loader = get_mnist_dataloader()

# Step 2: Initialize model parameters
def init_params(layer_sizes, key):
    params = []
    keys = jax.random.split(key, len(layer_sizes) - 1)

    for in_size, out_size, k in zip(layer_sizes[:-1], layer_sizes[1:], keys):
        w_key, b_key = jax.random.split(k)
        weights = jax.random.normal(w_key, (in_size, out_size)) * jnp.sqrt(2.0 / in_size)
        bias = jnp.zeros(out_size)
        params.append((weights, bias))
    return params

# Step 3: Forward pass through the network
def forward(params, x):
    for w, b in params[:-1]:
        x = jnp.dot(x, w) + b
        x = jax.nn.relu(x)
    w, b = params[-1]
    return jax.nn.log_softmax(jnp.dot(x, w) + b)

# Step 4: Loss function (negative log-likelihood)
def loss_fn(params, x, y):
    preds = forward(params, x)
    return -jnp.mean(jnp.sum(preds * y, axis=1))

# Step 5: Accuracy metric
def accuracy(params, x, y):
    preds = jnp.argmax(forward(params, x), axis=1)
    return jnp.mean(preds == jnp.argmax(y, axis=1))

# Step 6: Update function using gradient descent
def update(params, opt_state, x, y, opt_update):
    grads = jax.grad(loss_fn)(params, x, y)
    updates, opt_state = opt_update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state

# Step 7: Training loop
def one_hot(labels, num_classes=10):
    return jax.nn.one_hot(labels, num_classes)

def train_model(params, train_loader, test_loader, epochs=5, lr=0.001):
    opt_init, opt_update = optax.adam(lr)
    opt_state = opt_init(params)

    for epoch in range(epochs):
        # Training
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(-1, 28 * 28).numpy()  # Flatten the image
            target = one_hot(target.numpy())
            params, opt_state = update(params, opt_state, data, target, opt_update)

        # Testing
        test_acc = 0
        test_loss = 0
        for data, target in test_loader:
            data = data.view(-1, 28 * 28).numpy()
            target = one_hot(target.numpy())
            test_loss += loss_fn(params, data, target)
            test_acc += accuracy(params, data, target)

        print(f'Epoch {epoch+1}, Loss: {test_loss / len(test_loader):.4f}, Accuracy: {test_acc / len(test_loader):.4f}')

# Step 8: Main script

layer_sizes = [28 * 28, 128, 64, 10]  # Input, hidden layers, and output sizes
key = jax.random.PRNGKey(0)
params = init_params(layer_sizes, key)
train_model(params, train_loader, test_loader)


ModuleNotFoundError: No module named 'jax'