In [1]:
from argparse import ArgumentParser

from torch import nn
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.datasets import MNIST

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss

from tqdm import tqdm

In [2]:
def save_model(model, PATH):
    torch.save(model.state_dict, PATH)
    
def load_model(model, PATH):
    model.load_state_dict(torch.load(PATH))
    model.eval()

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        
        return F.log_softmax(x, dim=-1)

In [4]:
def get_data_loader(train_batch_size, test_batch_size):
    data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
    
    train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, 
                                    train=True), batch_size=train_batch_size, shuffle=True)
    
    test_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform,
                                  train=True), batch_size=test_batch_size, shuffle=True)
    
    return train_loader, test_loader

In [5]:
def run(train_batch_size, test_batch_size, epochs, lr, momentum, log_interval):
    
    train_loader, test_loader = get_data_loader(train_batch_size, test_batch_size)
    
    model = Net()
    device = 'cpu'
    
    #if torch.cuda.is_available:
    #    device = 'cuda'
    #print('Device: {}'.format(device))
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    
    #Здесь начинается магия движка ignite
    trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
    evaluator = create_supervised_evaluator(model, 
                                           metrics={'accuracy': Accuracy(),
                                                    'nll': Loss(F.nll_loss)},
                                           device=device)
    
    desc = "Iter - loss : {:.2f}"
    pbar = tqdm(initial=0, leave=False, total=(len(train_loader)), desc=desc.format(0))
    
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_interval)
    
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        tqdm.write(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll)
        )

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(test_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        tqdm.write(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))

        pbar.n = pbar.last_print_n = 0

    trainer.run(train_loader, max_epochs=epochs)
    pbar.close()
    
    return model

In [6]:
model = run(train_batch_size=64, test_batch_size=1000, epochs=10, lr=0.01, momentum=0.5, log_interval=10)

Iter - loss : 0.25:  99%|█████████▉| 930/938 [00:48<00:00, 30.33it/s]

Training Results - Epoch: 1  Avg accuracy: 0.95 Avg loss: 0.15


Iter - loss : 0.25:  99%|█████████▉| 930/938 [01:04<00:00, 30.33it/s]

Validation Results - Epoch: 1  Avg accuracy: 0.95 Avg loss: 0.15


Iter - loss : 0.11:  99%|█████████▉| 930/938 [01:54<00:00, 27.67it/s]

Training Results - Epoch: 2  Avg accuracy: 0.97 Avg loss: 0.10


Iter - loss : 0.11:  99%|█████████▉| 930/938 [02:11<00:00, 27.67it/s]

Validation Results - Epoch: 2  Avg accuracy: 0.97 Avg loss: 0.10


Iter - loss : 0.10:  99%|█████████▉| 930/938 [03:03<00:00, 28.69it/s]

Training Results - Epoch: 3  Avg accuracy: 0.98 Avg loss: 0.07


Iter - loss : 0.10:  99%|█████████▉| 930/938 [03:20<00:00, 28.69it/s]

Validation Results - Epoch: 3  Avg accuracy: 0.98 Avg loss: 0.07


Iter - loss : 0.15:  99%|█████████▉| 930/938 [04:09<00:00, 29.97it/s]

Training Results - Epoch: 4  Avg accuracy: 0.98 Avg loss: 0.07


Iter - loss : 0.15:  99%|█████████▉| 930/938 [04:25<00:00, 29.97it/s]

Validation Results - Epoch: 4  Avg accuracy: 0.98 Avg loss: 0.07


Iter - loss : 0.06:  99%|█████████▉| 930/938 [05:12<00:00, 30.42it/s]

Training Results - Epoch: 5  Avg accuracy: 0.98 Avg loss: 0.05


Iter - loss : 0.06:  99%|█████████▉| 930/938 [05:28<00:00, 30.42it/s]

Validation Results - Epoch: 5  Avg accuracy: 0.98 Avg loss: 0.05


Iter - loss : 0.04:  99%|█████████▉| 930/938 [06:17<00:00, 30.50it/s]

Training Results - Epoch: 6  Avg accuracy: 0.98 Avg loss: 0.05


Iter - loss : 0.04:  99%|█████████▉| 930/938 [06:34<00:00, 30.50it/s]

Validation Results - Epoch: 6  Avg accuracy: 0.98 Avg loss: 0.05


Iter - loss : 0.09:  99%|█████████▉| 930/938 [07:22<00:00, 30.07it/s]

Training Results - Epoch: 7  Avg accuracy: 0.99 Avg loss: 0.04


Iter - loss : 0.09:  99%|█████████▉| 930/938 [07:39<00:00, 30.07it/s]

Validation Results - Epoch: 7  Avg accuracy: 0.99 Avg loss: 0.04


Iter - loss : 0.05:  99%|█████████▉| 930/938 [08:33<00:00, 25.95it/s]

Training Results - Epoch: 8  Avg accuracy: 0.99 Avg loss: 0.04


Iter - loss : 0.05:  99%|█████████▉| 930/938 [08:50<00:00, 25.95it/s]

Validation Results - Epoch: 8  Avg accuracy: 0.99 Avg loss: 0.04


Iter - loss : 0.04:  99%|█████████▉| 930/938 [09:42<00:00, 28.67it/s]

Training Results - Epoch: 9  Avg accuracy: 0.99 Avg loss: 0.03


Iter - loss : 0.04:  99%|█████████▉| 930/938 [10:01<00:00, 28.67it/s]

Validation Results - Epoch: 9  Avg accuracy: 0.99 Avg loss: 0.03


Iter - loss : 0.08:  99%|█████████▉| 930/938 [10:47<00:00, 32.81it/s]

Training Results - Epoch: 10  Avg accuracy: 0.99 Avg loss: 0.03


                                                                     

Validation Results - Epoch: 10  Avg accuracy: 0.99 Avg loss: 0.03




In [8]:
save_model(model, "state_model")

  "type " + obj.__name__ + ". It won't be checked "
