# Setup Dependencies

In [1]:
from transformers import pipeline
import kagglehub
import os
import pandas as pd
from utils import get_files_in_dir
from pandasql import sqldf

  from .autonotebook import tqdm as notebook_tqdm


# Get Model and Data

In [2]:
model_name = "openai/clip-vit-large-patch14-336"
classifier = pipeline("zero-shot-image-classification", model = model_name, use_fast=True)

path = kagglehub.dataset_download("gpreda/happy-mammals-with-128x128-image-size")

print("Path to dataset files:", path)

image_paths =  get_files_in_dir(f"{path}/train_images_128")
labels_path="train.csv"

labels = pd.read_csv(labels_path, index_col="image")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Device set to use cuda:0


Path to dataset files: /home/chris/.cache/kagglehub/datasets/gpreda/happy-mammals-with-128x128-image-size/versions/1


# Collect all possible labels and setup results table

In [3]:
possible_labels=labels["species"].unique()
print(f"Categorizing images into the following species: {", ".join(possible_labels)}")

results=pd.DataFrame(columns=["image_name","label", "prediction","is_prediction_correct"])
results["is_prediction_correct"].astype(bool)

Categorizing images into the following species: melon_headed_whale, humpback_whale, false_killer_whale, bottlenose_dolphin, beluga, minke_whale, fin_whale, blue_whale, gray_whale, southern_right_whale, common_dolphin, kiler_whale, pilot_whale, dusky_dolphin, killer_whale, long_finned_pilot_whale, sei_whale, spinner_dolphin, bottlenose_dolpin, cuviers_beaked_whale, spotted_dolphin, globis, brydes_whale, commersons_dolphin, white_sided_dolphin, short_finned_pilot_whale, rough_toothed_dolphin, pantropic_spotted_dolphin, pygmy_killer_whale, frasiers_dolphin


Series([], Name: is_prediction_correct, dtype: bool)

# Generate all predictions

In [None]:
for image_path in image_paths:
    image_name=os.path.basename(image_path)
    
    if  image_name not in labels.index:
        raise KeyError(f"image {image_name} not in labels index")

    label=labels.loc[image_name,"species"]

    print(label)
    
    scores = classifier(image_path, candidate_labels = possible_labels)
    prediction = scores[0]['label']

    print(f"Image {image_name}: The highest score is {scores[0]['score']:.3f} for the label: {label} and prediction: {prediction}")
    
    is_prediction_correct=label==prediction

    result = pd.DataFrame([{"image_name": image_name, "is_prediction_correct": is_prediction_correct, "label":label,"prediction":prediction}])
    results = pd.concat([results, result])

bottlenose_dolphin
Image 173edf821e72c4.jpg: The highest score is 0.311 for the label: bottlenose_dolphin and prediction: spotted_dolphin
blue_whale
Image e714ba12963261.jpg: The highest score is 0.698 for the label: blue_whale and prediction: minke_whale
melon_headed_whale
Image c191adeee3dc98.jpg: The highest score is 0.173 for the label: melon_headed_whale and prediction: dusky_dolphin
bottlenose_dolpin
Image 7b6c757cee035f.jpg: The highest score is 0.147 for the label: bottlenose_dolpin and prediction: dusky_dolphin
beluga
Image 9b06b3977fb06d.jpg: The highest score is 0.218 for the label: beluga and prediction: minke_whale
humpback_whale
Image 8d6589915cf904.jpg: The highest score is 0.485 for the label: humpback_whale and prediction: fin_whale
fin_whale
Image 77fef72e28fb28.jpg: The highest score is 0.664 for the label: fin_whale and prediction: fin_whale
beluga
Image f543606d27f3fa.jpg: The highest score is 0.682 for the label: beluga and prediction: beluga
bottlenose_dolphin
Im

KeyboardInterrupt: 

# Accuracy Analysis

In [5]:
print("")
numCorrect=results["is_prediction_correct"].sum()
accuracy=numCorrect/len(results["is_prediction_correct"])
print(f"num correct: {numCorrect} Accuracy: {accuracy}")


query="SELECT label, COUNT(CASE WHEN is_prediction_correct = FALSE THEN 1 END) AS num_incorrect FROM results GROUP BY label ORDER BY num_incorrect desc"
sqldf(query)


num correct: 15 Accuracy: 0.15


Unnamed: 0,label,num_incorrect
0,bottlenose_dolphin,19
1,humpback_whale,16
2,blue_whale,10
3,beluga,9
4,kiler_whale,5
5,dusky_dolphin,5
6,spinner_dolphin,4
7,melon_headed_whale,4
8,false_killer_whale,4
9,minke_whale,2
