<a href="https://colab.research.google.com/github/Rishit-dagli/ConvMixer-torch2tf/blob/main/classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Classification

Authors: [Rishit Dagli](https://twitter.com/rishit_dagli)

In this Notebook we will use the models we converted earlier to do image classification.

## Setup

In [1]:
import tensorflow as tf
import tensorflow_hub as hub

from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import requests
import os

## Download Model

In [2]:
models = {
    "convmixer_1536_20": "https://storage.googleapis.com/convmixer-hubmodels.appspot.com/convmixer_1536_20.tar.gz",
    "convmixer_768_32": "https://storage.googleapis.com/convmixer-hubmodels.appspot.com/convmixer_768_32.tar.gz",
    "convmixer_1024_20": "https://storage.googleapis.com/convmixer-hubmodels.appspot.com/convmixer_1024_20.tar.gz",
}

In [3]:
# fmt: off

#@title Choose Model variant
model_variant = "convmixer_1536_20" #@param ['convmixer_1536_20', 'convmixer_768_32', 'convmixer_1024_20']
resolution = [224, 224] 
num_classes = 1000
os.environ['model_url'] = models[model_variant]
os.environ['model_path'] = model_variant + ".tar.gz"
os.environ['model_variant'] = model_variant
os.environ['saved_model_path'] = model_variant

# fmt: on

In [4]:
!wget $model_url

--2021-10-24 07:24:05--  https://storage.googleapis.com/convmixer-hubmodels.appspot.com/convmixer_1536_20.tar.gz
Resolving storage.googleapis.com (storage.googleapis.com)... 64.233.189.128, 108.177.125.128, 142.251.8.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|64.233.189.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 192364203 (183M) [application/x-gzip]
Saving to: ‘convmixer_1536_20.tar.gz’


2021-10-24 07:24:08 (65.5 MB/s) - ‘convmixer_1536_20.tar.gz’ saved [192364203/192364203]



## Extract model

In [5]:
!mkdir $model_variant
%cd $model_variant
!tar -xvf ../$model_path
%cd ..

/content/convmixer_1536_20
assets/
saved_model.pb
variables/
variables/variables.data-00000-of-00001
variables/variables.index
/content


## Image preprocessing utilities (adapted from [Willi Gierke](https://ch.linkedin.com/in/willi-gierke))

In [6]:
def preprocess_image(image):
    image = np.array(image)
    image_resized = tf.image.resize(image, (resolution[0], resolution[1]))
    image_resized = tf.cast(image_resized, tf.float32)
    image_resized = image_resized / 255
    image_resized = tf.keras.layers.Normalization(
        mean=(0.485, 0.456, 0.406), variance=(0.052441, 0.050176, 0.050625)
    )(image_resized)
    return tf.expand_dims(image_resized, 0).numpy()


def load_image_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    image = preprocess_image(image)
    return image
    
!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt

--2021-10-24 07:24:11--  https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.8.128, 74.125.23.128, 74.125.203.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.8.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 21675 (21K) [text/plain]
Saving to: ‘ilsvrc2012_wordnet_lemmas.txt’


2021-10-24 07:24:12 (54.5 MB/s) - ‘ilsvrc2012_wordnet_lemmas.txt’ saved [21675/21675]



## Load Image and infer

In [7]:
model = hub.load(model_variant)

In [8]:
with open("ilsvrc2012_wordnet_lemmas.txt", "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

def infer_on_image(img_url, expected_label):
    image = load_image_from_url(img_url)
    predictions = model.signatures["serving_default"](tf.constant(image))
    logits = predictions["output"][0]
    predicted_label = imagenet_int_to_str[int(np.argmax(logits))]
    assert (
        predicted_label == expected_label
    ), f"Expected {expected_label} but was {predicted_label}"

Let's try 5 images

In [9]:
infer_on_image(img_url = "https://i.imgur.com/mH0Wrvb.jpg", expected_label = "goldfish, Carassius_auratus")
infer_on_image(img_url = "https://i.imgur.com/A5m4ZG1.jpg", expected_label = "scorpion")
infer_on_image(img_url = "https://i.imgur.com/faOAEFg.jpg", expected_label = "leatherback_turtle, leatherback, leathery_turtle, Dermochelys_coriacea")
infer_on_image(img_url = "https://i.imgur.com/lfhdaSi.jpg", expected_label = "Siamese_cat, Siamese")
infer_on_image(img_url = "https://i.imgur.com/Qwa8wHX.jpg", expected_label = "boa_constrictor, Constrictor_constrictor")