In [2]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path

In [3]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
dataset_path = Path("mnist_samples")
dataset_path_files = sorted(list(dataset_path.glob("*.pt")))
dataset_path_files

[PosixPath('mnist_samples/mnist_label_0.pt'),
 PosixPath('mnist_samples/mnist_label_1.pt'),
 PosixPath('mnist_samples/mnist_label_2.pt'),
 PosixPath('mnist_samples/mnist_label_3.pt'),
 PosixPath('mnist_samples/mnist_label_4.pt'),
 PosixPath('mnist_samples/mnist_label_5.pt'),
 PosixPath('mnist_samples/mnist_label_6.pt'),
 PosixPath('mnist_samples/mnist_label_7.pt'),
 PosixPath('mnist_samples/mnist_label_8.pt'),
 PosixPath('mnist_samples/mnist_label_9.pt')]

In [5]:
saved_models = [Path('pretrained_models') / ('pretrained_' + f.name) for f in dataset_path_files]
saved_models

[PosixPath('pretrained_models/pretrained_mnist_label_0.pt'),
 PosixPath('pretrained_models/pretrained_mnist_label_1.pt'),
 PosixPath('pretrained_models/pretrained_mnist_label_2.pt'),
 PosixPath('pretrained_models/pretrained_mnist_label_3.pt'),
 PosixPath('pretrained_models/pretrained_mnist_label_4.pt'),
 PosixPath('pretrained_models/pretrained_mnist_label_5.pt'),
 PosixPath('pretrained_models/pretrained_mnist_label_6.pt'),
 PosixPath('pretrained_models/pretrained_mnist_label_7.pt'),
 PosixPath('pretrained_models/pretrained_mnist_label_8.pt'),
 PosixPath('pretrained_models/pretrained_mnist_label_9.pt')]

In [10]:
def train_model(dataset_file_path: Path, saved_model_path: Path):

    model = SimpleNN()  # Initialize train_model
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1)

    images, labels = torch.load(dataset_file_path)

    # create a tensordataset
    dataset = TensorDataset(images, labels)

    # create a dataloader for the dataset
    train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

    # training loop
    for epoch in range(1000):
        running_loss = 0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # accumulate loss
            running_loss += loss.item()
        
        # Calculate average loss for the epoch
        avg_loss = running_loss / len(train_loader)

        if (epoch + 1) % 100 == 0:
            avg_loss = running_loss / len(train_loader)
            print(f"Epoch {epoch + 1:04d}: Loss = {avg_loss:.6f}")

    # Save the model
    torch.save(model.state_dict(), str(saved_model_path))


In [12]:
for dataset_file_path, saved_model_path in zip(dataset_path_files, saved_models):
    train_model(dataset_file_path, saved_model_path)

  images, labels = torch.load(dataset_file_path)


Epoch 0100: Loss = 0.000087
Epoch 0200: Loss = 0.000047
Epoch 0300: Loss = 0.000026
Epoch 0400: Loss = 0.000020
Epoch 0500: Loss = 0.000015
Epoch 0600: Loss = 0.000012
Epoch 0700: Loss = 0.000011
Epoch 0800: Loss = 0.000009
Epoch 0900: Loss = 0.000008
Epoch 1000: Loss = 0.000007
Epoch 0100: Loss = 0.000117
Epoch 0200: Loss = 0.000058
Epoch 0300: Loss = 0.000035
Epoch 0400: Loss = 0.000026
Epoch 0500: Loss = 0.000020
Epoch 0600: Loss = 0.000017
Epoch 0700: Loss = 0.000015
Epoch 0800: Loss = 0.000013
Epoch 0900: Loss = 0.000011
Epoch 1000: Loss = 0.000010
Epoch 0100: Loss = 0.000098
Epoch 0200: Loss = 0.000046
Epoch 0300: Loss = 0.000030
Epoch 0400: Loss = 0.000021
Epoch 0500: Loss = 0.000017
Epoch 0600: Loss = 0.000014
Epoch 0700: Loss = 0.000012
Epoch 0800: Loss = 0.000010
Epoch 0900: Loss = 0.000009
Epoch 1000: Loss = 0.000008
Epoch 0100: Loss = 0.000090
Epoch 0200: Loss = 0.000042
Epoch 0300: Loss = 0.000027
Epoch 0400: Loss = 0.000020
Epoch 0500: Loss = 0.000016
Epoch 0600: Loss = 0

In [15]:
def test_model(model_path: Path, test_data_path: Path):
    # Load the model
    model = SimpleNN()
    model.load_state_dict(torch.load(str(model_path), weights_only=True))
    model.eval()
    
    # Load test data
    test_images, test_labels = torch.load(test_data_path, weights_only=True)
    
    # Make predictions
    with torch.no_grad():
        outputs = model(test_images)
        _, predicted = torch.max(outputs.data, 1)
    
    # Calculate accuracy
    total = test_labels.size(0)
    correct = (predicted == test_labels).sum().item()
    accuracy = 100 * correct / total
    
    print(f'Accuracy on test data: {accuracy:.2f}%')
    return accuracy

In [17]:
for model_path, test_data_path in zip(saved_models, dataset_path_files):
    test_model(model_path, test_data_path)


Accuracy on test data: 100.00%
Accuracy on test data: 100.00%
Accuracy on test data: 100.00%
Accuracy on test data: 100.00%
Accuracy on test data: 100.00%
Accuracy on test data: 100.00%
Accuracy on test data: 100.00%
Accuracy on test data: 100.00%
Accuracy on test data: 100.00%
Accuracy on test data: 100.00%
