### 参考

- https://imagingsolution.net/deep-learning/pytorch/pytorch_mnist_sample_program/

In [1]:
from torch_directml import device

dml = device()

In [2]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor()
])

trainset = datasets.MNIST(
    root='./data', 
    train=True, 
    download=True, 
    transform=transform
)

testset = datasets.MNIST(
    root='./data', 
    train=False,
    download=True, 
    transform=transform
)

train_dataloader = DataLoader(
    trainset, 
    batch_size=100, 
    shuffle=True, 
)

test_dataloader = DataLoader(
    testset, 
    batch_size=100, 
    shuffle=False, 
)

In [8]:
import torch
import torch.nn.functional as F
from torch import Tensor, nn

class MLPModel(nn.Module):
    def __init__(self) -> None:
        super(MLPModel, self).__init__()
        
        self.fc1 = nn.Linear(784, 100)
        self.fc2 = nn.Linear(100, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)

        return x
    
model = MLPModel().to(dml)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model.train()

for i in range(10):
    loss_sum = 0

    for inputs, labels in train_dataloader:
        inputs = inputs.to(dml)
        labels = labels.to(dml)

        # 勾配の計算
        optimizer.zero_grad()

        # 学習
        inputs = inputs.view(-1, 28 * 28) # データを一次元化
        outputs = model(inputs)

        # 損失の計算
        loss = criterion(outputs, labels)
        loss_sum += loss

        # 勾配の計算
        loss.backward()

        # 最適化
        optimizer.step()

    print(f'Epoch: {i + 1}, Loss: {loss_sum.item() / len(train_dataloader)}')
    torch.save(model.state_dict(), 'model_weights.pth')        

Epoch: 1, Loss: 0.4004826863606771
Epoch: 2, Loss: 0.19105876922607423
Epoch: 3, Loss: 0.14109102884928384
Epoch: 4, Loss: 0.1097244135538737
Epoch: 5, Loss: 0.09047023773193359
Epoch: 6, Loss: 0.07557068506876628
Epoch: 7, Loss: 0.06330962498982748
Epoch: 8, Loss: 0.05446879704793294
Epoch: 9, Loss: 0.04681886037190755
Epoch: 10, Loss: 0.04070676485697428


In [9]:
loss_sum = 0
correct = 0

with torch.no_grad():
    for inputs, labels in test_dataloader:
        inputs = inputs.to(dml)
        labels = labels.to(dml)

        inputs = inputs.view(-1, 28 * 28)
        outputs = model(inputs)

        loss_sum += criterion(outputs, labels)
        pred = outputs.argmax(1)
        correct += pred.eq(labels.view_as(pred)).sum().item()

print(f'Loss: {loss_sum / len(test_dataloader)}, Acc: {correct / len(testset)}')

Loss: 0.07903732359409332, Acc: 0.9755
