# End-to-end `scivision` pipeline for a pretrained/prebuilt model for plankton classification

This notebook demonstrates `scivision` to load a pretrained `ResNet50` model suited to predict plankton species from images captured by Plantkon Image . `detectreeRGB` was implemented in python 3.8 using Mask R-CNN deployed from [detectron2](https://github.com/facebookresearch/detectron2/blob/main/docs/tutorials/install.md) library to delineate tree crowns accurately.
Further details of the challenge can be found in [the original model repository](https://github.com/alan-turing-institute/plankton-dsg-challenge).

The code of this notebook let `scivision` to fetch input data from:
TODO (current version using a toy xarray dataset)

and load the pretrained `detectreeRGB` from:
https://github.com/alan-turing-institute/plankton-cefas-scivision

## Install libraries

In [1]:
!pip install scivision

## Load libraries

In [2]:
from scivision.io import load_pretrained_model
import numpy as np
import xarray as xr

## Model

In [None]:
# Load model
scivision_yml = 'https://github.com/alan-turing-institute/plankton-cefas-scivision/.scivision-config.yaml'
model = load_pretrained_model(scivision_yml, allow_install=True)

In [4]:
# let's explore the model object
model

## Data

In [6]:
# recreate plankton Xarray Dataset

N = 15

image = np.random.randint(255, size=(900, 800, 3), dtype=np.uint8)

ds = np.array([image] * N)

ds = xr.DataArray(ds, dims=['concat_dim','y', 'x', 'channel'],
                        coords={'concat_dim':  np.arange(ds.shape[0]),
                                'y': np.arange(ds.shape[1]),
                                'x': np.arange(ds.shape[2]),
                                'channel': np.arange(ds.shape[3])})

ds = ds.to_dataset(name='raster')

ds = ds.assign(
    image_width = np.random.randint(500, 600),
    image_length = np.random.randint(500, 600)
)

scivision.PretrainedModel( 
  module='scivision_plankton_models', 
  model='resnet50_label3', 
  source='https://github.com/acocac/scivision-plankton-models' 
  pipe='DataPipe(input=<Parameter "X: numpy.ndarray">, output=<Parameter "image: numpy.ndarray">)' 
)

## Prediction and visualisation

In [9]:
y = model.predict(ds)

236
