In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os
from PIL import Image
from fastai.vision.all import *
from pprint import pprint
from inspect import getmembers, getmembers_static
from sklearn.metrics import f1_score
import json

In [None]:
parquet_path=Path("/home/george/codes/lepinet/data/mini/0013397-241007104925546_processing_metadata_postprocessed.parquet")
images_path=Path("/home/george/codes/lepinet/data/mini/images")
root_path=Path("/home/george/codes/lepinet/data/mini")
export_path=Path("/home/george/codes/lepinet/data/mini/models")

In [None]:
# Load DataFrame
df=pd.read_parquet(parquet_path)
levels = {
    "speciesKey":"scientificName",
    "genusKey":"genus",
    "familyKey":"family",
}
hierarchy_levels = list(reversed(levels.keys()))

In [None]:
# Load trained model
model_path = export_path / "00_lepi_mini_model1"

learn = load_learner(model_path, cpu=False)
len(learn.dls.vocab)

In [None]:
# Load hierarchy
with open(root_path/"hierarchy_train.json", "r") as file:
    hierarchy=json.load(file)

with open(root_path/"hierarchy_all.json", "r") as file:
    hierarchy_all=json.load(file)

In [None]:
def prepare_df(df, remove_in=[], keep_in=[]):
    # Filter out 'test_ood' rows and 'test_in' rows
    if len(remove_in)>0:
        df = df[~df['set'].isin(remove_in)]
    if len(keep_in)>0:
        df = df[df['set'].isin(keep_in)]
    def generate_image_path(row):
        return Path(str(row['speciesKey'])) / row['filename']

    # Apply the function to create the image paths
    df['image_path'] = df.apply(generate_image_path, axis=1)
    # Add a column to specify whether the row is for training or validation
    df['is_valid'] = df['set'] == '0'
    # Define the hierarchical levels
    # hierarchy_levels = ["familyKey", "genusKey", "speciesKey"]

    # Create a function to extract the labels at different hierarchy levels
    def get_hierarchy_labels(row):
        return ' '.join(map(str, [row[level] for level in hierarchy_levels]))

    # Add a column with hierarchy labels
    df['hierarchy_labels'] = df.apply(get_hierarchy_labels, axis=1)
    # Keep only the columns needed for ImageDataLoaders
    df = df[['image_path', 'hierarchy_labels', 'is_valid']]
    return df

df_val = prepare_df(pd.read_parquet(parquet_path), keep_in=["0"])
df_train = prepare_df(pd.read_parquet(parquet_path), remove_in=["test_ood"])
df_ood = prepare_df(pd.read_parquet(parquet_path), keep_in=["test_ood"])
df_all = prepare_df(pd.read_parquet(parquet_path))

## Result analysis

In [None]:
df_val.head()

In [None]:
# test on one image
pred=learn.predict(images_path/df_val["image_path"].iloc[0])
pred_classes, pred_one_hot, pred_proba = pred
pred_classes, type(pred_classes)

In [None]:
targs=df_val["hierarchy_labels"].iloc[0].split(" ")
targs

In [None]:
f1_macro = F1ScoreMulti(thresh=0.5, average='macro')
f1_macro(pred_classes, targs)

In [None]:
dls = ImageDataLoaders.from_df(
    df_train,
    images_path,
    valid_col='is_valid',
    label_delim=' ',
    item_tfms=Resize(460),
    batch_tfms=aug_transforms(size=224))

In [None]:
getmembers_static(dls)

In [None]:
getmembers_static(dls.loaders[0])

In [None]:
type(dls.valid)

In [None]:
# Make sure that one_batch give the same batch and is not randomly generated.
for i in range(2):
    batch=dls.valid.one_batch()
    pprint(batch[0][0].sum())

In [None]:
len(batch)

In [None]:
batch_in, batch_targs = batch
batch_in.shape, batch_targs.shape

In [None]:
learn.validate(dl=dls.valid)


* F1 macro on the validation set
* F1 macro on the test set

In [None]:
f1_score(np.array([[0,1]]), np.array([[0,1]]), average='macro')

### Evaluation on Out-of-distribution species

The evaluation function must be able to deal with two different vocab, one 
for the testing set and one for the training set.

In [None]:
def define_vocab(df):
    vocab=[]
    for i, row in df.iterrows():
        vocab += row["hierarchy_labels"].split()
    vocab = sorted(np.unique(vocab).tolist())
    return vocab

test_eq(define_vocab(df_train), learn.dls.vocab)

In [None]:
len(define_vocab(df_val)), len(define_vocab(df_train)), len(define_vocab(df_ood)), len(define_vocab(df_all))

In [None]:
learn.predict(images_path/df_all["image_path"].iloc[0])

In [None]:
df_ood.head()

In [None]:
filenames_ood = df_ood["image_path"].apply(lambda x: str(images_path/x)).tolist()
targs_ood = df_ood["hierarchy_labels"]

test_dl=learn.dls.test_dl(filenames_ood)

preds, _ = learn.get_preds(dl=test_dl)
len(preds), len(filenames_ood), (preds>.5).sum(1), len(targs_ood), preds.shape

In [None]:
next(learn.model.parameters()).is_cuda

Display some predictions alongside their ground truth.

In [None]:
for i, p in enumerate(preds):
    classes = learn.dls.vocab[p>.5]
    if len(classes) > 0:
        print(i, classes, targs_ood.iloc[i])

Display some examples

In [None]:
def select_taxon_imgs(
        taxon_id: str, 
        images_dir: str, 
        hierarchy: dict, 
        n: int = 5,
        image_formats: tuple = (
            '.jpg', '.jpeg', '.png', '.gif', '.tiff', '.tif', '.webp')
        ):
    """
    Given a taxon ID and a hierarchy, select a random set of n images located in images_dir.
    If taxon_id is a species, images are selected from its folder.
    If taxon_id is a higher-level category, images are selected from all associated species' folders.
    
    Returns a list of image file paths.
    """
    species_folders = set()

    def find_species(node, path=[]):
        """Recursively collect all species belonging to a taxon."""
        for key, subnode in node.items():
            if isinstance(subnode, list):  # This is the penultimate level
                if taxon_id == key or taxon_id in path:  
                    species_folders.update(subnode)
                elif taxon_id in subnode: # If taxon is a species
                    species_folders.add(taxon_id)
            elif isinstance(subnode, dict):  # Higher level, continue traversal
                find_species(subnode, path + [key])

    # Identify relevant species folders
    find_species(hierarchy)

    # Collect all image file paths
    image_paths = []
    for species in species_folders:
        species_path = os.path.join(images_dir, species)
        if os.path.isdir(species_path):
            image_files = [os.path.join(species_path, f) for f in os.listdir(species_path) if f.lower().endswith(image_formats)]
            image_paths.extend(image_files)

    # Select a random subset of images
    return random.sample(image_paths, min(n, len(image_paths)))

def show_files(filenames, suptitle=None):
    # Load images
    ims = [Image.open(f) for f in filenames]
    titles=[Path(f).parent.name for f in filenames]
    show_images(ims, titles=titles, suptitle=suptitle)

def show_taxon(taxon_id):
    """Display samples for the given taxon id.
    """
    filenames=select_taxon_imgs(
        taxon_id=taxon_id,
        images_dir=images_path,
        hierarchy=hierarchy_all,
    )
    show_files(filenames, suptitle=taxon_id)

In [None]:


family_images=select_taxon_imgs(
    taxon_id="7015",
    images_dir=images_path,
    hierarchy=hierarchy,
)
genus_images=select_taxon_imgs(
    taxon_id="1768691",
    images_dir=images_path,
    hierarchy=hierarchy,
)
species_images=select_taxon_imgs(
    taxon_id="1768749",
    images_dir=images_path,
    hierarchy=hierarchy,
)
family_images, genus_images, species_images

show_files(family_images, suptitle="Family")
show_files(genus_images, suptitle="Genus")
show_files(species_images, suptitle="Species")

In [None]:
species_images=select_taxon_imgs(
    taxon_id="1768749",
    images_dir=images_path,
    hierarchy=hierarchy,
)

Divide the model outputs into hierarchy levels

In [None]:
def gen_level_idx(vocab, hierarchy):
    """
    Returns a list of integers of the size of vocab indicating the hierarchical level of the taxa at index i.
    - Species is level 0, Genus 1, Family 2, etc.
    - Missing values are noted with -1.

    Args:
    - vocab (list): List of taxa names to find levels for.
    - hierarchy (dict): Nested dictionary representing taxonomic hierarchy.

    Returns:
    - np.ndarray: Array of level indices for each taxa in vocab.
    """
    level_lookup = {}

    def traverse(node, level=0):
        """Recursively traverse the hierarchy and store levels."""
        for key, subnode in node.items():
            level_lookup[key] = level  # Assign level to the taxon
            if isinstance(subnode, dict):
                traverse(subnode, level + 1)
            elif isinstance(subnode, list):  # Leaf nodes (species level)
                for species in subnode:
                    level_lookup[species] = level + 1

    # Build the level lookup dictionary
    traverse(hierarchy)  # Start from -1 so species end up at level 0

    # Assign levels to vocab, default to -1 if missing
    indices = np.array([level_lookup.get(v, -1) for v in vocab], dtype=int)

    # Invert the indices, so species is 0, genus is 1 etc
    indices = np.where(indices < 0, indices, indices.max()-indices)

    # Warning for missing values
    missing_count = np.sum(indices == -1)
    if missing_count > 0:
        print(f"[Warning] Missing values in taxa dictionary: {missing_count}.")

    return indices

def split_preds(preds:torch.Tensor, indices:np.ndarray):
    """Returns split preds using indices.

    `preds` is a batch of predictions.
    """
    out_preds = []
    indices = torch.from_numpy(indices)
    for i in range(indices.max()+1):
        out_preds += [preds[:,indices==i].cpu().numpy()]
    return out_preds

def get_pred_conf(preds:torch.Tensor, vocab:CategoryMap, indices:np.ndarray):
    """Returns predicted labels and confidence for each pred and for each 
    hierarchy level.

    `preds` is a batch of predictions.
    """
    out_preds = []
    out_confs = []
    indices = torch.from_numpy(indices)
    for i in range(indices.max()+1):
        one_level_pred = preds[:,indices==i].cpu().numpy()
        one_level_prd = vocab[indices==i][one_level_pred.argmax(axis=1)]
        one_level_cnf = one_level_pred.max(axis=1)
        out_preds += [one_level_prd]
        out_confs += [one_level_cnf]
    return np.array(out_preds).swapaxes(0,1), np.array(out_confs).swapaxes(0,1)

def get_lbls(df:pd.DataFrame, setname:str, levels:dict):
    return df[df['set']==setname][[l for l in levels.keys()]].to_numpy()

def save_npy(
    fname:str,
    prds:np.ndarray,
    cnfs:np.ndarray, 
    lbls:np.ndarray=None):
    if lbls is not None:
        output = np.stack((prds,cnfs,lbls), axis=-1)
    else:
        output = np.stack((prds,cnfs), axis=-1)
    np.save(fname, output)

def save_csv(
    fname:str,
    filenames:list,
    prds:np.ndarray,
    cnfs:np.ndarray, 
    lbls:np.ndarray=None):

    # Flatten the predictions and confidences
    n, p = prds.shape
    lvls = np.tile(np.arange(p), n)
    prds = prds.flatten(order='C') # C-type 
    cnfs = cnfs.flatten(order='C')
    flns = np.repeat(filenames, p)

    df=pd.DataFrame({
        'filename':flns,
        'level':lvls,
        'prediction':prds,
        'confidence':cnfs
    })
    
    if lbls is not None:
        df["label"] = lbls.flatten(order='C')

    df.to_csv(fname)



indices=gen_level_idx(learn.dls.vocab, hierarchy)
print(len(indices))
one_pred=split_preds(preds, indices)
for i in range(len(one_pred)):
    print(one_pred[i].shape)


prds, cnfs = get_pred_conf(preds, learn.dls.vocab, indices)
lbls = get_lbls(df, setname='test_ood', levels=levels)
print(prds.shape, lbls.shape)
print(prds[:10])
print(lbls[:10])

save_csv(
    root_path/'pred_ood_model1.csv',
    filenames=df[df['set']=='test_ood'].apply(lambda r: f"{r["speciesKey"]}/{r["filename"]}", axis=1),
    prds=prds,
    cnfs=cnfs,
    lbls=lbls
    )

In [None]:
np.stack((np.ones((4,3)),np.ones((4,3))), axis=-1).shape

In [None]:
taxon_level = 0 # species
nbof_levels = 2
file_idx = np.random.randint(len(targs_ood))

print("Input image.")
show_image(Image.open(filenames_ood[file_idx]))

print("Prediction class.")
pred_taxon=learn.dls.vocab[indices==taxon_level][one_pred[taxon_level][file_idx].argmax()]
print(pred_taxon)
show_taxon(pred_taxon)

print("Ground truth class.")
show_taxon(targs_ood.iloc[file_idx].split(' ')[nbof_levels-taxon_level])



In [None]:
# show_taxon(targs_ood[0].split(' ')[-1])
to_show=targs_ood.iloc[file_idx].split(' ')[nbof_levels-taxon_level]
show_taxon(to_show)