# Use a pretrained ImageNet model via `scivision`

(Work in Progress)

In this notebook, we will:

1. Load a model from the following test repo:
https://github.com/alan-turing-institute/scivision-test-plugin which we previously added to the scivision catalog with the name "scivision-test-plugin", as per [this guide](https://scivision.readthedocs.io/en/latest/contributing.html#extending-the-scivision-catalog)
2. Use the scivision catalog to find a matching dataset, which the model can be run on
3. Run the model on the data, performing simple model inference

In [None]:
model_repo = "https://github.com/alan-turing-institute/scivision-test-plugin"
model_name = "scivision-test-plugin"

Note: The model repository follows the strcuture specified in [this template](https://scivision.readthedocs.io/en/latest/model_repository_template.html), including a `scivision` [model config file](https://github.com/alan-turing-institute/scivision-test-plugin/blob/main/.scivision/model.yml).

We first import some things from scivision: `default_catalog` is a scivision **catalog** that will let us discover models and datasets, and `load_pretrained_model` provides a convenient way to load and run a model.

In [None]:
from scivision import default_catalog, load_pretrained_model

## Query the scivision catalog

A scivision catalog is a collection of **models** and **datasources**.

For this example, we want to find datasources compatible with "scivision-test-plugin".  But first, let's first let's take a look at all of the models in the *default catalog* (the built-in catalog, distributed as part of scivision).

In [None]:
default_catalog.models.to_dataframe()

Next, we identify datasources in the catalog that would be compatible with the model (based on `tasks`, `format` and `labels_provided`/`labels_required`).

In [None]:
compatible_datasources = default_catalog.compatible_datasources(model_name).to_dataframe()
compatible_datasources

Let's use `data-003`, an image of a baby Koala.

TODO: To be replaced by a query-based interface on the catalog)

In [None]:
target_datasource = compatible_datasources.loc[compatible_datasources['name'] == 'data-003']
target_datasource

## Load a model

In [None]:
model = load_pretrained_model(model_repo, allow_install=True)

In [None]:
# let's explore the model object
model

## Load input image

In [None]:
# libraries
from skimage.io import imread
from skimage.transform import resize
from tensorflow.keras.applications.imagenet_utils import decode_predictions

import matplotlib.pyplot as plt

In [None]:
inputs = target_datasource['url'].item()
inputs

In [None]:
x = imread(inputs)
plt.imshow(x)

In [None]:
# preprocess: resize to (224, 224) and cast back to 0-255 range
X = resize(x, (224, 224), 
           preserve_range=True, 
           anti_aliasing=True)

## Model predictions

Now let's use the loaded model on the dataset we found in the catalog.

In [None]:
y = model.predict(X)

In [None]:
def get_imagenet_label(probs):
    return decode_predictions(probs, top=1)[0][0]

In [None]:
plt.figure()
plt.imshow(x)
_, image_class, class_confidence = get_imagenet_label(y)
plt.title("{} : {:.2f}%".format(image_class, class_confidence * 100))
plt.show()