In [None]:
from openai import OpenAI
import torch
import clip
from PIL import Image
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

In [None]:
# 设置OpenAI API密钥
clients = OpenAI(
    api_key='sk-38d7140e0b844cf387a75a385e28e74f',
    base_url="https://api.deepseek.com/v1",
)

In [None]:
# 初始化CLIP模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
class Dataset:
    def __init__(self, data_path, classes):
        self.data_path = data_path
        self.classes = classes
        self.images = []
        self.labels = []
        
        for label, class_name in enumerate(classes):
            class_path = os.path.join(data_path, class_name)
            for img_name in os.listdir(class_path):
                self.images.append(os.path.join(class_path, img_name))
                self.labels.append(label)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = preprocess(Image.open(self.images[idx]))
        label = self.labels[idx]
        return image, label

In [None]:
def generate_class_descriptions(classes, method="simple"):
    """
    使用LLM生成类别描述
    method: "simple" - 简单描述
            "contrastive" - 对比描述
    """
    descriptions = {}
    
    for class_name in classes:
        if method == "simple":
            prompt = f"Generate a detailed description of what a '{class_name}' looks like, focusing on visual characteristics that would help recognize it in an image. Only keywords are allowed, not sentences, no more than 10 words in total."
        elif method == "contrastive":
            other_classes = [c for c in classes if c != class_name]
            prompt = f"Describe how to distinguish a '{class_name}' from {', '.join(other_classes)}. Focus on visual differences. Only keywords are allowed, not sentences, no more than 10 words in total."
        
        response = clients.chat.completions.create(
            model="deepseek-chat",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.7
        )
        
        descriptions[class_name] = response.choices[0].message.content
    
    return descriptions

In [None]:
def evaluate_clip(dataset, text_descriptions=None, number_descriptions=None):
    """
    评估CLIP模型性能
    text_descriptions: None - 使用原始类别名称
                      dict - 使用生成的描述文本
    """
    if text_descriptions is None:
        # 使用原始类别名称
        if number_descriptions is None:
            text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in dataset.classes]).to(device)
        else:
            text_inputs = torch.cat([clip.tokenize(
                f"a photo of {number_descriptions} {c}{'s' if number_descriptions > 1 else ''}"
            ) for c in dataset.classes]).to(device)
    else:
        # 使用生成的描述文本
        text_inputs = torch.cat([clip.tokenize(d) for d in text_descriptions.values()]).to(device)
    
    # 计算文本特征
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    
    # 预测所有图像
    all_preds = []
    all_labels = []
    
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
    
    for images, labels in tqdm(dataloader):
        images = images.to(device)
        
        # 计算图像特征
        with torch.no_grad():
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)
        
        # 计算相似度
        similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
        preds = similarity.argmax(dim=-1).cpu().numpy()
        
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())
    
    # 计算准确率
    accuracy = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    
    return accuracy, cm

In [None]:
def plot_confusion_matrix(cm, classes, title):
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()

In [None]:
def main():
    # 示例数据集 (实际使用时替换为真实数据集)
    classes = ["dog", "cat"]
    data_path = "imgs"  # 数据集应按照类别分文件夹存放
    
    # 创建数据集
    dataset = Dataset(data_path, classes)
    
    # 1. 基线评估 - 使用原始类别名称
    base_acc, base_cm = evaluate_clip(dataset)
    print(f"Baseline Accuracy: {base_acc:.4f}")
    # plot_confusion_matrix(base_cm, classes, "Baseline CLIP Confusion Matrix")
    
    # 2. 使用LLM生成的简单描述
    simple_descriptions = generate_class_descriptions(classes, method="simple")
    simple_acc, simple_cm = evaluate_clip(dataset, simple_descriptions)
    print(f"Simple Descriptions Accuracy: {simple_acc:.4f}")
    # plot_confusion_matrix(simple_cm, classes, "CLIP with Simple Descriptions")
    
    # 3. 使用LLM生成的对比描述
    contrastive_descriptions = generate_class_descriptions(classes, method="contrastive")
    contrast_acc, contrast_cm = evaluate_clip(dataset, contrastive_descriptions)
    print(f"Contrastive Descriptions Accuracy: {contrast_acc:.4f}")
    # plot_confusion_matrix(contrast_cm, classes, "CLIP with Contrastive Descriptions")
    
    # 结果比较
    methods = ["Baseline", "Simple", "Contrastive"]
    accuracies = [base_acc, simple_acc, contrast_acc]
    
    plt.figure(figsize=(10, 6))
    plt.bar(methods, accuracies)
    plt.title("Comparison of CLIP Performance with Different Text Descriptions")
    plt.ylabel("Accuracy")
    plt.ylim(0, 1.2)
    for i, acc in enumerate(accuracies):
        plt.text(i, acc + 0.02, f"{acc:.4f}", ha='center')
    plt.tight_layout()
    plt.show()

In [None]:
main()

In [None]:
def test_number():
    classes = ["dog", "cat"]
    data_path = "imgs"  # 数据集应按照类别分文件夹存放
    
    # 创建数据集
    dataset = Dataset(data_path, classes)
    
    acc1, cm1 = evaluate_clip(dataset, number_descriptions=1)
    acc2, cm2 = evaluate_clip(dataset, number_descriptions=2)
    acc3, cm3 = evaluate_clip(dataset, number_descriptions=3)
    acc10, cm10 = evaluate_clip(dataset, number_descriptions=10)

    # 结果比较
    methods = ["One", "Two", "Three", 'Ten']
    accuracies = [acc1, acc2, acc3, acc10]
    
    plt.figure(figsize=(10, 6))
    plt.bar(methods, accuracies)
    plt.title("Comparison of CLIP Performance with Different Number Descriptions")
    plt.ylabel("Accuracy")
    plt.ylim(0, 1.2)
    for i, acc in enumerate(accuracies):
        plt.text(i, acc + 0.02, f"{acc:.4f}", ha='center')
    plt.tight_layout()
    plt.show()

In [None]:
test_number()