In [None]:
#hide
!pip install -Uqq fastbook
import fastbook
fastbook.setup_book()
from fastai.vision.widgets import *
from PIL import Image

In [None]:
from fastbook import *
urls = search_images_ddg('guitarra', max_images=100)
len(urls),urls[0]

In [None]:
download_url(urls[0], 'images/guitarra.jpg')
im = Image.open('images/guitarra.jpg')
im.thumbnail((256,256))
im

Descargamos las imagenes que usaremos para el modelo

In [None]:
instrumentos = 'guitarra','ukelele','violin'
path = Path('gdrive/MyDrive/instrumentos')

if not path.exists():
  path.mkdir()
  for o in instrumentos:
    dest = (path/o)
    dest.mkdir(exist_ok=True)
    results = search_images_ddg(f'{o} instrumento', max_images=60)
    cont = 0
    
    for i in results:
      try:
        download_url(i, str(dest) + "/" + str(cont) + ".jpg",timeout=20)
      except:
        print('Error')  
      cont += 1

In [None]:
fns = get_image_files(path)
fns

Limpiamos las imagenes descargadas que estén corruptas

In [None]:
failed = verify_images(fns)
failed.map(Path.unlink)

Creamos el DataLoader

In [None]:
instrumentos = DataBlock(
    blocks=(ImageBlock, CategoryBlock), 
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=Resize(128))

Entrenamos el modelo con las imágenes

In [None]:
instrumentos = instrumentos.new(
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms())
dls = instrumentos.dataloaders(path)

In [None]:
learn = cnn_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(4)

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

In [None]:
interp.plot_top_losses(5, nrows=1)

In [None]:
cleaner = ImageClassifierCleaner(learn)
cleaner

In [None]:
for idx in cleaner.delete(): cleaner.fns[idx].unlink()
for idx,cat in cleaner.change(): shutil.move(str(cleaner.fns[idx]), path/cat)

Convertimos el modelo en una aplicación online

In [None]:
learn.export()
import PIL
path = Path()
learn_inf = load_learner(path/'export.pkl',cpu=True)
btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
lbl_pred = widgets.Label()

btn_run = widgets.Button(description='Clasificar')

def on_click_classify(change):
    img = PILImage.create(btn_upload.data[-1])
    out_pl.clear_output()
    with out_pl: display(img.to_thumb(128,128))
    pred,pred_idx,probs = learn_inf.predict(img)
    lbl_pred.value = f'Predicción: {pred}; Probabilidad: {probs[pred_idx]:.04f}'

btn_run.on_click(on_click_classify)

VBox([widgets.Label('Seleciona tu instrumento'), 
      btn_upload, btn_run, out_pl, lbl_pred])