<a href="https://colab.research.google.com/github/GaryZhous/A-collection-of-interesting-stuff/blob/main/classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Classification

Authors: [Rishit Dagli](https://twitter.com/rishit_dagli)

In this Notebook we will use the models we converted earlier to do image classification.

## Setup

In [None]:
import tensorflow as tf
import tensorflow_hub as hub

from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import requests
import os

## Cgoose Model

In [None]:
models = {
    "convmixer_1536_20": "https://tfhub.dev/rishit-dagli/convmixer-1536-20/1",
    "convmixer_768_32": "https://tfhub.dev/rishit-dagli/convmixer-768-32/1",
    "convmixer_1024_20": "https://tfhub.dev/rishit-dagli/convmixer-1024-20/1",
}

In [None]:
# fmt: off

#@title Choose Model variant
model_variant = "convmixer_1536_20" #@param ['convmixer_1536_20', 'convmixer_768_32', 'convmixer_1024_20']
resolution = [224, 224] 
num_classes = 1000

# fmt: on

## Image preprocessing utilities (adapted from [Willi Gierke](https://ch.linkedin.com/in/willi-gierke))

In [None]:
def preprocess_image(image):
    image = np.array(image)
    image_resized = tf.image.resize(image, (resolution[0], resolution[1]))
    image_resized = tf.cast(image_resized, tf.float32)
    image_resized = image_resized / 255
    image_resized = tf.keras.layers.Normalization(
        mean=(0.485, 0.456, 0.406), variance=(0.052441, 0.050176, 0.050625)
    )(image_resized)
    return tf.expand_dims(image_resized, 0).numpy()


def load_image_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    image = preprocess_image(image)
    return image
    
!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt

--2021-10-31 03:51:46--  https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt
Resolving storage.googleapis.com (storage.googleapis.com)... 64.233.191.128, 173.194.74.128, 173.194.192.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|64.233.191.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 21675 (21K) [text/plain]
Saving to: ‘ilsvrc2012_wordnet_lemmas.txt’


2021-10-31 03:51:47 (43.8 MB/s) - ‘ilsvrc2012_wordnet_lemmas.txt’ saved [21675/21675]



## Load Image and infer

In [None]:
model = hub.load(models[model_variant])

In [None]:
with open("ilsvrc2012_wordnet_lemmas.txt", "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

def infer_on_image(img_url, expected_label):
    image = load_image_from_url(img_url)
    predictions = model.signatures["serving_default"](tf.constant(image))
    logits = predictions["output"][0]
    predicted_label = imagenet_int_to_str[int(np.argmax(logits))]
    assert (
        predicted_label == expected_label
    ), f"Expected {expected_label} but was {predicted_label}"

Let's try 5 images

In [None]:
infer_on_image(img_url = "https://storage.googleapis.com/rishit-dagli.appspot.com/sample-images/mH0Wrvb.jpg", expected_label = "goldfish, Carassius_auratus")
infer_on_image(img_url = "https://storage.googleapis.com/rishit-dagli.appspot.com/sample-images/A5m4ZG1.jpg", expected_label = "scorpion")
infer_on_image(img_url = "https://storage.googleapis.com/rishit-dagli.appspot.com/sample-images/faOAEFg.jpg", expected_label = "leatherback_turtle, leatherback, leathery_turtle, Dermochelys_coriacea")
infer_on_image(img_url = "https://storage.googleapis.com/rishit-dagli.appspot.com/sample-images/lfhdaSi.jpg", expected_label = "Siamese_cat, Siamese")
infer_on_image(img_url = "https://storage.googleapis.com/rishit-dagli.appspot.com/sample-images/Qwa8wHX.jpg", expected_label = "boa_constrictor, Constrictor_constrictor")