In [None]:
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
print("Loading CLIP model...")
model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"âœ… Model loaded on {device}")

In [None]:
def classify_image_zeroshot(image, labels=["safe content", "unsafe content"]):
    """
    Use CLIP's zero-shot capabilities to classify image
    """
    inputs = processor(
        text=labels,
        images=image,
        return_tensors="pt",
        padding=True
    ).to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)
    
    return probs.cpu().numpy()[0]

In [None]:
print("\nTesting zero-shot classification...")

# Load dataset from previous notebook or reload
from datasets import load_dataset
dataset = load_dataset("FalconLLM/nsfw_image_dataset", split="train")
safe_samples = [item for item in dataset if item['label'] == 'safe'][:5]

for idx, sample in enumerate(safe_samples):
    img = sample['image']
    true_label = sample['label']
    
    probs = classify_image_zeroshot(img)
    pred_label = "safe" if probs[0] > probs[1] else "unsafe"
    confidence = max(probs)
    
    print(f"\nImage {idx+1}:")
    print(f"  True: {true_label}")
    print(f"  Predicted: {pred_label} (confidence: {confidence:.2%})")
    print(f"  Probs: safe={probs[0]:.2%}, unsafe={probs[1]:.2%}")

In [None]:
print("\n" + "="*50)
print("Evaluating on 100 images...")
print("="*50)

# Take balanced subset
eval_dataset = dataset.shuffle(seed=42).select(range(100))

predictions = []
ground_truth = []

for item in tqdm(eval_dataset):
    img = item['image']
    true_label = item['label']
    
    probs = classify_image_zeroshot(img)
    pred_label = "safe" if probs[0] > probs[1] else "unsafe"
    
    predictions.append(pred_label)
    ground_truth.append(true_label)

In [None]:
print("\n" + "="*50)
print("BASELINE RESULTS")
print("="*50)

print("\nClassification Report:")
print(classification_report(ground_truth, predictions))

# Confusion Matrix
cm = confusion_matrix(ground_truth, predictions)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['safe', 'unsafe'],
            yticklabels=['safe', 'unsafe'])
plt.title('Confusion Matrix - CLIP Zero-Shot')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.savefig('../results/confusion_matrix_baseline.png')
plt.show()

In [None]:
results = {
    'model': model_name,
    'approach': 'zero-shot',
    'num_samples': len(eval_dataset),
    'accuracy': (np.array(predictions) == np.array(ground_truth)).mean(),
    'predictions': predictions,
    'ground_truth': ground_truth
}

import json
with open('../results/baseline_results.json', 'w') as f:
    json.dump({k: v for k, v in results.items() if k not in ['predictions', 'ground_truth']}, f, indent=2)

print(f"\nâœ… Results saved to results/baseline_results.json")
print(f"\nðŸŽ¯ Baseline Accuracy: {results['accuracy']:.2%}")