In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import time

In [18]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 4471381.63it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1533052.72it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 2234698.54it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6689090.16it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [21]:
def assessment(model):
    model.eval()
    as_total, as_correct = 0, 0
    with torch.no_grad():
        for as_images, as_labels in test_loader:
            as_outputs = model(as_images)
            _, as_predicted = torch.max(as_outputs.data, 1)
            as_total += as_labels.size(0)
            as_correct += (as_predicted == as_labels).sum().item()

        as_accuracy = 100 * as_correct / as_total
    return as_accuracy

In [22]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

In [23]:
digit_recognizer = SimpleNN()
digit_recognizer.load_state_dict(torch.load('digit_recognizer.pth'))
print(f'Точность Загруженной FC модели при проверке на тестовом наборе: {assessment(digit_recognizer):.2f}%')

Точность Загруженной FC модели при проверке на тестовом наборе: 94.07%


In [24]:
def fc_learn(fc_model, epochs=5, step=0.1):
    fc_criterion = nn.CrossEntropyLoss()
    fc_optimizer = torch.optim.SGD(fc_model.parameters(), lr=step)
    for epoch in range(epochs):
        fc_model.train()
        for fc_i, (fc_images, fc_labels) in enumerate(train_loader):
            fc_optimizer.zero_grad()  # Обнуление градиентов
            fc_outputs = fc_model(fc_images)  # Получение выхода модели
            fc_loss = fc_criterion(fc_outputs, fc_labels)  # Вычисление потерь
            fc_loss.backward()  # Обратное распространение ошибки
            fc_optimizer.step()  # Обновление весов

In [25]:
fc_learn(digit_recognizer, 2, 0.01) # дообучаем сохраненную модель
print(f'Точность Дообученной FC модели при проверке на тестовом наборе: {assessment(digit_recognizer):.2f}%')

Точность Дообученной FC модели при проверке на тестовом наборе: 97.36%
