In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

print(f"PyTorch version: {torch.__version__}")

torch.set_num_threads(2)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

print("Loading datasets...")
train_dataset = MNIST('./datasets/mnist_data/data', train=True, download=True, transform=transform)
test_dataset = MNIST('./datasets/mnist_data/data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0)
print(f"Datasets loaded. Train size: {len(train_dataset)}, Test size: {len(test_dataset)}")


class SimpleMLPModel(nn.Module):
    def __init__(self, use_adapt=False, axx_mult='mul8s_acc'):
        super(SimpleMLPModel, self).__init__()

        self.use_adapt = False  # Force disable AdaPT
        self.axx_mult = None

        print("Initializing network with standard PyTorch layers")
        self.fc1 = nn.Linear(784, 128)
        gain = nn.init.calculate_gain('sigmoid')
        nn.init.xavier_normal_(self.fc1.weight, gain=gain)
        nn.init.zeros_(self.fc1.bias)

        self.bn1 = nn.BatchNorm1d(128)

        self.fc2 = nn.Linear(128, 10)
        nn.init.normal_(self.fc2.weight, mean=0, std=1)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        if x.dim() == 4 or x.dim() == 3:
            x = x.view(-1, 784)

        x = self.fc1(x)
        x = self.bn1(x)
        x = torch.sigmoid(x)
        x = F.dropout(x, p=0.2, training=self.training)

        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


def train_model(model, device, train_loader, test_loader, epochs=5):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.NLLLoss()

    best_accuracy = 0.0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        start_time = time.time()
        batch_count = 0

        print(f"Epoch {epoch + 1}/{epochs} started")
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            batch_count += 1

            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            if batch_idx % 100 == 99:
                avg_loss = running_loss / batch_count
                accuracy = 100. * correct / total
                print(f'  Batch {batch_idx + 1}: Loss {avg_loss:.4f}, Accuracy {accuracy:.2f}%')

        epoch_loss = running_loss / batch_count
        epoch_acc = 100. * correct / total
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch + 1} completed in {epoch_time:.2f}s. Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")

        test_loss, test_accuracy = evaluate_model(model, device, test_loader, criterion)
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            filename = 'mnist_mlp_standard_best.pt'
            try:
                torch.save(model.state_dict(), filename)
                print(f"Saved new best model to {filename} with accuracy: {best_accuracy:.2f}%")
            except Exception as e:
                print(f"Could not save model: {e}")

    return best_accuracy


def evaluate_model(model, device, test_loader, criterion=None):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    if criterion is None:
        criterion = nn.NLLLoss()

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

    test_loss /= len(test_loader)
    accuracy = 100. * correct / total

    return test_loss, accuracy


def main():
    device = torch.device("cpu")
    print(f"Using device: {device}")

    print("\n=== Training Standard PyTorch MLP ===")
    model_standard = SimpleMLPModel(use_adapt=False)
    print(f"Model architecture:\n{model_standard}")

    try:
        accuracy_standard = train_model(model_standard, device, train_loader, test_loader, epochs=10)
        print(f"Standard PyTorch model achieved {accuracy_standard:.2f}% accuracy")
    except Exception as e:
        print(f"Error during training: {e}")


if __name__ == "__main__":
    main()


PyTorch version: 2.5.1+cu124
Loading datasets...
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./datasets/mnist_data/data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.1MB/s]


Extracting ./datasets/mnist_data/data/MNIST/raw/train-images-idx3-ubyte.gz to ./datasets/mnist_data/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./datasets/mnist_data/data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 481kB/s]


Extracting ./datasets/mnist_data/data/MNIST/raw/train-labels-idx1-ubyte.gz to ./datasets/mnist_data/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./datasets/mnist_data/data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.46MB/s]


Extracting ./datasets/mnist_data/data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./datasets/mnist_data/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./datasets/mnist_data/data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 8.66MB/s]


Extracting ./datasets/mnist_data/data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./datasets/mnist_data/data/MNIST/raw

Datasets loaded. Train size: 60000, Test size: 10000
Using device: cpu

=== Training Standard PyTorch MLP ===
Initializing network with standard PyTorch layers
Model architecture:
SimpleMLPModel(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)
Epoch 1/10 started
  Batch 100: Loss 2.1969, Accuracy 63.95%
  Batch 200: Loss 1.5984, Accuracy 71.49%
  Batch 300: Loss 1.3548, Accuracy 74.72%
  Batch 400: Loss 1.2206, Accuracy 76.70%
  Batch 500: Loss 1.1288, Accuracy 78.06%
  Batch 600: Loss 1.0615, Accuracy 79.04%
  Batch 700: Loss 1.0112, Accuracy 79.73%
  Batch 800: Loss 0.9688, Accuracy 80.31%
  Batch 900: Loss 0.9309, Accuracy 80.84%
Epoch 1 completed in 12.10s. Training Loss: 0.9202, Accuracy: 81.01%
Test Los