<a href="https://colab.research.google.com/github/MoqiSheng/MoqiSheng.github.io/blob/main/baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# 连接 Google Drive
from google.colab import drive
drive.mount('/content/drive')

# 解压 anchor_pool.zip 到 images/anchor_pool 文件夹
!unzip /content/drive/MyDrive/images/anchor_pool.zip -d /content/images/anchor_pool

# 解压 candidate_images.zip 到 images/candidate_images 文件夹
!unzip /content/drive/MyDrive/images/candidate_images.zip -d /content/images/candidate_images

Mounted at /content/drive
Archive:  /content/drive/MyDrive/images/anchor_pool.zip
checkdir:  cannot create extraction directory: /content/images/anchor_pool
           No such file or directory
Archive:  /content/drive/MyDrive/images/candidate_images.zip
checkdir:  cannot create extraction directory: /content/images/candidate_images
           No such file or directory


In [4]:
# 在 /content/ 下创建 images 文件夹
!mkdir -p /content/images

# 解压 anchor_pool.zip 到 /content/images/anchor_pool 文件夹
!unzip /content/drive/MyDrive/images/anchor_pool.zip -d /content/images/anchor_pool

# 解压 candidate_images.zip 到 /content/images/candidate_images 文件夹
!unzip /content/drive/MyDrive/images/candidate_images.zip -d /content/images/candidate_images

Archive:  /content/drive/MyDrive/images/anchor_pool.zip
  inflating: /content/images/anchor_pool/1002.png  
  inflating: /content/images/anchor_pool/1004.png  
  inflating: /content/images/anchor_pool/1010.png  
  inflating: /content/images/anchor_pool/1013.png  
  inflating: /content/images/anchor_pool/1029.png  
  inflating: /content/images/anchor_pool/1034.png  
  inflating: /content/images/anchor_pool/1043.png  
  inflating: /content/images/anchor_pool/1048.png  
  inflating: /content/images/anchor_pool/105.png  
  inflating: /content/images/anchor_pool/1052.png  
  inflating: /content/images/anchor_pool/1057.png  
  inflating: /content/images/anchor_pool/1058.png  
  inflating: /content/images/anchor_pool/1059.png  
  inflating: /content/images/anchor_pool/1069.png  
  inflating: /content/images/anchor_pool/1079.png  
  inflating: /content/images/anchor_pool/1086.png  
  inflating: /content/images/anchor_pool/1092.png  
  inflating: /content/images/anchor_pool/1096.png  
  inflati

In [7]:
import os
import torch
import glob
from PIL import Image
import numpy as np
from tqdm import tqdm
from transformers import ViTImageProcessor, CLIPProcessor, CLIPModel
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.preprocessing import normalize

def setup_directories():
    """创建必要的文件夹"""
    os.makedirs("image_embeddings", exist_ok=True)
    os.makedirs("image_embeddings/anchor_embeddings", exist_ok=True)
    os.makedirs("image_embeddings/candidate_embeddings", exist_ok=True)

def load_models():
    """从网络加载DINOv2、CLIP、ResNet101和Places365模型"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    models_dict = {}

    print("加载DINOv2模型...")
    dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14', trust_repo=True).to(device).eval()
    dinov2_processor = ViTImageProcessor.from_pretrained('facebook/dinov2-base')
    models_dict['dinov2'] = (dinov2, dinov2_processor)

    print("加载CLIP模型...")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    models_dict['clip'] = (clip_model, clip_processor)

    print("加载ResNet101模型...")
    resnet101 = models.resnet101(pretrained=True).to(device).eval()
    resnet101 = torch.nn.Sequential(*list(resnet101.children())[:-1])
    models_dict['resnet101'] = (resnet101, None)

    print("加载Places365模型...")
    try:
        # 尝试加载Places365的ResNet50模型
        places365 = models.resnet50(pretrained=False).to(device)
        # 下载Places365预训练权重
        checkpoint = torch.hub.load_state_dict_from_url(
            url="http://places2.csail.mit.edu/models_places365/resnet50_places365.pth.tar",
            map_location=device
        )
        state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
        places365.load_state_dict(state_dict)
        places365 = torch.nn.Sequential(*list(places365.children())[:-1]).eval()
        models_dict['places365'] = (places365, None)
    except Exception as e:
        print(f"加载Places365模型失败: {e}")
        print("使用标准ResNet50作为备用...")
        places365 = models.resnet50(pretrained=True).to(device)
        places365 = torch.nn.Sequential(*list(places365.children())[:-1]).eval()
        models_dict['places365'] = (places365, None)

    return models_dict, device

def extract_id(filename):
    """从文件名中提取数字ID"""
    basename = os.path.basename(filename)
    name_without_ext = os.path.splitext(basename)[0]
    match = re.search(r'(\d+)', name_without_ext)
    return int(match.group(1)) if match else float('inf')

def get_sorted_images(directory):
    """获取指定目录下按ID排序的图像文件列表"""
    image_files = []
    for ext in ['*.png', '*.jpg', '*.jpeg']:
        image_files.extend(glob.glob(os.path.join(directory, ext)))
    return sorted(image_files, key=extract_id)

def encode_images(image_files, model, processor, device, output_dir, model_name, batch_size=16):
    """批量编码图像并保存结果"""
    if not image_files:
        print(f"警告: {output_dir} 没有找到图像文件")
        return

    print(f"使用 {model_name} 处理 {len(image_files)} 张图像...")
    all_embeddings = {}

    for i in tqdm(range(0, len(image_files), batch_size)):
        batch_files = image_files[i:i+batch_size]
        batch_images = []
        batch_ids = []

        for img_file in batch_files:
            try:
                img = Image.open(img_file).convert('RGB')
                batch_images.append(img)
                img_id = extract_id(img_file)
                batch_ids.append(img_id)
            except Exception as e:
                print(f"处理图像 {img_file} 出错: {e}")

        if not batch_images:
            continue

        with torch.no_grad():
            if model_name in ['dinov2', 'clip']:
                inputs = processor(images=batch_images, return_tensors="pt", padding=True)
                inputs = {k: v.to(device) for k, v in inputs.items()}
                if model_name == 'dinov2':
                    outputs = model(**inputs)
                    embeddings = outputs[0][:, 0, :].cpu().numpy()
                else:
                    outputs = model.get_image_features(**inputs)
                    embeddings = outputs.cpu().numpy()
            else:
                inputs = torch.stack([
                    transforms.ToTensor()(img.resize((224, 224))) for img in batch_images
                ])
                inputs = transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )(inputs).to(device)
                embeddings = model(inputs).squeeze(-1).squeeze(-1).cpu().numpy()

            normalized_embeddings = normalize(embeddings, axis=1)

            for j, img_id in enumerate(batch_ids):
                all_embeddings[img_id] = torch.tensor(normalized_embeddings[j], dtype=torch.float32)

    sorted_ids = sorted(all_embeddings.keys())
    sorted_embeddings = torch.stack([all_embeddings[img_id] for img_id in sorted_ids])

    embedding_file = os.path.join(output_dir, f"{model_name}_image_emb.pt")
    id_mapping_file = os.path.join(output_dir, f"{model_name}_image_id.pt")

    torch.save(sorted_embeddings, embedding_file)
    torch.save(sorted_ids, id_mapping_file)

    print(f"{model_name} 嵌入已保存到 {embedding_file}")
    print(f"{model_name} ID映射已保存到 {id_mapping_file}")
    print(f"{model_name} 嵌入形状: {sorted_embeddings.shape}")

def main():
    setup_directories()
    models_dict, device = load_models()

    # parent_dir = os.path.dirname(os.getcwd())
    anchor_dir = os.path.join("images", "anchor_pool")
    anchor_images = get_sorted_images(anchor_dir)
    print(f"找到 {len(anchor_images)} 张锚点图像")

    predict_dir = os.path.join("images", "candidate_images")
    predict_images = get_sorted_images(predict_dir)
    print(f"找到 {len(predict_images)} 张预测图像")

    for model_name, (model, processor) in models_dict.items():
        encode_images(
            anchor_images,
            model,
            processor,
            device,
            "image_embeddings/anchor_embeddings",
            model_name
        )
        encode_images(
            predict_images,
            model,
            processor,
            device,
            "image_embeddings/candidate_embeddings",
            model_name
        )

    print("所有图像编码完成!")

if __name__ == "__main__":
    main()

加载DINOv2模型...


Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


加载CLIP模型...
加载ResNet101模型...




加载Places365模型...




加载Places365模型失败: Error(s) in loading state_dict for ResNet:
	size mismatch for fc.weight: copying a param with shape torch.Size([365, 2048]) from checkpoint, the shape in current model is torch.Size([1000, 2048]).
	size mismatch for fc.bias: copying a param with shape torch.Size([365]) from checkpoint, the shape in current model is torch.Size([1000]).
使用标准ResNet50作为备用...




找到 203 张锚点图像
找到 1132 张预测图像
使用 dinov2 处理 203 张图像...


  return self.preprocess(images, **kwargs)
  0%|          | 0/13 [00:00<?, ?it/s]


ValueError: The `size` dictionary must contain the keys `height` and `width`. Got dict_keys(['shortest_edge'])

In [8]:
import os
import torch
import glob
from PIL import Image
import numpy as np
from tqdm import tqdm
from transformers import ViTImageProcessor, CLIPProcessor, CLIPModel
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.preprocessing import normalize

def setup_directories():
    """创建必要的文件夹"""
    os.makedirs("image_embeddings", exist_ok=True)
    os.makedirs("image_embeddings/anchor_embeddings", exist_ok=True)
    os.makedirs("image_embeddings/candidate_embeddings", exist_ok=True)

def load_models():
    """从网络加载DINOv2、CLIP、ResNet101和Places365模型"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    models_dict = {}

    print("加载DINOv2模型...")
    dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14', trust_repo=True).to(device).eval()
    # 显式设置size参数以避免shortest_edge问题
    dinov2_processor = ViTImageProcessor.from_pretrained('facebook/dinov2-base', size={'height': 224, 'width': 224})
    models_dict['dinov2'] = (dinov2, dinov2_processor)

    print("加载CLIP模型...")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    models_dict['clip'] = (clip_model, clip_processor)

    print("加载ResNet101模型...")
    resnet101 = models.resnet101(pretrained=True).to(device).eval()
    resnet101 = torch.nn.Sequential(*list(resnet101.children())[:-1])
    models_dict['resnet101'] = (resnet101, None)

    print("加载Places365模型...")
    try:
        places365 = models.resnet50(pretrained=False).to(device)
        checkpoint = torch.hub.load_state_dict_from_url(
            url="http://places2.csail.mit.edu/models_places365/resnet50_places365.pth.tar",
            map_location=device
        )
        state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
        places365.load_state_dict(state_dict)
        places365 = torch.nn.Sequential(*list(places365.children())[:-1]).eval()
        models_dict['places365'] = (places365, None)
    except Exception as e:
        print(f"加载Places365模型失败: {e}")
        print("使用标准ResNet50作为备用...")
        places365 = models.resnet50(pretrained=True).to(device)
        places365 = torch.nn.Sequential(*list(places365.children())[:-1]).eval()
        models_dict['places365'] = (places365, None)

    return models_dict, device

def extract_id(filename):
    """从文件名中提取数字ID"""
    basename = os.path.basename(filename)
    name_without_ext = os.path.splitext(basename)[0]
    match = re.search(r'(\d+)', name_without_ext)
    return int(match.group(1)) if match else float('inf')

def get_sorted_images(directory):
    """获取指定目录下按ID排序的图像文件列表"""
    image_files = []
    for ext in ['*.png', '*.jpg', '*.jpeg']:
        image_files.extend(glob.glob(os.path.join(directory, ext)))
    return sorted(image_files, key=extract_id)

def encode_images(image_files, model, processor, device, output_dir, model_name, batch_size=16):
    """批量编码图像并保存结果"""
    if not image_files:
        print(f"警告: {output_dir} 没有找到图像文件")
        return

    print(f"使用 {model_name} 处理 {len(image_files)} 张图像...")
    all_embeddings = {}

    for i in tqdm(range(0, len(image_files), batch_size)):
        batch_files = image_files[i:i+batch_size]
        batch_images = []
        batch_ids = []

        for img_file in batch_files:
            try:
                img = Image.open(img_file).convert('RGB')
                batch_images.append(img)
                img_id = extract_id(img_file)
                batch_ids.append(img_id)
            except Exception as e:
                print(f"处理图像 {img_file} 出错: {e}")

        if not batch_images:
            continue

        try:
            with torch.no_grad():
                if model_name in ['dinov2', 'clip']:
                    # 移除padding=True，因为ViT不需要
                    inputs = processor(images=batch_images, return_tensors="pt")
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                    if model_name == 'dinov2':
                        outputs = model(**inputs)
                        embeddings = outputs[0][:, 0, :].cpu().numpy()
                    else:
                        outputs = model.get_image_features(**inputs)
                        embeddings = outputs.cpu().numpy()
                else:
                    inputs = torch.stack([
                        transforms.ToTensor()(img.resize((224, 224))) for img in batch_images
                    ])
                    inputs = transforms.Normalize(
                        mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]
                    )(inputs).to(device)
                    embeddings = model(inputs).squeeze(-1).squeeze(-авис

                    normalized_embeddings = normalize(embeddings, axis=1)

                    for j, img_id in enumerate(batch_ids):
                        all_embeddings[img_id] = torch.tensor(normalized_embeddings[j], dtype=torch.float32)

        except Exception as e:
            print(f"批次 {i//batch_size} 处理出错: {e}")
            continue

    sorted_ids = sorted(all_embeddings.keys())
    sorted_embeddings = torch.stack([all_embeddings[img_id] for img_id in sorted_ids])

    embedding_file = os.path.join(output_dir, f"{model_name}_image_emb.pt")
    id_mapping_file = os.path.join(output_dir, f"{model_name}_image_id.pt")

    torch.save(sorted_embeddings, embedding_file)
    torch.save(sorted_ids, id_mapping_file)

    print(f"{model_name} 嵌入已保存到 {embedding_file}")
    print(f"{model_name} ID映射已保存到 {id_mapping_file}")
    print(f"{model_name} 嵌入形状: {sorted_embeddings.shape}")

def main():
    setup_directories()
    models_dict, device = load_models()

    parent_dir = os.path.dirname(os.getcwd())
    anchor_dir = os.path.join("images", "anchor_pool")
    anchor_images = get_sorted_images(anchor_dir)
    print(f"找到 {len(anchor_images)} 张锚点图像")

    predict_dir = os.path.join("images", "candidate_images")
    predict_images = get_sorted_images(predict_dir)
    print(f"找到 {len(predict_images)} 张预测图像")

    for model_name, (model, processor) in models_dict.items():
        encode_images(
            anchor_images,
            model,
            processor,
            device,
            "image_embeddings/anchor_embeddings",
            model_name
        )
        encode_images(
            predict_images,
            model,
            processor,
            device,
            "image_embeddings/candidate_embeddings",
            model_name
        )

    print("所有图像编码完成!")

if __name__ == "__main__":
    main()

SyntaxError: '(' was never closed (<ipython-input-8-30c725b1ebc1>, line 119)

In [9]:
import os
import torch
import glob
from PIL import Image
import numpy as np
from tqdm import tqdm
from transformers import ViTImageProcessor, CLIPProcessor, CLIPModel
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.preprocessing import normalize

def setup_directories():
    """创建必要的文件夹"""
    os.makedirs("image_embeddings", exist_ok=True)
    os.makedirs("image_embeddings/anchor_embeddings", exist_ok=True)
    os.makedirs("image_embeddings/candidate_embeddings", exist_ok=True)

def load_models():
    """从网络加载DINOv2、CLIP、ResNet101和Places365模型"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    models_dict = {}

    print("加载DINOv2模型...")
    dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14', trust_repo=True).to(device).eval()
    dinov2_processor = ViTImageProcessor.from_pretrained('facebook/dinov2-base', size={'height': 224, 'width': 224})
    models_dict['dinov2'] = (dinov2, dinov2_processor)

    print("加载CLIP模型...")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    models_dict['clip'] = (clip_model, clip_processor)

    print("加载ResNet101模型...")
    resnet101 = models.resnet101(pretrained=True).to(device).eval()
    resnet101 = torch.nn.Sequential(*list(resnet101.children())[:-1])
    models_dict['resnet101'] = (resnet101, None)

    print("加载Places365模型...")
    try:
        places365 = models.resnet50(pretrained=False).to(device)
        checkpoint = torch.hub.load_state_dict_from_url(
            url="http://places2.csail.mit.edu/models_places365/resnet50_places365.pth.tar",
            map_location=device
        )
        state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
        places365.load_state_dict(state_dict)
        places365 = torch.nn.Sequential(*list(places365.children())[:-1]).eval()
        models_dict['places365'] = (places365, None)
    except Exception as e:
        print(f"加载Places365模型失败: {e}")
        print("使用标准ResNet50作为备用...")
        places365 = models.resnet50(pretrained=True).to(device)
        places365 = torch.nn.Sequential(*list(places365.children())[:-1]).eval()
        models_dict['places365'] = (places365, None)

    return models_dict, device

def extract_id(filename):
    """从文件名中提取数字ID"""
    basename = os.path.basename(filename)
    name_without_ext = os.path.splitext(basename)[0]
    match = re.search(r'(\d+)', name_without_ext)
    return int(match.group(1)) if match else float('inf')

def get_sorted_images(directory):
    """获取指定目录下按ID排序的图像文件列表"""
    image_files = []
    for ext in ['*.png', '*.jpg', '*.jpeg']:
        image_files.extend(glob.glob(os.path.join(directory, ext)))
    return sorted(image_files, key=extract_id)

def encode_images(image_files, model, processor, device, output_dir, model_name, batch_size=16):
    """批量编码图像并保存结果"""
    if not image_files:
        print(f"警告: {output_dir} 没有找到图像文件")
        return

    print(f"使用 {model_name} 处理 {len(image_files)} 张图像...")
    all_embeddings = {}

    for i in tqdm(range(0, len(image_files), batch_size)):
        batch_files = image_files[i:i+batch_size]
        batch_images = []
        batch_ids = []

        for img_file in batch_files:
            try:
                img = Image.open(img_file).convert('RGB')
                batch_images.append(img)
                img_id = extract_id(img_file)
                batch_ids.append(img_id)
            except Exception as e:
                print(f"处理图像 {img_file} 出错: {e}")

        if not batch_images:
            continue

        with torch.no_grad():
            if model_name in ['dinov2', 'clip']:
                # 为DINOv2和CLIP移除padding参数，并确保正确处理
                inputs = processor(images=batch_images, return_tensors="pt", do_resize=True, do_normalize=True)
                inputs = {k: v.to(device) for k, v in inputs.items()}
                if model_name == 'dinov2':
                    outputs = model(**inputs)
                    embeddings = outputs[0][:, 0, :].cpu().numpy()
                else:
                    outputs = model.get_image_features(**inputs)
                    embeddings = outputs.cpu().numpy()
            else:
                inputs = torch.stack([
                    transforms.ToTensor()(img.resize((224, 224))) for img in batch_images
                ])
                inputs = transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )(inputs).to(device)
                embeddings = model(inputs).squeeze(-1).squeeze(-1).cpu().numpy()

            normalized_embeddings = normalize(embeddings, axis=1)

            for j, img_id in enumerate(batch_ids):
                all_embeddings[img_id] = torch.tensor(normalized_embeddings[j], dtype=torch.float32)

    sorted_ids = sorted(all_embeddings.keys())
    sorted_embeddings = torch.stack([all_embeddings[img_id] for img_id in sorted_ids])

    embedding_file = os.path.join(output_dir, f"{model_name}_image_emb.pt")
    id_mapping_file = os.path.join(output_dir, f"{model_name}_image_id.pt")

    torch.save(sorted_embeddings, embedding_file)
    torch.save(sorted_ids, id_mapping_file)

    print(f"{model_name} 嵌入已保存到 {embedding_file}")
    print(f"{model_name} ID映射已保存到 {id_mapping_file}")
    print(f"{model_name} 嵌入形状: {sorted_embeddings.shape}")

def main():
    setup_directories()
    models_dict, device = load_models()

    parent_dir = os.path.dirname(os.getcwd())
    anchor_dir = os.path.join("images", "anchor_pool")
    anchor_images = get_sorted_images(anchor_dir)
    print(f"找到 {len(anchor_images)} 张锚点图像")

    predict_dir = os.path.join("images", "candidate_images")
    predict_images = get_sorted_images(predict_dir)
    print(f"找到 {len(predict_images)} 张预测图像")

    for model_name, (model, processor) in models_dict.items():
        encode_images(
            anchor_images,
            model,
            processor,
            device,
            "image_embeddings/anchor_embeddings",
            model_name
        )
        encode_images(
            predict_images,
            model,
            processor,
            device,
            "image_embeddings/candidate_embeddings",
            model_name
        )

    print("所有图像编码完成!")

if __name__ == "__main__":
    main()

加载DINOv2模型...


Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


加载CLIP模型...
加载ResNet101模型...




加载Places365模型...




加载Places365模型失败: Error(s) in loading state_dict for ResNet:
	size mismatch for fc.weight: copying a param with shape torch.Size([365, 2048]) from checkpoint, the shape in current model is torch.Size([1000, 2048]).
	size mismatch for fc.bias: copying a param with shape torch.Size([365]) from checkpoint, the shape in current model is torch.Size([1000]).
使用标准ResNet50作为备用...




找到 203 张锚点图像
找到 1132 张预测图像
使用 dinov2 处理 203 张图像...


  0%|          | 0/13 [00:00<?, ?it/s]


TypeError: DinoVisionTransformer.forward_features() got an unexpected keyword argument 'pixel_values'

In [10]:
import os
import torch
import glob
from PIL import Image
import numpy as np
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.preprocessing import normalize

def setup_directories():
    """创建必要的文件夹"""
    os.makedirs("image_embeddings", exist_ok=True)
    os.makedirs("image_embeddings/anchor_embeddings", exist_ok=True)
    os.makedirs("image_embeddings/candidate_embeddings", exist_ok=True)

def load_models():
    """从网络加载DINOv2、CLIP、ResNet101和Places365模型"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    models_dict = {}

    print("加载DINOv2模型...")
    dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14', trust_repo=True).to(device).eval()
    # DINOv2需要自定义预处理，不使用ViTImageProcessor
    dinov2_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    models_dict['dinov2'] = (dinov2, dinov2_transform)

    print("加载CLIP模型...")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    models_dict['clip'] = (clip_model, clip_processor)

    print("加载ResNet101模型...")
    resnet101 = models.resnet101(pretrained=True).to(device).eval()
    resnet101 = torch.nn.Sequential(*list(resnet101.children())[:-1])
    models_dict['resnet101'] = (resnet101, None)

    print("加载Places365模型...")
    try:
        places365 = models.resnet50(pretrained=False).to(device)
        checkpoint = torch.hub.load_state_dict_from_url(
            url="http://places2.csail.mit.edu/models_places365/resnet50_places365.pth.tar",
            map_location=device
        )
        state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
        places365.load_state_dict(state_dict)
        places365 = torch.nn.Sequential(*list(places365.children())[:-1]).eval()
        models_dict['places365'] = (places365, None)
    except Exception as e:
        print(f"加载Places365模型失败: {e}")
        print("使用标准ResNet50作为备用...")
        places365 = models.resnet50(pretrained=True).to(device)
        places365 = torch.nn.Sequential(*list(places365.children())[:-1]).eval()
        models_dict['places365'] = (places365, None)

    return models_dict, device

def extract_id(filename):
    """从文件名中提取数字ID"""
    basename = os.path.basename(filename)
    name_without_ext = os.path.splitext(basename)[0]
    match = re.search(r'(\d+)', name_without_ext)
    return int(match.group(1)) if match else float('inf')

def get_sorted_images(directory):
    """获取指定目录下按ID排序的图像文件列表"""
    image_files = []
    for ext in ['*.png', '*.jpg', '*.jpeg']:
        image_files.extend(glob.glob(os.path.join(directory, ext)))
    return sorted(image_files, key=extract_id)

def encode_images(image_files, model, processor, device, output_dir, model_name, batch_size=16):
    """批量编码图像并保存结果"""
    if not image_files:
        print(f"警告: {output_dir} 没有找到图像文件")
        return

    print(f"使用 {model_name} 处理 {len(image_files)} 张图像...")
    all_embeddings = {}

    for i in tqdm(range(0, len(image_files), batch_size)):
        batch_files = image_files[i:i+batch_size]
        batch_images = []
        batch_ids = []

        for img_file in batch_files:
            try:
                img = Image.open(img_file).convert('RGB')
                batch_images.append(img)
                img_id = extract_id(img_file)
                batch_ids.append(img_id)
            except Exception as e:
                print(f"处理图像 {img_file} 出错: {e}")

        if not batch_images:
            continue

        with torch.no_grad():
            if model_name == 'dinov2':
                # DINOv2使用torchvision的transform处理
                inputs = torch.stack([processor(img) for img in batch_images]).to(device)
                outputs = model(inputs)
                embeddings = outputs[:, 0, :].cpu().numpy()  # 获取[CLS] token
            elif model_name == 'clip':
                inputs = processor(images=batch_images, return_tensors="pt", do_resize=True, do_normalize=True)
                inputs = {k: v.to(device) for k, v in inputs.items()}
                outputs = model.get_image_features(**inputs)
                embeddings = outputs.cpu().numpy()
            else:
                inputs = torch.stack([
                    transforms.ToTensor()(img.resize((224, 224))) for img in batch_images
                ])
                inputs = transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )(inputs).to(device)
                embeddings = model(inputs).squeeze(-1).squeeze(-1).cpu().numpy()

            normalized_embeddings = normalize(embeddings, axis=1)

            for j, img_id in enumerate(batch_ids):
                all_embeddings[img_id] = torch.tensor(normalized_embeddings[j], dtype=torch.float32)

    sorted_ids = sorted(all_embeddings.keys())
    sorted_embeddings = torch.stack([all_embeddings[img_id] for img_id in sorted_ids])

    embedding_file = os.path.join(output_dir, f"{model_name}_image_emb.pt")
    id_mapping_file = os.path.join(output_dir, f"{model_name}_image_id.pt")

    torch.save(sorted_embeddings, embedding_file)
    torch.save(sorted_ids, id_mapping_file)

    print(f"{model_name} 嵌入已保存到 {embedding_file}")
    print(f"{model_name} ID映射已保存到 {id_mapping_file}")
    print(f"{model_name} 嵌入形状: {sorted_embeddings.shape}")

def main():
    setup_directories()
    models_dict, device = load_models()

    parent_dir = os.path.dirname(os.getcwd())
    anchor_dir = os.path.join("images", "anchor_pool")
    anchor_images = get_sorted_images(anchor_dir)
    print(f"找到 {len(anchor_images)} 张锚点图像")

    predict_dir = os.path.join("images", "candidate_images")
    predict_images = get_sorted_images(predict_dir)
    print(f"找到 {len(predict_images)} 张预测图像")

    for model_name, (model, processor) in models_dict.items():
        encode_images(
            anchor_images,
            model,
            processor,
            device,
            "image_embeddings/anchor_embeddings",
            model_name
        )
        encode_images(
            predict_images,
            model,
            processor,
            device,
            "image_embeddings/candidate_embeddings",
            model_name
        )

    print("所有图像编码完成!")

if __name__ == "__main__":
    main()

加载DINOv2模型...


Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


加载CLIP模型...
加载ResNet101模型...
加载Places365模型...
加载Places365模型失败: Error(s) in loading state_dict for ResNet:
	size mismatch for fc.weight: copying a param with shape torch.Size([365, 2048]) from checkpoint, the shape in current model is torch.Size([1000, 2048]).
	size mismatch for fc.bias: copying a param with shape torch.Size([365]) from checkpoint, the shape in current model is torch.Size([1000]).
使用标准ResNet50作为备用...
找到 203 张锚点图像
找到 1132 张预测图像
使用 dinov2 处理 203 张图像...


  0%|          | 0/13 [00:01<?, ?it/s]


IndexError: too many indices for tensor of dimension 2

In [11]:
import os
import torch
import glob
from PIL import Image
import numpy as np
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.preprocessing import normalize

def setup_directories():
    """创建必要的文件夹"""
    os.makedirs("image_embeddings", exist_ok=True)
    os.makedirs("image_embeddings/anchor_embeddings", exist_ok=True)
    os.makedirs("image_embeddings/candidate_embeddings", exist_ok=True)

def load_models():
    """从网络加载DINOv2、CLIP、ResNet101和Places365模型"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    models_dict = {}

    print("加载DINOv2模型...")
    dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14', trust_repo=True).to(device).eval()
    dinov2_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    models_dict['dinov2'] = (dinov2, dinov2_transform)

    print("加载CLIP模型...")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    models_dict['clip'] = (clip_model, clip_processor)

    print("加载ResNet101模型...")
    resnet101 = models.resnet101(pretrained=True).to(device).eval()
    resnet101 = torch.nn.Sequential(*list(resnet101.children())[:-1])
    models_dict['resnet101'] = (resnet101, None)

    print("加载Places365模型...")
    try:
        places365 = models.resnet50(pretrained=False).to(device)
        checkpoint = torch.hub.load_state_dict_from_url(
            url="http://places2.csail.mit.edu/models_places365/resnet50_places365.pth.tar",
            map_location=device
        )
        state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
        places365.load_state_dict(state_dict)
        places365 = torch.nn.Sequential(*list(places365.children())[:-1]).eval()
        models_dict['places365'] = (places365, None)
    except Exception as e:
        print(f"加载Places365模型失败: {e}")
        print("使用标准ResNet50作为备用...")
        places365 = models.resnet50(pretrained=True).to(device)
        places365 = torch.nn.Sequential(*list(places365.children())[:-1]).eval()
        models_dict['places365'] = (places365, None)

    return models_dict, device

def extract_id(filename):
    """从文件名中提取数字ID"""
    basename = os.path.basename(filename)
    name_without_ext = os.path.splitext(basename)[0]
    match = re.search(r'(\d+)', name_without_ext)
    return int(match.group(1)) if match else float('inf')

def get_sorted_images(directory):
    """获取指定目录下按ID排序的图像文件列表"""
    image_files = []
    for ext in ['*.png', '*.jpg', '*.jpeg']:
        image_files.extend(glob.glob(os.path.join(directory, ext)))
    return sorted(image_files, key=extract_id)

def encode_images(image_files, model, processor, device, output_dir, model_name, batch_size=16):
    """批量编码图像并保存结果"""
    if not image_files:
        print(f"警告: {output_dir} 没有找到图像文件")
        return

    print(f"使用 {model_name} 处理 {len(image_files)} 张图像...")
    all_embeddings = {}

    for i in tqdm(range(0, len(image_files), batch_size)):
        batch_files = image_files[i:i+batch_size]
        batch_images = []
        batch_ids = []

        for img_file in batch_files:
            try:
                img = Image.open(img_file).convert('RGB')
                batch_images.append(img)
                img_id = extract_id(img_file)
                batch_ids.append(img_id)
            except Exception as e:
                print(f"处理图像 {img_file} 出错: {e}")

        if not batch_images:
            continue

        with torch.no_grad():
            if model_name == 'dinov2':
                inputs = torch.stack([processor(img) for img in batch_images]).to(device)
                outputs = model(inputs)
                embeddings = outputs.cpu().numpy()  # DINOv2直接输出[batch_size, embedding_dim]
            elif model_name == 'clip':
                inputs = processor(images=batch_images, return_tensors="pt", do_resize=True, do_normalize=True)
                inputs = {k: v.to(device) for k, v in inputs.items()}
                outputs = model.get_image_features(**inputs)
                embeddings = outputs.cpu().numpy()
            else:
                inputs = torch.stack([
                    transforms.ToTensor()(img.resize((224, 224))) for img in batch_images
                ])
                inputs = transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )(inputs).to(device)
                embeddings = model(inputs).squeeze(-1).squeeze(-1).cpu().numpy()

            normalized_embeddings = normalize(embeddings, axis=1)

            for j, img_id in enumerate(batch_ids):
                all_embeddings[img_id] = torch.tensor(normalized_embeddings[j], dtype=torch.float32)

    sorted_ids = sorted(all_embeddings.keys())
    sorted_embeddings = torch.stack([all_embeddings[img_id] for img_id in sorted_ids])

    embedding_file = os.path.join(output_dir, f"{model_name}_image_emb.pt")
    id_mapping_file = os.path.join(output_dir, f"{model_name}_image_id.pt")

    torch.save(sorted_embeddings, embedding_file)
    torch.save(sorted_ids, id_mapping_file)

    print(f"{model_name} 嵌入已保存到 {embedding_file}")
    print(f"{model_name} ID映射已保存到 {id_mapping_file}")
    print(f"{model_name} 嵌入形状: {sorted_embeddings.shape}")

def main():
    setup_directories()
    models_dict, device = load_models()

    parent_dir = os.path.dirname(os.getcwd())
    anchor_dir = os.path.join("images", "anchor_pool")
    anchor_images = get_sorted_images(anchor_dir)
    print(f"找到 {len(anchor_images)} 张锚点图像")

    predict_dir = os.path.join("images", "candidate_images")
    predict_images = get_sorted_images(predict_dir)
    print(f"找到 {len(predict_images)} 张预测图像")

    for model_name, (model, processor) in models_dict.items():
        encode_images(
            anchor_images,
            model,
            processor,
            device,
            "image_embeddings/anchor_embeddings",
            model_name
        )
        encode_images(
            predict_images,
            model,
            processor,
            device,
            "image_embeddings/candidate_embeddings",
            model_name
        )

    print("所有图像编码完成!")

if __name__ == "__main__":
    main()

加载DINOv2模型...


Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


加载CLIP模型...
加载ResNet101模型...
加载Places365模型...
加载Places365模型失败: Error(s) in loading state_dict for ResNet:
	size mismatch for fc.weight: copying a param with shape torch.Size([365, 2048]) from checkpoint, the shape in current model is torch.Size([1000, 2048]).
	size mismatch for fc.bias: copying a param with shape torch.Size([365]) from checkpoint, the shape in current model is torch.Size([1000]).
使用标准ResNet50作为备用...
找到 203 张锚点图像
找到 1132 张预测图像
使用 dinov2 处理 203 张图像...


100%|██████████| 13/13 [00:02<00:00,  4.45it/s]


dinov2 嵌入已保存到 image_embeddings/anchor_embeddings/dinov2_image_emb.pt
dinov2 ID映射已保存到 image_embeddings/anchor_embeddings/dinov2_image_id.pt
dinov2 嵌入形状: torch.Size([203, 768])
使用 dinov2 处理 1132 张图像...


100%|██████████| 71/71 [00:16<00:00,  4.33it/s]


dinov2 嵌入已保存到 image_embeddings/candidate_embeddings/dinov2_image_emb.pt
dinov2 ID映射已保存到 image_embeddings/candidate_embeddings/dinov2_image_id.pt
dinov2 嵌入形状: torch.Size([1132, 768])
使用 clip 处理 203 张图像...


100%|██████████| 13/13 [00:02<00:00,  5.27it/s]


clip 嵌入已保存到 image_embeddings/anchor_embeddings/clip_image_emb.pt
clip ID映射已保存到 image_embeddings/anchor_embeddings/clip_image_id.pt
clip 嵌入形状: torch.Size([203, 512])
使用 clip 处理 1132 张图像...


100%|██████████| 71/71 [00:12<00:00,  5.61it/s]


clip 嵌入已保存到 image_embeddings/candidate_embeddings/clip_image_emb.pt
clip ID映射已保存到 image_embeddings/candidate_embeddings/clip_image_id.pt
clip 嵌入形状: torch.Size([1132, 512])
使用 resnet101 处理 203 张图像...


100%|██████████| 13/13 [00:02<00:00,  5.91it/s]


resnet101 嵌入已保存到 image_embeddings/anchor_embeddings/resnet101_image_emb.pt
resnet101 ID映射已保存到 image_embeddings/anchor_embeddings/resnet101_image_id.pt
resnet101 嵌入形状: torch.Size([203, 2048])
使用 resnet101 处理 1132 张图像...


100%|██████████| 71/71 [00:11<00:00,  6.42it/s]


resnet101 嵌入已保存到 image_embeddings/candidate_embeddings/resnet101_image_emb.pt
resnet101 ID映射已保存到 image_embeddings/candidate_embeddings/resnet101_image_id.pt
resnet101 嵌入形状: torch.Size([1132, 2048])
使用 places365 处理 203 张图像...


100%|██████████| 13/13 [00:01<00:00,  6.83it/s]


places365 嵌入已保存到 image_embeddings/anchor_embeddings/places365_image_emb.pt
places365 ID映射已保存到 image_embeddings/anchor_embeddings/places365_image_id.pt
places365 嵌入形状: torch.Size([203, 2048])
使用 places365 处理 1132 张图像...


100%|██████████| 71/71 [00:10<00:00,  6.68it/s]

places365 嵌入已保存到 image_embeddings/candidate_embeddings/places365_image_emb.pt
places365 ID映射已保存到 image_embeddings/candidate_embeddings/places365_image_id.pt
places365 嵌入形状: torch.Size([1132, 2048])
所有图像编码完成!





In [13]:
# 移动 /content/images/image_embeddings 文件夹下的文件到 Google Drive 的指定路径
# 假设目标路径为 Google Drive 的 MyDrive/images/image_embeddings
!mkdir -p /content/drive/MyDrive/images/image_embeddings
!mv /content/image_embeddings/* /content/drive/MyDrive/images/image_embeddings/