In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from dreamsim import dreamsim
from rating import preprocess_image

# 初始化DreamSim模型
print("正在加载DreamSim模型...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = dreamsim(device=device)
print(f"模型已加载到设备: {device}")

# 计算两张图像的相似度
def calculate_similarity(img1_path, img2_path):
    """计算两张图像的相似度"""
    img1_tensor = preprocess_image(img1_path)
    img2_tensor = preprocess_image(img2_path)
    
    if img1_tensor is None or img2_tensor is None:
        return None
    
    with torch.no_grad():
        similarity = model(img1_tensor.unsqueeze(0), img2_tensor.unsqueeze(0))
    
    return similarity.item()

# 测试函数
def test_image_pair(img1_path, img2_path):
    """测试一对图像并显示结果"""
    similarity = calculate_similarity(img1_path, img2_path)
    
    if similarity is None:
        print("计算相似度失败")
        return
    
    print(f"相似度: {similarity:.4f}")
    
    # 显示图像
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(Image.open(img1_path))
    ax[0].set_title("图像 1")
    ax[0].axis('off')
    
    ax[1].imshow(Image.open(img2_path))
    ax[1].set_title("图像 2")
    ax[1].axis('off')
    
    plt.suptitle(f"DreamSim 相似度: {similarity:.4f}")
    plt.tight_layout()
    plt.show()

# 批量测试文件夹中的图像
def test_folder(folder1, folder2, limit=10):
    """测试两个文件夹中的图像相似度"""
    folder1 = Path(folder1)
    folder2 = Path(folder2)
    
    if not folder1.exists() or not folder2.exists():
        print("文件夹不存在")
        return
    
    # 获取文件列表
    files1 = [f for f in folder1.glob("*.jpg") or folder1.glob("*.png")]
    files2 = [f for f in folder2.glob("*.jpg") or folder2.glob("*.png")]
    
    if not files1 or not files2:
        print("文件夹中没有找到图像")
        return
    
    # 限制测试数量
    files1 = files1[:limit]
    files2 = files2[:limit]
    
    # 确保两个列表长度相同
    min_len = min(len(files1), len(files2))
    files1 = files1[:min_len]
    files2 = files2[:min_len]
    
    # 计算相似度
    similarities = []
    for img1, img2 in tqdm(zip(files1, files2), total=len(files1), desc="计算相似度"):
        sim = calculate_similarity(str(img1), str(img2))
        if sim is not None:
            similarities.append(sim)
    
    if not similarities:
        print("没有成功计算的相似度")
        return
    
    # 显示结果
    avg_sim = np.mean(similarities)
    print(f"平均相似度: {avg_sim:.4f}")
    
    plt.figure(figsize=(10, 6))
    plt.hist(similarities, bins=20, alpha=0.7)
    plt.axvline(avg_sim, color='r', linestyle='dashed', linewidth=2)
    plt.title(f"相似度分布 (平均: {avg_sim:.4f})")
    plt.xlabel("相似度")
    plt.ylabel("频率")
    plt.grid(True, alpha=0.3)
    plt.show()
    
    return similarities

# 示例用法
if __name__ == "__main__":

    test_image_pair("./img1.png", "./img2.png")
    
    # 文件夹测试示例
    # test_folder("./folder1", "./folder2", limit=5)
    
