Permalink
Cannot retrieve contributors at this time
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
106 lines (79 sloc)
3.83 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Antelope Classification using FastAI | |
This script is an end-to-end case study of creating a custom image dataset of | |
major African antelope and training a deep convolutional neural network to | |
classify each species. | |
The basic workflow is as follows: | |
1. Download images of each antelope and build a dataset. | |
2. Pre-process and prepare the dataset for learning. | |
3. Create a deep neural network model for classification. | |
4. Train the DNN using transfer learning on the data. | |
5. Output error rate. | |
""" | |
import logging | |
from typing import List | |
from fastai.vision import * | |
from fastai.metrics import error_rate | |
from google_images_download import google_images_download | |
logging.basicConfig(level=logging.INFO) | |
ANTELOPE = ['kudu', 'eland', 'sable antelope', 'roan antelope', 'waterbuck', | |
'impala antelope', 'nyala', 'bushbuck', 'tsessebe', | |
'lichtensteins hartebeest', 'grey duiker', 'steenbok', | |
'klipspringer'] | |
DATA_PATH = Path('data') | |
VALID_PCT = 0.2 | |
IMAGE_SIZE = 224 | |
BATCH_SIZE = 32 | |
ARCHITECTURE = models.resnet34 | |
def download_antelope_images(output_path: Path, limit: int = 50, kws = []) -> None: | |
"""Download images for each of the antelope to the output path. | |
Each species is put in a separate sub-directory under output_path. | |
""" | |
try: | |
if len(output_path.ls()) > 0: | |
logging.info(f"Directory '{output_path}' is not empty. Skipping image download.") | |
return | |
except FileNotFoundError: | |
logging.info(f"Directory '{output_path} does not exist and will be created.") | |
response = google_images_download.googleimagesdownload() | |
for antelope in ANTELOPE: | |
for gender in ['male', 'female']: | |
output_directory = str(output_path/antelope).replace(' ', '_') | |
arguments = { | |
'keywords': f'wild {antelope} {gender} {" ".join(kws)} -hunting -stock', | |
'output_directory': output_directory, | |
#'usage_rights': 'labeled-for-nocommercial-reuse', | |
'no_directory': True, | |
'size': 'medium', | |
'limit': limit | |
} | |
response.download(arguments) | |
def validate_labels(data_path: Path, labels: List[str]) -> None: | |
"""Validate the file names of each of the labeled images. | |
If a file name contains the label of another class, the path is logged. | |
""" | |
non_alpha = re.compile('([^a-zA-Z]+|antelope)') | |
filtered_labels = [non_alpha.sub('', label) for label in labels] | |
for path in [d for d in data_path.ls() if d.is_dir()]: | |
label = non_alpha.sub('', path.name) | |
other_labels = [other for other in filtered_labels if other != label] | |
file_names = [non_alpha.sub('', f.name) for f in path.ls() if f.is_file()] | |
for name in file_names: | |
for other_label in other_labels: | |
if other_label in name: | |
logging.info(f'Potential mislabeling: {path}/{name}') | |
def train_model(data_path: Path, valid_pct, image_size, batch_size, architecture) -> Learner: | |
"""Train a deep convolutional NN classifier on the downloaded data. | |
Learning rates were found using learn.lr_find() in accompanying Jupyter notebook. | |
""" | |
image_data = ImageDataBunch.from_folder(data_path, valid_pct=valid_pct,\ | |
ds_tfms=get_transforms(), size=image_size, | |
bs=batch_size).normalize(imagenet_stats) | |
learner = cnn_learner(image_data, architecture, metrics=error_rate) | |
learner.fit_one_cycle(4, max_lr=slice(1e-3, 1e-2)) | |
learner.unfreeze() | |
learner.fit_one_cycle(4, 1e-4) | |
return learner | |
if __name__ == '__main__': | |
download_antelope_images(DATA_PATH) | |
learner = train_model(DATA_PATH, VALID_PCT, IMAGE_SIZE, BATCH_SIZE, ARCHITECTURE) | |
print(f'Error rate: {learner.recorder.metrics[-1]}') | |