In [4]:
from torch import nn, optim, Tensor

In [5]:
class MLPModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 10)

        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.parameters())

    def forward(self, x: Tensor) -> Tensor:
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)

        return x

In [6]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

trainset = datasets.MNIST(
    root='./data', 
    train=True, 
    download=True, 
    transform=transforms.ToTensor()
)

testset = datasets.MNIST(
    root='./data', 
    train=False,
    download=True, 
    transform=transforms.ToTensor()
)

train_loader = DataLoader(
    trainset, 
    batch_size=100, 
    shuffle=True, 
    num_workers=2
)

test_loader = DataLoader(
    testset, 
    batch_size=100, 
    shuffle=False, 
    num_workers=2
)

In [7]:
import torch

def train(model: nn.Module, train_loader: DataLoader):
    model.train()

    total_count = 0
    total_correct = 0
    total_loss = 0

    for batch_imgs, batch_labels in train_loader:
        batch_imgs = batch_imgs.reshape(-1, 28 * 28)
        labels = torch.eye(10)[batch_labels]

        outputs = model(batch_imgs)

        # 勾配の計算
        model.optimizer.zero_grad()

        loss = model.criterion(outputs, labels)
        loss.backward()

        # 最適化
        model.optimizer.step()
        

        # 正答率、損失の記録
        _, predicts = torch.max(outputs, 1)
        
        for i in range(len(batch_labels)):
            total_count += 1
            
            if predicts[i] == batch_labels[i]:
                total_correct += 1
                continue
            
            total_loss += loss.item()

    accuracy = total_correct / total_count
    loss = total_loss / total_count
    
    return accuracy, loss

In [8]:
model = MLPModel()

acc, loss = train(model, train_loader)

print(f'Accuracy: {acc * 100 :.4F}')
print(f'Loss: {loss:.4F}')

Accuracy: 94.5983
Loss: 0.0014
