# Setup

In [2]:
import torch
print(torch.cuda.is_available())  # Should return True if CUDA is correctly set up

True


In [1]:
!pip install datasets transformers huggingface_hub supervision timm  sentence_transformers open_clip_torch
! pip install git+https://github.com/deepglint/unicom.git

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting supervision
  Downloading supervision-0.25.0-py3-none-any.whl.metadata (14 kB)
Collecting open_clip_torch
  Downloading open_clip_torch-2.29.0-py3-none-any.whl.metadata (31 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting ftfy (from open_clip_torch)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32

# Gather Images

In [None]:
!git lfs clone https://huggingface.co/datasets/sammarfy/VLM4Bio VLM4BIO_data

In [None]:
!mkdir downloaded_images

!mv VLM4BIO_data/datasets/Bird/chunk_0/* downloaded_images
!mv VLM4BIO_data/datasets/Bird/chunk_1/* downloaded_images
!mv VLM4BIO_data/datasets/Bird/chunk_2/* downloaded_images
!mv VLM4BIO_data/datasets/Bird/chunk_3/* downloaded_images
!mv VLM4BIO_data/datasets/Bird/chunk_4/* downloaded_images

!mv VLM4BIO_data/datasets/Butterfly/chunk_0/* downloaded_images
!mv VLM4BIO_data/datasets/Butterfly/chunk_1/* downloaded_images
!mv VLM4BIO_data/datasets/Butterfly/chunk_2/* downloaded_images
!mv VLM4BIO_data/datasets/Butterfly/chunk_3/* downloaded_images
!mv VLM4BIO_data/datasets/Butterfly/chunk_4/* downloaded_images

!mv VLM4BIO_data/datasets/Fish/chunk_0/* downloaded_images
!mv VLM4BIO_data/datasets/Fish/chunk_1/* downloaded_images
!mv VLM4BIO_data/datasets/Fish/chunk_2/* downloaded_images
!mv VLM4BIO_data/datasets/Fish/chunk_3/* downloaded_images
!mv VLM4BIO_data/datasets/Fish/chunk_4/* downloaded_images

In [6]:
from scipy.io import loadmat, savemat
import pandas as pd

def load_df(mat_filename):
    """
    Load dataframe for UNICOM clustering from .mat file

    Parameters
    ----------
    mat_filename : str
        DESCRIPTION.

    Returns
    -------
    df : Dataframe
        Dataframe of the df file including image_name, scientific_name, category, caption, image/text_embeddings fields.
    """

    data = loadmat(f"{mat_filename}.mat")
    img_embeddings = data['image_embeddings']
    text_embeddings = data['text_embeddings']

    print(data['caption'].shape)
    # Create a DataFrame with each list as a column
    df = pd.DataFrame({
        'image_name': data['image_name'],
        'scientific_name': data['scientific_name'],
        'category': data['category'],
        'caption': data['caption'].squeeze(0),
        'image_embeddings': [sub_array for sub_array in img_embeddings],
        'text_embeddings': [sub_array for sub_array in text_embeddings]
    })

    return df
embeds_df = load_df('ViT-H-14-embeddings')

(1, 31452)


In [9]:
embeds_df

Unnamed: 0,image_name,scientific_name,category,caption,image_embeddings,text_embeddings
0,UWZM-F-0001570.JPG ...,Lepomis macrochirus,Fish,"[This image depicts a Lepomis macrochirus, a s...","[[0.0062551806, 0.016257662, -0.051555164, -0....","[[-0.0073634945, -0.0022322312, -0.0049071913,..."
1,UWZM-F-0001664.JPG ...,Lepomis megalotis,Fish,"[This image depicts a Lepomis megalotis, showc...","[[-0.010808857, 0.010653244, -0.039921395, 0.0...","[[-0.008894377, -0.00045271497, -0.0051247277,..."
2,UWZM-F-0001696.JPG ...,Lepomis microlophus,Fish,"[This image depicts a Lepomis microlophus, a s...","[[-0.008302152, 0.016121196, -0.037911564, 0.0...","[[-0.011098293, -0.002307442, -0.0057375436, 0..."
3,UWZM-F-0001697.JPG ...,Lepomis punctatus,Fish,"[This image depicts a Lepomis punctatus, a sma...","[[0.005701786, 0.018549936, -0.05157007, 0.001...","[[-0.009999439, -0.0022197915, -0.005710659, 0..."
4,UWZM-F-0000002.JPG ...,Alosa aestivalis,Fish,[The image shows a specimen of Alosa aestivali...,"[[-0.008683015, 0.0030286494, -0.040459294, 0....","[[-0.0014464473, -0.0048591304, -0.012673609, ..."
...,...,...,...,...,...,...
31447,Butterfly_imbalanced_test_Eueides_isabella_114...,Eueides isabella,Butterfly,[The image shows a close-up of Eueides isabell...,"[[0.035732385, -0.020132823, -0.017447814, 0.0...","[[0.050824434, -0.003763202, -0.0049715494, 0...."
31448,Butterfly_imbalanced_test_Eueides_isabella_980...,Eueides isabella,Butterfly,[This image shows a close-up of Eueides isabel...,"[[0.044416793, -0.034998138, -0.016503036, 0.0...","[[0.04885181, -0.0017174394, -0.0060338983, 0...."
31449,Butterfly_imbalanced_test_Rhetus_periander_307...,Rhetus periander,Butterfly,[The image shows a close-up view of a Rhetus p...,"[[0.052031703, -0.01843505, -0.0047156126, 0.0...","[[0.035970498, -0.0088407565, 0.0036507202, 0...."
31450,Butterfly_imbalanced_test_Rhetus_periander_371...,Rhetus periander,Butterfly,[The image shows a close-up of a Rhetus perian...,"[[0.031045783, -0.020515068, 0.0017269664, 0.0...","[[0.03335584, -0.00914594, 0.0046747504, 0.013..."


# Split in to train/testing

Training dataset will be used for prototype generation.

Testing dataset will be used to evaluate.

In [None]:
from sklearn.model_selection import train_test_split

# for this test I'm doing same stratified test_train_split, but I'm making sure count is at LEAST 10 images/class
# this will be edited in future iterations, just pushing for now

def get_test_train_split(df_cleaned, test_size=0.1):

    species_counts = df_cleaned['scientific_name'].value_counts()
    valid_species = species_counts[species_counts > 10].index
    df_filtered = df_cleaned[df_cleaned['scientific_name'].isin(valid_species)]
    calc_min = len(valid_species)/len(df_filtered)
    min_split = max(test_size, calc_min)

    train_df, test_df = train_test_split(df_filtered, test_size=min_split, stratify=df_filtered['scientific_name'])
    return train_df, test_df


In [12]:
# update the test_train_split, stratify + include minimum within train class

train_split, test_split = get_test_train_split(embeds_df)

# Create prototypes for each class

In [48]:
# unique_classes = pd.unique(train_split['scientific_name'])

# # create dictionaries to have idx for each unique class
# idx_to_class = {idx: cls for idx, cls in enumerate(unique_classes)}
# class_to_idx = {cls: idx for idx, cls in enumerate(unique_classes)}

train_split["embedding_tensor"] = train_split["image_embeddings"].apply(torch.tensor)
class_prototypes = (
    train_split.groupby("scientific_name")["embedding_tensor"]
    .apply(lambda x: torch.stack(list(x.squeeze(0))).mean(dim=0))
    .to_dict()
)

# create dictionaries to have idx for each unique class
prototype_labels = list(class_prototypes.keys())
prototype_to_idx = {cls: i for i, cls in enumerate(prototype_labels)}

# prototype_to_idx will be used to translate test set from class_name to class_idx

# Generate embeddings for images in test set

In [49]:
# this is a test dataset, assuming the embeddings exist on the df

class EmbeddingsDataset(torch.utils.data.Dataset):
    def __init__(self, df, cls_to_idx, transform=None):
        self.df = df
        self.transform = transform
        self.cls_to_idx = cls_to_idx

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Load image and apply transformations
        row = self.df.iloc[idx]
        image_embeddings = row['image_embeddings']
        cls = row['scientific_name']
        # image embeddings are size [1, embed_dim] --> squeeze first dim
        return image_embeddings.squeeze(0), self.cls_to_idx[cls]

dataset = EmbeddingsDataset(test_split, prototype_to_idx)

test_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)


# Calculate Accuracy

In [50]:
import torch.nn.functional as F

# remember the embeddings are stored as [1, embed_dim], so need to squeeze

def compute_accuracy(test_loader, class_prototypes, prototype_to_idx, topk=(1, 3, 5)):

    # assume the prototype embeddings exist in a stack, with the index of the prototype being the same as its class idx
    prototype_embeddings = torch.stack(list(class_prototypes.values())).squeeze(1)  # Shape: [num_classes, embedding_dim]

    correct = {k: 0 for k in topk}
    total = 0

    with torch.no_grad():
        for embeddings, labels in test_loader:

            # Compute similarity between test embeddings and class prototypes

            similarities = F.cosine_similarity(embeddings.unsqueeze(1), prototype_embeddings.unsqueeze(0), dim=-1) #  Shape [batch_size, num_classes]

            # Get top-k predictions
            _, predictions = similarities.topk(max(topk), dim=-1)  # Shape: [batch_size, max(topk)]

            # set predictions for each top k (1, 3, 5)
            for k in topk:
                correct[k] += (predictions[:, :k] == labels.unsqueeze(1)).any(dim=1).sum().item()

            total += labels.size(0)

    # Compute accuracy for each k
    accuracies = {k: correct[k] / total for k in topk}
    return accuracies

# Run evaluation
topk_accuracies = compute_accuracy(test_loader, class_prototypes, prototype_to_idx, topk=(1, 3, 5))
print(f"Top-k Accuracies: {topk_accuracies}")

Top-k Accuracies: {1: 0.6284135240572172, 3: 0.8403771131339401, 5: 0.9044213263979194}
