In [None]:
import numpy as np
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset, load_from_disk
from PIL import Image
import torch
import json

device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(device).eval()
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

with open("./data/label_rephrased_dict_cifar100.json", "r") as f:
    label_rephrased_dict = json.load(f)

label_names = {
    0: "apple", 1: "aquarium_fish", 2: "baby", 3: "bear", 4: "beaver", 5: "bed", 
    6: "bee", 7: "beetle", 8: "bicycle", 9: "bottle", 10: "bowl", 11: "boy", 
    12: "bridge", 13: "bus", 14: "butterfly", 15: "camel", 16: "can", 17: "castle", 
    18: "caterpillar", 19: "cattle", 20: "chair", 21: "chimpanzee", 22: "clock", 
    23: "cloud", 24: "cockroach", 25: "couch", 26: "crab", 27: "crocodile", 
    28: "cup", 29: "dinosaur", 30: "dolphin", 31: "elephant", 32: "flatfish", 
    33: "forest", 34: "fox", 35: "girl", 36: "hamster", 37: "house", 38: "kangaroo", 
    39: "keyboard", 40: "lamp", 41: "lawn_mower", 42: "leopard", 43: "lion", 
    44: "lizard", 45: "lobster", 46: "man", 47: "maple_tree", 48: "motorcycle", 
    49: "mountain", 50: "mouse", 51: "mushroom", 52: "oak_tree", 53: "orange", 
    54: "orchid", 55: "otter", 56: "palm_tree", 57: "pear", 58: "pickup_truck", 
    59: "pine_tree", 60: "plain", 61: "plate", 62: "poppy", 63: "porcupine", 
    64: "possum", 65: "rabbit", 66: "raccoon", 67: "ray", 68: "road", 69: "rocket", 
    70: "rose", 71: "sea", 72: "seal", 73: "shark", 74: "shrew", 75: "skunk", 
    76: "skyscraper", 77: "snail", 78: "snake", 79: "spider", 80: "squirrel", 
    81: "streetcar", 82: "sunflower", 83: "sweet_pepper", 84: "table", 85: "tank", 
    86: "telephone", 87: "television", 88: "tiger", 89: "tractor", 90: "train", 
    91: "trout", 92: "tulip", 93: "turtle", 94: "wardrobe", 95: "whale", 
    96: "willow_tree", 97: "wolf", 98: "woman", 99: "worm"
}

text_prompts_per_class = []
for class_name in label_names.values():
    original = f"a photo of a {class_name}"
    
    rephrased = label_rephrased_dict.get(class_name, [original] * 30)
    prompts = [original] + rephrased
    while len(prompts) < 31: 
        prompts.append(original)
    text_prompts_per_class.append(prompts)

text_prompts_per_variant = list(map(list, zip(*text_prompts_per_class)))
flattened_text_inputs = [text for variant in text_prompts_per_variant for text in variant]

dataset = load_from_disk("./cifar_100_rephrased_labels")

BATCH_SIZE = 256  # Adjust based on your GPU memory

num_samples = len(dataset)
probs_matrix = np.zeros((num_samples, 31, 100), dtype=np.float32)
targets = np.zeros((num_samples,), dtype=np.int64)

correct_predictions = {
    'default': 0,  # For the default prompt
    'ensemble': 0,  # For ensemble of all prompts
    'per_variant': [0] * 31  # For each prompt variant
}

for variant_idx, text_prompts in enumerate(text_prompts_per_variant):
    print(f"Processing variant {variant_idx+1}/31")
    
    # Process dataset in batches
    for batch_start in tqdm(range(0, num_samples, BATCH_SIZE), desc=f"Variant {variant_idx+1}"):
        batch_end = min(batch_start + BATCH_SIZE, num_samples)
        batch_indices = list(range(batch_start, batch_end))
        batch_size = len(batch_indices)
        
        # Collect batch data
        batch_images = []
        for idx in batch_indices:
            example = dataset[idx]
            image_array = example["img"]
            targets[idx] = example["fine_label"]
            batch_images.append(image_array)
        
        with torch.no_grad():
            # Process batch of images with all text prompts
            inputs = processor(text=text_prompts, images=batch_images, return_tensors="pt", padding=True)
            # Move inputs to the same device as the model
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = model(**inputs)
            logits = outputs.logits_per_image  # image-text similarity score (batch_size x num_classes)
            probs = logits.softmax(dim=1).cpu().numpy()
            
            # Store probabilities
            for i, idx in enumerate(batch_indices):
                probs_matrix[idx, variant_idx] = probs[i]
                
                # Calculate accuracy metrics (storing in memory, will calculate at the end)
                pred_class = np.argmax(probs[i])
                if pred_class == targets[idx]:
                    correct_predictions['per_variant'][variant_idx] += 1
                    
                    # Track default prompt accuracy (first variant)
                    if variant_idx == 0:
                        correct_predictions['default'] += 1
# Calculate ensemble predictions after all variants are processed
print("Calculating ensemble predictions...")
for i in tqdm(range(num_samples), desc="Ensemble accuracy"):
    # Average predictions across all prompt variants
    ensemble_probs = np.mean(probs_matrix[i], axis=0)
    ensemble_pred = np.argmax(ensemble_probs)
    if ensemble_pred == targets[i]:
        correct_predictions['ensemble'] += 1

# Calculate and print accuracy metrics
default_accuracy = correct_predictions['default'] / num_samples
ensemble_accuracy = correct_predictions['ensemble'] / num_samples
variant_accuracies = [count / num_samples for count in correct_predictions['per_variant']]

print(f"Default prompt accuracy: {default_accuracy:.4f}")
print(f"Ensemble prompt accuracy: {ensemble_accuracy:.4f}")
print(f"Best variant accuracy: {max(variant_accuracies):.4f} (variant {np.argmax(variant_accuracies)})")
print(f"Worst variant accuracy: {min(variant_accuracies):.4f} (variant {np.argmin(variant_accuracies)})")

# Save accuracy results
accuracy_results = {
    'default_accuracy': float(default_accuracy),
    'ensemble_accuracy': float(ensemble_accuracy),
    'variant_accuracies': [float(acc) for acc in variant_accuracies],
    'best_variant': int(np.argmax(variant_accuracies)),
    'worst_variant': int(np.argmin(variant_accuracies))
}

with open("./data/clip_accuracy_results_cifar100.json", "w") as f:
    json.dump(accuracy_results, f, indent=2)

# Save probability and target outputs
np.save("./data/clip_probs_matrix_cifar100.npy", probs_matrix)
np.save("./data/clip_targets_cifar100.npy", targets)

print(f"Saved results with shape: {probs_matrix.shape}, {targets.shape}")
print(f"Saved accuracy results to ./data/clip_accuracy_results_cifar100.json")