<a href="https://colab.research.google.com/github/amir-asari/CLIP-Basic/blob/main/CLIP_Img_Classification_Zero_Shot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Import necessary libraries

In [None]:
import torch
from torchvision.datasets import CIFAR10
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm # For a nice progress bar
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

### ==============================================================================
### STEP 1: Load Model, Processor, and Dataset
### ==============================================================================
This step loads the pre-trained model, its processor, and the dataset.
For a proper zero-shot evaluation, we use the TEST set.

In [None]:
# Check if a GPU is available and set the device accordingly
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# --- Load Model and Processor ---
print("Loading CLIP model and processor...")
try:
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    print("CLIP model and processor loaded successfully.")
except Exception as e:
    print(f"Error loading model or processor: {e}")
    model = None # Set to None on failure

# --- Load Dataset ---
# We load the CIFAR-10 test set for evaluation.
print("\nLoading CIFAR-10 test dataset...")
try:
    # Use train=False to get the test set
    test_dataset = CIFAR10(root="./data", train=False, download=True)
    print("CIFAR-10 test dataset loaded successfully.")
except Exception as e:
    print(f"Error loading CIFAR-10 dataset: {e}")
    test_dataset = None # Set to None on failure

### ==============================================================================
### STEP 2: Define the Zero-Shot Prediction Function
### ==============================================================================
This function performs the core zero-shot prediction for a single image.
It returns the predicted class name without printing or plotting.

In [None]:
def predict_zero_shot(model, processor, image, class_names):
    """
    Performs zero-shot classification for a single image using CLIP.

    Args:
        model (CLIPModel): The pre-loaded CLIP model.
        processor (CLIPProcessor): The pre-loaded CLIP processor.
        image (PIL.Image): The image to classify.
        class_names (list): A list of target class names.

    Returns:
        str: The predicted class name.
    """
    if not model or not processor:
        return "N/A"

    # Create the text prompts for zero-shot classification
    text_prompts = [f"a photo of a {class_name}" for class_name in class_names]
    inputs = processor(
        text=text_prompts,
        images=image,
        return_tensors="pt",
        padding=True
    ).to(device)

    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image
        # Get the index of the highest probability
        prediction_index = logits_per_image.argmax(dim=1).item()

    return class_names[prediction_index]

### ==============================================================================
### STEP 3: Run Zero-Shot Evaluation on the Test Set
### ==============================================================================
We'll now loop through a sample of the test set, get predictions,
and calculate the overall zero-shot accuracy.

In [None]:
# --- Configuration ---
NUM_SAMPLES_TO_EVALUATE = 1000 # Evaluate on 1000 images for a good estimate
if test_dataset:
    class_names = test_dataset.classes
    correct_predictions = 0
    true_labels = []
    predicted_labels = []

    print(f"\n--- Starting Zero-Shot Evaluation on {NUM_SAMPLES_TO_EVALUATE} CIFAR-10 Test Images ---")

    # --- Evaluation Loop ---
    # Using tqdm for a progress bar
    for i in tqdm(range(NUM_SAMPLES_TO_EVALUATE), desc="Evaluating"):
        # Get image and true label
        image, true_label_id = test_dataset[i]
        true_label_name = class_names[true_label_id]
        true_labels.append(true_label_name)

        # Get the prediction
        predicted_label_name = predict_zero_shot(model, processor, image, class_names)
        predicted_labels.append(predicted_label_name)

        # Check if the prediction was correct
        if predicted_label_name == true_label_name:
            correct_predictions += 1

    # --- Calculate and Display Accuracy ---
    if NUM_SAMPLES_TO_EVALUATE > 0:
        accuracy = (correct_predictions / NUM_SAMPLES_TO_EVALUATE) * 100
        print("\n--- Zero-Shot Evaluation Complete ---")
        print(f"Evaluated on: {NUM_SAMPLES_TO_EVALUATE} images")
        print(f"Correct Predictions: {correct_predictions}")
        print(f"Overall Zero-Shot Accuracy: {accuracy:.2f}%")
    else:
        print("No samples were evaluated.")

### ==============================================================================
### STEP 4: Display Detailed Performance Metrics
### ==============================================================================

In [None]:
print("\n--- Classification Report ---")
    # Generate and print the classification report
    report = classification_report(true_labels, predicted_labels, labels=class_names)
    print(report)

    print("\n--- Confusion Matrix ---")
    # Generate the confusion matrix
    cm = confusion_matrix(true_labels, predicted_labels, labels=class_names)

    # Plot the confusion matrix using seaborn
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix', fontsize=16)
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

else:
    print("Cannot run evaluation because the dataset failed to load.")