In [None]:
import json
from tqdm import tqdm
import traceback
import numpy as np
import pandas as pd
import polars as pl
from PIL import Image
from pathlib import Path
from torchvision import transforms
from huggingface_hub import hf_hub_download
from open_clip import create_model, get_tokenizer

import torch
import torch.nn.functional as F

RANKS = ["kingdom", "phylum", "class", "order", "family", "genus", "species"]
ranks = [a.title() for a in RANKS]

MODEL_STR = "hf-hub:imageomics/bioclip-2"
TOKENIZER_STR = "ViT-L-14"
HF_DATA_STR = "imageomics/TreeOfLife-200M"
NB_CLASS_TO_KEEP = 5

SESSION_PATH = Path("/media/bioeos/F/202505_plancha_session/20250519_REU-ST-LEU_ASV-1_01")

In [None]:
class BioClip2:

    def __init__(self):
        
        self.preprocess_img = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224), antialias=True),
                transforms.Normalize(
                    mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711),
                ),
            ]
        )

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.init_model()
    
    def init_model(self):
        self.model = create_model(MODEL_STR, output_dict=True, require_pretrained=True)
        self.model = self.model.to(self.device)

        self.tokenizer = get_tokenizer(TOKENIZER_STR)

        self.txt_emb = torch.from_numpy(np.load(hf_hub_download(
            repo_id=HF_DATA_STR,
            filename="embeddings/txt_emb_species.npy",
            repo_type="dataset",
        ))).to(self.device)
        
        with open(hf_hub_download(
            repo_id=HF_DATA_STR,
            filename="embeddings/txt_emb_species.json",
            repo_type="dataset",
        )) as fd:
            self.txt_names = json.load(fd)
        
        with open(hf_hub_download(
                repo_id="imageomics/bioclip-2-demo",
                filename="components/metadata.parquet",
                repo_type="space",
            )) as fd:

                self.metadata_df = pl.read_parquet(fd, low_memory = False)
                self.metadata_df = self.metadata_df.with_columns(pl.col(["eol_page_id", "gbif_id"]).cast(pl.Int64))

    def format_name(self, taxon, common):
        taxon = " ".join(taxon)
        if not common:
            return taxon
        return f"{taxon} ({common})"

    @torch.no_grad()
    def predict(self, img):
   
        img = self.preprocess_img(img).to(self.device)
        img_features = self.model.encode_image(img.unsqueeze(0))
        img_features = F.normalize(img_features, dim=-1)

        logits = (self.model.logit_scale.exp() * img_features @ self.txt_emb).squeeze()
        probs = F.softmax(logits, dim=0).to("cpu")
        topk = probs.topk(NB_CLASS_TO_KEEP)
        prediction_dict = {
            self.format_name(*self.txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
        }

        return prediction_dict

    def get_sample_data(self, pred_taxon, rank = 6):
        for idx in range(rank + 1):
            taxon = RANKS[idx]
            target_taxon = pred_taxon.split(" ")[idx]
            self.metadata_df = self.metadata_df.filter(pl.col(taxon) == target_taxon)

        if self.metadata_df.shape[0] == 0:
            return None, np.nan, np.nan, "", False

        # First, try to find entries with empty lower ranks
        exact_df = self.metadata_df
        for lower_rank in RANKS[rank + 1:]:
            exact_df = exact_df.filter((pl.col(lower_rank).is_null()) | (pl.col(lower_rank) == ""))

        if exact_df.shape[0] > 0:
            df_filtered = exact_df.sample()
            full_name = " ".join(df_filtered.select(RANKS[:rank+1]).row(0))
            return df_filtered["file_path"][0], df_filtered["gbif_taxon_id"].cast(pl.String)[0], df_filtered["eol_page_id"].cast(pl.String)[0], full_name, True

        # If no exact matches, return any entry with the specified rank
        df_filtered = self.metadata_df.sample()
        full_name = " ".join(df_filtered.select(RANKS[:rank+1]).row(0)) + " " + " ".join(df_filtered.select(RANKS[rank+1:]).row(0))
        return df_filtered["file_path"][0], df_filtered["gbif_taxon_id"].cast(pl.String)[0], df_filtered["eol_page_id"].cast(pl.String)[0], full_name, False


bioclip = BioClip2()

In [None]:
if not SESSION_PATH.exists(): raise NameError("Session not found")

metadata_csv_path = Path(SESSION_PATH, "METADATA", "metadata.csv")
if not metadata_csv_path.exists(): raise NameError("Metadata csv not found")

frame_path = Path(SESSION_PATH, "PROCESSED_DATA/FRAMES")
ia_path = Path(SESSION_PATH, "PROCESSED_DATA/IA")
ia_path.mkdir(exist_ok=True)

csv_connector_classes = open(Path(ia_path, "bioclip2.csv"), "w")
csv_connector_classes.write(f"FileName,{','.join(ranks)},score,gbif_id,eol_id\n")

try:
    for img_path in tqdm(sorted(list(Path(frame_path).iterdir()))):
        img_input = Image.open(img_path)
        open_domain_output = bioclip.predict(img_input)

        key_with_max_value = max(open_domain_output, key=lambda k: open_domain_output[k])
        _, gbif_id, eol_id, _, _ = bioclip.get_sample_data(key_with_max_value)
        csv_connector_classes.write(f"{img_path.name},{key_with_max_value.split(' (')[0].replace(' ', ',')},{open_domain_output[key_with_max_value]},{gbif_id},{eol_id}\n")

except:
    print(traceback.format_exc(), end="\n\n")
finally:
    csv_connector_classes.close()
    
