# Try to use BioCLIP2

In [None]:
from open_clip import create_model, get_tokenizer
import polars as pl
import torch
from torchvision import transforms
import numpy as np
from huggingface_hub import hf_hub_download
import json
from pathlib import Path
from PIL import Image
import torch.nn.functional as F
import collections
import heapq
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
import asyncio, aiohttp
import pandas as pd
import nest_asyncio
nest_asyncio.apply()

# model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2')
# tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip-2')

In [None]:
model_str = "hf-hub:imageomics/bioclip-2"
tokenizer_str = "ViT-L-14"
HF_DATA_STR = "imageomics/TreeOfLife-200M"

min_prob = 1e-9
k = 5

# device = torch.device("cpu")
device = torch.device("cuda")

preprocess_img = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711),
        ),
    ]
)

ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")


In [None]:
model = create_model(model_str, output_dict=True, require_pretrained=True)
model = model.to(device)

In [None]:
model = torch.compile(model)

In [None]:
tokenizer = get_tokenizer(tokenizer_str)

In [None]:
txt_emb = torch.from_numpy(np.load(hf_hub_download(
        repo_id=HF_DATA_STR,
        filename="embeddings/txt_emb_species.npy",
        repo_type="dataset",
    )))

In [None]:
with open(hf_hub_download(
        repo_id=HF_DATA_STR,
        filename="embeddings/txt_emb_species.json",
        repo_type="dataset",
    )) as fd:
        txt_names = json.load(fd)

In [None]:
# def get_sample(df, pred_taxon, rank):
#     '''
#     Function to retrieve a sample image of the predicted taxon and GBIF or EOL page link for more info.
#     Parameters:
#     -----------
#     df : DataFrame
#         DataFrame with all sample images listed and their filepaths (in "file_path" column).
#     pred_taxon : str
#         Predicted taxon of the uploaded image.
#     rank : int
#         Index of rank in RANKS chosen for prediction.
#     Returns:
#     --------
#     img : PIL.Image
#         Sample image of predicted taxon for display.
#     ref_page : str
#         URL to GBIF or EOL page for the taxon (may be a lower rank, e.g., species sample).
#     '''
#     logger.info(f"Getting sample for taxon: {pred_taxon} at rank: {rank}")
#     try:
#         filepath, gbif_taxon_id, eol_page_id, full_name, is_exact = get_sample_data(df, pred_taxon, rank)
#     except Exception as e:
#         logger.error(f"Error retrieving sample data: {e}")
#         return None, f"We encountered the following error trying to retrieve a sample image: {e}."
#     if filepath is None:
#         logger.warning(f"No sample image found for taxon: {pred_taxon}")
#         return None, f"Sorry, our GBIF and EOL images do not include {pred_taxon}."

#     # Get sample image of selected individual
#     try:
#         img_src = s3_client.generate_presigned_url('get_object',
#                                                    Params={'Bucket': 'treeoflife-200m-sample-images',
#                                                            'Key': filepath}
#                                                    )
#         img_resp = requests.get(img_src)
#         img = Image.open(io.BytesIO(img_resp.content))
#         if gbif_taxon_id:
#             gbif_url = GBIF_URL + gbif_taxon_id
#             if eol_page_id:
#                 eol_url = EOL_URL + eol_page_id
#                 if is_exact:
#                     ref_page = f"<p>Check out the <a href={eol_url} target='_blank'>EOL</a> or <a href={gbif_url} target='_blank'>GBIF</a> entry for {pred_taxon} to learn more.</p>"
#                 else:
#                     ref_page = f"<p>Check out an example entry within {pred_taxon} to learn more: {full_name} at <a href={eol_url} target='_blank'>EOL</a> or <a href={gbif_url} target='_blank'>GBIF</a>.</p>"
#             else:
#                 if is_exact:
#                     ref_page = f"<p>Check out the <a href={gbif_url} target='_blank'>GBIF</a> entry for {pred_taxon} to learn more.</p>"
#                 else:
#                     ref_page = f"<p>Check out an example GBIF entry within {pred_taxon} to learn more: <a href={gbif_url} target='_blank'>{full_name}</a>.</p>"
#         else:
#             eol_url = EOL_URL + eol_page_id
#             if is_exact:
#                     ref_page = f"<p>Check out the <a href={eol_url} target='_blank'>EOL</a> entry for {pred_taxon} to learn more.</p>"
#             else:
#                 ref_page = f"<p>Check out an example EOL entry within {pred_taxon} to learn more: <a href={eol_url} target='_blank'>{full_name}</a>.</p>"
#         logger.info(f"Successfully retrieved sample image and page for {pred_taxon}")
#         return img, ref_page
#     except Exception as e:
#         logger.error(f"Error retrieving sample image: {e}")
#         return None, f"We encountered the following error trying to retrieve a sample image: {e}."

def format_name(taxon, common):
    taxon = " ".join(taxon)
    if not common:
        return taxon
    return f"{taxon} ({common})"

@torch.no_grad()
def open_domain_classification(img, rank: int, return_all=False):
    """
    Predicts from the entire tree of life.
    If targeting a higher rank than species, then this function predicts among all
    species, then sums up species-level probabilities for the given rank.
    """

    # logger.info(f"Starting open domain classification for rank: {rank}")
    # print(f"Starting open domain classification for rank: {rank}")
    img = preprocess_img(img).to(device)
    img_features = model.encode_image(img.unsqueeze(0))
    img_features = F.normalize(img_features, dim=-1)

    logits = (model.logit_scale.exp() * img_features @ txt_emb.to(device)).squeeze()
    probs = F.softmax(logits, dim=0)

    if rank + 1 == len(ranks):
        topk = probs.topk(k)
        prediction_dict = {
            format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
        }
        # logger.info(f"Top K predictions: {prediction_dict}")
        # print(f"Top K predictions: {prediction_dict}")
        top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
        # logger.info(f"Top prediction name: {top_prediction_name}")
        # print(f"Top prediction name: {top_prediction_name}")
        # sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
        if return_all:
            return prediction_dict, sample_img, taxon_url
        return prediction_dict

    output = collections.defaultdict(float)
    for i in torch.nonzero(probs > min_prob).squeeze():
        output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]

    topk_names = heapq.nlargest(k, output, key=output.get)
    prediction_dict = {name: output[name] for name in topk_names}
    # logger.info(f"Top K names for output: {topk_names}")
    # logger.info(f"Prediction dictionary: {prediction_dict}")
    # print(f"Top K names for output: {topk_names}")
    # print(f"Prediction dictionary: {prediction_dict}")

    top_prediction_name = topk_names[0]
    # logger.info(f"Top prediction name: {top_prediction_name}")
    # print(f"Top prediction name: {top_prediction_name}")
    # sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
    # logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}")
    # print(f"Sample image and taxon URL: {sample_img}, {taxon_url}")

    if return_all:
        return prediction_dict, sample_img, taxon_url
    return prediction_dict

In [None]:
img_pth = Path("/home/george/codes/lepinet/data/flemming_ucloud/images/1732063/0e343351-e995-4255-9868-61ef7dc06039.jpg")
img_pth = Path("/home/george/codes/lepinet/data/flemming_ucloud/images/1811896/ad530cad-ed7a-4bf8-9572-1e516d57e6bb.jpg")
img = Image.open(img_pth)

In [None]:
pred = open_domain_classification(img, len(ranks)-1)

In [None]:
pred

In [None]:
pred_ = [(k,v) for k,v in pred.items()]

In [None]:
' '.join(pred_[0][0].split(' ')[5:7])

In [None]:
img_dir = Path("/home/george/codes/lepinet/data/flemming_ucloud/images")

In [None]:
img_filenames = list(img_dir.glob('*/*.jpg'))
img_filenames[:10], len(img_filenames)

In [None]:
batch_size = 64
img_size = 224

preds = []

# img_filenames=img_filenames[:10]

with torch.no_grad():
    for i in tqdm(range(0,len(img_filenames),batch_size)):
        bs = min(batch_size, len(img_filenames)-i)
        batch = torch.FloatTensor(bs, 3, img_size, img_size) 
        for j in range(bs):
            img_pth = img_filenames[i+j]
            img = cv2.imread(img_pth)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = preprocess_img(img)
            batch[j] = img
    
        img_features = model.encode_image(batch.to(device))
        img_features = F.normalize(img_features, dim=-1)
        
        # logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
        logits = (model.logit_scale.exp() * img_features @ txt_emb.to(device))
        # probs = F.softmax(logits, dim=0)
        probs = F.softmax(logits, dim=1)
        # topk = probs.topk(k)
        topk = probs.topk(k, dim=1)

        # prediction_dict = {
        #     format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
        # }
        # pred = open_domain_classification(img, rank=len(ranks)-1)
        # pred_ = [(k,v) for k,v in pred.items()]
        # preds += [(' '.join(pred_[0][0].split(' ')[5:7]), float(pred_[0][1]))]
        prediction_dict = [{
            format_name(*txt_names[k]): prob for k, prob in zip(indices, values)
        } for indices, values in zip(topk.indices, topk.values) ]
    
        prediction_list = [[(k,v) for k,v in pred.items()] for pred in prediction_dict]
        preds += [(pred_[0][0].split(' ')[4], pred_[0][0].split(' ')[5], ' '.join(pred_[0][0].split(' ')[5:7]), float(pred_[0][1])) for pred_ in prediction_list]
        
        

In [None]:
prediction_list

In [None]:
async def get_key(session, scientificName=None, usageKey=None, rank='SPECIES', order='Lepidoptera', family=None, genus=None):
    url = "https://api.gbif.org/v1/species/match?"
    # assert usageKey is not None or scientificName is not None, "One of scientificName or usageKey must be defined."

    if usageKey is not None:
        url += f"usageKey={usageKey}&"
    if scientificName is not None:
        if scientificName=='Tethea or': return 5142971 # bug fix
        url += f"scientificName={scientificName}&"
    if rank is not None:
        url += f"rank={rank}&"
    if order is not None:
        url += f"order={order}"
    if family is not None:
        url += f"family={family}"
    if genus is not None:
        url += f"genus={genus}"

    async with session.get(url) as response:
        r = await response.json()
        # return r if not 'canonicalName' in r.keys() else r['canonicalName']
        if rank in [None, 'SPECIES']:
            return r if not 'speciesKey' in r.keys() else r['speciesKey']
        elif rank == 'GENUS':
            return r if not 'genusKey' in r.keys() else r['genusKey']
        elif rank == 'FAMILY':
            return r if not 'familyKey' in r.keys() else r['familyKey']
            
async def get_all_keys(vocab):
    async with aiohttp.ClientSession() as session:
        tasks = [get_key(session, scientificName=k, rank=None) for k in vocab]
        return await asyncio.gather(*tasks)

async def get_all_family(vocab):
    async with aiohttp.ClientSession() as session:
        tasks = [get_key(session, scientificName=k, rank='FAMILY') for k in vocab]
        return await asyncio.gather(*tasks)

async def get_all_genus(vocab):
    async with aiohttp.ClientSession() as session:
        tasks = [get_key(session, scientificName=k, rank='GENUS') for k in vocab]
        return await asyncio.gather(*tasks)

async def get_parents(session, usageKey):
    url = "https://api.gbif.org/v1/species/{}/parents"
    if usageKey is not None:
        url = url.format(usageKey)
        
    async with session.get(url) as response:
        r = await response.json()
        return r

async def get_all_parents(vocab):
    async with aiohttp.ClientSession() as session:
        tasks = [get_parents(session, usageKey=k) for k in vocab]
        return await asyncio.gather(*tasks)

In [None]:
ranks = ('FAMILY','GENUS','SPECIES')
species = []
genera = []
families = []
cnfs = []
for i, (f,g,s,v) in enumerate(preds):
    species.append(s)
    genera.append(g)
    families.append(f)
    cnfs.append(v)

In [None]:
cnfs[-5:]

In [None]:
preds_keys = dict(
species_keys = asyncio.run(get_all_keys(species)),
genera_keys = asyncio.run(get_all_genus(genera)),
family_keys = asyncio.run(get_all_family(families)),)

In [None]:
labels = [f.parent.name for i, f in enumerate(img_filenames)]

In [None]:
parents=asyncio.run(get_all_parents(labels))

In [None]:
labels_all = []
for i,p in enumerate(parents):
    labels_all += [int(labels[i]), p[-1]['genusKey'], p[-1]['familyKey']]

In [None]:
len(preds_keys['species_keys']),len(

In [None]:
instance_id=[]
filename=[]
level=[]
label=[]
prediction=[]
confidence=[]
threshold=[] 

for i, f in enumerate(img_filenames):
    for j, l in enumerate(['species_keys','genera_keys','family_keys']):
        instance_id += [i]
        filename += [f]
        level += [j]
        label += [labels_all[i*3+j]]
        prediction += [preds_keys[l][i]]
        confidence += [float(cnfs[i])]
        threshold += [0.0]

df = pd.DataFrame({
    'instance_id':instance_id,
    'filename':filename,
    'level':level,
    'label':label,
    'prediction':prediction,
    'confidence':confidence,
    'threshold':threshold})

In [None]:
bioclip_pth = Path("/home/george/codes/lepinet/data/flemming_ucloud/bioclip2/bioclip2.csv")
df.to_csv(bioclip_pth, index=False)

In [None]:
df.tail()

In [None]:
print('hello')