In [4]:
import torch
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from model import Net
from utils import get_data_loaders

In [5]:
def train(model, device, train_loader, optimizer, criterion):
    # Training logic
    model.train()
    pbar = tqdm(train_loader)
    train_loss = 0
    correct = 0
    processed = 0
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)
        pbar.set_description(desc=f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100*correct/processed:0.2f}')
    train_loss /= len(train_loader.dataset)
    train_acc = 100. * correct / len(train_loader.dataset)
    return train_loss, train_acc


In [6]:
def test(model, device, test_loader, criterion):
    # Testing logic
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    test_acc = 100. * correct / len(test_loader.dataset)
    return test_loss, test_acc

In [7]:
def main():
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    train_loader, test_loader = get_data_loaders()
    model = Net().to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = F.nll_loss
    num_epochs = 20
    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []
    for epoch in range(1, num_epochs + 1):
        print(f'Epoch {epoch}')
        train_loss, train_acc = train(model, device, train_loader, optimizer, criterion)
        test_loss, test_acc = test(model, device, test_loader, criterion)
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        test_losses.append(test_loss)
        test_accuracies.append(test_acc)
        print(f'Train Loss: {train_loss:.6f}, Train Accuracy: {train_acc:.2f}')
        print(f'Test Loss: {test_loss:.6f}, Test Accuracy: {test_acc:.2f}')

if __name__ == '__main__':
    main()


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 135773384.77it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 37491703.44it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 44792332.83it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 19659988.41it/s]


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw

Epoch 1


Loss=0.4235957860946655 Batch_id=117 Accuracy=43.77: 100%|██████████| 118/118 [00:21<00:00,  5.55it/s]


Train Loss: 0.003493, Train Accuracy: 43.77
Test Loss: 0.000631, Test Accuracy: 91.19
Epoch 2


Loss=0.18294911086559296 Batch_id=117 Accuracy=93.08: 100%|██████████| 118/118 [00:22<00:00,  5.32it/s]


Train Loss: 0.000445, Train Accuracy: 93.08
Test Loss: 0.000185, Test Accuracy: 97.18
Epoch 3


Loss=0.11866804212331772 Batch_id=117 Accuracy=96.28: 100%|██████████| 118/118 [00:20<00:00,  5.77it/s]


Train Loss: 0.000240, Train Accuracy: 96.28
Test Loss: 0.000133, Test Accuracy: 98.01
Epoch 4


Loss=0.13326585292816162 Batch_id=117 Accuracy=97.00: 100%|██████████| 118/118 [00:20<00:00,  5.83it/s]


Train Loss: 0.000191, Train Accuracy: 97.00
Test Loss: 0.000097, Test Accuracy: 98.50
Epoch 5


Loss=0.05084845423698425 Batch_id=117 Accuracy=97.55: 100%|██████████| 118/118 [00:20<00:00,  5.63it/s]


Train Loss: 0.000154, Train Accuracy: 97.55
Test Loss: 0.000096, Test Accuracy: 98.43
Epoch 6


Loss=0.035147231072187424 Batch_id=117 Accuracy=97.80: 100%|██████████| 118/118 [00:20<00:00,  5.65it/s]


Train Loss: 0.000139, Train Accuracy: 97.80
Test Loss: 0.000076, Test Accuracy: 98.79
Epoch 7


Loss=0.1162460520863533 Batch_id=117 Accuracy=98.01: 100%|██████████| 118/118 [00:20<00:00,  5.68it/s]


Train Loss: 0.000125, Train Accuracy: 98.01
Test Loss: 0.000109, Test Accuracy: 98.23
Epoch 8


Loss=0.07338976114988327 Batch_id=117 Accuracy=98.14: 100%|██████████| 118/118 [00:21<00:00,  5.62it/s]


Train Loss: 0.000117, Train Accuracy: 98.14
Test Loss: 0.000073, Test Accuracy: 98.82
Epoch 9


Loss=0.014971748925745487 Batch_id=117 Accuracy=98.33: 100%|██████████| 118/118 [00:21<00:00,  5.56it/s]


Train Loss: 0.000102, Train Accuracy: 98.33
Test Loss: 0.000069, Test Accuracy: 98.85
Epoch 10


Loss=0.019410019740462303 Batch_id=117 Accuracy=98.46: 100%|██████████| 118/118 [00:20<00:00,  5.64it/s]


Train Loss: 0.000099, Train Accuracy: 98.46
Test Loss: 0.000056, Test Accuracy: 99.01
Epoch 11


Loss=0.008957958780229092 Batch_id=117 Accuracy=98.54: 100%|██████████| 118/118 [00:20<00:00,  5.63it/s]


Train Loss: 0.000092, Train Accuracy: 98.54
Test Loss: 0.000053, Test Accuracy: 99.02
Epoch 12


Loss=0.058468226343393326 Batch_id=117 Accuracy=98.70: 100%|██████████| 118/118 [00:20<00:00,  5.64it/s]


Train Loss: 0.000083, Train Accuracy: 98.70
Test Loss: 0.000052, Test Accuracy: 99.08
Epoch 13


Loss=0.08984129875898361 Batch_id=117 Accuracy=98.72: 100%|██████████| 118/118 [00:20<00:00,  5.88it/s]


Train Loss: 0.000080, Train Accuracy: 98.72
Test Loss: 0.000049, Test Accuracy: 99.14
Epoch 14


Loss=0.010929272510111332 Batch_id=117 Accuracy=98.75: 100%|██████████| 118/118 [00:19<00:00,  5.93it/s]


Train Loss: 0.000078, Train Accuracy: 98.75
Test Loss: 0.000050, Test Accuracy: 99.08
Epoch 15


Loss=0.005710983648896217 Batch_id=117 Accuracy=98.72: 100%|██████████| 118/118 [00:21<00:00,  5.57it/s]


Train Loss: 0.000078, Train Accuracy: 98.72
Test Loss: 0.000050, Test Accuracy: 99.10
Epoch 16


Loss=0.03275136649608612 Batch_id=117 Accuracy=98.87: 100%|██████████| 118/118 [00:23<00:00,  5.07it/s]


Train Loss: 0.000070, Train Accuracy: 98.87
Test Loss: 0.000050, Test Accuracy: 99.10
Epoch 17


Loss=0.05980045720934868 Batch_id=117 Accuracy=98.93: 100%|██████████| 118/118 [00:21<00:00,  5.61it/s]


Train Loss: 0.000066, Train Accuracy: 98.93
Test Loss: 0.000053, Test Accuracy: 99.14
Epoch 18


Loss=0.01932060904800892 Batch_id=117 Accuracy=98.97: 100%|██████████| 118/118 [00:21<00:00,  5.53it/s]


Train Loss: 0.000065, Train Accuracy: 98.97
Test Loss: 0.000050, Test Accuracy: 99.24
Epoch 19


Loss=0.003583707846701145 Batch_id=117 Accuracy=98.97: 100%|██████████| 118/118 [00:19<00:00,  5.99it/s]


Train Loss: 0.000064, Train Accuracy: 98.97
Test Loss: 0.000045, Test Accuracy: 99.21
Epoch 20


Loss=0.04729605093598366 Batch_id=117 Accuracy=99.00: 100%|██████████| 118/118 [00:20<00:00,  5.68it/s]


Train Loss: 0.000059, Train Accuracy: 99.00
Test Loss: 0.000044, Test Accuracy: 99.23
