<a href="https://colab.research.google.com/github/MoqiSheng/MoqiSheng.github.io/blob/main/GeometricMedian_add0203.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!unzip /content/drive/MyDrive/UrbanASIFpro/Anchor_types.zip -d /content/Anchor_image

Archive:  /content/drive/MyDrive/UrbanASIFpro/Anchor_types.zip
replace /content/Anchor_image/civic, governmental and cultural/1162.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [None]:
import os
import torch
from transformers import ViTFeatureExtractor, ViTModel
from PIL import Image
from tqdm import tqdm
import numpy as np

# 获取脚本所在的目录路径并切换当前工作目录
# abspath = os.path.abspath(__file__)  # 获取脚本文件的绝对路径
# dname = os.path.dirname(abspath)     # 提取脚本所在目录的路径
# os.chdir(dname)                      # 切换当前工作目录到脚本所在的目录

# 加载预训练的 google/vit-base-patch16-224-in21k 模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

# Weiszfeld 算法实现
def compute_weiszfeld_median(vectors):
    """
    使用 Weiszfeld 算法计算加权几何中位数
    """
    vectors = torch.stack(vectors, dim=0)
    # 初始化几何中位数为第一个向量
    median = vectors[0]

    # Weiszfeld算法迭代
    epsilon = 1e-6
    max_iterations = 1000
    for _ in range(max_iterations):
        distances = torch.norm(vectors - median, dim=-1)  # 计算到每个向量的距离
        weights = 1.0 / (distances + epsilon)  # 防止除以零，加入一个小的偏移量
        weighted_vectors = (weights.unsqueeze(-1) * vectors).sum(dim=0)  # 加权平均
        new_median = weighted_vectors / weights.sum()  # 更新几何中位数

        # 如果变化小于阈值，则停止迭代
        if torch.norm(new_median - median) < epsilon:
            break
        median = new_median

    return median

# 定义特征提取函数 extract_features
def extract_features(image_folder):
    image_features_list = []
    # 读取 image_folder 中所有扩展名为 .jpg 或 .png 的文件，并将文件名按自然数顺序排序（例如，1.jpg、2.jpg、3.jpg 顺序排列）
    image_paths = sorted(
        [os.path.join(image_folder, img) for img in os.listdir(image_folder) if img.lower().endswith(('.jpg', '.png'))],
        key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
    )

    for image_path in tqdm(image_paths, desc=f"Processing {image_folder}"):
        try:
            # 加载并预处理图像
            image = Image.open(image_path)
            inputs = feature_extractor(images=image, return_tensors="pt").to(device)  # 使用 ViTFeatureExtractor 进行预处理

            # 获取图像特征
            with torch.no_grad():  # 禁用了梯度计算（不需要反向传播）
                outputs = model(**inputs)  # 获取模型输出
                image_features = outputs.last_hidden_state[:, 0]  # 取 [CLS] token 的特征
                image_features /= image_features.norm(dim=-1, keepdim=True)  # 归一化处理
                image_features_list.append(image_features.cpu())  # 将特征向量移动到 CPU ，然后将其追加到列表中
        except Exception as e:
            print(f"Error processing {image_path}: {e}")

    return image_features_list

# 遍历所有文件夹，计算几何中位数并保存结果
def process_folders(base_folder):
    for folder_name in os.listdir(base_folder):
        folder_path = os.path.join(base_folder, folder_name)
        if os.path.isdir(folder_path):
            print(f"Processing folder: {folder_name}")

            # 提取该文件夹内所有图片的特征
            image_features_list = extract_features(folder_path)

            if image_features_list:
                # 计算该文件夹的几何中位数
                weiszfeld_median = compute_weiszfeld_median(image_features_list)

                # 获取文件夹名称的第一个单词
                first_word = folder_name.split()[0]

                # 保存几何中位数，文件名为文件夹名的第一个单词
                output_file = os.path.join(base_folder, f"{first_word}.pt")
                torch.save(weiszfeld_median, output_file)
                print(f"Geometric median saved to {output_file}")
            else:
                print(f"No valid images found in {folder_path}")

# 调用函数，处理文件夹中的图片
process_folders('./Anchor_image')

Processing folder: civic, governmental and cultural


Processing ./Anchor_image/civic, governmental and cultural: 100%|██████████| 10/10 [00:00<00:00, 61.87it/s]


Geometric median saved to ./Anchor_image/civic,.pt
Processing folder: transportation


Processing ./Anchor_image/transportation: 100%|██████████| 6/6 [00:00<00:00, 64.71it/s]


Geometric median saved to ./Anchor_image/transportation.pt
Processing folder: hotel


Processing ./Anchor_image/hotel: 100%|██████████| 6/6 [00:00<00:00, 63.85it/s]


Geometric median saved to ./Anchor_image/hotel.pt
Processing folder: sports and recreation


Processing ./Anchor_image/sports and recreation: 100%|██████████| 5/5 [00:00<00:00, 63.98it/s]


Geometric median saved to ./Anchor_image/sports.pt
Processing folder: residential


Processing ./Anchor_image/residential: 100%|██████████| 39/39 [00:00<00:00, 63.50it/s]


Geometric median saved to ./Anchor_image/residential.pt
Processing folder: health care


Processing ./Anchor_image/health care: 100%|██████████| 7/7 [00:00<00:00, 63.45it/s]


Geometric median saved to ./Anchor_image/health.pt
Processing folder: industrial


Processing ./Anchor_image/industrial: 100%|██████████| 27/27 [00:00<00:00, 64.74it/s]


Geometric median saved to ./Anchor_image/industrial.pt
Processing folder: .ipynb_checkpoints


Processing ./Anchor_image/.ipynb_checkpoints: 0it [00:00, ?it/s]


No valid images found in ./Anchor_image/.ipynb_checkpoints
Processing folder: commercial


Processing ./Anchor_image/commercial: 100%|██████████| 30/30 [00:00<00:00, 64.46it/s]


Geometric median saved to ./Anchor_image/commercial.pt
Processing folder: education


Processing ./Anchor_image/education: 100%|██████████| 17/17 [00:00<00:00, 63.73it/s]


Geometric median saved to ./Anchor_image/education.pt
Processing folder: outdoors and natural


Processing ./Anchor_image/outdoors and natural: 100%|██████████| 5/5 [00:00<00:00, 64.91it/s]

Geometric median saved to ./Anchor_image/outdoors.pt





In [None]:
!unzip /content/drive/MyDrive/UrbanASIFpro/Anchor.zip -d /content/Data

Archive:  /content/drive/MyDrive/UrbanASIFpro/Anchor.zip
  inflating: /content/Data/1007.png  
  inflating: /content/Data/1037.png  
  inflating: /content/Data/1039.png  
  inflating: /content/Data/1052.png  
  inflating: /content/Data/1094.png  
  inflating: /content/Data/1095.png  
  inflating: /content/Data/1097.png  
  inflating: /content/Data/1104.png  
  inflating: /content/Data/111.png   
  inflating: /content/Data/1112.png  
  inflating: /content/Data/1117.png  
  inflating: /content/Data/1140.png  
  inflating: /content/Data/1142.png  
  inflating: /content/Data/115.png   
  inflating: /content/Data/1162.png  
  inflating: /content/Data/1172.png  
  inflating: /content/Data/118.png   
  inflating: /content/Data/1193.png  
  inflating: /content/Data/1194.png  
  inflating: /content/Data/1201.png  
  inflating: /content/Data/1202.png  
  inflating: /content/Data/1206.png  
  inflating: /content/Data/1216.png  
  inflating: /content/Data/1224.png  
  inflating: /content/Data/1271

In [None]:
import os
import torch
import pandas as pd
from transformers import ViTFeatureExtractor, ViTModel
from PIL import Image
from tqdm import tqdm
import re  # 导入正则表达式模块

# 加载预训练的 google/vit-base-patch16-224-in21k 模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

# 加载 anchor.csv 并将其转换为 DataFrame
csv_file = './anchor.csv'
df = pd.read_csv(csv_file)
print(f"Loaded {len(df)} rows from {csv_file}")

# 定义特征提取函数 extract_features
def extract_features(image_folder, output_file):
    image_features_list = []

    # 读取 image_folder 中所有扩展名为 .jpg 或 .png 的文件，并将文件名按自然数顺序排序（例如，1.jpg、2.jpg、3.jpg 顺序排列）
    image_paths = sorted(
        [os.path.join(image_folder, img) for img in os.listdir(image_folder) if img.lower().endswith(('.jpg', '.png'))],
        key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
    )
    print(f"Found {len(image_paths)} images in {image_folder}")

    for image_path in tqdm(image_paths, desc=f"Processing {image_folder}"):
        try:
            # 获取图片ID
            image_id = os.path.splitext(os.path.basename(image_path))[0]

            # 从 CSV 文件中获取对应的 primary_function 列的第一个单词
            primary_function = df.loc[df['ID'] == int(image_id), 'primary_function'].values
            if primary_function.size == 0:
                print(f"Warning: No primary_function found for image ID {image_id}, skipping.")
                continue  # 如果没有找到对应的行，跳过该图片

            # 使用正则表达式去除标点符号，提取第一个单词
            # 正则表达式 '[^\w\s]' 用于去除所有非字母和空格字符
            first_word = re.sub(r'[^\w\s]', '', primary_function[0]).split()[0]  # 去掉标点符号并取第一个单词
            print(f"Image {image_id}: Primary function is {primary_function[0]}, using {first_word} as the embedding keyword.")

            # 加载并预处理图像
            image = Image.open(image_path)
            inputs = feature_extractor(images=image, return_tensors="pt").to(device)  # 使用 ViTFeatureExtractor 进行预处理

            # 获取图像特征
            with torch.no_grad():  # 禁用了梯度计算（不需要反向传播）
                outputs = model(**inputs)  # 获取模型输出
                image_features = outputs.last_hidden_state[:, 0]  # 取 [CLS] token 的特征
                image_features /= image_features.norm(dim=-1, keepdim=True)  # 归一化处理
                image_features *= 0.5  # 对图像特征向量进行 0.8 的缩放
                print(f"Image {image_id}: Extracted features and scaled by 0.8")

                # 加载与第一个单词对应的嵌入文件（.pt）
                embedding_file = f'./Anchor_image/{first_word}.pt'
                if os.path.exists(embedding_file):
                    word_embedding = torch.load(embedding_file).to(device)
                    image_features += 0.5 * word_embedding  # 加权叠加
                    print(f"Image {image_id}: Added embedding from {first_word}.pt with weight 0.2")
                else:
                    print(f"Warning: Embedding file {embedding_file} not found, skipping embedding addition.")

                image_features_list.append(image_features.cpu())  # 将特征向量移动到 CPU ，然后将其追加到列表中

        except Exception as e:
            print(f"Error processing {image_path}: {e}")

    # 将所有图像特征保存为矩阵
    if image_features_list:
        image_features = torch.cat(image_features_list, dim=0)  # 将存储在 image_features_list 中的所有图像特征向量拼接成一个大的张量
        torch.save(image_features, output_file)
        print(f"Features saved to {output_file}")
    else:
        print(f"No valid images found in {image_folder}")

# 调用特征提取函数，提取并保存每个城市的特征
extract_features('./Data', './imgs_anchor')