<a href="https://colab.research.google.com/github/Rishit-dagli/Transformer-in-Transformer/blob/main/example/pre_trained_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TNT pre-trained model

This notebook shows how to use the pre-trained TNT model.
This is an Implementation of the [Transformer in Transformer](https://arxiv.org/abs/2103.00112)
paper by Han et al. for image classification, attention inside local patches.
**Transformer in Transformers** uses pixel level attention paired with patch
level attention for image classification, in TensorFlow.

If you find this useful please consider giving a ⭐ to [the repo](https://github.com/Rishit-dagli/Transformer-in-Transformer).

In [2]:
!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt

--2022-01-17 06:16:40--  https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.210.128, 173.194.213.128, 173.194.215.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.210.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 21675 (21K) [text/plain]
Saving to: ‘ilsvrc2012_wordnet_lemmas.txt’


2022-01-17 06:16:40 (113 MB/s) - ‘ilsvrc2012_wordnet_lemmas.txt’ saved [21675/21675]



In [5]:
from io import BytesIO

import numpy as np
import requests
import tensorflow as tf
import tensorflow_hub as hub
from PIL import Image

model = hub.load("https://tfhub.dev/rishit-dagli/tnt-s/1")

resolution = [224, 224]


def preprocess_image(image):
    image = np.array(image)
    image_resized = tf.image.resize(image, (resolution[0], resolution[1]))
    image_resized = tf.cast(image_resized, tf.float32)
    image_resized = (image_resized - 127.5) / 127.5
    image_resized = tf.keras.layers.Normalization(
        mean=(0.5, 0.5, 0.5), variance=(0.25, 0.25, 0.25)
    )(image_resized)
    return tf.expand_dims(image_resized, 0).numpy()


def load_image_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    image = preprocess_image(image)
    return image


with open("ilsvrc2012_wordnet_lemmas.txt", "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]


def infer_on_image(img_url, expected_label, model):
    image = load_image_from_url(img_url)
    predictions = model.signatures["serving_default"](tf.constant(image))
    logits = predictions["output"][0]
    predicted_label = imagenet_int_to_str[int(np.argmax(logits))]
    assert (
        predicted_label == expected_label
    ), f"Expected {expected_label} but was {predicted_label}"


infer_on_image(
    img_url="https://storage.googleapis.com/rishit-dagli.appspot.com/sample-images/gW4Gh5v.jpg",
    expected_label="tench, Tinca_tinca",
    model=model,
)
infer_on_image(
    img_url="https://storage.googleapis.com/rishit-dagli.appspot.com/sample-images/Wv99De3.jpg",
    expected_label="window_screen",
    model=model,
)