# Setup Dependencies

In [1]:
from transformers import pipeline
import kagglehub
import os
import pandas as pd
from utils import get_files_in_dir
from pandasql import sqldf

  from .autonotebook import tqdm as notebook_tqdm


# Get Model and Data

In [2]:
model_name = "openai/clip-vit-large-patch14-336"
classifier = pipeline("zero-shot-image-classification", model = model_name, use_fast=True)

path = kagglehub.dataset_download("gpreda/happy-mammals-with-128x128-image-size")

print("Path to dataset files:", path)

image_paths =  get_files_in_dir(f"{path}/train_images_128")
labels_path="train.csv"

labels = pd.read_csv(labels_path, index_col="image")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Device set to use cuda:0


Path to dataset files: /home/chris/.cache/kagglehub/datasets/gpreda/happy-mammals-with-128x128-image-size/versions/1


# Collect all possible labels and setup results table

In [3]:
possible_labels=labels["species"].unique()
print(f"Categorizing images into the following species: {", ".join(possible_labels)}")

results=pd.DataFrame(columns=["image_name","label", "prediction","is_prediction_correct"])
results["is_prediction_correct"].astype(bool)

Categorizing images into the following species: melon_headed_whale, humpback_whale, false_killer_whale, bottlenose_dolphin, beluga, minke_whale, fin_whale, blue_whale, gray_whale, southern_right_whale, common_dolphin, kiler_whale, pilot_whale, dusky_dolphin, killer_whale, long_finned_pilot_whale, sei_whale, spinner_dolphin, bottlenose_dolpin, cuviers_beaked_whale, spotted_dolphin, globis, brydes_whale, commersons_dolphin, white_sided_dolphin, short_finned_pilot_whale, rough_toothed_dolphin, pantropic_spotted_dolphin, pygmy_killer_whale, frasiers_dolphin


Series([], Name: is_prediction_correct, dtype: bool)

# Generate all predictions

In [31]:
image_names = [os.path.basename(image_path) for image_path in image_paths]

true_labels = labels.loc[image_names, "species"].tolist()

batch_size = 8
max_num_batches=100

for i in range(0, len(image_paths), batch_size):
    batch_num=i // batch_size + 1
    
    if batch_num >= max_num_batches:
        break

    batch_image_paths = image_paths[i:i + batch_size]
    batch_image_names = image_names[i:i + batch_size]
    batch_true_labels = true_labels[i:i + batch_size]

    print(f"Generating scores for batch {batch_num}")
    batch_scores = classifier(batch_image_paths, candidate_labels=possible_labels)

    for image_name, label, scores in zip(batch_image_names, batch_true_labels, batch_scores):
        prediction = scores[0]['label']
        is_prediction_correct = label == prediction

        result=pd.DataFrame([{
            "image_name": image_name,
            "is_prediction_correct": is_prediction_correct,
            "label": label,
            "prediction": prediction
        }])
        results = pd.concat([results, result], ignore_index=True)


    print(f"Processed batch {batch_num}/{(len(image_paths) + batch_size - 1) // batch_size}")

print("Classification complete.")

Generating scores for batch 1
Processed batch 1/6380
Generating scores for batch 2
Processed batch 2/6380
Generating scores for batch 3
Processed batch 3/6380
Generating scores for batch 4
Processed batch 4/6380
Generating scores for batch 5
Processed batch 5/6380
Generating scores for batch 6
Processed batch 6/6380
Generating scores for batch 7
Processed batch 7/6380
Generating scores for batch 8
Processed batch 8/6380
Generating scores for batch 9
Processed batch 9/6380
Generating scores for batch 10
Processed batch 10/6380
Generating scores for batch 11
Processed batch 11/6380
Generating scores for batch 12
Processed batch 12/6380
Generating scores for batch 13
Processed batch 13/6380
Generating scores for batch 14
Processed batch 14/6380
Generating scores for batch 15
Processed batch 15/6380
Generating scores for batch 16
Processed batch 16/6380
Generating scores for batch 17
Processed batch 17/6380
Generating scores for batch 18
Processed batch 18/6380
Generating scores for batch 

# Accuracy Analysis

In [34]:
print("")
numCorrect=results["is_prediction_correct"].sum()
accuracy=numCorrect/len(results["is_prediction_correct"])
print(f"num correct: {numCorrect} Accuracy: {accuracy}")


query1="""SELECT label, 
COUNT(CASE WHEN is_prediction_correct = FALSE THEN 1 END) AS num_incorrect, 
COUNT(*) AS total,
ROUND(CAST(COUNT(CASE WHEN is_prediction_correct = TRUE THEN 1 END) AS FLOAT)/COUNT(label),3) AS label_accuracy
FROM results 
GROUP BY label 
ORDER BY label_accuracy DESC"""
sqldf(query1)



num correct: 445 Accuracy: 0.16053391053391053


Unnamed: 0,label,num_incorrect,total,label_accuracy
0,bottlenose_dolpin,16,65,0.754
1,pilot_whale,7,15,0.533
2,killer_whale,28,53,0.472
3,fin_whale,30,55,0.455
4,minke_whale,62,109,0.431
5,gray_whale,34,55,0.382
6,beluga,264,414,0.362
7,southern_right_whale,31,42,0.262
8,cuviers_beaked_whale,17,21,0.19
9,dusky_dolphin,134,159,0.157
