In [None]:
import os
os.environ['TRANSFORMERS_CACHE'] = '/root/data/transformers/model_zoo'

In [None]:
import numpy as np
import faiss
import torch
from pathlib import Path
from tqdm import tqdm
from transformers import AlbertModel, BertTokenizer

In [None]:
class AlbertDocEncoder(object):
    def __init__(self, pretrained='voidful/albert_chinese_tiny'):
        self.tokenizer = BertTokenizer.from_pretrained(pretrained, mirror='tuna')
        self.model = AlbertModel.from_pretrained(pretrained, mirror='tuna')

    def encode_doc(self, doc):
        input_ids = self.tokenizer.encode(doc, add_special_tokens=False, return_tensors='pt')
        with torch.no_grad():
            outputs = self.model(input_ids, output_hidden_states=True)
        hidden_state = outputs.hidden_states
        vec = torch.mean(hidden_state[1] + hidden_state[-1], dim=1).squeeze()
        return vec


class FaissIndexer(object):
    def __init__(self, embed_dim, param='HNSW64', measure=faiss.METRIC_INNER_PRODUCT):
        self.index = faiss.index_factory(embed_dim, param, measure)  
    
    def build_index(self, vecs):
        self.index.add(vecs)

    def most_similar(self, query_vecs, top_k=10):
        dist, ind = self.index.search(query_vecs, k=top_k)
        return ind[0]


class WhiteTransform(object):
    def __init__(self, vecs, n_components=128):
        self.embed_in = vecs.shape[1]
        self.mu = self.moving_mean(vecs)
        self.cov = self.moving_cov(vecs)
        self.W = self.get_kernel(n_components)

    def moving_mean(self, vecs):
        mu = np.zeros((1, vecs.shape[1]), dtype=np.float32)
        for i in range(vecs.shape[0]):
            mu = i / (i + 1) * mu + 1 / (i + 1) * vecs[i, :]
        return mu
    
    def moving_cov(self, vecs):
        cov = np.zeros((vecs.shape[1], vecs.shape[1]), dtype=np.float32)
        for i in range(vecs.shape[0]):
            vec = vecs[i, :].reshape((1, -1))
            vec = vec - self.mu
            cur_cov = np.dot(vec.T, vec)
            cov = i / (i + 1) * cov + 1 / (i + 1) * cur_cov
        return cov
    
    def get_kernel(self, n_components):
        u, s, vh = np.linalg.svd(self.cov)
        W = np.dot(u, np.diag(1 / np.sqrt(s)))
        return W[:, :n_components]

    def transform_vecs(self, vecs):
        vecs = (vecs - self.mu).dot(self.W)
        return vecs / (vecs**2).sum(axis=1, keepdims=True) ** 0.5

In [None]:
pretrained = 'voidful/albert_chinese_tiny'
albert_enc = AlbertDocEncoder()

In [None]:
data_path = Path('../../data/tnews/')
lines = open(data_path / 'pretrain_data_5k.txt', 'r', encoding='utf8').read().splitlines()
print(len(lines))
lines[:3]

In [None]:
n_sample = len(lines)
embed_dim = 312
embed_vecs = np.zeros((n_sample, embed_dim), dtype=np.float32)
for i in tqdm(range(len(lines))):
    vec = albert_enc.encode_doc(lines[i])
    embed_vecs[i, :] = vec

In [None]:
whiter = WhiteTransform(embed_vecs, n_components=128)

In [None]:
embed_vecs_white = whiter.transform_vecs(embed_vecs)

In [None]:
index_origin = FaissIndexer(embed_dim=embed_vecs.shape[1], param='HNSW64', measure=faiss.METRIC_L2)
index_origin.build_index(embed_vecs)

In [None]:
index_white = FaissIndexer(embed_dim=embed_vecs_white.shape[1], param='HNSW32', measure=faiss.METRIC_INNER_PRODUCT)
index_white.build_index(embed_vecs_white)

In [None]:
q_vec = albert_enc.encode_doc(u'比亚迪4月销量劲增20%').numpy()
q_vec = q_vec.reshape((1, -1))
for i in index_origin.most_similar(q_vec, 10):
    print(lines[i])

In [None]:
q_vec = albert_enc.encode_doc(u'比亚迪4月销量劲增20%？').numpy()
q_vec = q_vec.reshape((1, -1))
q_vec_white = whiter.transform_vecs(q_vec)
for i in index_white.most_similar(q_vec_white, 10):
    print(lines[i])