In [1]:
# coding=utf-8
import torch
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer
from typing import List
from threading import Lock


class APISemanticEmbeddingManagerV2:
    def __init__(self, huggingface_model_path):
        self.lock = Lock()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = BertTokenizer.from_pretrained(huggingface_model_path)
        self.model = BertModel.from_pretrained(huggingface_model_path)
        self.model.to(self.device)
        self.model.eval()

    def _preprocess(self, title_str: str, content_str: str):
        with self.lock:
            input_encoding = self.tokenizer(title_str, content_str,
                                            padding="max_length", truncation=True,
                                            max_length=512, return_tensors="pt")
        input_encoding = {k: v.to(self.device) for k, v in input_encoding.items()}
        return input_encoding

    def calc_semantic_embedding(self, title_str: str, content_str: str) -> List:
        input_encoding = self._preprocess(title_str, content_str)
        with torch.no_grad():
            embedding = self.model(**input_encoding)[1][0]
            embedding = F.normalize(embedding, p=2, dim=-1)
            embedding = embedding.detach().cpu().numpy().tolist()
        return embedding


import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

if __name__ == '__main__':
    manager = APISemanticEmbeddingManagerV2("../hfl/chinese-roberta-wwm-ext")
    texts = [
        "罚510亿？阿里巴巴被反垄断调查后，暴跌7000亿，还面临天价罚单",
        "重磅！阿里巴巴涉嫌垄断被立案调查！人民日报：加强反垄断监管是为了更好发展",
        "人脸识别成售楼处“标配” 开发商为何紧盯购房人的脸不放?",
        "售楼处“人脸识别”的那些秘密",
        "售楼处 人脸识别",
        "人脸识别 售楼处",
        "人脸识别",
        "售楼处",
        "《007：无暇赴死》口碑解禁 英国预售仅次于《复联4》"
    ]
    contents = [
        "近来蚂蚁接连被约谈的消息引起极大热议，也给阿里巴巴造成巨大影响，在阿里巴巴被反垄断调查后，其股价暴跌7000亿。不仅是股价发生大震动，其还要面临天价罚单，一旦阿里巴巴“二选一”被认定为违反《反垄断法》，其将被责令停止违法行为，且并处罚上一年销售额1%以上或10%以下的罚款。阿里巴巴去年销售收入约为5097亿，也就是说最多罚款将达到510亿，此前反垄断罚款最高纪录是高通公司，在2015年时被罚款",
        "据新华视点消息，中国人民银行、中国银保监会、中国证监会、国家外汇管理局将于近日约谈蚂蚁集团，督促指导蚂蚁集团按照市场化、法治化原则，落实金融监管、公平竞争和保护消费者合法权益等要求，规范金融业务经营与发展。",
        "“为保护个人信息，戴着头盔去看房。”日前，一男子戴着头盔看楼盘的短视频在网上热传，售楼处安装人脸识别系统抓取个人信息的问题受到各界关注。据媒体报道，目前，人脸识别系统已广泛应用于售楼处，甚至成为“标配”。楼市是否真的进入“看脸”时代？开发商为何紧盯购房人的脸不放？“露脸”会对买房产生什么影响？近日，记者走访北京各大楼盘发现，在各售楼处纷纷加装视网膜识别摄像头背后，楼市营销格局正在发生变化，房企对渠道的依赖不断加深，利用人脸识别等技术手段开展的“渠道暗战”愈演愈烈，渠道佣金返点最高已达两位数。",
        "嗯，前段时间，有个新闻上了热搜，大致意思是“被人脸识别拍到就得多花30万？有人被迫戴头盔看房”。此事爆出之后，事件持续发酵，引起了关于售楼处人脸识别的各种讨论，甚至于延伸到小区入口和学校入口处人脸识别进入的讨论，乃至上至央媒下至自媒都各种上阵讨论。",
        "",
        "",
        "",
        "",
        "时光网讯  《007：无暇赴死》凌晨在英国举行全球首映式，媒体口碑也解禁，烂番茄93%（27个评价）暂居丹尼尔·克雷格系列第二；MTC综合评分77（16个评价）。影片将与9月30日在英国率先开画，预售表现仅次于2019年的《复仇者联盟4：终局之战》。"
    ]

    result_0 = manager.calc_semantic_embedding(texts[0], contents[0])
    result_1 = manager.calc_semantic_embedding(texts[1], contents[1])
    result_2 = manager.calc_semantic_embedding(texts[2], contents[2])
    result_3 = manager.calc_semantic_embedding(texts[3], contents[3])
    result_4 = manager.calc_semantic_embedding(texts[4], contents[4])
    result_5 = manager.calc_semantic_embedding(texts[5], contents[5])
    result_6 = manager.calc_semantic_embedding(texts[6], contents[6])
    result_7 = manager.calc_semantic_embedding(texts[7], contents[7])
    result_8 = manager.calc_semantic_embedding(texts[8], contents[8])

    results_1 = np.array([result_0, result_1, result_2, result_3, result_4, result_5, result_6, result_7, result_8])
    results_2 = np.array([result_4, result_5, result_6, result_7, result_8])
    print(cosine_similarity(results_1, results_2))


[[0.82664777 0.83107226 0.82946419 0.85537636 0.86511823]
 [0.84192467 0.84451377 0.83700384 0.87114459 0.84520511]
 [0.9609179  0.95892917 0.93561128 0.93760429 0.83096129]
 [0.95211838 0.95113822 0.93140277 0.93078833 0.81978036]
 [1.         0.99615201 0.974429   0.96759863 0.81685313]
 [0.99615201 1.         0.98205994 0.96885033 0.8183877 ]
 [0.974429   0.98205994 1.         0.93145926 0.82015003]
 [0.96759863 0.96885033 0.93145926 1.         0.84373261]
 [0.81685313 0.8183877  0.82015003 0.84373261 1.        ]]


In [None]:
-------mongo
docker run -d -p 27017:27017 -e MONGO_INITDB_ROOT_USERNAME=root -e MONGO_INITDB_ROOT_PASSWORD=ai2021 --name mongodb_test_1 mongo:5.0.3
docker exec -it mongodb_test_1 bash
mongo -u root -p ai2021
use admin
db.createUser({user: "bzl", pwd: "ai2021", roles: [{role: "readWrite", db: "test_jimmy"}]})

In [4]:
from urllib.parse import quote_plus
from pymongo import MongoClient
import json
import random
import re
from tqdm import tqdm

def get_mongo_db():
    mongo_user = "bzl"
    mongo_password = "ai2021"
    mongo_host = "10.8.204.89:27017"
    mongo_db = "admin"
    mongo_uri = "mongodb://{}:{}@{}/?authSource={}".format(quote_plus(mongo_user), quote_plus(mongo_password), mongo_host, mongo_db)
    faiss_mongo_client = MongoClient(host=mongo_uri)
    faiss_mongo_db = faiss_mongo_client["test_jimmy"]
    return faiss_mongo_db


faiss_mongo_db = get_mongo_db()
news_info_collection = faiss_mongo_db["news_info"]
# news_info_collection.drop()

with open("./data_collect_news_info_20211008.json", "r", encoding="utf-8") as f:
    json_data = json.load(f)
    random.shuffle(json_data)

manager = APISemanticEmbeddingManagerV2("../hfl/chinese-roberta-wwm-ext")
json_data_len = len(json_data)

pbar = tqdm(json_data, position=0, leave=True)

for i, x in enumerate(pbar):
    try:
        title = x["title"]
        paragraphs = [str(x).strip() for x in x["content"]]
        paragraphs = [x for x in paragraphs if len("".join(re.findall("[\u4e00-\u9fa5]+", x))) / (len(x) + 1e-5) > 0.5]
        content = "\n".join(paragraphs)
        embedding = manager.calc_semantic_embedding(title, content)
        x_data = {
            "i": i,
            "title": title,
            "content": content,
            "embedding": embedding
        }
        news_info_collection.insert_one(x_data)
    except:
        pass


100%|██████████| 178598/178598 [2:28:45<00:00, 20.01it/s]  


In [5]:
#build index
from urllib.parse import quote_plus
from pymongo import MongoClient
import json
import random
import re
from tqdm import tqdm
import numpy as np
from math import sqrt, ceil
import faiss


def get_mongo_db():
    mongo_user = "bzl"
    mongo_password = "ai2021"
    mongo_host = "10.8.204.89:27017"
    mongo_db = "admin"
    mongo_uri = "mongodb://{}:{}@{}/?authSource={}".format(quote_plus(mongo_user), quote_plus(mongo_password), mongo_host, mongo_db)
    faiss_mongo_client = MongoClient(host=mongo_uri)
    faiss_mongo_db = faiss_mongo_client["test_jimmy"]
    return faiss_mongo_db


faiss_mongo_db = get_mongo_db()
news_info_collection = faiss_mongo_db["news_info"]

id_list = []
embedding_list = []
for x in news_info_collection.find({}, {"i": 1, "embedding": 1}):
    id_list.append(x["i"])
    embedding_list.append(x["embedding"])

embeddings = np.array(embedding_list, dtype=np.float32)
ids = np.array(id_list, dtype=np.int64)

d = embeddings.shape[1]
nlist = ceil(sqrt(embeddings.shape[0]))
quantizer = faiss.IndexFlatIP(d)

index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT)
index.nprobe = 10
print("索引训练开始")
index.train(embeddings)
print("索引训练完成")
index.add_with_ids(embeddings, ids)
index_file_path = "./faiss.index"
faiss.write_index(index, index_file_path)

索引训练开始
索引训练完成


In [6]:
#test search
import faiss
import numpy as np
from urllib.parse import quote_plus
from pymongo import MongoClient

def get_mongo_db():
    mongo_user = "bzl"
    mongo_password = "ai2021"
    mongo_host = "10.8.204.89:27017"
    mongo_db = "admin"
    mongo_uri = "mongodb://{}:{}@{}/?authSource={}".format(quote_plus(mongo_user), quote_plus(mongo_password), mongo_host, mongo_db)
    faiss_mongo_client = MongoClient(host=mongo_uri)
    faiss_mongo_db = faiss_mongo_client["test_jimmy"]
    return faiss_mongo_db


faiss_mongo_db = get_mongo_db()
news_info_collection = faiss_mongo_db["news_info"]

index_file_path = "./faiss.index"
index = faiss.read_index(index_file_path)

manager = APISemanticEmbeddingManagerV2("../hfl/chinese-roberta-wwm-ext")
search_term = "马斯克 SpaceX 火箭"
search_embedding = manager.calc_semantic_embedding(search_term,"")
xq = np.array([search_embedding], dtype=np.float32)
d, i = index.search(xq, k=20)

print(d.tolist())
print(i.tolist())

for faiss_idx in i.tolist()[0]:
    x = news_info_collection.find_one({"i": faiss_idx})
    print("--------------")
    print(faiss_idx)
    print("title:\n", x["title"])
    print("content:\n", x["content"][:500])

[[0.9605381488800049, 0.9567288756370544, 0.9550487995147705, 0.9548767805099487, 0.9539167881011963, 0.9527953267097473, 0.9522258043289185, 0.952174186706543, 0.951046347618103, 0.9510403275489807, 0.9508522152900696, 0.9505863785743713, 0.9504168033599854, 0.9502747654914856, 0.9501602053642273, 0.9497288465499878, 0.949600875377655, 0.9494177103042603, 0.9489873647689819, 0.9489282369613647]]
[[4137, 69617, 76205, 131335, 172138, 170186, 74129, 30677, 130818, 91943, 48711, 85508, 55374, 64618, 29268, 112548, 26851, 1170, 57768, 165036]]
--------------
4137
title:
 马斯克：“星际飞船”首个轨道堆栈将在几周内准备好
content:
 埃隆·马斯克日前表示，“星际飞船”火箭的首个轨道堆栈将在未来几周内准备就绪，只需等待监管部门的批准。今年5月，SpaceX公司成功地将其Starship原型机SN15着陆，这是一种可重复使用的重型运载火箭，可以将宇航员和大型货物有效载荷送上月球和火星。（新浪财经）
--------------
69617
title:
 我国正在建的空间站，未来在太空都能做什么？
content:
 我国空间站建造已于今年拉开帷幕，定于2022年在轨建造完毕，实现中国载人航天工程三步走发展战略第三步的任务目标。
早在我国载人航天工程立项初期就流行着一句话：“造船为建站，建站为应用”。意思是说，造宇宙飞船是为了给空间站提供天地往返运输服务，而造空间站则是为了进行载人空间应用。所以，开展长期有人照料的大规模空间应用是建造空间站的最终目的。那么，位于太空的空间站究竟能做些什么？今天我们约请全