# Zero-Shot Classification on Oxford 102 Flower Dataset using CLIP

This notebook performs zero-shot classification on the Oxford 102 Flower Dataset using the CLIP model. It includes steps for data loading, preprocessing, model inference, accuracy calculation, error analysis, and visualization of misclassified samples.

In [2]:
# 导入必要的库
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.datasets import Flowers102
from torch.utils.data import DataLoader
import clip
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import os
from collections import defaultdict
from sklearn.metrics import confusion_matrix
import seaborn as sns

# 设置设备为 GPU（如果可用）
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [3]:
# 定义数据预处理步骤，CLIP 对图像的输入要求
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),  # CLIP 通常使用 224x224 的输入
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711))
])

In [5]:
# 加载 Oxford-102 Flower Dataset
# 使用测试集来评估零样本分类的性能
dataset = Flowers102(root="../data", split="test", download=True, transform=preprocess)

# 创建数据加载器
batch_size = 64
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [6]:
# 加载 CLIP 模型和预处理器
model, preprocess_clip = clip.load("ViT-B/32", device=device)
model.eval()  # 设置为评估模式

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [8]:
# 获取所有类别名称
from scipy.io import loadmat

# 加载标签文件
labels_path = "../data/flowers-102/imagelabels.mat"
labels_mat = loadmat(labels_path)
labels = labels_mat['labels'][0] - 1  # 转换为0索引

# 创建类别列表 (0-101)
categories = [str(i) for i in range(102)]

# 为每个类别创建文本描述
# 使用 "a photo of a {class}" 的格式
text_prompts = [f"a photo of a {category}" for category in categories]

# 编码文本
with torch.no_grad():
    text_tokens = clip.tokenize(text_prompts).to(device)  # 将文本转换为 tokens
    text_embeddings = model.encode_text(text_tokens)  # 获取文本嵌入
    text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)  # 归一化

In [9]:
# 创建自己的类别映射
idx_to_class = {i: str(i) for i in range(102)}  # 如果只用数字标签
class_to_idx = {str(i): i for i in range(102)}

# 如果你有花名文件，可以这样使用：
try:
    with open("../data/oxford-102-flower-dataset/cat_to_name.json", 'r') as f:
        cat_to_name = json.load(f)
        # 注意：cat_to_name 中的键是从1开始的，需要调整为从0开始
        idx_to_class = {i: cat_to_name[str(i+1)] for i in range(102)}
        class_to_idx = {v: k for k, v in idx_to_class.items()}
except:
    # 如果没有花名文件，使用数字作为类别名
    idx_to_class = {i: str(i) for i in range(102)}
    class_to_idx = {str(i): i for i in range(102)}

# 确保类别数量一致
assert len(categories) == len(text_prompts), "类别数量不匹配"

AttributeError: 'Flowers102' object has no attribute 'class_to_idx'

In [7]:
# 进行零样本分类并计算准确率
correct = 0
total = 0

# 保存预测结果用于后续分析
all_predictions = []
all_labels = []
all_images = []

# 遍历数据集
for images, labels in tqdm(dataloader, desc="分类中"):
    images = images.to(device)
    labels = labels.to(device)
    
    with torch.no_grad():
        image_embeddings = model.encode_image(images)
        image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
        
        # 计算相似度
        similarity = image_embeddings @ text_embeddings.T  # [batch_size, num_classes]
        preds = similarity.argmax(dim=1)  # 获取最大相似度的类别索引
    
    # 记录正确预测的数量
    correct += (preds == labels).sum().item()
    total += labels.size(0)
    
    # 记录所有预测和标签
    all_predictions.extend(preds.cpu().numpy())
    all_labels.extend(labels.cpu().numpy())
    all_images.extend(images.cpu())

In [8]:
# 计算准确率
accuracy = correct / total
print(f"零样本分类准确率: {accuracy*100:.2f}%")

零样本分类准确率: XX.XX%


In [9]:
# 分析每个类别的错误率
# 计算每个类别的正确预测数和总样本数
class_correct = defaultdict(int)
class_total = defaultdict(int)

for pred, label in zip(all_predictions, all_labels):
    class_total[label] += 1
    if pred == label:
        class_correct[label] += 1

In [10]:
# 计算每个类别的错误率
class_error_rate = {}
for label in class_total:
    correct_count = class_correct[label]
    total_count = class_total[label]
    error = (total_count - correct_count) / total_count
    class_error_rate[label] = error

# 将类别按错误率排序（从高到低）
sorted_error = sorted(class_error_rate.items(), key=lambda x: x[1], reverse=True)

# 打印前10个错误率最高的类别
print("\n错误率最高的10个类别:")
for idx, (label, error) in enumerate(sorted_error[:10], 1):
    print(f"{idx}. 类别: {idx_to_class[label]}, 错误率: {error*100:.2f}%")


错误率最高的10个类别:
1. 类别: {类别名称1}, 错误率: YY.YY%
2. 类别: {类别名称2}, 错误率: YY.YY%
...


In [11]:
# 选择错误率最高的5个类别
top_n = 5
top_errors = sorted_error[:top_n]
top_error_labels = [label for label, _ in top_errors]

print(f"\n选择错误率最高的{top_n}个类别进行分析:")
for i, label in enumerate(top_error_labels, 1):
    print(f"{i}. 类别: {idx_to_class[label]}, 错误率: {class_error_rate[label]*100:.2f}%")


选择错误率最高的5个类别进行分析:
1. 类别: {类别名称1}, 错误率: YY.YY%
2. 类别: {类别名称2}, 错误率: YY.YY%
...


In [12]:
# 创建一个 DataFrame 标记错误分类的样本
df = pd.DataFrame({
    "预测类别": [idx_to_class[pred] for pred in all_predictions],
    "真实类别": [idx_to_class[label] for label in all_labels],
})

# 找出错误的预测
incorrect_df = df[df["预测类别"] != df["真实类别"]]

# 筛选出属于高错误率类别的错误样本
high_error_incorrect = df.loc[incorrect_df.index]
high_error_incorrect = high_error_incorrect[high_error_incorrect["真实类别"].isin([idx_to_class[label] for label in top_error_labels])]

In [13]:
# 设置每个类别展示的错误样本数量
num_samples_per_class = 5

# 随机选择每个类别的错误样本
samples_to_display = []
for label in top_error_labels:
    class_name = idx_to_class[label]
    class_errors = high_error_incorrect[high_error_incorrect["真实类别"] == class_name]
    if not class_errors.empty:
        sampled = class_errors.sample(n=min(num_samples_per_class, len(class_errors)), random_state=42)
        samples_to_display.extend(sampled.index.tolist())

In [14]:
# 展示错误分类的图像
plt.figure(figsize=(20, 4 * top_n))
for i, idx in enumerate(samples_to_display, 1):
    row = df.iloc[idx]
    img = all_images[idx]
    img = img.permute(1, 2, 0).numpy()
    img = (img * np.array([0.26862954, 0.26130258, 0.27577711]) + 
           np.array([0.48145466, 0.4578275, 0.40821073]))
    img = np.clip(img, 0, 1)
    
    plt.subplot(top_n, num_samples_per_class, i)
    plt.imshow(img)
    plt.title(f"Predicted: {row['预测类别']}\nTrue: {row['真实类别']}")
    plt.axis('off')
plt.suptitle("错误率最高类别的错误分类案例")
plt.show()

In [15]:
# 构建混淆矩阵
# 选择高错误率类别的标签和预测
filtered_labels = []
filtered_preds = []
for pred, label in zip(all_predictions, all_labels):
    if label in top_error_labels:
        filtered_labels.append(label)
        filtered_preds.append(pred)

# 生成混淆矩阵
cm = confusion_matrix(filtered_labels, filtered_preds, labels=top_error_labels)

# 绘制热力图
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', 
            xticklabels=[idx_to_class[label] for label in top_error_labels],
            yticklabels=[idx_to_class[label] for label in top_error_labels],
            cmap='Blues')
plt.xlabel('预测类别')
plt.ylabel('真实类别')
plt.title('高错误率类别的混淆矩阵')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

In [16]:
# 总结
print(f"\n总样本数: {total}")
print(f"正确预测数: {correct}")
print(f"错误预测数: {total - correct}")
print(f"零样本分类准确率: {accuracy*100:.2f}%")


总样本数: XXXX
正确预测数: XXXX
错误预测数: XXXX
零样本分类准确率: XX.XX%
错误率最高的类别及其错误率已展示。
