<a href="https://www.kaggle.com/code/rickykyaw/cat-dog-image-classification?scriptVersionId=221776826" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
import os
iskaggle = os.environ.get('KAGGLE_KERNEL_RUN_TYPE', '')

if iskaggle:
    !pip install -Uqq fastai icrawler --use-deprecated=legacy-resolver

In [None]:
from icrawler import ImageDownloader
from icrawler.builtin import GoogleImageCrawler
from icrawler.builtin.google import GoogleFeeder, GoogleParser

class MyDownloader(ImageDownloader):
    def get_filename(self, task, default_ext):
        filename = super(MyDownloader, self).get_filename(task, default_ext)
        filename = filename.split(".")[0]
        return self.prefix + filename + ".png"

class MyCrawler(GoogleImageCrawler):
    def __init__(self, feeder_cls=GoogleFeeder, parser_cls=GoogleParser, downloader_cls=MyDownloader, prefix="", *args, **kwargs):
        super().__init__(feeder_cls, parser_cls, downloader_cls, *args, **kwargs)
        self.downloader.prefix = prefix

In [None]:
def search_images(term, max_images=30, folder_name="."):
    print(f"Searching for '{term}'")
    crawler = MyCrawler(
        prefix=term,
        storage={'root_dir': folder_name}, 
    )
    crawler.crawl(keyword=term, max_num=max_images)

In [None]:
from fastai.vision.all import *

In [None]:
no_of_photos = 20

searches = 'people', 'animals', 'trees'

path = Path('dataset')
from time import sleep

for o in searches:
    dest = (path/o)
    print(f'This is dest, {dest}')
    dest.mkdir(exist_ok=True, parents=True)
    search_images(f"{o} photo", no_of_photos, dest)
    sleep(10)
    print(f"Photos of {o} completed!")

resize_images(path, max_size=400, dest=path, recurse=True)
print(f"Photos resized!")

In [None]:
failed = verify_images(get_image_files(path))
failed.map(Path.unlink)
len(failed)

In [None]:
dls = DataBlock(
    blocks=(ImageBlock, CategoryBlock), 
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=[Resize(192, method='squish')]
).dataloaders(path, bs=32)

dls.show_batch(max_n=12)

In [None]:
learn = vision_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(10)

In [None]:
classes = list(learn.dls.vocab)

print(classes)

In [None]:
search_images('man', 1)

In [None]:
img = Image.open("man000001.png")
img.to_thumb(256,256)

In [None]:
predicted_class,_,probs = learn.predict(img)
print(f"This is an image of: {predicted_class}.")
predicted_idx = classes.index(predicted_class)
for idx, ele in enumerate(classes):
    print(f"Probability it's from {ele} category: {probs[idx]:.4f}")

In [None]:
search_images('lion', 1)

In [None]:
img = Image.open("lion000001.png")
img.to_thumb(256,256)

In [None]:
predicted_class,_,probs = learn.predict(img)
print(f"This is an image of: {predicted_class}.")
predicted_idx = classes.index(predicted_class)
for idx, ele in enumerate(classes):
    print(f"Probability it's from {ele} category: {probs[idx]:.4f}")