In [1]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens')

In [3]:
from pathlib import Path
from utils import get_project_root, read_json

default_config = read_json(Path(get_project_root()) / "config" / "mind_rs_default.json")

In [4]:
def load_news_from_file(news_file, news_title):
    with open(news_file, "r", encoding="utf-8") as rd:
        for text in rd:
            # news id, category, subcategory, title, abstract, url
            nid, vert, subvert, title, abstract, url, title_entity, abs_entity = text.strip("\n").split("\t")
            if nid in news_title:
                continue
            news_title[nid] = title + " " + abstract
    return news_title

In [6]:
from collections import OrderedDict

root_path = Path(default_config["data_config"]["data_path"])
news_title_large = OrderedDict()
phases = ["train", "valid", "test"]
for phase in phases:
    file = root_path / "large" / phase / "news.tsv"
    if file.exists():
        news_title_large = load_news_from_file(file, news_title_large)


In [10]:
sentences = list(news_title_large.values())
news_ids = list(news_title_large.keys())

In [11]:
sentence_embedding = model.encode(sentences)


In [21]:
def save_embedding_dict(embedding_dict: dict, saved_path):
    with open(saved_path, "w", encoding="utf-8") as w:
        for nid, vector in embedding_dict.items():
            line = f"{nid} {' '.join([str(v) for v in vector])}\n"
            w.write(line)

In [23]:
sentence_embed_path = Path(get_project_root()) / "dataset/utils/sentence_embed"
sentence_embed_path.mkdir(exist_ok=True)
sentence_embedding_dict = dict(zip(news_ids, sentence_embedding.tolist()))
embed_file = sentence_embed_path / "distilbert.txt"

In [22]:
save_embedding_dict(sentence_embedding_dict, embed_file)

In [24]:
from utils import load_embedding_from_path

test_embed = load_embedding_from_path(embed_file)

In [25]:
import torch
embed_file = sentence_embed_path / "distilbert.vec"
torch.save(sentence_embedding_dict, embed_file)

In [26]:
torch_embed = torch.load(embed_file)