In [2]:
from fastbook import *
from fastai.vision.widgets import *

In [3]:
from PIL import ExifTags
from exif import Image as ExifImage

In [4]:
import pathlib
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath

In [5]:
def exif_type_tfms(fn, cls, **kwargs):     
  def get_orientation(fn):
    with open(fn, 'rb') as image_file:
      exif_img = ExifImage(image_file)
      try:
        return exif_img.orientation.value
      except:
        # ignore orientation unset
        return 1
  def f(img, rotate=0, transpose=False):
    img = img.rotate(rotate, expand=True)
    if transpose:
      img = img.transpose(Image.FLIP_LEFT_RIGHT)
    return img

  # Image.rotate will do shorcuts on these magic angles, so no need for any 
  # specific resampling strategy
  trafo_fns = {
    1: partial(f, rotate=0),
    6: partial(f, rotate=270),
    8: partial(f, rotate=90),
    3: partial(f, rotate=180),
    2: partial(f, rotate=0, transpose=True),
    5: partial(f, rotate=270, transpose=True),
    7: partial(f, rotate=90, transpose=True),
    4: partial(f, rotate=180, transpose=True),
  }
  img = cls.create(fn, **kwargs)
  orientation = get_orientation(fn)
  img = trafo_fns[orientation](img)
  return cls(img)

def ExifImageBlock(cls=PILImage):
  """
  if images are rotated with the EXIF orientation flag
  it must be respected when loading the images

  ExifImageBlock can be pickled (which is important to dump learners)
  >>> pickle.dumps(ExifImageBlock())
  b'...
  """
  return TransformBlock(type_tfms=partial(exif_type_tfms, cls=cls), batch_tfms=IntToFloatTensor)

In [6]:
path = Path()
learn_inf = load_learner(path/'export.pkl', cpu=True)
btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
lbl_pred = widgets.Label()

In [7]:
def on_data_change(change):
    lbl_pred.value = ''
    img = PILImage.create(btn_upload.data[-1])
    out_pl.clear_output()
    with out_pl: display(img.to_thumb(224,224))
    pred,pred_idx,probs = learn_inf.predict(img)
    lbl_pred.value = f'Prediction: {pred}\n Probability: {probs[pred_idx]:.04f}'

In [8]:
btn_upload.observe(on_data_change, names=['data'])

In [9]:
display(VBox([widgets.Label('Select rabbit!'), btn_upload, out_pl, lbl_pred]))

VBox(children=(Label(value='Select rabbit!'), FileUpload(value={}, description='Upload'), Output(), Label(valu…