In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os
from PIL import Image
from fastai.vision.all import *
import json
from collections import defaultdict

In [None]:
parquet_path=Path("/home/george/codes/lepinet/data/mini/0013397-241007104925546_processing_metadata_postprocessed.parquet")
images_path=Path("/home/george/codes/lepinet/data/mini/images")
root_path=Path("/home/george/codes/lepinet/data/mini")
export_path=Path("/home/george/codes/lepinet/data/mini/models")

In [None]:
df=pd.read_parquet(parquet_path)

In [None]:
def prepare_df(df, remove_in=[], keep_in=[]):
    # Filter out 'test_ood' rows and 'test_in' rows
    if len(remove_in)>0:
        df = df[~df['set'].isin(remove_in)]
    if len(keep_in)>0:
        df = df[df['set'].isin(keep_in)]
    def generate_image_path(row):
        return Path(str(row['speciesKey'])) / row['filename']

    # Apply the function to create the image paths
    df['image_path'] = df.apply(generate_image_path, axis=1)
    # Add a column to specify whether the row is for training or validation
    df['is_valid'] = df['set'] == '0'
    # Define the hierarchical levels
    hierarchy_levels = ["familyKey", "genusKey", "speciesKey"]

    # Create a function to extract the labels at different hierarchy levels
    def get_hierarchy_labels(row):
        return ' '.join(map(str, [row[level] for level in hierarchy_levels]))

    # Add a column with hierarchy labels
    df['hierarchy_labels'] = df.apply(get_hierarchy_labels, axis=1)
    # Keep only the columns needed for ImageDataLoaders
    df = df[['image_path', 'hierarchy_labels', 'is_valid']]
    return df

## First model training

In [None]:
len(df)

In [None]:
row = df.iloc[0]
row.keys()

In [None]:
image_path = images_path / row["speciesKey"] / row["filename"]

In [None]:
image_path, os.path.isfile(image_path)

In [None]:
image = Image.open(image_path)

In [None]:
(sum(df['set'].isin(['test_ood', '0'])), 
sum(df['set'].isin(['test_ood'])),
sum(df['set'].isin(['0']))
)

In [None]:
df=prepare_df(df.copy(), remove_in=['test_ood'])

In [None]:
df.head()

In [None]:
dls = ImageDataLoaders.from_df(
    df,
    images_path,
    valid_col='is_valid',
    label_delim=' ',
    item_tfms=Resize(460),
    batch_tfms=aug_transforms(size=224))

In [None]:
dls.show_batch()

In [None]:
f1_macro = F1ScoreMulti(thresh=0.5, average='macro')
f1_macro.name = 'F1(macro)'
f1_samples = F1ScoreMulti(thresh=0.5, average='samples')
f1_samples.name = 'F1(samples)'
learn = vision_learner(
    dls, 
    resnet50, 
    metrics=[partial(accuracy_multi, thresh=0.5), f1_macro, f1_samples],
    cbs=[ShowGraphCallback(), CSVLogger(export_path/"history1.csv")]
    )

In [None]:
res=learn.lr_find()

In [None]:
res.valley

In [None]:
learn.fine_tune(10, 2e-2)

In [None]:
learn.show_results()

In [None]:
# Save the model
os.makedirs(export_path, exist_ok=True)

model_path = export_path / "00_lepi_mini_model1"
learn.export(model_path)

In [None]:
!ls -alh /home/george/codes/lepinet/data/mini/models

In [None]:
model_path = export_path / "00_lepi_mini_model1"

learn = load_learner(model_path)

In [None]:
learn.dls.vocab[:10], len(learn.dls.vocab)

### I need to specify the vocab of the MultiCategoryBlock

So I need to go down in the layered architecture

In [None]:
def define_vocab_v1(df):
    vocab=[]
    hierarchy_levels = ["familyKey", "genusKey", "speciesKey"]
    for col in hierarchy_levels:
        vocab += pd.unique(df[col]).tolist()
    vocab = sorted(vocab)
    return vocab 

def build_hierarchy(df: pd.DataFrame, hierarchy_levels: list):
    """
    Build a hierarchical tree where the penultimate level holds a unique list of the lowest-level values.
    """
    hierarchy = defaultdict(lambda: defaultdict(set))  # Use set to avoid duplicates

    for _, row in df.iterrows():
        current_level = hierarchy
        for i, level in enumerate(hierarchy_levels):
            key = row[level]

            if i == len(hierarchy_levels) - 2:  # Penultimate level
                if key not in current_level:
                    current_level[key] = set()
                current_level = current_level[key]
            elif i == len(hierarchy_levels) - 1:  # Lowest level (store unique values)
                current_level.add(key)
            else:
                if key not in current_level:
                    current_level[key] = defaultdict(set)
                current_level = current_level[key] # Goes deeper in the hierarchy

    # Convert sets to lists for the final output
    def convert_sets_to_lists(node):
        if isinstance(node, dict):
            return {k: convert_sets_to_lists(v) for k, v in node.items()}
        elif isinstance(node, set):
            return list(node)
        return node

    return convert_sets_to_lists(hierarchy)

def save_hierarchy_to_file(hierarchy: dict, filename: str):
    """
    Save the hierarchy dictionary to a JSON file.
    """
    with open(filename, "w") as f:
        json.dump(hierarchy, f, indent=4)

def flatten_hierarchy_v1(hierarchy: dict):
    """
    Flatten the hierarchy into a sequential list.
    """
    flat_list = []
    
    def traverse(node):
        if isinstance(node, dict):  # Regular nested dictionary structure
            for key, subnode in node.items():
                flat_list.append(key)
                traverse(subnode)
        elif isinstance(node, list):  # Leaf level is a list
            for item in node:
                flat_list.append(item)

    traverse(hierarchy)
    return flat_list

def flatten_hierarchy(hierarchy: dict):
    """
    Flatten the hierarchy into a sequential list.
    """
    flat_list = []
    
    def traverse(node):
        if isinstance(node, dict):  # Regular nested dictionary structure
            for key, subnode in node.items():
                flat_list.append(key)
            for key, subnode in node.items():
                traverse(subnode)
        elif isinstance(node, list):  # Leaf level is a list
            for item in node:
                flat_list.append(item)

    traverse(hierarchy)
    return flat_list

def define_vocab(df):
    hierarchy=build_hierarchy(
        df, hierarchy_levels=["familyKey", "genusKey", "speciesKey"])
    vocab=flatten_hierarchy(hierarchy)
    return vocab

def get_higher_levels(hierarchy: dict, taxa_id, path=None):
    """
    Get the list of all higher levels in the hierarchy for a given taxa_id.
    """
    if path is None:
        path = []
    
    for key, subnode in hierarchy.items():
        new_path = path + [key]

        if isinstance(subnode, list):  # If the penultimate level is a list
            if taxa_id in subnode:
                return new_path + [taxa_id]
        elif isinstance(subnode, dict):  # Traverse deeper levels
            result = get_higher_levels(subnode, taxa_id, new_path)
            if result:
                return result
    
    return None


def invert_hierarchy(hierarchy: dict):
    """
    Generate an inverse tree where each taxa_id points to its higher-level category.
    """
    inverse = {}

    def traverse(node, parent=None):
        for key, subnode in node.items():
            if isinstance(subnode, list):  # If we reach the list level
                inverse[key] = parent
                for item in subnode:
                    inverse[item] = key  # The item belongs to the parent category
            elif isinstance(subnode, dict):  # Continue traversing deeper levels
                inverse[key] = parent
                traverse(subnode, key)

    traverse(hierarchy)
    return inverse

df=pd.read_parquet(parquet_path)
df_train_val = df[~df['set'].isin(["test_ood"])]
hierarchy=build_hierarchy(df_train_val, hierarchy_levels = ["familyKey", "genusKey", "speciesKey"])
save_hierarchy_to_file(hierarchy, filename=root_path/"hierarchy_train.json")
inverse_hierarchy=invert_hierarchy(hierarchy)
save_hierarchy_to_file(inverse_hierarchy, filename=root_path/"inverse_hierarchy_train.json")
vocab=define_vocab(df_train_val)
vocab[:10]

In [None]:
# entire hierarchy
df=pd.read_parquet(parquet_path)
hierarchy=build_hierarchy(df, hierarchy_levels = ["familyKey", "genusKey", "speciesKey"])
save_hierarchy_to_file(hierarchy, filename=root_path/"hierarchy_all.json")

In [None]:
df = prepare_df(df.copy(), remove_in=["test_ood"])

In [None]:
# Let's redefine the dataloader

datablock = DataBlock(
    blocks=(ImageBlock, MultiCategoryBlock(vocab=vocab)),
    splitter=ColSplitter(),
    get_x=ColReader(0, pref=images_path),
    get_y=ColReader(1, label_delim=' '),
    item_tfms=Resize(460),
    batch_tfms=aug_transforms(size=224)
)
dls = datablock.dataloaders(df)


In [None]:
dls.show_batch()

In [None]:
f1_macro = F1ScoreMulti(thresh=0.5, average='macro')
f1_macro.name = 'F1(macro)'
f1_samples = F1ScoreMulti(thresh=0.5, average='samples')
f1_samples.name = 'F1(samples)'
learn = vision_learner(
    dls, 
    resnet50, 
    metrics=[partial(accuracy_multi, thresh=0.5), f1_macro, f1_samples],
    cbs=[ShowGraphCallback(), CSVLogger(export_path/"history3.csv")]
    )

In [None]:
learn.fine_tune(10, 2e-2,)

In [None]:
learn.cbs

In [None]:
learn.show_results()

In [None]:
# Save the model
os.makedirs(export_path, exist_ok=True)

model_path = export_path / "00_lepi_mini_model2"
learn.export(model_path)