# Evaluation of the ResNet-50 model

## Import libraries

In [3]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

from scivision.io import load_pretrained_model, load_dataset

## Load hold-out (test) dataset

In [6]:
cat = load_dataset('https://github.com/alan-turing-institute/plankton-dsg-challenge')

ds_all = cat.plankton_multiple().to_dask()
labels_holdout = cat.labels_holdout().read()

labels_holdout_dedup = xr.Dataset.from_dataframe(
    labels_holdout
    .drop_duplicates(subset=["filename"])
    .set_index("filename")
    .sort_index()
)

ds_holdout_labelled = (
    ds_all
    .swap_dims({"concat_dim": "filename"})
    .merge(labels_holdout_dedup, join="inner")
    .swap_dims({"filename": "concat_dim"})
)

In [7]:
ds_holdout_labelled = ds_holdout_labelled.assign(
    image_width = ds_holdout_labelled['EXIF Image ImageWidth'].to_pandas().apply(lambda x: x.values[0]),
    image_length = ds_holdout_labelled['EXIF Image ImageLength'].to_pandas().apply(lambda x: x.values[0])
)

## Load pretrained model

In [1]:
# run if changes are made in https://github.com/acocac/scivision-plankton-models then restart the kernel
#!pip -q uninstall -y scivision_plankton_models 

In [4]:
# Load model
scivision_yml = 'https://github.com/acocac/scivision-plankton-models/.scivision-config-resnet50_label3.yaml'
model = load_pretrained_model(scivision_yml, allow_install=True)

In [5]:
model

scivision.PretrainedModel( 
  module='scivision_plankton_models', 
  model='resnet50_label3', 
  source='https://github.com/acocac/scivision-plankton-models' 
  pipe='DataPipe(input=<Parameter "X: numpy.ndarray">, output=<Parameter "image: numpy.ndarray">)' 
)

## Dataset and Dataloader

In [31]:
from scivision_plankton_models import PlanktonDataset
from torch.utils.data import DataLoader
import tqdm
import torch
import matplotlib.pyplot as plt

In [14]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [26]:
dataset = PlanktonDataset(ds_holdout_labelled)

batch_size=24

num_iterations = max(1, len(dataset) // batch_size)
print(num_iterations)

236


In [28]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

## Predict and visualise

In [29]:
nb_classes = 39

confusion_matrix = torch.zeros(nb_classes, nb_classes)
with torch.no_grad():
    for i, (_, inputs) in enumerate(dataloader):
        inputs = inputs.to(device)
        outputs = model.predict(inputs)
        _, preds = torch.max(outputs, 1)
        break

  batch = torch.tensor(image)


In [30]:
preds

tensor([21, 38, 38,  4, 13, 28, 38,  4, 13, 38, 21, 21, 21, 38, 29, 12, 21, 18,
        18, 28, 21,  4, 21, 21])