This notebook uses FastAI to simplify the steps needed to create a state-of-the-art image classifier. REMEMBER to change the runtime type so as to add hardware acceleration (GPU or TPU) if you are using Colab.

In [None]:
import os
import PIL
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

# Use FastAI
# upgrade first, Colab uses an old version by default
!pip install fastai --upgrade -q  
from fastai.vision.all import *

# Install an older version of bing_image_downloader
# The latest version is broken and return non related images
!pip install bing_image_downloader==1.0.4
from bing_image_downloader import downloader

# You can also mount your google drive if you want to store 
# data directly to it (no need to download new images every time)
from google.colab import drive
drive.mount('/content/drive')

# Keeps the kernel from dying in notebooks on Windows machines (not needed in Colab)
#import os
#os.environ['KMP_DUPLICATE_LIB_OK']='True'

### Download a dataset of your choice using Bing image search

In [None]:
# Determine how many images of each class that you want to download
n_images_per_class = 100
# Give a name for the direcotry where images are to be stored
data_dir= 'image_data'

# Add the Google drive extension if your drive is mapped
if os.path.exists('drive'):
    data_path = 'drive/MyDrive/' + data_dir
else:
    data_path = data_dir
print('Saving images to: {:s}'.format(data_path))

# Give search queries for the classes you want to get pictures of
quaries = ['trains', 'airplanes', 'cars', 'ships']

In [None]:
# ONLY RUN THIS CELL IF YOU WANT TO DOWNLOAD NEW IMAGES!!!

# Create the data direcotry if it does not exist
if not os.path.exists(data_path):
    os.makedirs(data_path)
    
# Download images for each query
for quary in quaries:
    downloader.download(quary, 
                        limit=n_images_per_class,  
                        output_dir=data_path, 
                        adult_filter_off=True, 
                        force_replace=False, 
                        timeout=5)

### Show examples of the downloaded images for each class

In [None]:
# Defien the number of example to show
n_examples = 5
# Define the number of classes based on the number of quaries
n_classes = len(quaries)

# Create a matplotlib figure window
fig = plt.figure(figsize=[15, 10])
# Loop over all classes 
for row in range(n_classes):
    # List all the files in the class directory
    class_dir_path = data_path + '/' + quaries[row]
    files = os.listdir(class_dir_path)
    # Loop over the first n_example images
    for col in range(n_examples):
        # Create a subplot axes
        plt.subplot(n_classes, n_examples, row*n_examples+col+1)
        # Open the image
        img = PIL.Image.open(class_dir_path+'/'+files[col])
        # Resize the image
        img.thumbnail([200, 200])
        # Plot the image
        plt.imshow(img)
        # Remove ticks
        fig.gca().set(xticks=[], yticks=[])

### Define a FastAI data block and dataloaders for the data we downloaded

In [None]:
# Define the data block
img_classes = DataBlock(blocks=(ImageBlock, CategoryBlock),               # Classification task: image data and category labels
                        get_items=get_image_files,                        # Function for converting file path to image data
                        splitter=RandomSplitter(valid_pct=0.25, seed=1),  # Validation data fraction
                        get_y=parent_label,                               # Use the direcotry name as the class label
                        item_tfms=Resize(224))                            # Resize images to sizexsize

# Create dataloaders from the data block
dls = img_classes.dataloaders(data_path)

# Visualize training images
dls.train.show_batch(max_n=10, nrows=2, unique=False)
# Try to define a smaller batch_size if you get an error
#dls.train.show_batch(max_n=10, nrows=2, unique=False, batch_size=8)

### Use augmentation to artificially grow our data set

In [None]:
# Create a new data block from the old one with default
# FastAI augmentation transforms.
img_classes = img_classes.new(item_tfms=Resize(224), 
                              batch_tfms=aug_transforms())

# Create new dataloaders from the new data block
dls = img_classes.dataloaders(data_path)
# Show augmented example images
dls.train.show_batch(max_n=10, nrows=2, unique=True)

### Fine tune an existing state-of-the-art network (transfer learning)

In [None]:
# Load a pretrained network (resnet18)
learn = cnn_learner(dls, resnet18, metrics=error_rate)
# Fine tune the parameters of the network
learn.fine_tune(3)

In [None]:
# Evaluate the results on the validation set
interp = ClassificationInterpretation.from_learner(learn)
# by for example plotting the confusion matrix
interp.plot_confusion_matrix()

### Test the fine tuned network using new test images

In [None]:
# Use a widget to upload a test image from your computer
uploader = widgets.FileUpload()
uploader

In [None]:
# Create a PIL image object that can be fed our network
img = PILImage.create(uploader.data[0])

# predict the class label and print results
predicted_label, _, probs = learn.predict(img)
print('The predicted label is: {:s}'.format(predicted_label))
print('The estimated probability is: {:2.1f}%'.format(probs.max()*100))
# Show the image as well
img.to_thumb(400)