In [None]:
import os
from PIL import Image

In [None]:
import fastbook
fastbook.setup_book()

In [None]:
from fastbook import *
from fastai.vision.widgets import *

In [None]:
# get key from ENV variable
key = os.environ.get('GCS_DEVELOPER_KEY')
cx  = os.environ.get('GCS_CX') 

### Get images from google image search api

In [None]:
# https://pypi.org/project/Google-Images-Search/

In [None]:
# google image search
from google_images_search import GoogleImagesSearch

# you can provide API key and CX using arguments,
# or you can set environment variables: GCS_DEVELOPER_KEY, GCS_CX
gis = GoogleImagesSearch(key, cx)

In [None]:
# search param for Bengal cat
_search_params_1 = {
    'q': 'Bengal cat',
    'num': 120,
    'safe': 'high',
    'fileType': 'jpg|png',
    'imgType': 'photo',
    'rights': 'cc_publicdomain|cc_attribute|cc_sharealike|cc_noncommercial|cc_nonderived'
}

# search param for Persian cat
_search_params_2 = {
    'q': 'Persian Cat',
    'num': 120,
    'safe': 'high',
    'fileType': 'jpg',
    'imgType': 'photo',
    'rights': 'cc_publicdomain|cc_attribute|cc_sharealike|cc_noncommercial|cc_nonderived'
}

In [None]:
# search first, then download and resize afterwards:
gis.search(search_params=_search_params_1,  custom_image_name='bengal')
for image in gis.results():
    image.download('/images/bengal') #download images to the directory
    image.resize(500, 500) # resize image
    
gis.next_page()  # next page 
for image in gis.results():
    image.download('/images/bengal')
    image.resize(500, 500)  

In [None]:
os.listdir('/images/bengal')

In [None]:
img_to_be_removed = []
gis.search(search_params=_search_params_2,  custom_image_name='persian')
for image in gis.results():
    image.download('/images/persian')
    try: # to catch unsupported image file
        image.resize(500, 500)
    except:
        img_to_be_removed.append(image)
    
gis.next_page()
for image in gis.results():
    image.download('/images/persian')
    try:
        image.resize(500, 500)
    except UnidentifiedImageError:
        img_to_be_removed.append(image)


In [None]:
# os.listdir('/images/Persian')
img_to_be_removed

### display image

In [None]:
im = Image.open('images/bengal/bengal(1).jpg')
im.to_thumb(128,128)

### path for the images - set

In [None]:
path = Path('images')
fns = get_image_files(path)
fns

In [None]:
# check for failed images 
failed = verify_images(fns)
failed

In [None]:
# remove failed 
failed.map(Path.unlink);

## data load

In [None]:
cats = DataBlock(
    blocks=(ImageBlock, CategoryBlock), 
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=45),
    get_y=parent_label, # directory name - label 
    item_tfms=Resize(128))

In [None]:
dls = cats.dataloaders(path, num_workers=0)

### skim through images

In [None]:
dls.valid.show_batch(max_n=4, nrows=1)

### cnn_learner

In [None]:
learn = cnn_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(4)

### confusion matrix

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

### view the top wrong prediction

In [None]:
interp.plot_top_losses(5, nrows=1)

In [None]:
# export the model
learn.export()

-------------------------------------------------------------------

In [None]:
# load the saved learner

In [None]:
# import learner
saved_learner = load_learner('export.pkl')

In [None]:
# GUI

In [None]:
btn_upload = widgets.FileUpload()
btn_upload

In [None]:
img = PILImage.create(btn_upload.data[-1])

out_pl = widgets.Output()
out_pl.clear_output()
with out_pl: display(img.to_thumb(128,128))
out_pl

In [None]:
# predict 
pred, pred_idx, probs = saved_learner.predict(img)

In [None]:
pred, float(probs[pred_idx])

In [None]:
# complete GUI 

btn_run = widgets.Button(description='Classify')
btn_run 

In [None]:
lbl_pred = widgets.Label()
lbl_pred.value = f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'
lbl_pred

In [None]:
def on_click_classify(change):
    img = PILImage.create(btn_upload.data[-1])
    out_pl.clear_output()
    with out_pl: display(img.to_thumb(128,128))
    pred,pred_idx,probs = saved_learner.predict(img)
    lbl_pred.value = f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'

btn_run.on_click(on_click_classify)

In [None]:
btn_upload = widgets.FileUpload()

In [None]:
VBox([widgets.Label('Upload your cat!'), 
      btn_upload, btn_run, out_pl, lbl_pred])