# Use a pretrained ImageNet model via `scivision`

(Work in Progress)

In this notebook, we will:

1. load a model from the following test repo:
https://github.com/quantumjot/scivision-test-plugin
2. find matching datasets
3. perform a simple model inference on an image of a koala

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from scivision.io import load_pretrained_model
from scivision.catalog import catalog

## Load the database and query a model

In [None]:
# Query a certain model
query_dict = {'model': 'https://github.com/quantumjot/scivision-test-plugin'}

queryO = catalog.query(query_dict)
queryO

`query_dict`, in the above cell, can take other keys/values, e.g.:

```python
query_dict = {'task': 'object-detection'}
```

We can specify more than one key/value pair, and the specified conditioned will be joined by: `" & ".join(queries)`

---

In the above cell, we queried two datasets stored here: 

https://github.com/alan-turing-institute/scivision/tree/main/scivision/catalog/data

Also, you can access them via:

```python
# models
catalog._catalog._models
# data sources
catalog._catalog._datasources
```

In [None]:
# let's subset the second entry, a baby koala
target = queryO[1]
target

## Load a model

Each model contains a `scivision` config file, for example, see: https://github.com/quantumjot/scivision-test-plugin/blob/main/.scivision-config_imagenet.yaml

There are different ways to specify the path to the config file:

```python
config_path = target['model'] + "/" + '.scivision-config_imagenet.yaml'
```

`scivision` will construct the correct github path.

Or:

```python
config_path = "https://github.com/quantumjot/scivision-test-plugin/blob/main/.scivision-config_imagenet.yaml"
```

Or:

```python
config_path = "https://raw.githubusercontent.com/quantumjot/scivision-test-plugin/main/.scivision-config_imagenet.yaml"
```

In [None]:
config_path = "https://github.com/quantumjot/scivision-test-plugin/blob/main/.scivision-config_imagenet.yaml"

In [None]:
model = load_pretrained_model(config_path, 
                              allow_install=True)

In [None]:
# let's explore the model object
model

## Load input image

In [None]:
# libraries
from skimage.io import imread
from skimage.transform import resize
from tensorflow.keras.applications.imagenet_utils import decode_predictions

import matplotlib.pyplot as plt

In [None]:
inputs = target['datasource']
inputs

In [None]:
x = imread(inputs)
plt.imshow(x)

In [None]:
# preprocess: resize to (224, 224) and cast back to 0-255 range
X = resize(x, (224, 224), 
           preserve_range=True, 
           anti_aliasing=True)

## Model predictions

In [None]:
y = model.predict(X)

In [None]:
def get_imagenet_label(probs):
    return decode_predictions(probs, top=1)[0][0]

In [None]:
plt.figure()
plt.imshow(x)
_, image_class, class_confidence = get_imagenet_label(y)
plt.title("{} : {:.2f}%".format(image_class, class_confidence * 100))
plt.show()