In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time

device = 'cuda' if torch.cuda.is_available() else "cpu"

class Mnist(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(28 * 28, 1000)
        self.layer2 = nn.Linear(1000, 2000)
        self.layer3 = nn.Linear(2000, 2000)
        self.layer4 = nn.Linear(2000, 10)
        self.relu = nn.ReLU()

    def forward(self, X):
        X = X.view(-1, 28 * 28)
        X = self.relu(self.layer1(X))
        X = self.relu(self.layer2(X))
        X = self.relu(self.layer3(X))
        X = self.layer4(X) 
        return X

    def train_model(self, train_loader, criterion, optimizer, device, epochs):
        self.train()
        for epoch in range(epochs):
            running_loss = 0.0
            correct = 0
            total = 0
            start_time = time.time()
            for batch_idx, (data, target) in enumerate(train_loader):
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                outputs = self(data)
                loss = criterion(outputs, target)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
                if (batch_idx + 1) % 100 == 0:
                    print(f'Epoch [{epoch+1}/{epochs}], Step [{batch_idx + 1}/{len(train_loader)}], '
                          f'Loss: {running_loss / (batch_idx + 1):.4f}, '
                          f'Accuracy: {100 * correct / total:.2f}%')
            epoch_time = time.time() - start_time
            print(f'Epoch [{epoch+1}/{epochs}] completed in {epoch_time:.2f} seconds. '
                  f'Average Loss: {running_loss / len(train_loader):.4f}, '
                  f'Accuracy: {100 * correct / total:.2f}%')

    def test(self, test_loader, criterion, device, class_names):
        self.eval()
        test_loss = 0.0
        correct = 0
        total = 0
        class_correct = list(0. for _ in range(len(class_names)))
        class_total = list(0. for _ in range(len(class_names)))
        class_predictions = {i: [] for i in range(len(class_names))}
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                outputs = self(data)
                loss = criterion(outputs, target)
                test_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
                
                for i, (t, p) in enumerate(zip(target, predicted)):
                    class_total[t.item()] += 1
                    if t == p:
                        class_correct[t.item()] += 1
                    class_predictions[t.item()].append(p.item())

        average_loss = test_loss / len(test_loader)
        accuracy = 100 * correct / total
        print(f'Test Loss: {average_loss:.4f}, Test Accuracy: {accuracy:.2f}%\n')
        
        for i in range(len(class_names)):
            if class_total[i] > 0:
                correct_count = int(class_correct[i])
                total_count = int(class_total[i])
                class_accuracy = 100 * class_correct[i] / class_total[i]
                print(f'Class: {class_names[i]:15s} - Correct: {correct_count}/{total_count} '
                      f'({class_accuracy:.2f}%)')
                
                # Detailed prediction breakdown
                prediction_counts = {j: class_predictions[i].count(j) for j in range(len(class_names))}
                print(f'  Prediction breakdown:')
                for j in range(len(class_names)):
                    if prediction_counts[j] > 0:
                        print(f'    Predicted as {j}: {prediction_counts[j]} out of {total_count} '
                              f'({100 * prediction_counts[j] / total_count:.2f}%)')
            else:
                print(f'Class: {class_names[i]:15s} - No samples.')
            print()
        
        return accuracy

def prepare_data(batch_size_train=64, batch_size_test=1024):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
    train_loader = DataLoader(mnist_trainset, batch_size=batch_size_train, shuffle=True)
    mnist_testset = datasets.MNIST(root='../data', train=False, download=True, transform=transform)
    test_loader = DataLoader(mnist_testset, batch_size=batch_size_test, shuffle=False)
    return train_loader, test_loader
 
if __name__ == "__main__":
    model = Mnist().to(device)
    print(model)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.01)

    train_loader, test_loader = prepare_data()
    class_names = [str(i) for i in range(10)]  
    model.train_model(train_loader, criterion, optimizer, device, epochs=100)

    # Test the model
    model.test(test_loader, criterion, device, class_names)

Mnist(
  (layer1): Linear(in_features=784, out_features=1000, bias=True)
  (layer2): Linear(in_features=1000, out_features=2000, bias=True)
  (layer3): Linear(in_features=2000, out_features=2000, bias=True)
  (layer4): Linear(in_features=2000, out_features=10, bias=True)
  (relu): ReLU()
)
Epoch [1/100], Step [100/938], Loss: 3.7352, Accuracy: 55.20%
Epoch [1/100], Step [200/938], Loss: 2.2499, Accuracy: 65.75%
Epoch [1/100], Step [300/938], Loss: 1.7121, Accuracy: 70.78%
Epoch [1/100], Step [400/938], Loss: 1.4111, Accuracy: 74.36%
Epoch [1/100], Step [500/938], Loss: 1.2250, Accuracy: 76.59%
Epoch [1/100], Step [600/938], Loss: 1.0993, Accuracy: 78.18%
Epoch [1/100], Step [700/938], Loss: 1.0088, Accuracy: 79.44%
Epoch [1/100], Step [800/938], Loss: 0.9356, Accuracy: 80.47%
Epoch [1/100], Step [900/938], Loss: 0.8766, Accuracy: 81.30%
Epoch [1/100] completed in 94.90 seconds. Average Loss: 0.8562, Accuracy: 81.61%
Epoch [2/100], Step [100/938], Loss: 0.3796, Accuracy: 89.47%
Epoch [2

KeyboardInterrupt: 

In [6]:
model

Mnist(
  (layer1): Linear(in_features=784, out_features=1000, bias=True)
  (layer2): Linear(in_features=1000, out_features=2000, bias=True)
  (layer3): Linear(in_features=2000, out_features=2000, bias=True)
  (layer4): Linear(in_features=2000, out_features=10, bias=True)
  (relu): ReLU()
)

6809010