# Intro
Classify the user-provided image into one of these categories: 
* anime
* Disney animation
* soviet animation
    
The code is based on the code from this repo: https://github.com/fastai/fastbook , which was released under a free licence (GPL-3). The code changes include: making the code work for any classes, refactoring.

The final model is based on a pre-trained [resnet18](https://pytorch.org/hub/pytorch_vision_resnet/). 

# Install and import dependencies

In [None]:
# UNCOMMENT THE LINES BELOW IF YOU WANT TO TRAIN YOUR OWN CLASSIFIER
# !pip install -Uqq fastbook
# import fastbook
# fastbook.setup_book()

In [None]:
# from fastbook import *
from fastai.vision.all import *
from fastai.vision.widgets import *

# Credentials
Replace the *secret_key_dont_share_it* below with your own key. For how to get the key, see [this](https://forums.fast.ai/t/getting-the-bing-image-search-key/67417).

In [None]:
secret_key_dont_share_it = 'REPLACE WITH YOURS'

# Classes
If you want to create another image classifier, replace the classes names below with your own. The names are also the terms that Bing will search to build the dataset

In [None]:
classes_names = 'anime','disney animation','soviet animation'


# Build and prepare the dataset

In [None]:
def build_dataset_from_bing_images(categories):
    key = os.environ.get('AZURE_SEARCH_KEY', secret_key_dont_share_it)
    dir_name = "_".join(categories).replace(" ", "_")
    path = Path(dir_name)
    
    if not path.exists():
        path.mkdir()
        for o in categories:
            dest = (path/o)
            dest.mkdir(exist_ok=True)
            results = search_images_bing(key, o)
            download_images(dest, urls=results.attrgot('content_url'))
    return path

In [None]:
def remove_failed_downloads(path):
    fns = get_image_files(path)
    failed = verify_images(fns)
    display(print("There are this many failed dowloads to delete:", len(failed)))
    failed.map(Path.unlink)

In [None]:
def build_dataloaders(path):
    raw_elements = DataBlock(
        blocks=(ImageBlock, CategoryBlock), 
        get_items=get_image_files, 
        splitter=RandomSplitter(valid_pct=0.2, seed=42),
        get_y=parent_label,
        item_tfms=Resize(128))
    raw_dls = raw_elements.dataloaders(path)
    return raw_dls, raw_elements

In [None]:
def resize_and_and_augment(elements, path):
    elements = elements.new(
        item_tfms=RandomResizedCrop(224, min_scale=0.5),
        batch_tfms=aug_transforms())
    dls = elements.dataloaders(path)
    return dls, elements

In [None]:
def show_dataset_sample(raw_dls, dls):
    # "Before and after augmentation:"
    raw_dls.train.show_batch(max_n=38, nrows=4)
    dls.train.show_batch(max_n=38, nrows=4)

In [None]:
def get_dataset(classes_names):
    path = build_dataset_from_bing_images(classes_names)
    remove_failed_downloads(path)
    raw_dls, raw_elements = build_dataloaders(path)
    dls, elements = resize_and_and_augment(raw_elements, path)
    show_dataset_sample(raw_dls, dls)
    return dls, elements

## Train

In [None]:
def train_my_model(dls):
    learn = cnn_learner(dls, resnet18, metrics=error_rate)
    learn.fine_tune(10)
    return learn

In [None]:
def show_training_results_diagrams(learn):
    interp = ClassificationInterpretation.from_learner(learn)
    interp.plot_confusion_matrix()  
    interp.plot_top_losses(5, nrows=5)
    cleaner = ImageClassifierCleaner(learn)
    cleaner
    # cleaner can be used to clean the dataset from the junk you marked:
    # for idx in cleaner.delete(): cleaner.fns[idx].unlink()
    # for idx,cat in cleaner.change(): shutil.move(str(cleaner.fns[idx]), path/cat)

In [None]:
def export_model(learn):
    learn.export()
    path = Path()
    print("The model is saved to this file:", path.ls(file_exts='.pkl'))  

In [None]:
def get_new_model():
    dls, elements = get_dataset(classes_names)
    learn = train_my_model(dls)
    export_model(learn)
    show_training_results_diagrams(learn)

# Uncomment this line to re-create the model

In [None]:
# get_new_model()

## Voila Online Application

In [None]:
def load_inference_learner():
    path = Path()
    learn_inf = load_learner(path/'export.pkl', cpu=True)
    vocab = learn_inf.dls.vocab
    return learn_inf, vocab

In [None]:
def build_widgets():
    btn_upload = widgets.FileUpload()
    out_pl = widgets.Output()
    out_pl.clear_output()
    lbl_pred = widgets.Label()
    btn_run = widgets.Button(description='Classify')
    return btn_upload, btn_run, out_pl, lbl_pred

In [None]:
btn_upload, btn_run, out_pl, lbl_pred = build_widgets()

In [None]:
learn_inf, vocab = load_inference_learner()

In [None]:
def on_click_classify(change):
    img = PILImage.create(btn_upload.data[-1])
    out_pl.clear_output()
    with out_pl: display(img.to_thumb(128,128))
    pred,pred_idx,probs = learn_inf.predict(img)
    lbl_pred.value = f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'

In [None]:
btn_run.on_click(on_click_classify)    

In [None]:
VBox([widgets.Label('I will try to classify your image as one of the following classes:' + str(vocab)), 
      btn_upload, btn_run, out_pl, lbl_pred])