<a href="https://colab.research.google.com/github/MoqiSheng/MoqiSheng.github.io/blob/main/kmeans0206.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!unzip /content/drive/MyDrive/UrbanASIFpr/data.zip -d /content/data

Archive:  /content/drive/MyDrive/UrbanASIFpr/data.zip
   creating: /content/data/industrial/
  inflating: /content/data/industrial/1000.png  
  inflating: /content/data/industrial/1001.png  
  inflating: /content/data/industrial/1004.png  
  inflating: /content/data/industrial/1008.png  
  inflating: /content/data/industrial/1009.png  
  inflating: /content/data/industrial/1013.png  
  inflating: /content/data/industrial/1017.png  
  inflating: /content/data/industrial/1020.png  
  inflating: /content/data/industrial/1022.png  
  inflating: /content/data/industrial/1026.png  
  inflating: /content/data/industrial/1033.png  
  inflating: /content/data/industrial/1051.png  
  inflating: /content/data/industrial/1054.png  
  inflating: /content/data/industrial/1055.png  
  inflating: /content/data/industrial/1059.png  
  inflating: /content/data/industrial/1063.png  
  inflating: /content/data/industrial/1064.png  
  inflating: /content/data/industrial/1067.png  
  inflating: /content/dat

In [16]:
import os
import torch
from transformers import ViTFeatureExtractor, ViTModel
from PIL import Image
from tqdm import tqdm
from sklearn.cluster import KMeans
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity

# 获取脚本所在的目录路径并切换当前工作目录
# abspath = os.path.abspath(__file__)  # 获取脚本文件的绝对路径
# dname = os.path.dirname(abspath)     # 提取脚本所在目录的路径
# os.chdir(dname)                      # 切换当前工作目录到脚本所在的目录

# 加载预训练的 google/vit-base-patch16-224-in21k 模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

# 定义特征提取函数 extract_features
def extract_features(image_folder, output_file, num_clusters=5):
    image_features_list = []
    image_paths = sorted(
        [os.path.join(image_folder, img) for img in os.listdir(image_folder) if img.lower().endswith(('.jpg', '.png'))],
        key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
    )

    image_ids = []  # 用于存储每张图像的编号
    for image_path in tqdm(image_paths, desc=f"Processing {image_folder}"):
        try:
            image = Image.open(image_path)
            inputs = feature_extractor(images=image, return_tensors="pt").to(device)

            with torch.no_grad():
                outputs = model(**inputs)
                image_features = outputs.last_hidden_state[:, 0]  # 取 [CLS] token 的特征
                image_features /= image_features.norm(dim=-1, keepdim=True)  # 归一化处理
                image_features_list.append(image_features.cpu())
                image_ids.append(os.path.basename(image_path))  # 保存图像编号
        except Exception as e:
            print(f"Error processing {image_path}: {e}")

    if image_features_list:
        image_features = torch.cat(image_features_list, dim=0).numpy()  # 转换为 NumPy 数组
        torch.save(image_features, output_file)
        print(f"Features saved to {output_file}")

        # 聚类操作
        kmeans = KMeans(n_clusters=num_clusters, random_state=42)
        kmeans.fit(image_features)

        # 获取每个聚类的中心
        cluster_centers = kmeans.cluster_centers_

        # 计算聚类中心和每个图像之间的余弦相似度
        similarity_matrix = cosine_similarity(image_features, cluster_centers)

        # 准备 CSV 输出
        output_data = []

        # 找到每个聚类的中心和最接近的图像
        for cluster_id in range(num_clusters):
            # 获取当前聚类中所有图像的ID
            cluster_indices = np.where(kmeans.labels_ == cluster_id)[0]

            # 获取当前聚类中最接近中心的图像ID
            cluster_similarity = similarity_matrix[cluster_indices, cluster_id]
            closest_image_idx = cluster_indices[np.argmax(cluster_similarity)]  # 找到最接近的图像
            closest_image_id = image_ids[closest_image_idx]

            # 计算每个图像与最接近的图像之间的余弦相似度
            for index in cluster_indices:
                image_id = image_ids[index]
                # 计算当前图像与最接近的图像之间的余弦相似度
                closest_image_features = image_features[closest_image_idx]
                image_features_vector = image_features[index]
                similarity = cosine_similarity([image_features_vector], [closest_image_features])[0][0]

                # 保存聚类结果
                output_data.append([cluster_id, closest_image_id, image_id, similarity])

        # 将结果保存为 CSV 文件
        df = pd.DataFrame(output_data, columns=["Cluster_ID", "Center_Image_ID", "Image_ID", "Cosine_Similarity"])
        df.to_csv('transportation.csv', index=False)
        print("Cluster results saved to cluster_results_with_centers.csv")
    else:
        print(f"No valid images found in {image_folder}")

# 调用特征提取函数，提取并保存每个城市的特征，并进行聚类
extract_features('./data/transportation', 'anchor_image.pt', num_clusters=6)

Processing ./data/transportation: 100%|██████████| 61/61 [00:00<00:00, 65.44it/s]


Features saved to anchor_image.pt
Cluster results saved to cluster_results_with_centers.csv


In [None]:
!python /content/emb.py

2025-01-30 04:56:10.129882: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738212970.152757    2971 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738212970.160004    2971 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/urllib3/connectionpool.py", line 534, in _make_request
    response = conn.getresponse()
               ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/urllib3/connection.py", line 516, in getresponse
    httplib_response = super().getresponse()
                       ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/http/client.py",