## ignite

https://docs.pytorch.org/ignite/
https://github.com/pytorch/ignite

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from ignite.engine import Engine, Events
from ignite.metrics import Accuracy, Loss, ConfusionMatrix, Precision, Recall
from ignite.contrib.handlers import ProgressBar

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
learning_rate = 0.001
num_epochs = 5
num_classes = 10

In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [7]:
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 128),
    nn.ReLU(),
    nn.Linear(128, num_classes)
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [8]:
# ----------------------------
# 학습용 Engine 정의
# ----------------------------
def train_step(engine, batch):
    model.train()
    inputs, targets = batch
    inputs, targets = inputs.to(device), targets.to(device)

    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    return outputs, targets

trainer = Engine(train_step)

In [9]:
# ----------------------------
# 검증용 Engine 정의
# ----------------------------
def eval_step(engine, batch):
    model.eval()
    with torch.no_grad():
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        return outputs, targets

evaluator = Engine(eval_step)

In [10]:
# ----------------------------
# 메트릭 붙이기
# ----------------------------
Accuracy().attach(trainer, "accuracy")
Loss(criterion).attach(trainer, "loss")

metrics = {
    "accuracy": Accuracy(),
    "loss": Loss(criterion),
    "precision": Precision(average=True),
    "recall": Recall(average=True),
    "confusion_matrix": ConfusionMatrix(num_classes=num_classes)
}

for name, metric in metrics.items():
    metric.attach(evaluator, name)

In [11]:
# ----------------------------
# 진행 바
# ----------------------------
ProgressBar().attach(trainer)

In [12]:
# ----------------------------
# Epoch 끝날 때마다 평가
# ----------------------------
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    print(f"\nValidation Results - Epoch: {engine.state.epoch}")
    print(f" Accuracy: {metrics['accuracy']:.4f}")
    print(f" Precision: {metrics['precision']:.4f}")
    print(f" Recall: {metrics['recall']:.4f}")
    print(f" Loss: {metrics['loss']:.4f}")
    print(f" Confusion Matrix:\n{metrics['confusion_matrix'].cpu().numpy()}")


In [13]:
# ----------------------------
# 학습 실행
# ----------------------------
trainer.run(train_loader, max_epochs=num_epochs)

[1/938]   0%|           [00:00<?]


Validation Results - Epoch: 1
 Accuracy: 0.9606
 Precision: 0.9608
 Recall: 0.9601
 Loss: 0.1347
 Confusion Matrix:
[[ 968    0    0    2    0    2    5    1    2    0]
 [   0 1124    2    2    0    0    3    0    4    0]
 [   9    1  989    9    2    1    2    6   13    0]
 [   1    0    7  981    0    3    0    5   11    2]
 [   2    0   10    1  931    0    8    3    7   20]
 [   3    2    0   22    0  844    9    0    9    3]
 [   7    3    0    0    3    9  930    1    5    0]
 [   1   11   15   10    1    1    0  974    4   11]
 [   4    2    5   19    4    4    7    5  923    1]
 [   4    8    2   15   14    4    2    7   11  942]]


[1/938]   0%|           [00:00<?]


Validation Results - Epoch: 2
 Accuracy: 0.9649
 Precision: 0.9649
 Recall: 0.9649
 Loss: 0.1053
 Confusion Matrix:
[[ 966    0    0    1    0    2    6    2    3    0]
 [   0 1119    3    1    0    1    4    0    7    0]
 [   9    1  980    2    7    1    3    9   20    0]
 [   1    0    2  941    0   38    0    5   14    9]
 [   1    0    1    0  932    1    6    3    6   32]
 [   3    0    0    1    0  871    4    1    9    3]
 [   4    4    0    0    3   11  929    1    6    0]
 [   1    7    8    1    2    0    0  984    4   21]
 [   5    0    2    1    1    9    1    3  948    4]
 [   4    6    0    4    6    4    0    3    3  979]]


[1/938]   0%|           [00:00<?]


Validation Results - Epoch: 3
 Accuracy: 0.9734
 Precision: 0.9733
 Recall: 0.9733
 Loss: 0.0851
 Confusion Matrix:
[[ 966    0    3    1    1    2    3    2    2    0]
 [   0 1118    4    1    0    0    5    1    6    0]
 [   5    1  994    1    5    0    3    6   17    0]
 [   0    0    4  981    0   12    0    4    6    3]
 [   1    0    5    0  955    0    6    1    3   11]
 [   3    0    0    6    0  866    4    1   10    2]
 [   4    4    0    0    5    6  933    0    6    0]
 [   1    6   11    6    1    0    1  987    3   12]
 [   2    0    2    4    2    2    2    3  953    4]
 [   3    2    0    3    8    3    0    3    6  981]]


[1/938]   0%|           [00:00<?]


Validation Results - Epoch: 4
 Accuracy: 0.9742
 Precision: 0.9739
 Recall: 0.9741
 Loss: 0.0852
 Confusion Matrix:
[[ 966    0    2    1    1    4    5    1    0    0]
 [   0 1127    1    2    0    0    2    0    3    0]
 [   5    5 1005    4    2    0    3    4    4    0]
 [   0    0    1  974    0   25    0    3    2    5]
 [   1    0    2    1  940    2    6    1    3   26]
 [   2    0    0    1    0  883    3    0    2    1]
 [   4    3    0    1    3    7  940    0    0    0]
 [   0    5   10    4    1    0    0  996    2   10]
 [   3    0    3    8    2   15    2    4  932    5]
 [   2    2    1    4    3   12    1    3    2  979]]


[1/938]   0%|           [00:00<?]


Validation Results - Epoch: 5
 Accuracy: 0.9751
 Precision: 0.9749
 Recall: 0.9749
 Loss: 0.0904
 Confusion Matrix:
[[ 969    0    0    2    1    3    3    1    1    0]
 [   0 1123    3    2    0    0    2    2    3    0]
 [   9    0  997    2    2    1    3   14    3    1]
 [   2    0    0  994    0    9    0    3    1    1]
 [   1    0    3    1  947    2    5    3    1   19]
 [   2    0    0    5    0  881    3    0    0    1]
 [   4    2    0    1    1   16  932    0    2    0]
 [   2    0    5    2    0    0    0 1011    0    8]
 [   7    0    1   12    1   25    1    7  907   13]
 [   0    2    0    4    2    2    2    7    0  990]]


State:
	iteration: 4690
	epoch: 5
	epoch_length: 938
	max_epochs: 5
	output: <class 'tuple'>
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>