# 使用 CLIP 进行 Caltech-101 数据集的零样本分类

本 Notebook 演示如何使用 OpenAI 的 CLIP 模型在 Caltech-101 数据集上进行零样本分类，并对分类结果进行错误分析。

### 导入必要的库

In [1]:
# 导入必要的库
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.datasets import Caltech101
from torch.utils.data import DataLoader
import clip
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import os
from collections import defaultdict
from sklearn.metrics import confusion_matrix
import seaborn as sns

### 设置设备为 GPU（如果可用）

In [2]:
# 设置设备为 GPU（如果可用）
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")

使用设备: cuda


### 定义数据预处理步骤，CLIP 对图像的输入要求

In [3]:
# 定义数据预处理步骤，CLIP 对图像的输入要求
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),  # CLIP 通常使用 224x224 的输入
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711))
])

### 加载 Caltech-101 数据集

由于 Caltech-101 数据集没有官方的测试集，我们将手动拆分数据集为训练集和测试集。

In [4]:
# 加载 Caltech-101 数据集
# 我们将使用测试集来评估零样本分类的性能
# Caltech-101 数据集没有官方的测试集，因此我们将手动拆分
dataset_full = Caltech101(root="../data", download=True, transform=preprocess)

# 手动拆分数据集为训练集和测试集
# 这里我们使用 80% 作为训练集，20% 作为测试集
from torch.utils.data import random_split

total_size = len(dataset_full)
test_size = int(0.2 * total_size)
train_size = total_size - test_size

dataset_train, dataset_test = random_split(dataset_full, [train_size, test_size])

# 创建数据加载器
batch_size = 64
dataloader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

Files already downloaded and verified


### 加载 CLIP 模型和预处理器

In [5]:
# 加载 CLIP 模型和预处理器
model, preprocess_clip = clip.load("ViT-B/32", device=device)
model.eval()  # 设置为评估模式

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

### 获取所有类别名称并准备标签映射

In [6]:
import os

def get_caltech_categories(root_dir="../data/caltech101/101_ObjectCategories"):
    # 获取所有类别目录
    categories = sorted(os.listdir(root_dir))
    
    # 创建类别到索引的映射
    class_to_idx = {cat: idx for idx, cat in enumerate(categories)}
    
    # 创建索引到类别的映射
    idx_to_class = {idx: cat for cat, idx in class_to_idx.items()}
    
    return categories, class_to_idx, idx_to_class

categories, class_to_idx, idx_to_class = get_caltech_categories()

# 为每个类别生成文本描述并编码
text_descriptions = [f"a photo of a {cat.lower()}" for cat in categories]
text_tokens = clip.tokenize(text_descriptions).to(device)

with torch.no_grad():
    text_embeddings = model.encode_text(text_tokens)
    text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)

### 进行零样本分类并计算准确率

In [7]:
# 进行零样本分类并计算准确率
correct = 0
total = 0

# 保存预测结果用于后续分析
all_predictions = []
all_labels = []
all_images = []

# 遍历数据集
for images, labels in tqdm(dataloader, desc="分类中"):
    images = images.to(device)
    labels = labels.to(device)
    
    with torch.no_grad():
        image_embeddings = model.encode_image(images)
        image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
        
        # 计算相似度
        similarity = image_embeddings @ text_embeddings.T  # [batch_size, num_classes]
        preds = similarity.argmax(dim=1)  # 获取最大相似度的类别索引
    
    # 记录正确预测的数量
    correct += (preds == labels).sum().item()
    total += labels.size(0)
    
    # 记录所有预测和标签
    all_predictions.extend(preds.cpu().numpy())
    all_labels.extend(labels.cpu().numpy())
    all_images.extend(images.cpu())

分类中:   0%|          | 0/28 [00:11<?, ?it/s]


RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]

### 计算准确率

In [None]:
# 计算准确率
accuracy = correct / total
print(f"零样本分类准确率: {accuracy*100:.2f}%")

### 分析每个类别的错误率

In [11]:
# 分析每个类别的错误率
# 计算每个类别的正确预测数和总样本数
class_correct = defaultdict(int)
class_total = defaultdict(int)

for pred, label in zip(all_predictions, all_labels):
    class_total[label] += 1
    if pred == label:
        class_correct[label] += 1
    
# 计算每个类别的错误率
class_error_rate = {}
for label in class_total:
    correct_count = class_correct[label]
    total_count = class_total[label]
    error = (total_count - correct_count) / total_count
    class_error_rate[label] = error
    
# 将类别按错误率排序（从高到低）
sorted_error = sorted(class_error_rate.items(), key=lambda x: x[1], reverse=True)
    
# 打印前10个错误率最高的类别
print("\n错误率最高的10个类别:")
for idx, (label, error) in enumerate(sorted_error[:10], 1):
    print(f"{idx}. 类别: {idx_to_class[label]}, 错误率: {error*100:.2f}%")

### 选择错误率最高的5个类别进行分析

In [12]:
# 选择错误率最高的5个类别
top_n = 5
top_errors = sorted_error[:top_n]
top_error_labels = [label for label, _ in top_errors]

print(f"\n选择错误率最高的{top_n}个类别进行分析:")
for i, label in enumerate(top_error_labels, 1):
    print(f"{i}. 类别: {idx_to_class[label]}, 错误率: {class_error_rate[label]*100:.2f}%")

### 创建并展示错误分类的样本

In [13]:
# 创建一个 DataFrame 标记预测和真实类别
df = pd.DataFrame({
    "预测类别": [idx_to_class[pred] for pred in all_predictions],
    "真实类别": [idx_to_class[label] for label in all_labels],
})

# 找出错误的预测
incorrect_df = df[df["预测类别"] != df["真实类别"]]

# 筛选出属于高错误率类别的错误样本
high_error_incorrect = df.loc[incorrect_df.index]
high_error_incorrect = high_error_incorrect[high_error_incorrect["真实类别"].isin([idx_to_class[label] for label in top_error_labels])]

# 设置每个类别展示的错误样本数量
num_samples_per_class = 5

# 随机选择每个类别的错误样本
samples_to_display = []
for label in top_error_labels:
    class_name = idx_to_class[label]
    class_errors = high_error_incorrect[high_error_incorrect["真实类别"] == class_name]
    if len(class_errors) > 0:
        sampled = class_errors.sample(n=min(num_samples_per_class, len(class_errors)), random_state=42)
        samples_to_display.extend(sampled.index.tolist())

# 展示错误分类的图像
plt.figure(figsize=(20, 4 * top_n))
for i, idx in enumerate(samples_to_display, 1):
    row = df.iloc[idx]
    img = all_images[idx]
    img = img.permute(1, 2, 0).numpy()
    img = (img * np.array([0.26862954, 0.26130258, 0.27577711]) + 
           np.array([0.48145466, 0.4578275, 0.40821073]))
    img = np.clip(img, 0, 1)
    
    plt.subplot(top_n, num_samples_per_class, i)
    plt.imshow(img)
    plt.title(f"预测: {row['预测类别']}\n真实: {row['真实类别']}")
    plt.axis('off')
plt.suptitle("错误率最高类别的错误分类案例")
plt.show()

### 构建混淆矩阵

In [14]:
# 构建混淆矩阵
# 选择高错误率类别的标签和预测
filtered_labels = []
filtered_preds = []
for pred, label in zip(all_predictions, all_labels):
    if label in top_error_labels:
        filtered_labels.append(label)
        filtered_preds.append(pred)

# 生成混淆矩阵
cm = confusion_matrix(filtered_labels, filtered_preds, labels=top_error_labels)

# 绘制热力图
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', 
            xticklabels=[idx_to_class[label] for label in top_error_labels],
            yticklabels=[idx_to_class[label] for label in top_error_labels],
            cmap='Blues')
plt.xlabel('预测类别')
plt.ylabel('真实类别')
plt.title('高错误率类别的混淆矩阵')
plt.show()

### 总结

In [15]:
# 总结
print(f"\n总样本数: {total}")
print(f"正确预测数: {correct}")
print(f"错误预测数: {total - correct}")
print(f"零样本分类准确率: {accuracy*100:.2f}%")
print(f"错误率最高的类别及其错误率已展示。")