In [None]:
import torch
from dreamsim import dreamsim
from rating import preprocess_image

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = dreamsim(pretrained=True, device=device)

img1 = preprocess_image("./img1.png")
img2 = preprocess_image("./img2.png")

similarity = model(img1, img2)
print(similarity)

In [None]:
# 导入必要的库
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import cv2

def display_comparison(img_path1, img_path2, similarity_score):
    """
    显示两张图片的对比和相似度分数
    """
    # 读取图片
    img1 = cv2.imread(img_path1)
    img2 = cv2.imread(img_path2)
    
    if img1 is None or img2 is None:
        print(f"无法读取图片: {img_path1 if img1 is None else img_path2}")
        return
    
    # 转换为RGB (OpenCV默认是BGR)
    img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
    img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
    
    # 调整图片大小以便于显示
    target_size = (224, 224)
    
    # 保持纵横比的调整大小函数
    def resize_keep_aspect(img, target_size):
        h, w = img.shape[:2]
        ratio = min(target_size[0] / w, target_size[1] / h)
        new_size = (int(w * ratio), int(h * ratio))
        resized = cv2.resize(img, new_size, interpolation=cv2.INTER_AREA)
        
        # 创建白色画布
        canvas = np.ones((target_size[1], target_size[0], 3), dtype=np.uint8) * 255
        
        # 居中放置图像
        x_offset = (target_size[0] - new_size[0]) // 2
        y_offset = (target_size[1] - new_size[1]) // 2
        
        canvas[y_offset:y_offset+new_size[1], x_offset:x_offset+new_size[0]] = resized
        return canvas
    
    img1_resized = resize_keep_aspect(img1, target_size)
    img2_resized = resize_keep_aspect(img2, target_size)
    
    # 创建图表
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    
    # 显示图片
    ax1.imshow(img1_resized)
    ax1.set_title("图片1")
    ax1.axis('off')
    
    ax2.imshow(img2_resized)
    ax2.set_title("图片2")
    ax2.axis('off')
    
    # 添加相似度分数
    plt.suptitle(f"相似度分数: {similarity_score:.4f}", fontsize=16)
    
    plt.tight_layout()
    plt.show()

# 使用函数显示对比
display_comparison("./img1.png", "./img2.png", similarity.item())
