In [2]:
#hide
! [ -e /content ] && pip install -Uqq fastbook
import fastbook
fastbook.setup_book()

In [3]:
#hide
from fastai.vision.all import *
from fastbook import *

matplotlib.rc('image', cmap='Greys')

In [4]:
import torchvision

In [5]:
from pathlib import Path

In [5]:
testing_folder = Path('../../.fastai/data/mnist_png/testing')
training_folder = Path('../../.fastai/data/mnist_png/training')

[x for x in testing_folder.iterdir() if x.is_dir()]

[Path('../../.fastai/data/mnist_png/testing/5'),
 Path('../../.fastai/data/mnist_png/testing/3'),
 Path('../../.fastai/data/mnist_png/testing/7'),
 Path('../../.fastai/data/mnist_png/testing/4'),
 Path('../../.fastai/data/mnist_png/testing/2'),
 Path('../../.fastai/data/mnist_png/testing/1'),
 Path('../../.fastai/data/mnist_png/testing/9'),
 Path('../../.fastai/data/mnist_png/testing/0'),
 Path('../../.fastai/data/mnist_png/testing/8'),
 Path('../../.fastai/data/mnist_png/testing/6')]

In [6]:
sorted_testing_folder = testing_folder.ls().sorted()
sorted_training_folder = training_folder.ls().sorted()

sorted_testing_folder, sorted_training_folder

((#10) [Path('../../.fastai/data/mnist_png/testing/0'),Path('../../.fastai/data/mnist_png/testing/1'),Path('../../.fastai/data/mnist_png/testing/2'),Path('../../.fastai/data/mnist_png/testing/3'),Path('../../.fastai/data/mnist_png/testing/4'),Path('../../.fastai/data/mnist_png/testing/5'),Path('../../.fastai/data/mnist_png/testing/6'),Path('../../.fastai/data/mnist_png/testing/7'),Path('../../.fastai/data/mnist_png/testing/8'),Path('../../.fastai/data/mnist_png/testing/9')],
 (#10) [Path('../../.fastai/data/mnist_png/training/0'),Path('../../.fastai/data/mnist_png/training/1'),Path('../../.fastai/data/mnist_png/training/2'),Path('../../.fastai/data/mnist_png/training/3'),Path('../../.fastai/data/mnist_png/training/4'),Path('../../.fastai/data/mnist_png/training/5'),Path('../../.fastai/data/mnist_png/training/6'),Path('../../.fastai/data/mnist_png/training/7'),Path('../../.fastai/data/mnist_png/training/8'),Path('../../.fastai/data/mnist_png/training/9')])

In [7]:
transform = torchvision.transforms.Compose(
    [torchvision.transforms.Grayscale(), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5])]
)

In [8]:
training_dataset = torchvision.datasets.ImageFolder((training_folder).as_posix(), transform = transform)
training_dataset

Dataset ImageFolder
    Number of datapoints: 60000
    Root location: ../../.fastai/data/mnist_png/training
    StandardTransform
Transform: Compose(
               Grayscale(num_output_channels=1)
               ToTensor()
               Normalize(mean=[0.5], std=[0.5])
           )

In [9]:
batchSize = 64
train_dataloader = torch.utils.data.DataLoader(training_dataset, batch_size=batchSize, shuffle=True)

train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f02d8d33bb0>

In [10]:
testing_dataset = torchvision.datasets.ImageFolder(testing_folder, transform = transform)
testing_dataloader = torch.utils.data.DataLoader(testing_dataset, batch_size=batchSize)

testing_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f02d8d33e20>

In [11]:
dataloaders = {
    "train": train_dataloader,
    "validation": testing_dataloader
}

In [12]:
pytorch_net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 50),
    nn.ReLU(),
    nn.Linear(50,30),
    nn.ReLU(),
    nn.Linear(30,10),
    nn.LogSoftmax(dim=1))

In [13]:
# nb_epoch is our number of epochs, meaning the number of complete passes through the training dataset.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
lr = 1e-2
nb_epoch = 15

device

device(type='cuda', index=0)

In [14]:
optimizer = torch.optim.SGD(pytorch_net.parameters(), lr=lr)

criterion = nn.NLLLoss()

In [18]:
def train_model(model, criterion, optimizer, dataloaders, num_epochs=10):
    #liveloss = PlotLosses() Live training plot generic API
    model = model.to(device) # Moves and/or casts the parameters and buffers to device.
    
    for epoch in range(num_epochs): # Number of passes through the entire training & validation datasets
        logs = {}
        for phase in ['train', 'validation']: # First train, then validate
            if phase == 'train':
                model.train() # Set the module in training mode
            else:
                model.eval() # Set the module in evaluation mode

            running_loss = 0.0 # keep track of loss
            running_corrects = 0 # count of carrectly classified inputs

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device) # Perform Tensor device conversion
                labels = labels.to(device)

                outputs = model(inputs) # forward pass through network
                loss = criterion(outputs, labels) # Calculate loss

                if phase == 'train':
                    optimizer.zero_grad() # Set all previously calculated gradients to 0
                    loss.backward() # Calculate gradients
                    optimizer.step() # Step on the weights using those gradient w -=  gradient(w) * lr

                _, preds = torch.max(outputs, 1) # Get model's predictions
                running_loss += loss.detach() * inputs.size(0) # multiply mean loss by the number of elements
                running_corrects += torch.sum(preds == labels.data) # add number of correct predictions to total

            epoch_loss = running_loss / len(dataloaders[phase].dataset) # get the "mean" loss for the epoch
            epoch_acc = running_corrects.float() / len(dataloaders[phase].dataset) # Get proportion of correct predictions
            
            # Logging
            prefix = ''
            if phase == 'validation':
                prefix = 'val_'

            logs[prefix + 'log loss'] = epoch_loss.item()
            logs[prefix + 'accuracy'] = epoch_acc.item()
        print('loss: ', epoch_loss.item(), ' accuracy: ', epoch_acc.item())
        
        #liveloss.update(logs) Update logs
        #liveloss.send()  draw, display stuff

In [19]:
train_model(pytorch_net, criterion, optimizer, dataloaders, nb_epoch)

loss:  0.4052080810070038  accuracy:  0.8804000020027161
loss:  0.34317728877067566  accuracy:  0.8989999890327454
loss:  0.3060593008995056  accuracy:  0.9096999764442444
loss:  0.2595570385456085  accuracy:  0.9241999983787537
loss:  0.23294584453105927  accuracy:  0.9304999709129333
loss:  0.21377748250961304  accuracy:  0.9378999471664429
loss:  0.1968572586774826  accuracy:  0.9429000020027161
loss:  0.17762190103530884  accuracy:  0.9496999979019165
loss:  0.16019994020462036  accuracy:  0.9524999856948853
loss:  0.1489601582288742  accuracy:  0.9549999833106995
loss:  0.1422155797481537  accuracy:  0.9583999514579773
loss:  0.14073260128498077  accuracy:  0.957099974155426
loss:  0.15762324631214142  accuracy:  0.949999988079071
loss:  0.13402201235294342  accuracy:  0.9598000049591064
loss:  0.12043929100036621  accuracy:  0.9657999873161316


In [21]:
# Save the model
#torch.save(pytorch_net, 'models/my_digit_clasifier_3L_97pct.pt')