In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
# Define the transformation to apply to the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

In [3]:
# Load the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [4]:
# Define the data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [5]:
class FCNN(nn.Module):
    def __init__(self):
        super(FCNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [6]:
# Define the model, optimizer, and loss function
model = FCNN()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = nn.CrossEntropyLoss()

In [7]:
# Train the model
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Epoch: {}, Batch index: {}, Loss: {:.4f}'.format(epoch, batch_idx, loss.item()))

    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data, target in test_loader:
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

        accuracy = correct / total
        print('Epoch: {}, Accuracy: {:.2f}%'.format(epoch, 100 * accuracy))

Epoch: 0, Batch index: 0, Loss: 2.3151
Epoch: 0, Batch index: 100, Loss: 1.0681
Epoch: 0, Batch index: 200, Loss: 0.5415
Epoch: 0, Batch index: 300, Loss: 0.5540
Epoch: 0, Batch index: 400, Loss: 0.3818
Epoch: 0, Batch index: 500, Loss: 0.2644
Epoch: 0, Batch index: 600, Loss: 0.3227
Epoch: 0, Batch index: 700, Loss: 0.2413
Epoch: 0, Batch index: 800, Loss: 0.2809
Epoch: 0, Batch index: 900, Loss: 0.2634
Epoch: 0, Accuracy: 92.33%
Epoch: 1, Batch index: 0, Loss: 0.3673
Epoch: 1, Batch index: 100, Loss: 0.3470
Epoch: 1, Batch index: 200, Loss: 0.2280
Epoch: 1, Batch index: 300, Loss: 0.1070
Epoch: 1, Batch index: 400, Loss: 0.2684
Epoch: 1, Batch index: 500, Loss: 0.1660
Epoch: 1, Batch index: 600, Loss: 0.2012
Epoch: 1, Batch index: 700, Loss: 0.2362
Epoch: 1, Batch index: 800, Loss: 0.1754
Epoch: 1, Batch index: 900, Loss: 0.2054
Epoch: 1, Accuracy: 94.50%
Epoch: 2, Batch index: 0, Loss: 0.1487
Epoch: 2, Batch index: 100, Loss: 0.1156
Epoch: 2, Batch index: 200, Loss: 0.0921
Epoch: 2,