In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
transform = transforms.Compose([
    transforms.Resize((16, 16)),
    transforms.ToTensor()
    ])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [3]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.net(x)

In [4]:
device = torch.device("cuda")
model = SimpleNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)

In [5]:
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)

        # прямой проход
        scores = model(data)
        loss = criterion(scores, targets)

        # обратный проход
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Тестирование модели
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    accuracy = 100 * correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Test Accuracy: {accuracy:.2f}%")


Epoch [1/5], Loss: 0.1156, Test Accuracy: 94.47%
Epoch [2/5], Loss: 0.0405, Test Accuracy: 95.94%
Epoch [3/5], Loss: 0.0351, Test Accuracy: 96.37%
Epoch [4/5], Loss: 0.1275, Test Accuracy: 97.12%
Epoch [5/5], Loss: 0.2003, Test Accuracy: 97.40%
