<a href="https://colab.research.google.com/github/MoqiSheng/MoqiSheng.github.io/blob/main/0415_%E4%B8%8D%E5%90%8C%E6%96%87%E6%9C%AC%E7%BC%96%E7%A0%81%E5%99%A8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os
import torch
import pandas as pd
import json
import numpy as np
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel, BertTokenizer, BertModel
from sklearn.preprocessing import normalize
import re

def setup_directories():
    """创建必要的文件夹"""
    os.makedirs("vit_sentencetransformer/anchor_embeddings", exist_ok=True)
    os.makedirs("vit_sentencetransformer/candidate_embeddings", exist_ok=True)

def extract_id(id_str):
    """从ID字符串中提取数字ID"""
    match = re.search(r'(\d+)', str(id_str))
    if match:
        return int(match.group(1))
    return float('inf')  # 如果没有找到数字，返回无穷大使其排序在最后

def load_models():
    """加载所有编码模型"""
    print("加载编码模型...")

    # CLIP-ViT-L-14
    clip_vit_l_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
    clip_vit_l_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

    # CLIP-ViT-B-32
    clip_vit_b_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    clip_vit_b_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    # BERT
    bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    bert_model = BertModel.from_pretrained("bert-base-uncased")

    return {
        "clip_vit_l": (clip_vit_l_model, clip_vit_l_processor),
        "clip_vit_b": (clip_vit_b_model, clip_vit_b_processor),
        "bert": (bert_model, bert_tokenizer)
    }

def encode_text_with_model(model_info, texts, model_name):
    """使用指定模型对文本进行编码"""
    if model_name in ["clip_vit_l", "clip_vit_b"]:
        model, processor = model_info
        inputs = processor(text=texts, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            embeddings = model.get_text_features(**inputs)
        return embeddings.cpu().numpy()

    elif model_name == "bert":
        model, tokenizer = model_info
        inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
        with torch.no_grad():
            outputs = model(**inputs)
        # 使用[CLS] token的表示
        embeddings = outputs.last_hidden_state[:, 0, :]
        return embeddings.cpu().numpy()

def encode_anchor_texts(models, csv_path, output_dir):
    """编码锚点文本并保存结果"""
    print(f"读取锚点文本数据: {csv_path}")

    try:
        df = pd.read_csv(csv_path)

        if 'description' not in df.columns:
            print("错误: CSV文件中没有'description'列")
            return

        if 'ID' not in df.columns:
            print("错误: CSV文件中没有'ID'列")
            return

        df['numeric_id'] = df['ID'].apply(extract_id)
        df = df.sort_values('numeric_id').reset_index(drop=True)

        print(f"处理 {len(df)} 条锚点文本...")

        descriptions = df['description'].tolist()

        for model_name, model_info in models.items():
            print(f"正在使用 {model_name} 编码锚点文本...")
            text_embeddings = []

            for description in tqdm(descriptions):
                description = "" if pd.isna(description) or not description else description
                embedding = encode_text_with_model(model_info, [description], model_name)[0]
                normalized_embedding = normalize([embedding])[0]
                text_embeddings.append(torch.tensor(normalized_embedding, dtype=torch.float32))

            text_embeddings_tensor = torch.stack(text_embeddings)

            embedding_file = os.path.join(output_dir, f"anchor_text_emb_{model_name}.pt")
            id_file = os.path.join(output_dir, f"anchor_text_id_{model_name}.pt")

            torch.save(text_embeddings_tensor, embedding_file)
            torch.save(df['ID'].tolist(), id_file)

            print(f"{model_name} 锚点文本嵌入已保存到 {embedding_file}")
            print(f"{model_name} 锚点文本ID已保存到 {id_file}")
            print(f"嵌入形状: {text_embeddings_tensor.shape}")

    except Exception as e:
        print(f"处理锚点文本时出错: {e}")

def encode_predict_texts(models, json_path, output_dir):
    """编码待预测文本并保存结果"""
    print(f"读取城市分类数据: {json_path}")

    urbanclip_templates = [
        "{} area featuring {}.",
        "{} area featuring {} with cars.",
        "{} area featuring {} with parking lot.",
        "{} area featuring {} on the road.",
        "{} area featuring {} with many trees.",
        "{} area featuring {} in city."
    ]

    try:
        with open(json_path, 'r') as f:
            urban_taxonomy = json.load(f)

        for model_name, model_info in models.items():
            print(f"正在使用 {model_name} 为所有城市对象类型生成嵌入...")

            embeddings = []
            categories = []
            uots = []

            for category, category_uots in urban_taxonomy.items():
                for uot in tqdm(category_uots, desc=f"Processing {category}"):
                    sentences = [template.format(category, uot) for template in urbanclip_templates]
                    sentence_embeddings = encode_text_with_model(model_info, sentences, model_name)
                    normalized_embeddings = normalize(sentence_embeddings, axis=1)
                    avg_embedding = np.mean(normalized_embeddings, axis=0)
                    final_embedding = normalize([avg_embedding])[0]
                    final_embedding_tensor = torch.tensor(final_embedding, dtype=torch.float32)

                    embeddings.append(final_embedding_tensor)
                    categories.append(category)
                    uots.append(uot)

            embeddings_tensor = torch.stack(embeddings)

            embedding_file = os.path.join(output_dir, f"candidate_text_emb_{model_name}.pt")
            metadata_file = os.path.join(output_dir, f"candidate_text_metadata_{model_name}.pt")

            torch.save(embeddings_tensor, embedding_file)
            torch.save({'categories': categories, 'uots': uots}, metadata_file)

            print(f"{model_name} 待预测文本嵌入已保存到 {embedding_file}")
            print(f"{model_name} 元数据已保存到 {metadata_file}")
            print(f"嵌入形状: {embeddings_tensor.shape}")
            print(f"总计处理了 {len(embeddings)} 个城市对象类型")

    except Exception as e:
        print(f"处理待预测文本时出错: {e}")

def main():
    setup_directories()

    parent_dir = os.path.dirname(os.getcwd())
    images_dir = os.path.join("images")

    models = load_models()

    anchor_csv_path = os.path.join("anchor_descriptions.csv")
    urban_taxonomy_path = os.path.join("urban_taxonomy.json")

    encode_anchor_texts(models, anchor_csv_path, "vit_sentencetransformer/anchor_embeddings")
    encode_predict_texts(models, urban_taxonomy_path, "vit_sentencetransformer/candidate_embeddings")

    print("所有文本编码完成!")

if __name__ == "__main__":
    main()

加载编码模型...
读取锚点文本数据: anchor_descriptions.csv
处理 203 条锚点文本...
正在使用 clip_vit_l 编码锚点文本...


100%|██████████| 203/203 [00:09<00:00, 20.45it/s]


clip_vit_l 锚点文本嵌入已保存到 vit_sentencetransformer/anchor_embeddings/anchor_text_emb_clip_vit_l.pt
clip_vit_l 锚点文本ID已保存到 vit_sentencetransformer/anchor_embeddings/anchor_text_id_clip_vit_l.pt
嵌入形状: torch.Size([203, 768])
正在使用 clip_vit_b 编码锚点文本...


100%|██████████| 203/203 [00:05<00:00, 35.68it/s]


clip_vit_b 锚点文本嵌入已保存到 vit_sentencetransformer/anchor_embeddings/anchor_text_emb_clip_vit_b.pt
clip_vit_b 锚点文本ID已保存到 vit_sentencetransformer/anchor_embeddings/anchor_text_id_clip_vit_b.pt
嵌入形状: torch.Size([203, 512])
正在使用 bert 编码锚点文本...


100%|██████████| 203/203 [00:10<00:00, 19.61it/s]


bert 锚点文本嵌入已保存到 vit_sentencetransformer/anchor_embeddings/anchor_text_emb_bert.pt
bert 锚点文本ID已保存到 vit_sentencetransformer/anchor_embeddings/anchor_text_id_bert.pt
嵌入形状: torch.Size([203, 768])
读取城市分类数据: urban_taxonomy.json
正在使用 clip_vit_l 为所有城市对象类型生成嵌入...


Processing residential: 100%|██████████| 49/49 [00:03<00:00, 14.77it/s]
Processing commercial: 100%|██████████| 78/78 [00:05<00:00, 15.22it/s]
Processing hotel: 100%|██████████| 17/17 [00:01<00:00, 16.04it/s]
Processing industrial: 100%|██████████| 22/22 [00:01<00:00, 15.72it/s]
Processing education: 100%|██████████| 32/32 [00:02<00:00, 15.83it/s]
Processing health care: 100%|██████████| 22/22 [00:01<00:00, 13.95it/s]
Processing civic, governmental and cultural: 100%|██████████| 46/46 [00:03<00:00, 13.63it/s]
Processing sports and recreation: 100%|██████████| 24/24 [00:01<00:00, 15.80it/s]
Processing outdoors and natural: 100%|██████████| 38/38 [00:02<00:00, 15.39it/s]
Processing transportation: 100%|██████████| 26/26 [00:01<00:00, 15.82it/s]


clip_vit_l 待预测文本嵌入已保存到 vit_sentencetransformer/candidate_embeddings/candidate_text_emb_clip_vit_l.pt
clip_vit_l 元数据已保存到 vit_sentencetransformer/candidate_embeddings/candidate_text_metadata_clip_vit_l.pt
嵌入形状: torch.Size([354, 768])
总计处理了 354 个城市对象类型
正在使用 clip_vit_b 为所有城市对象类型生成嵌入...


Processing residential: 100%|██████████| 49/49 [00:01<00:00, 25.20it/s]
Processing commercial: 100%|██████████| 78/78 [00:03<00:00, 23.71it/s]
Processing hotel: 100%|██████████| 17/17 [00:00<00:00, 25.79it/s]
Processing industrial: 100%|██████████| 22/22 [00:00<00:00, 27.92it/s]
Processing education: 100%|██████████| 32/32 [00:01<00:00, 26.77it/s]
Processing health care: 100%|██████████| 22/22 [00:00<00:00, 26.11it/s]
Processing civic, governmental and cultural: 100%|██████████| 46/46 [00:01<00:00, 25.56it/s]
Processing sports and recreation: 100%|██████████| 24/24 [00:00<00:00, 24.55it/s]
Processing outdoors and natural: 100%|██████████| 38/38 [00:01<00:00, 24.43it/s]
Processing transportation: 100%|██████████| 26/26 [00:01<00:00, 22.97it/s]


clip_vit_b 待预测文本嵌入已保存到 vit_sentencetransformer/candidate_embeddings/candidate_text_emb_clip_vit_b.pt
clip_vit_b 元数据已保存到 vit_sentencetransformer/candidate_embeddings/candidate_text_metadata_clip_vit_b.pt
嵌入形状: torch.Size([354, 512])
总计处理了 354 个城市对象类型
正在使用 bert 为所有城市对象类型生成嵌入...


Processing residential: 100%|██████████| 49/49 [00:03<00:00, 15.29it/s]
Processing commercial: 100%|██████████| 78/78 [00:04<00:00, 17.15it/s]
Processing hotel: 100%|██████████| 17/17 [00:00<00:00, 17.49it/s]
Processing industrial: 100%|██████████| 22/22 [00:01<00:00, 17.70it/s]
Processing education: 100%|██████████| 32/32 [00:01<00:00, 16.04it/s]
Processing health care: 100%|██████████| 22/22 [00:01<00:00, 14.39it/s]
Processing civic, governmental and cultural: 100%|██████████| 46/46 [00:02<00:00, 15.83it/s]
Processing sports and recreation: 100%|██████████| 24/24 [00:01<00:00, 16.65it/s]
Processing outdoors and natural: 100%|██████████| 38/38 [00:02<00:00, 17.18it/s]
Processing transportation: 100%|██████████| 26/26 [00:01<00:00, 17.44it/s]

bert 待预测文本嵌入已保存到 vit_sentencetransformer/candidate_embeddings/candidate_text_emb_bert.pt
bert 元数据已保存到 vit_sentencetransformer/candidate_embeddings/candidate_text_metadata_bert.pt
嵌入形状: torch.Size([354, 768])
总计处理了 354 个城市对象类型
所有文本编码完成!





In [4]:

import os
import torch
import pandas as pd
import json
import numpy as np
from tqdm import tqdm
import tensorflow_hub as hub
from sklearn.preprocessing import normalize
import re

def setup_directories():
    """创建必要的文件夹"""
    os.makedirs("vit_sentencetransformer/anchor_embeddings", exist_ok=True)
    os.makedirs("vit_sentencetransformer/candidate_embeddings", exist_ok=True)

def extract_id(id_str):
    """从ID字符串中提取数字ID"""
    match = re.search(r'(\d+)', str(id_str))
    if match:
        return int(match.group(1))
    return float('inf')  # 如果没有找到数字，返回无穷大使其排序在最后

def load_model():
    """加载USE模型"""
    print("加载USE模型...")
    use_model = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")
    return use_model

def encode_text_with_use(model, texts):
    """使用USE模型对文本进行编码"""
    embeddings = model(texts)
    return embeddings.numpy()

def encode_anchor_texts(model, csv_path, output_dir):
    """编码锚点文本并保存结果"""
    print(f"读取锚点文本数据: {csv_path}")

    try:
        df = pd.read_csv(csv_path)

        if 'description' not in df.columns:
            print("错误: CSV文件中没有'description'列")
            return

        if 'ID' not in df.columns:
            print("错误: CSV文件中没有'ID'列")
            return

        df['numeric_id'] = df['ID'].apply(extract_id)
        df = df.sort_values('numeric_id').reset_index(drop=True)

        print(f"处理 {len(df)} 条锚点文本...")

        descriptions = df['description'].tolist()

        print("正在使用USE编码锚点文本...")
        text_embeddings = []

        for description in tqdm(descriptions):
            description = "" if pd.isna(description) or not description else description
            embedding = encode_text_with_use(model, [description])[0]
            normalized_embedding = normalize([embedding])[0]
            # 转换为PyTorch tensor
            text_embeddings.append(torch.tensor(normalized_embedding, dtype=torch.float32))

        text_embeddings_tensor = torch.stack(text_embeddings)

        embedding_file = os.path.join(output_dir, "anchor_text_emb_use.pt")
        id_file = os.path.join(output_dir, "anchor_text_id_use.pt")

        torch.save(text_embeddings_tensor, embedding_file)
        torch.save(df['ID'].tolist(), id_file)

        print(f"USE锚点文本嵌入已保存到 {embedding_file}")
        print(f"USE锚点文本ID已保存到 {id_file}")
        print(f"嵌入形状: {text_embeddings_tensor.shape}")

    except Exception as e:
        print(f"处理锚点文本时出错: {e}")

def encode_predict_texts(model, json_path, output_dir):
    """编码待预测文本并保存结果"""
    print(f"读取城市分类数据: {json_path}")

    urbanclip_templates = [
        "{} area featuring {}.",
        "{} area featuring {} with cars.",
        "{} area featuring {} with parking lot.",
        "{} area featuring {} on the road.",
        "{} area featuring {} with many trees.",
        "{} area featuring {} in city."
    ]

    try:
        with open(json_path, 'r') as f:
            urban_taxonomy = json.load(f)

        print("正在使用USE为所有城市对象类型生成嵌入...")

        embeddings = []
        categories = []
        uots = []

        for category, category_uots in urban_taxonomy.items():
            for uot in tqdm(category_uots, desc=f"Processing {category}"):
                sentences = [template.format(category, uot) for template in urbanclip_templates]
                sentence_embeddings = encode_text_with_use(model, sentences)
                normalized_embeddings = normalize(sentence_embeddings, axis=1)
                avg_embedding = np.mean(normalized_embeddings, axis=0)
                final_embedding = normalize([avg_embedding])[0]
                # 转换为PyTorch tensor
                final_embedding_tensor = torch.tensor(final_embedding, dtype=torch.float32)

                embeddings.append(final_embedding_tensor)
                categories.append(category)
                uots.append(uot)

        embeddings_tensor = torch.stack(embeddings)

        embedding_file = os.path.join(output_dir, "candidate_text_emb_use.pt")
        metadata_file = os.path.join(output_dir, "candidate_text_metadata_use.pt")

        torch.save(embeddings_tensor, embedding_file)
        torch.save({'categories': categories, 'uots': uots}, metadata_file)

        print(f"USE待预测文本嵌入已保存到 {embedding_file}")
        print(f"USE元数据已保存到 {metadata_file}")
        print(f"嵌入形状: {embeddings_tensor.shape}")
        print(f"总计处理了 {len(embeddings)} 个城市对象类型")

    except Exception as e:
        print(f"处理待预测文本时出错: {e}")

def convert_npy_to_pt(npy_file, pt_file):
    """将.npy文件转换为.pt文件"""
    try:
        data = np.load(npy_file)
        tensor_data = torch.tensor(data, dtype=torch.float32)
        torch.save(tensor_data, pt_file)
        print(f"已将 {npy_file} 转换为 {pt_file}")
    except Exception as e:
        print(f"转换 {npy_file} 时出错: {e}")

def main():
    setup_directories()

    parent_dir = os.path.dirname(os.getcwd())
    images_dir = os.path.join(parent_dir, "images")

    model = load_model()

    anchor_csv_path = os.path.join("anchor_descriptions.csv")
    urban_taxonomy_path = os.path.join("urban_taxonomy.json")

    encode_anchor_texts(model, anchor_csv_path, "vit_sentencetransformer/anchor_embeddings")
    encode_predict_texts(model, urban_taxonomy_path, "vit_sentencetransformer/candidate_embeddings")

    # 可选：如果你有现有的.npy文件需要转换
    # 示例：将现有的.npy文件转换为.pt
    # npy_file = "path/to/your/file.npy"
    # pt_file = "path/to/your/file.pt"
    # convert_npy_to_pt(npy_file, pt_file)

    print("所有文本编码完成!")

if __name__ == "__main__":
    main()

加载USE模型...
读取锚点文本数据: anchor_descriptions.csv
处理 203 条锚点文本...
正在使用USE编码锚点文本...


100%|██████████| 203/203 [00:01<00:00, 144.50it/s]


USE锚点文本嵌入已保存到 vit_sentencetransformer/anchor_embeddings/anchor_text_emb_use.pt
USE锚点文本ID已保存到 vit_sentencetransformer/anchor_embeddings/anchor_text_id_use.pt
嵌入形状: torch.Size([203, 512])
读取城市分类数据: urban_taxonomy.json
正在使用USE为所有城市对象类型生成嵌入...


Processing residential: 100%|██████████| 49/49 [00:00<00:00, 185.49it/s]
Processing commercial: 100%|██████████| 78/78 [00:00<00:00, 186.50it/s]
Processing hotel: 100%|██████████| 17/17 [00:00<00:00, 191.86it/s]
Processing industrial: 100%|██████████| 22/22 [00:00<00:00, 187.25it/s]
Processing education: 100%|██████████| 32/32 [00:00<00:00, 194.35it/s]
Processing health care: 100%|██████████| 22/22 [00:00<00:00, 190.18it/s]
Processing civic, governmental and cultural: 100%|██████████| 46/46 [00:00<00:00, 190.12it/s]
Processing sports and recreation: 100%|██████████| 24/24 [00:00<00:00, 189.56it/s]
Processing outdoors and natural: 100%|██████████| 38/38 [00:00<00:00, 192.68it/s]
Processing transportation: 100%|██████████| 26/26 [00:00<00:00, 194.28it/s]

USE待预测文本嵌入已保存到 vit_sentencetransformer/candidate_embeddings/candidate_text_emb_use.pt
USE元数据已保存到 vit_sentencetransformer/candidate_embeddings/candidate_text_metadata_use.pt
嵌入形状: torch.Size([354, 512])
总计处理了 354 个城市对象类型
所有文本编码完成!



