In [7]:
"""
Simplified CLIP Zero-Shot Image Classification for Jupyter Notebook

Features:
- Zero-shot classification with CLIP
- Multiple prompt template experiments
- Optional linear probe classifier
"""

import os
import pickle
from pathlib import Path
import time

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image

from transformers import CLIPProcessor, CLIPModel
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

# =====================================================
# CONFIGURATION - EDIT THESE
# =====================================================
DATASET = 'mnist'  # 'mnist' or 'cifar10'
TEST_ROOT = './data/processed'
TRAIN_ROOT = './data/processed'  # Set to None to skip linear probe
SAMPLE_N = 2000
BATCH_SIZE = 128
USE_LINEAR_PROBE = False
TRAIN_TYPE = 'original'  # 'original', 'mixed_augmented', or 'combined_augmented'

# =====================================================
# Helper Functions
# =====================================================

def load_pickle_dataset(pkl_path):
    """Load pickle dataset and return list of (PIL.Image, int_label)"""
    print(f"Loading: {pkl_path}")
    with open(pkl_path, 'rb') as f:
        data = pickle.load(f)

    items = []
    
    if isinstance(data, dict):
        if 'images' in data and 'labels' in data:
            imgs = data['images']
            labs = data['labels']
            print(f"Found {len(imgs)} images")
            for img, lab in zip(imgs, labs):
                items.append((to_pil(img), int(lab)))
        else:
            raise ValueError(f"Unsupported dict format: {list(data.keys())}")
    
    elif isinstance(data, (tuple, list)) and len(data) == 2:
        imgs, labs = data[0], data[1]
        print(f"Found {len(imgs)} images")
        for img, lab in zip(imgs, labs):
            items.append((to_pil(img), int(lab)))
    
    else:
        raise ValueError(f"Unsupported pickle type: {type(data)}")
    
    return items


def to_pil(img):
    """Convert numpy arrays to PIL.Image (RGB)"""
    if isinstance(img, Image.Image):
        return img.convert('RGB')
    
    if isinstance(img, np.ndarray):
        if img.ndim == 2:
            img = np.stack([img]*3, axis=-1)
        elif img.ndim == 3 and img.shape[-1] == 1:
            img = np.stack([img.squeeze()]*3, axis=-1)
        
        if img.dtype != np.uint8:
            if img.max() <= 1.0:
                img = (img * 255).astype(np.uint8)
            else:
                img = np.clip(img, 0, 255).astype(np.uint8)
        
        return Image.fromarray(img).convert('RGB')
    
    raise ValueError(f"Unsupported image type: {type(img)}")


class PickleImageDataset(Dataset):
    def __init__(self, items, max_samples=None):
        if max_samples and len(items) > max_samples:
            items = items[:max_samples]
        self.items = items

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        img, lab = self.items[idx]
        return img, int(lab), idx


@torch.no_grad()
def extract_embeddings(model, processor, dataloader, device):
    """Extract CLIP image embeddings"""
    model.eval()
    all_feats = []
    all_labels = []
    all_idxs = []
    
    for batch in tqdm(dataloader, desc="Extracting embeddings"):
        imgs, labels, idxs = batch
        
        inputs = processor(images=list(imgs), return_tensors="pt", padding=True).to(device)
        image_feats = model.get_image_features(**inputs)
        image_feats = image_feats / image_feats.norm(p=2, dim=-1, keepdim=True)
        
        all_feats.append(image_feats.cpu())
        all_labels.extend([int(x) for x in labels])
        all_idxs.extend([int(x) for x in idxs])
    
    if len(all_feats) == 0:
        return np.zeros((0, model.config.projection_dim)), np.array([]), []
    
    all_feats = torch.cat(all_feats, dim=0).numpy()
    return all_feats, np.array(all_labels), all_idxs


@torch.no_grad()
def build_text_embeddings(model, processor, class_names, templates, device):
    """Build text embeddings from class names and templates"""
    texts = []
    for cname in class_names:
        for t in templates:
            texts.append(t.format(cname))
    
    all_text_feats = []
    batch_size = 64
    
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = processor(text=batch, return_tensors="pt", padding=True).to(device)
        tfeat = model.get_text_features(**inputs)
        tfeat = tfeat / tfeat.norm(p=2, dim=-1, keepdim=True)
        all_text_feats.append(tfeat.cpu())
    
    all_text_feats = torch.cat(all_text_feats, dim=0).numpy()
    
    # Average embeddings per class
    num_templates = len(templates)
    per_class = []
    for i in range(0, all_text_feats.shape[0], num_templates):
        per_class.append(all_text_feats[i:i+num_templates].mean(axis=0))
    
    per_class = np.vstack(per_class)
    per_class = per_class / (np.linalg.norm(per_class, axis=1, keepdims=True) + 1e-12)
    
    return per_class


def get_prompt_templates(dataset_name):
    """Define prompt template sets"""
    
    if dataset_name == 'mnist':
        template_sets = {
            'basic': [
                "a photo of the number {}.",
                "the digit {}.",
            ],
            'descriptive': [
                "a handwritten digit {}.",
                "a photo of a handwritten number {}.",
                "an image of the number {}.",
                "a drawing of the digit {}.",
            ],
            'context': [
                "a photo of a {} digit.",
                "a black and white image of number {}.",
                "a grayscale photo of the number {}.",
                "handwriting showing the digit {}.",
            ],
            'minimal': [
                "{}.",
            ],
        }
    else:  # CIFAR-10
        template_sets = {
            'basic': [
                "a photo of a {}.",
                "an image of a {}.",
            ],
            'descriptive': [
                "a photo of a {}.",
                "a blurry photo of a {}.",
                "a clear photo of a {}.",
                "a bright photo of a {}.",
            ],
            'context': [
                "a photo of a {} in the scene.",
                "a picture of the {}.",
                "an image showing a {}.",
            ],
            'minimal': [
                "{}.",
            ],
        }
    
    return template_sets


def train_linear_probe(model, processor, device, train_pkl_path, batch_size=128, max_samples=10000):
    """Train linear probe classifier from training data"""
    print("\n" + "="*60)
    print("Training Linear Probe Classifier")
    print("="*60)
    
    if not os.path.exists(train_pkl_path):
        print(f"Training data not found: {train_pkl_path}")
        return None
    
    items = load_pickle_dataset(train_pkl_path)
    
    if len(items) == 0:
        print("No training data found")
        return None
    
    if max_samples and len(items) > max_samples:
        print(f"Using {max_samples} samples (out of {len(items)})")
        items = items[:max_samples]
    
    ds = PickleImageDataset(items)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False,
                       collate_fn=lambda b: tuple(zip(*b)))
    
    print("Extracting training embeddings...")
    X_train, y_train, _ = extract_embeddings(model, processor, loader, device)
    print(f"Training embedding shape: {X_train.shape}")
    
    print("Training logistic regression...")
    clf = LogisticRegression(max_iter=2000, multi_class='multinomial', 
                            solver='lbfgs', verbose=0, n_jobs=-1)
    clf.fit(X_train, y_train)
    
    train_acc = clf.score(X_train, y_train)
    print(f"Training accuracy: {train_acc:.4f}")
    
    return clf


def run_classification(model, processor, device, test_pkl_path, class_names, 
                      template_sets, out_dir, batch_size=64, sample_n=2000, 
                      linear_probe_clf=None):
    """Run zero-shot and linear probe classification"""
    print(f"\n{'='*60}")
    print(f"Processing: {os.path.basename(test_pkl_path)}")
    print(f"{'='*60}")
    
    items = load_pickle_dataset(test_pkl_path)
    if len(items) == 0:
        print("No items found, skipping...")
        return
    
    if sample_n and len(items) > sample_n:
        items = items[:sample_n]
    print(f"Using {len(items)} samples")
    
    ds = PickleImageDataset(items)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False, 
                       collate_fn=lambda b: tuple(zip(*b)))
    
    start_time = time.time()
    X_test, y_test, idxs = extract_embeddings(model, processor, loader, device)
    embed_time = time.time() - start_time
    print(f"Embedding time: {embed_time:.2f}s")
    print(f"Embedding shape: {X_test.shape}")
    
    os.makedirs(out_dir, exist_ok=True)
    
    # Zero-shot with different prompts
    print(f"\n{'─'*60}")
    print("Prompt Engineering Experiments")
    print(f"{'─'*60}")
    
    results_summary = []
    
    for template_name, templates in template_sets.items():
        print(f"\nTesting prompt set: {template_name}")
        print(f"  Templates: {len(templates)}")
        
        text_embeds = build_text_embeddings(model, processor, class_names, templates, device)
        
        start_time = time.time()
        sims = X_test.dot(text_embeds.T)
        y_pred = sims.argmax(axis=1)
        pred_time = time.time() - start_time
        
        acc = accuracy_score(y_test, y_pred)
        report = classification_report(y_test, y_pred, target_names=class_names, 
                                      zero_division=0, digits=4, output_dict=True)
        
        print(f"  Accuracy: {acc:.4f}")
        
        prompt_dir = os.path.join(out_dir, f"prompt_{template_name}")
        os.makedirs(prompt_dir, exist_ok=True)
        
        report_text = classification_report(y_test, y_pred, target_names=class_names,
                                           zero_division=0, digits=4)
        with open(os.path.join(prompt_dir, "classification_report.txt"), "w") as f:
            f.write(f"Prompt Template Set: {template_name}\n")
            f.write(f"Number of Templates: {len(templates)}\n")
            f.write(f"Accuracy: {acc:.6f}\n\n")
            f.write(report_text)
        
        pd.DataFrame({
            "idx": idxs,
            "true_label": y_test.tolist(),
            "pred_label": y_pred.tolist(),
            "correct": (y_test == y_pred).astype(int).tolist()
        }).to_csv(os.path.join(prompt_dir, "predictions.csv"), index=False)
        
        cm = confusion_matrix(y_test, y_pred)
        cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)
        cm_df.to_csv(os.path.join(prompt_dir, "confusion_matrix.csv"))
        
        results_summary.append({
            'prompt_set': template_name,
            'num_templates': len(templates),
            'accuracy': acc,
            'precision_macro': report['macro avg']['precision'],
            'recall_macro': report['macro avg']['recall'],
            'f1_macro': report['macro avg']['f1-score'],
        })
    
    summary_df = pd.DataFrame(results_summary)
    summary_df.to_csv(os.path.join(out_dir, "prompt_comparison.csv"), index=False)
    print(f"\nPrompt Comparison:")
    print(summary_df.to_string(index=False))
    
    # Linear probe (if available)
    if linear_probe_clf is not None:
        print(f"\n{'─'*60}")
        print("Linear Probe Classification")
        print(f"{'─'*60}")
        
        y_pred_lp = linear_probe_clf.predict(X_test)
        acc_lp = accuracy_score(y_test, y_pred_lp)
        
        print(f"Accuracy: {acc_lp:.4f}")
        
        lp_dir = os.path.join(out_dir, "linear_probe")
        os.makedirs(lp_dir, exist_ok=True)
        
        report_lp = classification_report(y_test, y_pred_lp, target_names=class_names,
                                         zero_division=0, digits=4)
        with open(os.path.join(lp_dir, "classification_report.txt"), "w") as f:
            f.write(f"Linear Probe Classification\n")
            f.write(f"Accuracy: {acc_lp:.6f}\n\n")
            f.write(report_lp)


# =====================================================
# MAIN EXECUTION
# =====================================================

def main():
    print("\n" + "="*60)
    print("CLIP ZERO-SHOT CLASSIFICATION")
    print("="*60)
    print(f"Dataset: {DATASET.upper()}")
    print(f"Samples per test: {SAMPLE_N}")
    print(f"Batch size: {BATCH_SIZE}")
    print("="*60)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    
    # Load CLIP model
    model_name = "openai/clip-vit-large-patch14"
    print(f"\nLoading CLIP model: {model_name}")
    model = CLIPModel.from_pretrained(model_name).to(device)
    processor = CLIPProcessor.from_pretrained(model_name)
    print("Model loaded successfully")
    
    # Define class names
    if DATASET == 'mnist':
        class_names = [str(i) for i in range(10)]
    else:
        class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                      'dog', 'frog', 'horse', 'ship', 'truck']
    
    # Get prompt templates
    template_sets = get_prompt_templates(DATASET)
    print(f"\nLoaded {len(template_sets)} prompt template sets")
    
    # Define test files
    test_folder = Path(TEST_ROOT) / f'{DATASET}_test'
    test_files = {
        'original': test_folder / 'original.pkl',
        'noise': test_folder / 'noise.pkl',
        'occlusion_25': test_folder / 'occlusion_25.pkl',
        'rotation_15': test_folder / 'rotation_15.pkl',
        'scaling_0.8': test_folder / 'scaling_0.8.pkl',
        'all_combined': test_folder / 'all_combined.pkl',
    }
    
    # Train linear probe if requested
    linear_probe_clf = None
    if USE_LINEAR_PROBE and TRAIN_ROOT:
        train_folder = Path(TRAIN_ROOT) / f'{DATASET}_train'
        train_file = train_folder / f'{TRAIN_TYPE}.pkl'
        linear_probe_clf = train_linear_probe(model, processor, device, str(train_file), BATCH_SIZE)
    
    # Create output directory
    base_out = Path("CLIP_results") / DATASET
    os.makedirs(base_out, exist_ok=True)
    
    # Process each test file
    print(f"\n{'='*60}")
    print("PROCESSING TEST SETS")
    print(f"{'='*60}")
    
    all_results = []
    
    for test_name, test_path in test_files.items():
        if not test_path.exists():
            print(f"\nTest file missing: {test_path}")
            continue
        
        out_dir = base_out / test_name
        run_classification(
            model, processor, device, str(test_path),
            class_names, template_sets, str(out_dir),
            batch_size=BATCH_SIZE, sample_n=SAMPLE_N,
            linear_probe_clf=linear_probe_clf
        )
        
        prompt_comp = pd.read_csv(out_dir / "prompt_comparison.csv")
        prompt_comp['test_set'] = test_name
        all_results.append(prompt_comp)
    
    # Create overall summary
    if all_results:
        overall_summary = pd.concat(all_results, ignore_index=True)
        overall_summary.to_csv(base_out / "overall_summary.csv", index=False)
        
        print(f"\n{'='*60}")
        print("OVERALL RESULTS SUMMARY")
        print(f"{'='*60}")
        print(overall_summary.to_string(index=False))
    
    print(f"\n{'='*60}")
    print("PROCESSING COMPLETE")
    print(f"{'='*60}")
    print(f"Results saved to: {base_out}")


In [8]:

# Run the main function
if __name__ == "__main__":
    main()


CLIP ZERO-SHOT CLASSIFICATION
Dataset: MNIST
Samples per test: 2000
Batch size: 128
Device: cuda

Loading CLIP model: openai/clip-vit-large-patch14


model.safetensors:  31%|###1      | 535M/1.71G [00:00<?, ?B/s]

Error while downloading from https://cas-bridge.xethub.hf.co/xet-bridge-us/621ffdc136468d709f17ea63/9046d5fe172d35ca65c0140b3d9c638d31b2714cc17049ee40fcf887ab0e076a?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20251030%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20251030T161045Z&X-Amz-Expires=3600&X-Amz-Signature=3e74e961a3a0c1601fe696cfdf90d05f825599f3d171a00ffb36cf76593bc67a&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=public&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&x-id=GetObject&Expires=1761844245&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2MTg0NDI0NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2FzLWJyaWRnZS54ZXRodWIuaGYuY28veGV0LWJyaWRnZS11cy82MjFmZmRjMTM2NDY4ZDcwOWYxN2VhNjMvOTA0NmQ1ZmUxNzJkMzVjYTY1YzAxNDBiM2Q5YzYzOGQzMWIyNzE0Y2MxNzA0OWVlNDBmY2Y4ODdhYjBlMDc2YSoifV19&Signature=A5rTJ4zeOTBr4ZXb2sCyUwndMyZ8pOlHZe5397I3JRTwZwNgeNgIS

model.safetensors:  99%|#########9| 1.70G/1.71G [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/905 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

Model loaded successfully

Loaded 4 prompt template sets

PROCESSING TEST SETS

Processing: original.pkl
Loading: data\processed\mnist_test\original.pkl
Found 10000 images
Using 2000 samples


Extracting embeddings: 100%|██████████| 16/16 [01:33<00:00,  5.83s/it]


Embedding time: 93.34s
Embedding shape: (2000, 768)

────────────────────────────────────────────────────────────
Prompt Engineering Experiments
────────────────────────────────────────────────────────────

Testing prompt set: basic
  Templates: 2
  Accuracy: 0.5740

Testing prompt set: descriptive
  Templates: 4
  Accuracy: 0.6050

Testing prompt set: context
  Templates: 4
  Accuracy: 0.6670

Testing prompt set: minimal
  Templates: 1
  Accuracy: 0.7650

Prompt Comparison:
 prompt_set  num_templates  accuracy  precision_macro  recall_macro  f1_macro
      basic              2     0.574         0.805315      0.586243  0.606222
descriptive              4     0.605         0.837253      0.617435  0.637133
    context              4     0.667         0.830456      0.677376  0.686169
    minimal              1     0.765         0.822337      0.761528  0.772981

Processing: noise.pkl
Loading: data\processed\mnist_test\noise.pkl
Found 10000 images
Using 2000 samples


Extracting embeddings: 100%|██████████| 16/16 [01:30<00:00,  5.67s/it]


Embedding time: 90.80s
Embedding shape: (2000, 768)

────────────────────────────────────────────────────────────
Prompt Engineering Experiments
────────────────────────────────────────────────────────────

Testing prompt set: basic
  Templates: 2
  Accuracy: 0.6120

Testing prompt set: descriptive
  Templates: 4
  Accuracy: 0.6370

Testing prompt set: context
  Templates: 4
  Accuracy: 0.6355

Testing prompt set: minimal
  Templates: 1
  Accuracy: 0.5865

Prompt Comparison:
 prompt_set  num_templates  accuracy  precision_macro  recall_macro  f1_macro
      basic              2    0.6120         0.775248      0.621207  0.619895
descriptive              4    0.6370         0.843786      0.642794  0.668670
    context              4    0.6355         0.793156      0.643121  0.653565
    minimal              1    0.5865         0.740603      0.588069  0.573733

Processing: occlusion_25.pkl
Loading: data\processed\mnist_test\occlusion_25.pkl
Found 10000 images
Using 2000 samples


Extracting embeddings: 100%|██████████| 16/16 [01:33<00:00,  5.81s/it]


Embedding time: 93.01s
Embedding shape: (2000, 768)

────────────────────────────────────────────────────────────
Prompt Engineering Experiments
────────────────────────────────────────────────────────────

Testing prompt set: basic
  Templates: 2
  Accuracy: 0.4940

Testing prompt set: descriptive
  Templates: 4
  Accuracy: 0.5190

Testing prompt set: context
  Templates: 4
  Accuracy: 0.5840

Testing prompt set: minimal
  Templates: 1
  Accuracy: 0.6770

Prompt Comparison:
 prompt_set  num_templates  accuracy  precision_macro  recall_macro  f1_macro
      basic              2     0.494         0.778157      0.503303  0.521672
descriptive              4     0.519         0.799365      0.528160  0.546867
    context              4     0.584         0.763119      0.591655  0.596160
    minimal              1     0.677         0.744287      0.667781  0.678544

Processing: rotation_15.pkl
Loading: data\processed\mnist_test\rotation_15.pkl
Found 10000 images
Using 2000 samples


Extracting embeddings: 100%|██████████| 16/16 [01:29<00:00,  5.62s/it]


Embedding time: 89.94s
Embedding shape: (2000, 768)

────────────────────────────────────────────────────────────
Prompt Engineering Experiments
────────────────────────────────────────────────────────────

Testing prompt set: basic
  Templates: 2
  Accuracy: 0.5605

Testing prompt set: descriptive
  Templates: 4
  Accuracy: 0.6035

Testing prompt set: context
  Templates: 4
  Accuracy: 0.6305

Testing prompt set: minimal
  Templates: 1
  Accuracy: 0.6130

Prompt Comparison:
 prompt_set  num_templates  accuracy  precision_macro  recall_macro  f1_macro
      basic              2    0.5605         0.810422      0.572405  0.590585
descriptive              4    0.6035         0.831170      0.615640  0.631893
    context              4    0.6305         0.812988      0.642468  0.660095
    minimal              1    0.6130         0.707726      0.615866  0.604188

Processing: scaling_0.8.pkl
Loading: data\processed\mnist_test\scaling_0.8.pkl
Found 10000 images
Using 2000 samples


Extracting embeddings: 100%|██████████| 16/16 [02:01<00:00,  7.61s/it]


Embedding time: 121.82s
Embedding shape: (2000, 768)

────────────────────────────────────────────────────────────
Prompt Engineering Experiments
────────────────────────────────────────────────────────────

Testing prompt set: basic
  Templates: 2
  Accuracy: 0.6125

Testing prompt set: descriptive
  Templates: 4
  Accuracy: 0.6410

Testing prompt set: context
  Templates: 4
  Accuracy: 0.7015

Testing prompt set: minimal
  Templates: 1
  Accuracy: 0.8220

Prompt Comparison:
 prompt_set  num_templates  accuracy  precision_macro  recall_macro  f1_macro
      basic              2    0.6125         0.815664      0.624856  0.641650
descriptive              4    0.6410         0.841102      0.653177  0.669379
    context              4    0.7015         0.840821      0.712538  0.720389
    minimal              1    0.8220         0.853111      0.819730  0.825157

Processing: all_combined.pkl
Loading: data\processed\mnist_test\all_combined.pkl
Found 10000 images
Using 2000 samples


Extracting embeddings: 100%|██████████| 16/16 [01:57<00:00,  7.33s/it]


Embedding time: 117.24s
Embedding shape: (2000, 768)

────────────────────────────────────────────────────────────
Prompt Engineering Experiments
────────────────────────────────────────────────────────────

Testing prompt set: basic
  Templates: 2
  Accuracy: 0.4890

Testing prompt set: descriptive
  Templates: 4
  Accuracy: 0.4925

Testing prompt set: context
  Templates: 4
  Accuracy: 0.5210

Testing prompt set: minimal
  Templates: 1
  Accuracy: 0.5125

Prompt Comparison:
 prompt_set  num_templates  accuracy  precision_macro  recall_macro  f1_macro
      basic              2    0.4890         0.708652      0.495903  0.503450
descriptive              4    0.4925         0.782224      0.501543  0.528810
    context              4    0.5210         0.705658      0.527021  0.538912
    minimal              1    0.5125         0.619658      0.513745  0.478350

OVERALL RESULTS SUMMARY
 prompt_set  num_templates  accuracy  precision_macro  recall_macro  f1_macro     test_set
      basic  

In [9]:
"""
Simplified CLIP Zero-Shot Image Classification for Jupyter Notebook

Features:
- Zero-shot classification with CLIP
- Multiple prompt template experiments
- Optional linear probe classifier
"""

import os
import pickle
from pathlib import Path
import time

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image

from transformers import CLIPProcessor, CLIPModel
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

# =====================================================
# CONFIGURATION - EDIT THESE
# =====================================================
DATASET = 'cifar10'  # 'mnist' or 'cifar10'
TEST_ROOT = './data/processed'
TRAIN_ROOT = './data/processed'  # Set to None to skip linear probe
SAMPLE_N = 2000
BATCH_SIZE = 128
USE_LINEAR_PROBE = False
TRAIN_TYPE = 'original'  # 'original', 'mixed_augmented', or 'combined_augmented'

# =====================================================
# Helper Functions
# =====================================================

def load_pickle_dataset(pkl_path):
    """Load pickle dataset and return list of (PIL.Image, int_label)"""
    print(f"Loading: {pkl_path}")
    with open(pkl_path, 'rb') as f:
        data = pickle.load(f)

    items = []
    
    if isinstance(data, dict):
        if 'images' in data and 'labels' in data:
            imgs = data['images']
            labs = data['labels']
            print(f"Found {len(imgs)} images")
            for img, lab in zip(imgs, labs):
                items.append((to_pil(img), int(lab)))
        else:
            raise ValueError(f"Unsupported dict format: {list(data.keys())}")
    
    elif isinstance(data, (tuple, list)) and len(data) == 2:
        imgs, labs = data[0], data[1]
        print(f"Found {len(imgs)} images")
        for img, lab in zip(imgs, labs):
            items.append((to_pil(img), int(lab)))
    
    else:
        raise ValueError(f"Unsupported pickle type: {type(data)}")
    
    return items


def to_pil(img):
    """Convert numpy arrays to PIL.Image (RGB)"""
    if isinstance(img, Image.Image):
        return img.convert('RGB')
    
    if isinstance(img, np.ndarray):
        if img.ndim == 2:
            img = np.stack([img]*3, axis=-1)
        elif img.ndim == 3 and img.shape[-1] == 1:
            img = np.stack([img.squeeze()]*3, axis=-1)
        
        if img.dtype != np.uint8:
            if img.max() <= 1.0:
                img = (img * 255).astype(np.uint8)
            else:
                img = np.clip(img, 0, 255).astype(np.uint8)
        
        return Image.fromarray(img).convert('RGB')
    
    raise ValueError(f"Unsupported image type: {type(img)}")


class PickleImageDataset(Dataset):
    def __init__(self, items, max_samples=None):
        if max_samples and len(items) > max_samples:
            items = items[:max_samples]
        self.items = items

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        img, lab = self.items[idx]
        return img, int(lab), idx


@torch.no_grad()
def extract_embeddings(model, processor, dataloader, device):
    """Extract CLIP image embeddings"""
    model.eval()
    all_feats = []
    all_labels = []
    all_idxs = []
    
    for batch in tqdm(dataloader, desc="Extracting embeddings"):
        imgs, labels, idxs = batch
        
        inputs = processor(images=list(imgs), return_tensors="pt", padding=True).to(device)
        image_feats = model.get_image_features(**inputs)
        image_feats = image_feats / image_feats.norm(p=2, dim=-1, keepdim=True)
        
        all_feats.append(image_feats.cpu())
        all_labels.extend([int(x) for x in labels])
        all_idxs.extend([int(x) for x in idxs])
    
    if len(all_feats) == 0:
        return np.zeros((0, model.config.projection_dim)), np.array([]), []
    
    all_feats = torch.cat(all_feats, dim=0).numpy()
    return all_feats, np.array(all_labels), all_idxs


@torch.no_grad()
def build_text_embeddings(model, processor, class_names, templates, device):
    """Build text embeddings from class names and templates"""
    texts = []
    for cname in class_names:
        for t in templates:
            texts.append(t.format(cname))
    
    all_text_feats = []
    batch_size = 64
    
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = processor(text=batch, return_tensors="pt", padding=True).to(device)
        tfeat = model.get_text_features(**inputs)
        tfeat = tfeat / tfeat.norm(p=2, dim=-1, keepdim=True)
        all_text_feats.append(tfeat.cpu())
    
    all_text_feats = torch.cat(all_text_feats, dim=0).numpy()
    
    # Average embeddings per class
    num_templates = len(templates)
    per_class = []
    for i in range(0, all_text_feats.shape[0], num_templates):
        per_class.append(all_text_feats[i:i+num_templates].mean(axis=0))
    
    per_class = np.vstack(per_class)
    per_class = per_class / (np.linalg.norm(per_class, axis=1, keepdims=True) + 1e-12)
    
    return per_class


def get_prompt_templates(dataset_name):
    """Define prompt template sets"""
    
    if dataset_name == 'mnist':
        template_sets = {
            'basic': [
                "a photo of the number {}.",
                "the digit {}.",
            ],
            'descriptive': [
                "a handwritten digit {}.",
                "a photo of a handwritten number {}.",
                "an image of the number {}.",
                "a drawing of the digit {}.",
            ],
            'context': [
                "a photo of a {} digit.",
                "a black and white image of number {}.",
                "a grayscale photo of the number {}.",
                "handwriting showing the digit {}.",
            ],
            'minimal': [
                "{}.",
            ],
        }
    else:  # CIFAR-10
        template_sets = {
            'basic': [
                "a photo of a {}.",
                "an image of a {}.",
            ],
            'descriptive': [
                "a photo of a {}.",
                "a blurry photo of a {}.",
                "a clear photo of a {}.",
                "a bright photo of a {}.",
            ],
            'context': [
                "a photo of a {} in the scene.",
                "a picture of the {}.",
                "an image showing a {}.",
            ],
            'minimal': [
                "{}.",
            ],
        }
    
    return template_sets


def train_linear_probe(model, processor, device, train_pkl_path, batch_size=128, max_samples=10000):
    """Train linear probe classifier from training data"""
    print("\n" + "="*60)
    print("Training Linear Probe Classifier")
    print("="*60)
    
    if not os.path.exists(train_pkl_path):
        print(f"Training data not found: {train_pkl_path}")
        return None
    
    items = load_pickle_dataset(train_pkl_path)
    
    if len(items) == 0:
        print("No training data found")
        return None
    
    if max_samples and len(items) > max_samples:
        print(f"Using {max_samples} samples (out of {len(items)})")
        items = items[:max_samples]
    
    ds = PickleImageDataset(items)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False,
                       collate_fn=lambda b: tuple(zip(*b)))
    
    print("Extracting training embeddings...")
    X_train, y_train, _ = extract_embeddings(model, processor, loader, device)
    print(f"Training embedding shape: {X_train.shape}")
    
    print("Training logistic regression...")
    clf = LogisticRegression(max_iter=2000, multi_class='multinomial', 
                            solver='lbfgs', verbose=0, n_jobs=-1)
    clf.fit(X_train, y_train)
    
    train_acc = clf.score(X_train, y_train)
    print(f"Training accuracy: {train_acc:.4f}")
    
    return clf


def run_classification(model, processor, device, test_pkl_path, class_names, 
                      template_sets, out_dir, batch_size=64, sample_n=2000, 
                      linear_probe_clf=None):
    """Run zero-shot and linear probe classification"""
    print(f"\n{'='*60}")
    print(f"Processing: {os.path.basename(test_pkl_path)}")
    print(f"{'='*60}")
    
    items = load_pickle_dataset(test_pkl_path)
    if len(items) == 0:
        print("No items found, skipping...")
        return
    
    if sample_n and len(items) > sample_n:
        items = items[:sample_n]
    print(f"Using {len(items)} samples")
    
    ds = PickleImageDataset(items)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False, 
                       collate_fn=lambda b: tuple(zip(*b)))
    
    start_time = time.time()
    X_test, y_test, idxs = extract_embeddings(model, processor, loader, device)
    embed_time = time.time() - start_time
    print(f"Embedding time: {embed_time:.2f}s")
    print(f"Embedding shape: {X_test.shape}")
    
    os.makedirs(out_dir, exist_ok=True)
    
    # Zero-shot with different prompts
    print(f"\n{'─'*60}")
    print("Prompt Engineering Experiments")
    print(f"{'─'*60}")
    
    results_summary = []
    
    for template_name, templates in template_sets.items():
        print(f"\nTesting prompt set: {template_name}")
        print(f"  Templates: {len(templates)}")
        
        text_embeds = build_text_embeddings(model, processor, class_names, templates, device)
        
        start_time = time.time()
        sims = X_test.dot(text_embeds.T)
        y_pred = sims.argmax(axis=1)
        pred_time = time.time() - start_time
        
        acc = accuracy_score(y_test, y_pred)
        report = classification_report(y_test, y_pred, target_names=class_names, 
                                      zero_division=0, digits=4, output_dict=True)
        
        print(f"  Accuracy: {acc:.4f}")
        
        prompt_dir = os.path.join(out_dir, f"prompt_{template_name}")
        os.makedirs(prompt_dir, exist_ok=True)
        
        report_text = classification_report(y_test, y_pred, target_names=class_names,
                                           zero_division=0, digits=4)
        with open(os.path.join(prompt_dir, "classification_report.txt"), "w") as f:
            f.write(f"Prompt Template Set: {template_name}\n")
            f.write(f"Number of Templates: {len(templates)}\n")
            f.write(f"Accuracy: {acc:.6f}\n\n")
            f.write(report_text)
        
        pd.DataFrame({
            "idx": idxs,
            "true_label": y_test.tolist(),
            "pred_label": y_pred.tolist(),
            "correct": (y_test == y_pred).astype(int).tolist()
        }).to_csv(os.path.join(prompt_dir, "predictions.csv"), index=False)
        
        cm = confusion_matrix(y_test, y_pred)
        cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)
        cm_df.to_csv(os.path.join(prompt_dir, "confusion_matrix.csv"))
        
        results_summary.append({
            'prompt_set': template_name,
            'num_templates': len(templates),
            'accuracy': acc,
            'precision_macro': report['macro avg']['precision'],
            'recall_macro': report['macro avg']['recall'],
            'f1_macro': report['macro avg']['f1-score'],
        })
    
    summary_df = pd.DataFrame(results_summary)
    summary_df.to_csv(os.path.join(out_dir, "prompt_comparison.csv"), index=False)
    print(f"\nPrompt Comparison:")
    print(summary_df.to_string(index=False))
    
    # Linear probe (if available)
    if linear_probe_clf is not None:
        print(f"\n{'─'*60}")
        print("Linear Probe Classification")
        print(f"{'─'*60}")
        
        y_pred_lp = linear_probe_clf.predict(X_test)
        acc_lp = accuracy_score(y_test, y_pred_lp)
        
        print(f"Accuracy: {acc_lp:.4f}")
        
        lp_dir = os.path.join(out_dir, "linear_probe")
        os.makedirs(lp_dir, exist_ok=True)
        
        report_lp = classification_report(y_test, y_pred_lp, target_names=class_names,
                                         zero_division=0, digits=4)
        with open(os.path.join(lp_dir, "classification_report.txt"), "w") as f:
            f.write(f"Linear Probe Classification\n")
            f.write(f"Accuracy: {acc_lp:.6f}\n\n")
            f.write(report_lp)


# =====================================================
# MAIN EXECUTION
# =====================================================

def main():
    print("\n" + "="*60)
    print("CLIP ZERO-SHOT CLASSIFICATION")
    print("="*60)
    print(f"Dataset: {DATASET.upper()}")
    print(f"Samples per test: {SAMPLE_N}")
    print(f"Batch size: {BATCH_SIZE}")
    print("="*60)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    
    # Load CLIP model
    model_name = "openai/clip-vit-large-patch14"
    print(f"\nLoading CLIP model: {model_name}")
    model = CLIPModel.from_pretrained(model_name).to(device)
    processor = CLIPProcessor.from_pretrained(model_name)
    print("Model loaded successfully")
    
    # Define class names
    if DATASET == 'mnist':
        class_names = [str(i) for i in range(10)]
    else:
        class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                      'dog', 'frog', 'horse', 'ship', 'truck']
    
    # Get prompt templates
    template_sets = get_prompt_templates(DATASET)
    print(f"\nLoaded {len(template_sets)} prompt template sets")
    
    # Define test files
    test_folder = Path(TEST_ROOT) / f'{DATASET}_test'
    test_files = {
        'original': test_folder / 'original.pkl',
        'noise': test_folder / 'noise.pkl',
        'occlusion_25': test_folder / 'occlusion_25.pkl',
        'rotation_15': test_folder / 'rotation_15.pkl',
        'scaling_0.8': test_folder / 'scaling_0.8.pkl',
        'all_combined': test_folder / 'all_combined.pkl',
    }
    
    # Train linear probe if requested
    linear_probe_clf = None
    if USE_LINEAR_PROBE and TRAIN_ROOT:
        train_folder = Path(TRAIN_ROOT) / f'{DATASET}_train'
        train_file = train_folder / f'{TRAIN_TYPE}.pkl'
        linear_probe_clf = train_linear_probe(model, processor, device, str(train_file), BATCH_SIZE)
    
    # Create output directory
    base_out = Path("CLIP_results") / DATASET
    os.makedirs(base_out, exist_ok=True)
    
    # Process each test file
    print(f"\n{'='*60}")
    print("PROCESSING TEST SETS")
    print(f"{'='*60}")
    
    all_results = []
    
    for test_name, test_path in test_files.items():
        if not test_path.exists():
            print(f"\nTest file missing: {test_path}")
            continue
        
        out_dir = base_out / test_name
        run_classification(
            model, processor, device, str(test_path),
            class_names, template_sets, str(out_dir),
            batch_size=BATCH_SIZE, sample_n=SAMPLE_N,
            linear_probe_clf=linear_probe_clf
        )
        
        prompt_comp = pd.read_csv(out_dir / "prompt_comparison.csv")
        prompt_comp['test_set'] = test_name
        all_results.append(prompt_comp)
    
    # Create overall summary
    if all_results:
        overall_summary = pd.concat(all_results, ignore_index=True)
        overall_summary.to_csv(base_out / "overall_summary.csv", index=False)
        
        print(f"\n{'='*60}")
        print("OVERALL RESULTS SUMMARY")
        print(f"{'='*60}")
        print(overall_summary.to_string(index=False))
    
    print(f"\n{'='*60}")
    print("PROCESSING COMPLETE")
    print(f"{'='*60}")
    print(f"Results saved to: {base_out}")


In [10]:

# Run the main function
if __name__ == "__main__":
    main()


CLIP ZERO-SHOT CLASSIFICATION
Dataset: CIFAR10
Samples per test: 2000
Batch size: 128
Device: cuda

Loading CLIP model: openai/clip-vit-large-patch14
Model loaded successfully

Loaded 4 prompt template sets

PROCESSING TEST SETS

Processing: original.pkl
Loading: data\processed\cifar10_test\original.pkl
Found 10000 images
Using 2000 samples


Extracting embeddings: 100%|██████████| 16/16 [02:58<00:00, 11.18s/it]


Embedding time: 178.98s
Embedding shape: (2000, 768)

────────────────────────────────────────────────────────────
Prompt Engineering Experiments
────────────────────────────────────────────────────────────

Testing prompt set: basic
  Templates: 2
  Accuracy: 0.9465

Testing prompt set: descriptive
  Templates: 4
  Accuracy: 0.9425

Testing prompt set: context
  Templates: 3
  Accuracy: 0.9475

Testing prompt set: minimal
  Templates: 1
  Accuracy: 0.8900

Prompt Comparison:
 prompt_set  num_templates  accuracy  precision_macro  recall_macro  f1_macro
      basic              2    0.9465         0.947274      0.947724  0.946226
descriptive              4    0.9425         0.944081      0.943663  0.942401
    context              3    0.9475         0.948460      0.948535  0.947304
    minimal              1    0.8900         0.900700      0.890074  0.887572

Processing: noise.pkl
Loading: data\processed\cifar10_test\noise.pkl
Found 10000 images
Using 2000 samples


Extracting embeddings: 100%|██████████| 16/16 [03:31<00:00, 13.20s/it]


Embedding time: 211.20s
Embedding shape: (2000, 768)

────────────────────────────────────────────────────────────
Prompt Engineering Experiments
────────────────────────────────────────────────────────────

Testing prompt set: basic
  Templates: 2
  Accuracy: 0.8310

Testing prompt set: descriptive
  Templates: 4
  Accuracy: 0.8235

Testing prompt set: context
  Templates: 3
  Accuracy: 0.8290

Testing prompt set: minimal
  Templates: 1
  Accuracy: 0.7055

Prompt Comparison:
 prompt_set  num_templates  accuracy  precision_macro  recall_macro  f1_macro
      basic              2    0.8310         0.847364      0.833026  0.828274
descriptive              4    0.8235         0.838988      0.825375  0.820677
    context              3    0.8290         0.842592      0.831052  0.827136
    minimal              1    0.7055         0.808597      0.708500  0.704236

Processing: occlusion_25.pkl
Loading: data\processed\cifar10_test\occlusion_25.pkl
Found 10000 images
Using 2000 samples


Extracting embeddings: 100%|██████████| 16/16 [03:26<00:00, 12.92s/it]


Embedding time: 206.69s
Embedding shape: (2000, 768)

────────────────────────────────────────────────────────────
Prompt Engineering Experiments
────────────────────────────────────────────────────────────

Testing prompt set: basic
  Templates: 2
  Accuracy: 0.9200

Testing prompt set: descriptive
  Templates: 4
  Accuracy: 0.9175

Testing prompt set: context
  Templates: 3
  Accuracy: 0.9245

Testing prompt set: minimal
  Templates: 1
  Accuracy: 0.8535

Prompt Comparison:
 prompt_set  num_templates  accuracy  precision_macro  recall_macro  f1_macro
      basic              2    0.9200         0.922635      0.921314  0.919435
descriptive              4    0.9175         0.920813      0.918869  0.917000
    context              3    0.9245         0.926213      0.925535  0.924122
    minimal              1    0.8535         0.881413      0.853735  0.851255

Processing: rotation_15.pkl
Loading: data\processed\cifar10_test\rotation_15.pkl
Found 10000 images
Using 2000 samples


Extracting embeddings: 100%|██████████| 16/16 [03:20<00:00, 12.55s/it]


Embedding time: 200.84s
Embedding shape: (2000, 768)

────────────────────────────────────────────────────────────
Prompt Engineering Experiments
────────────────────────────────────────────────────────────

Testing prompt set: basic
  Templates: 2
  Accuracy: 0.8920

Testing prompt set: descriptive
  Templates: 4
  Accuracy: 0.8720

Testing prompt set: context
  Templates: 3
  Accuracy: 0.8940

Testing prompt set: minimal
  Templates: 1
  Accuracy: 0.8115

Prompt Comparison:
 prompt_set  num_templates  accuracy  precision_macro  recall_macro  f1_macro
      basic              2    0.8920         0.899606      0.893771  0.891304
descriptive              4    0.8720         0.887419      0.874165  0.872003
    context              3    0.8940         0.901683      0.895388  0.894165
    minimal              1    0.8115         0.844298      0.812580  0.803963

Processing: scaling_0.8.pkl
Loading: data\processed\cifar10_test\scaling_0.8.pkl
Found 10000 images
Using 2000 samples


Extracting embeddings: 100%|██████████| 16/16 [02:28<00:00,  9.28s/it]


Embedding time: 148.50s
Embedding shape: (2000, 768)

────────────────────────────────────────────────────────────
Prompt Engineering Experiments
────────────────────────────────────────────────────────────

Testing prompt set: basic
  Templates: 2
  Accuracy: 0.9520

Testing prompt set: descriptive
  Templates: 4
  Accuracy: 0.9460

Testing prompt set: context
  Templates: 3
  Accuracy: 0.9500

Testing prompt set: minimal
  Templates: 1
  Accuracy: 0.8940

Prompt Comparison:
 prompt_set  num_templates  accuracy  precision_macro  recall_macro  f1_macro
      basic              2     0.952         0.952398      0.952792  0.951858
descriptive              4     0.946         0.947275      0.946901  0.945966
    context              3     0.950         0.950468      0.950619  0.949767
    minimal              1     0.894         0.903491      0.893861  0.890530

Processing: all_combined.pkl
Loading: data\processed\cifar10_test\all_combined.pkl
Found 10000 images
Using 2000 samples


Extracting embeddings: 100%|██████████| 16/16 [00:27<00:00,  1.70s/it]

Embedding time: 27.20s
Embedding shape: (2000, 768)

────────────────────────────────────────────────────────────
Prompt Engineering Experiments
────────────────────────────────────────────────────────────

Testing prompt set: basic
  Templates: 2
  Accuracy: 0.6545

Testing prompt set: descriptive
  Templates: 4
  Accuracy: 0.6455

Testing prompt set: context
  Templates: 3
  Accuracy: 0.6275

Testing prompt set: minimal
  Templates: 1
  Accuracy: 0.4880

Prompt Comparison:
 prompt_set  num_templates  accuracy  precision_macro  recall_macro  f1_macro
      basic              2    0.6545         0.716047      0.657432  0.650215
descriptive              4    0.6455         0.709372      0.649093  0.638733
    context              3    0.6275         0.709999      0.630754  0.632359
    minimal              1    0.4880         0.745615      0.493167  0.508386

OVERALL RESULTS SUMMARY
 prompt_set  num_templates  accuracy  precision_macro  recall_macro  f1_macro     test_set
      basic   


