# Hierarchical labels

## Load the dataset

In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import geopandas
from scivision.io import load_dataset

In [2]:
cat = load_dataset('https://github.com/alan-turing-institute/plankton-dsg-challenge')

ds_all = cat.plankton_multiple().to_dask()
labels = cat.labels().read()

labels_dedup = xr.Dataset.from_dataframe(
    labels
    .drop_duplicates(subset=["filename"])
    .set_index("filename")
    .sort_index()
)

ds_labelled = (
    ds_all
    .swap_dims({"concat_dim": "filename"})
    .merge(labels_dedup, join="inner")
    .swap_dims({"filename": "concat_dim"})
)

`label3` contains the most granular labels of the data. The distinct classes are shown below

In [3]:
ds_labelled.label3.to_pandas().drop_duplicates()

concat_dim
0                                  appendicularia
16                            annelida_polychaeta
19                                   chaetognatha
20                 copepod_calanoida_candacia-spp
21                             tunicata_doliolida
30                                     mysideacea
51                                     euphausiid
83                                      fish-eggs
87                                bivalvia-larvae
91              copepod_calanoida_centropages-spp
96                                 byrozoa-larvae
98                             euphausiid_nauplii
129                               copepod_nauplii
152                              gastropoda-larva
172                copepod_cyclopoida_oithona-spp
467              copepod_cyclopoida_corycaeus-spp
488                 copepod_cyclopoida_oncaea-spp
684                  copepod_calanoida_temora-spp
724                          echniodermata-larvae
749                                    

While these labels may be used to produce entirely separate classes, the values of `label3` form a hierarchical structure.  One way to split label3, exposing this structure, is shown below.

In [4]:
labels_hierarchical = (
    ds_labelled.label3.to_pandas()
    .replace("nt_phyto_chains", "nt-phyto_chains")
    .replace("euphausiid_nauplii", "euphausiid-nauplii")
    .replace("copepod_nauplii", "copepod-nauplii")
    .str.split("_", expand=True)
    .replace("unknown", pd.NA)
    .fillna(pd.NA)
    .rename(columns={0:"label3level1",1:"label3level2",2:"label3level3"})
)

In [5]:
labels_hierarchical.drop_duplicates().sort_values(["label3level1", "label3level2", "label3level3"])

Unnamed: 0_level_0,label3level1,label3level2,label3level3
concat_dim,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
16,annelida,polychaeta,
0,appendicularia,,
87,bivalvia-larvae,,
96,byrozoa-larvae,,
19,chaetognatha,,
990,cirripedia,barnacle-nauplii,
3258,cladocera,evadne-spp,
749,cladocera,,
1764,cnidaria,,
864,copepod,calanoida,acartia-spp


In [6]:
print(ds_labelled.assign(labels_hierarchical))

<xarray.Dataset>
Dimensions:                               (concat_dim: 52894, y: 832, x: 1040, channel: 3)
Coordinates:
    filename                              (concat_dim) object 'Pia1.2016-08-0...
  * concat_dim                            (concat_dim) int64 0 1 ... 58750 58751
  * y                                     (y) int64 0 1 2 3 ... 828 829 830 831
  * x                                     (x) int64 0 1 2 3 ... 1037 1038 1039
  * channel                               (channel) int64 0 1 2
Data variables: (12/30)
    raster                                (concat_dim, y, x, channel) uint8 dask.array<chunksize=(1, 832, 1040, 3), meta=np.ndarray>
    EXIF Image ImageWidth                 (concat_dim) object dask.array<chunksize=(1,), meta=np.ndarray>
    EXIF Image ImageLength                (concat_dim) object dask.array<chunksize=(1,), meta=np.ndarray>
    EXIF Image BitsPerSample              (concat_dim) object dask.array<chunksize=(1,), meta=np.ndarray>
    EXIF Image Comp