In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import numpy as np 
import matplotlib.pyplot as plt

In [16]:
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss

In [3]:
train_ds = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
val_ds = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

In [4]:
classes = train_ds.classes

In [5]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=8, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=8, shuffle=True)

In [6]:
train_dl

<torch.utils.data.dataloader.DataLoader at 0x1a2439eb820>

In [7]:
import torch.nn.functional as F

In [8]:
#Define CNN Model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3) # nn.Conv2d(1, 16, 3, padding='same') update pytorch version to 1.11 or higher
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.linear1 = nn.Linear(32*5*5, 64) # 5x5 is the output of the pooling layer
        self.linear2 = nn.Linear(64, 10)
        
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 32*5*5) # flatten the output of the convolutional layer, size= 32, 800
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
model = CNN().to(device)
# model = CNN().to(torch.device('cpu'))

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

#### Training with Ignite

In [12]:
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)

In [13]:
val_metrics = {
    "accuracy": Accuracy(),
    "cross_entropy": Loss(criterion)
}
train_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)

In [15]:
training_history = {'accuracy':[],'loss':[]}
validation_history = {'accuracy':[],'loss':[]}
last_epoch = []

In [18]:
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    train_evaluator.run(train_dl)
    metrics = train_evaluator.state.metrics
    accuracy = metrics['accuracy']*100
    loss = metrics['cross_entropy']
    last_epoch.append(0)
    training_history['accuracy'].append(accuracy)
    training_history['loss'].append(loss)
    print("Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(trainer.state.epoch, accuracy, loss))

def log_validation_results(trainer):
    val_evaluator.run(val_dl)
    metrics = val_evaluator.state.metrics
    accuracy = metrics['accuracy']*100
    loss = metrics['cross_entropy']
    validation_history['accuracy'].append(accuracy)
    validation_history['loss'].append(loss)
    print("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(trainer.state.epoch, accuracy, loss))
    
trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results)

<ignite.engine.events.RemovableEventHandle at 0x1a25224f7f0>

In [19]:
trainer.run(train_dl, max_epochs=3)

Training Results - Epoch: 1  Avg accuracy: 98.58 Avg loss: 0.04
Validation Results - Epoch: 1  Avg accuracy: 98.46 Avg loss: 0.05
Training Results - Epoch: 2  Avg accuracy: 98.97 Avg loss: 0.03
Validation Results - Epoch: 2  Avg accuracy: 98.63 Avg loss: 0.04
Training Results - Epoch: 3  Avg accuracy: 99.14 Avg loss: 0.03
Validation Results - Epoch: 3  Avg accuracy: 98.56 Avg loss: 0.04


State:
	iteration: 22500
	epoch: 3
	epoch_length: 7500
	max_epochs: 3
	output: 0.0017028088914230466
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>