# How to Train an Image Classifier in PyTorch and use it to Perform Basic Inference on Single Images

## Tutorial on training ResNet with your own images

If you’re just getting started with PyTorch and want to learn how to do some basic image classification, you can follow this tutorial. It will go through how to organize your training data, use a pretrained neural network to train your model, and then predict other images.  

For this purpose, I’ll be using a dataset consisting of map tiles from Google Maps, and classifying them according to the land features they contain. I’ll write another story about how I use it (in brief: in order to identify safe areas for a drone to fly over or to land). But for now, I just want to use some training data in order to classify these map tiles.
The code snippets below are from a Jupyter Notebook. You can stitch them together to build your own Python script, or download the notebooks from GitHub. The notebooks are originally based on the PyTorch course from Udacity. And if you use a cloud VM for your deep learning development and don’t know how to open a notebook remotely, check out my tutorial.  

### Organize your training dataset

PyTorch expects the data to be organized by folders with one folder for each class. Most of the other PyTorch tutorials and examples expect you to further organize it with a training and validation folder at the top, and then the class folders inside them. But I think this is very cumbersome, to have to pick a certain number of images from each class and move them from the training to the validation folder. And since most people would do that by selecting a contiguous group of files, there might be a lot of bias in that selection.

So here’s a better way of splitting the dataset into a training and test set on the fly, like Python developers are used to from SKLearn. But first, let’s import the modules:

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models

Next we’ll define the train / validation dataset loader, using the SubsetRandomSampler for the split:

In [None]:
from torch.utils.data.sampler import SubsetRandomSampler

data_dir = '/data/train'

def load_split_train_test(datadir, valid_size = .2):
    train_transforms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), ])
    test_transforms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), ])
    train_data = datasets.ImageFolder(datadir, transform=train_transforms)
    test_data = datasets.ImageFolder(datadir, transform=test_transforms)
    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))
    np.random.shuffle(indices)
    
    train_idx, test_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(test_idx)
    trainloader = torch.utils.data.DataLoader(train_data, sampler=train_sampler, batch_size=64)
    testloader = torch.utils.data.DataLoader(test_data, sampler=test_sampler, batch_size=64)
    return (trainloader, testloader)

trainloader, testloader = load_split_train_test(data_dir, .2)

print(trainloader.dataset.classes)

Next we’ll determine whether we have GPU or not. I assume that if you’re doing this you have a GPU-powered machine, otherwise the code will be at least 10 times slower. But it’s a good idea to generalize and check for the GPU availability.

We’ll also load a pretrained model. For this case, I chose ResNet 50:

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet50(pretrained=True)
print(model)

Printing the model will show you the layer architecture of the ResNet model. It’s probably beyond mine or your comprehension but it’s still interesting to see what’s inside those deep hidden layers.

It’s up to you what model you choose, and it might be a different one based on your particular dataset. Here is a list of all the [PyTorch models](https://pytorch.org/docs/stable/torchvision/models.html).

Now we’re getting into the interesting part of the deep neural network. First, we have to freeze the pre-trained layers, so we don’t backprop through them during training. Then, we re-define the final fully-connected the layer, the one that we’ll train with our images. We also create the criterion (the loss function) and pick an optimizer (Adam in this case) and learning rate.

In [None]:
for param in model.parameters():
    param.requires_grad = False
    
model.fc = nn.Sequential(
    nn.Linear(2048, 512),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(512, 10),
    nn.LogSoftmax(dim=1)
)

criterion = nn.NLLLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.003)
model.to(device)

And now finally, let’s train our model! There’s just one epoch in this example but in most cases you’ll need more. The basic process is quite intuitive from the code: You load the batches of images and do the feed forward loop. Then calculate the loss function, and use the optimizer to apply gradient descent in back-propagation.

It’s that simple with PyTorch. Most of the code below deals with displaying the losses and calculate accuracy every 10 batches, so you get an update while training is running. During validation, don’t forget to set the model to `eval()` mode, and then back to `train()` once you’re finished.

In [None]:
epochs = 1
steps = 0
running_loss = 0
print_every = 10
train_losses, test_losses = [], []
for epoch in range(epochs):
    for inputs, labels in trainloader:
        steps += 1
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        logps = model.forward(inputs)
        loss = criterion(logps, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        if steps % print_every == 0:
            test_loss = 0
            accuracy = 0
            model.eval()
            with torch.no_grad():
                for inputs, labels in testloader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    logps = model.forward(inputs)
                    batch_loss = criterion(logps, labels)
                    test_loss += batch_loss.item()
                    ps = torch.exp(logps)
                    top_p, top_class = ps.topk(1, dim=1)
                    equals = top_class == labels.view(*top_class.shape)
                    accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
            train_losses.append(running_loss/len(trainloader))
            test_losses.append(test_loss/len(testloader))                    
            print(f"Epoch {epoch+1}/{epochs}.. "
                  f"Train loss: {running_loss/print_every:.3f}.. "
                  f"Test loss: {test_loss/len(testloader):.3f}.. "
                  f"Test accuracy: {accuracy/len(testloader):.3f}")
            running_loss = 0
            model.train()

torch.save(model, 'aerialmodel.pth')

And… after you wait a few minutes (or more, depending on the size of your dataset and the number of epochs), training is done and the model is saved for later predictions!

There is one more thing you can do now, which is to plot the training and validation losses:

In [None]:
plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.legend(frameon=False)
plt.show()

As you can see, in my particular example with one epoch, the validation loss (which is what we’re interested in) flatlines towards the end of the first epoch and even starts an upward trend, so probably 1 epoch is enough. The training loss, as expected, is very low.

Now on to the second part. So you trained your model, saved it, and need to use it in an application. For that, you’ll need to be able to perform simple inference on an image. You can find this demo notebook as well in our repository. We import the same modules as in the training notebook and then define again the transforms. I only declare the image folder again so I can use some examples from there:

In [None]:
data_dir = '/datadrive/FastAI/data/aerial_photos/train'
test_transforms = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])

Then again we check for GPU availability, load the model and put it into evaluation mode (so parameters are not altered):

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=torch.load('aerialmodel.pth')
model.eval()

The function that predicts the class of a specific image is very simple. Note that it requires a Pillow image, not a file path.

In [1]:
def predict_image(image):
    image_tensor = test_transforms(image).float()
    image_tensor = image_tensor.unsqueeze_(0)
    input = Variable(image_tensor)
    input = input.to(device)
    output = model(input)
    index = output.data.cpu().numpy().argmax()
    return(index)

Now for easier testing, I also created a function that will pick a number of random images from the dataset folders:

In [2]:
def get_random_images(num):
    data = datasets.ImageFolder(data_dir, transform=test_transforms)
    classes = data.classes
    indices = list(range(len(data)))
    np.random.shuffle(indices)
    idx = indices[:num]
    from torch.utils.data.sampler import SubsetRandomSampler
    sampler = SubsetRandomSampler(idx)
    loader = torch.utils.data.DataLoader(data, sampler=sampler, batch_size=num)
    dataiter = iter(loader)
    images, labels = dataiter.next()
    return(images, labels)

Finally, to demo the prediction function, I get the random image sample, predict them and display the results:

In [None]:
to_pil = transforms.ToPILImage()
images, labels = get_random_images(5)
fig=plt.figure(figsize=(10,10))
for ii in range(len(images)):
    image = to_pil(images[ii])
    index = predict_image(image)
    sub = fig.add_subplot(1, len(images), ii+1)
    res = int(labels[ii]) == index
    sub.set_title(str(classes[index]) + ":" + str(res))
    plt.axis('off')
    plt.imshow(image)
plt.show()

Here’s one example of such predictions on Google Map tiles. The label is the predicted class, and I’m also displaying whether it was a correct prediction or not.