<a href="https://colab.research.google.com/github/MoqiSheng/MoqiSheng.github.io/blob/main/encode_texts_0313.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from sentence_transformers import SentenceTransformer
import os

def download_model_to_local():
    # 设置本地保存路径
    model_save_path = './sentence-transformer-model'

    # 确保目录存在
    os.makedirs(model_save_path, exist_ok=True)

    print("开始下载SentenceTransformer模型...")
    # 加载并下载模型
    model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

    # 保存到本地
    model.save(model_save_path)
    print(f"模型已成功下载并保存到: {model_save_path}")

if __name__ == "__main__":
    download_model_to_local()

开始下载SentenceTransformer模型...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.4k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

模型已成功下载并保存到: ./sentence-transformer-model


In [2]:
!zip -r sentence-transformer-model.zip sentence-transformer-model

  adding: sentence-transformer-model/ (stored 0%)
  adding: sentence-transformer-model/1_Pooling/ (stored 0%)
  adding: sentence-transformer-model/1_Pooling/config.json (deflated 57%)
  adding: sentence-transformer-model/modules.json (deflated 62%)
  adding: sentence-transformer-model/2_Normalize/ (stored 0%)
  adding: sentence-transformer-model/model.safetensors (deflated 8%)
  adding: sentence-transformer-model/config.json (deflated 47%)
  adding: sentence-transformer-model/tokenizer_config.json (deflated 75%)
  adding: sentence-transformer-model/README.md (deflated 64%)
  adding: sentence-transformer-model/config_sentence_transformers.json (deflated 34%)
  adding: sentence-transformer-model/sentence_bert_config.json (deflated 4%)
  adding: sentence-transformer-model/vocab.txt (deflated 53%)
  adding: sentence-transformer-model/tokenizer.json (deflated 71%)
  adding: sentence-transformer-model/special_tokens_map.json (deflated 85%)


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!cp -r sentence-transformer-model.zip /content/drive/MyDrive/

In [5]:
ls sentence-transformer-model

[0m[01;34m1_Pooling[0m/    config_sentence_transformers.json  README.md                  tokenizer_config.json
[01;34m2_Normalize[0m/  model.safetensors                  sentence_bert_config.json  tokenizer.json
config.json   modules.json                       special_tokens_map.json    vocab.txt


In [6]:
import os
import torch
import pandas as pd
import json
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import normalize
import re

def setup_directories():
    """创建必要的文件夹"""
    os.makedirs("emb/anchor_embeddings", exist_ok=True)
    os.makedirs("emb/predict_embeddings", exist_ok=True)

def extract_id(id_str):
    """从ID字符串中提取数字ID"""
    match = re.search(r'(\d+)', str(id_str))
    if match:
        return int(match.group(1))
    return float('inf')  # 如果没有找到数字，返回无穷大使其排序在最后

def encode_anchor_texts(model, csv_path, output_dir):
    """编码锚点文本并保存结果"""
    print(f"读取锚点文本数据: {csv_path}")

    # 读取CSV文件
    try:
        df = pd.read_csv(csv_path)

        if 'description' not in df.columns:
            print("错误: CSV文件中没有'description'列")
            return

        if 'ID' not in df.columns:
            print("错误: CSV文件中没有'ID'列")
            return

        # 按ID排序
        df['numeric_id'] = df['ID'].apply(extract_id)
        df = df.sort_values('numeric_id').reset_index(drop=True)

        print(f"处理 {len(df)} 条锚点文本...")

        # 提取描述文本
        descriptions = df['description'].tolist()

        # 编码文本
        print("正在编码锚点文本...")
        text_embeddings = []

        for description in tqdm(descriptions):
            # 处理空值或NaN
            if pd.isna(description) or not description:
                # 使用空字符串替代，保持索引对齐
                embedding = model.encode("")
            else:
                embedding = model.encode(description)

            # 归一化嵌入
            normalized_embedding = normalize([embedding])[0]

            # 转换为PyTorch tensor并添加到列表
            text_embeddings.append(torch.tensor(normalized_embedding, dtype=torch.float32))

        # 堆叠所有嵌入成单个tensor
        text_embeddings_tensor = torch.stack(text_embeddings)

        # 保存嵌入和ID
        embedding_file = os.path.join(output_dir, "anchor_text_emb.pt")
        id_file = os.path.join(output_dir, "anchor_text_id.pt")

        torch.save(text_embeddings_tensor, embedding_file)
        torch.save(df['ID'].tolist(), id_file)

        print(f"锚点文本嵌入已保存到 {embedding_file}")
        print(f"锚点文本ID已保存到 {id_file}")
        print(f"嵌入形状: {text_embeddings_tensor.shape}")

    except Exception as e:
        print(f"处理锚点文本时出错: {e}")

def encode_predict_texts(model, json_path, output_dir):
    """编码待预测文本并保存结果"""
    print(f"读取城市分类数据: {json_path}")

    # 定义模板
    urbanclip_templates = [
        "{} area featuring {}.",
        "{} area featuring {} with cars.",
        "{} area featuring {} with parking lot.",
        "{} area featuring {} on the road.",
        "{} area featuring {} with many trees.",
        "{} area featuring {} in city."
    ]

    try:
        # 读取JSON文件
        with open(json_path, 'r') as f:
            urban_taxonomy = json.load(f)

        print("正在为所有城市对象类型生成嵌入...")

        # 存储数据
        embeddings = []
        categories = []
        uots = []

        # 遍历所有类别和UOT
        for category, category_uots in urban_taxonomy.items():
            for uot in tqdm(category_uots, desc=f"Processing {category}"):
                # 为每个UOT生成多个句子
                sentences = [template.format(category, uot) for template in urbanclip_templates]

                # 编码句子
                sentence_embeddings = model.encode(sentences)

                # 归一化每个句子的嵌入
                normalized_embeddings = normalize(sentence_embeddings, axis=1)

                # 计算归一化嵌入的平均值
                avg_embedding = np.mean(normalized_embeddings, axis=0)

                # 归一化最终的平均嵌入
                final_embedding = normalize([avg_embedding])[0]

                # 转换为PyTorch tensor
                final_embedding_tensor = torch.tensor(final_embedding, dtype=torch.float32)

                # 保存嵌入和元数据
                embeddings.append(final_embedding_tensor)
                categories.append(category)
                uots.append(uot)

        # 将所有嵌入堆叠成一个tensor
        embeddings_tensor = torch.stack(embeddings)

        # 保存嵌入和元数据
        embedding_file = os.path.join(output_dir, "predict_text_emb.pt")
        metadata_file = os.path.join(output_dir, "predict_text_metadata.pt")

        torch.save(embeddings_tensor, embedding_file)
        torch.save({'categories': categories, 'uots': uots}, metadata_file)

        print(f"待预测文本嵌入已保存到 {embedding_file}")
        print(f"元数据已保存到 {metadata_file}")
        print(f"嵌入形状: {embeddings_tensor.shape}")
        print(f"总计处理了 {len(embeddings)} 个城市对象类型")

    except Exception as e:
        print(f"处理待预测文本时出错: {e}")

def main():
    # 设置文件夹
    setup_directories()

    # 获取当前目录的上一级目录
    parent_dir = os.path.dirname(os.getcwd())
    # images_dir = os.path.join(parent_dir, "images")

    # 加载Sentence Transformer模型
    print("加载Sentence Transformer模型...")
    model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

    # 锚点文本路径
    anchor_csv_path = os.path.join("emb", "anchor_images.csv")

    # 待预测文本路径
    urban_taxonomy_path = os.path.join("emb", "urban_taxonomy.json")

    # 编码锚点文本
    encode_anchor_texts(model, anchor_csv_path, "emb/anchor_embeddings")

    # 编码待预测文本
    encode_predict_texts(model, urban_taxonomy_path, "emb/predict_embeddings")

    print("所有文本编码完成!")

if __name__ == "__main__":
    main()

加载Sentence Transformer模型...
读取锚点文本数据: emb/anchor_images.csv
处理 165 条锚点文本...
正在编码锚点文本...


100%|██████████| 165/165 [00:27<00:00,  6.09it/s]


锚点文本嵌入已保存到 emb/anchor_embeddings/anchor_text_emb.pt
锚点文本ID已保存到 emb/anchor_embeddings/anchor_text_id.pt
嵌入形状: torch.Size([165, 768])
读取城市分类数据: emb/urban_taxonomy.json
正在为所有城市对象类型生成嵌入...


Processing residential: 100%|██████████| 49/49 [00:15<00:00,  3.10it/s]
Processing commercial: 100%|██████████| 78/78 [00:25<00:00,  3.04it/s]
Processing hotel: 100%|██████████| 17/17 [00:05<00:00,  2.96it/s]
Processing industrial: 100%|██████████| 22/22 [00:06<00:00,  3.33it/s]
Processing education: 100%|██████████| 32/32 [00:10<00:00,  3.06it/s]
Processing health care: 100%|██████████| 22/22 [00:06<00:00,  3.29it/s]
Processing civic, governmental and cultural: 100%|██████████| 46/46 [00:18<00:00,  2.49it/s]
Processing sports and recreation: 100%|██████████| 24/24 [00:08<00:00,  2.95it/s]
Processing outdoors and natural: 100%|██████████| 38/38 [00:13<00:00,  2.88it/s]
Processing transportation: 100%|██████████| 26/26 [00:09<00:00,  2.77it/s]

待预测文本嵌入已保存到 emb/predict_embeddings/predict_text_emb.pt
元数据已保存到 emb/predict_embeddings/predict_text_metadata.pt
嵌入形状: torch.Size([354, 768])
总计处理了 354 个城市对象类型
所有文本编码完成!



