In [46]:
import torch

from transformers import (
    BertConfig,
    BertForMaskedLM,
    BertTokenizer,
)

device = 'cuda'

def tokenize_input(inputStr, tokenizer, seq_length=512):
    pad_id = 0
    tokenized_sentence = tokenizer.encode(inputStr)[:seq_length-20]
    tokens = tokenized_sentence
    token_length = len(tokens)
    tokens.extend([pad_id] * (seq_length - token_length))
    tokens = torch.tensor(tokens, dtype=torch.long)
    return tokens.reshape(1,seq_length), [token_length]


TEMP="temp/"
bert_model_path = '/home/gzy/Documents/lora/models/roberta-m-s_12L_cn'

bert_config = BertConfig.from_pretrained(bert_model_path, cache_dir=TEMP)

WRAPPED_MODEL = BertForMaskedLM.from_pretrained(
            bert_model_path,
            from_tf=False,
            config=bert_config,
            cache_dir=TEMP,
        )
for param in WRAPPED_MODEL.parameters():
    param.requires_grad = True
WRAPPED_MODEL.eval()

tokenizer = BertTokenizer.from_pretrained(bert_model_path)
WRAPPED_MODEL.resize_token_embeddings(len(tokenizer))

WRAPPED_MODEL.to(device)
print('device')

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RobertaTokenizer'. 
The class this function is called from is 'BertTokenizer'.


device


In [47]:
inputStr = '我想睡觉了'
input, lengths = tokenize_input(inputStr, tokenizer, seq_length=512)

In [48]:
out = WRAPPED_MODEL(input.cuda()).logits

In [49]:
out.shape

torch.Size([1, 512, 21128])

In [50]:
import numpy as np
from tqdm import tqdm
import pickle


data = np.load('./data/eastmoney_full_stocks_list_nlu_tencent.pkl', allow_pickle=True)

all = []

for idx in tqdm(range(len(data))):
    try:
        content = data[idx]['baike_content'] + data[idx]['baike_summary']
        tokens, token_length = tokenize_input(content, tokenizer, seq_length=512)
        output = WRAPPED_MODEL(tokens.to(device)).logits
        vector = output[0,token_length[0]].detach().cpu().numpy()
        one = {
            'ticker_id': data[idx]['ticker_id'],
            'ticker_name': data[idx]['ticker_name'],
            'vector': vector
        }
        all.append(one)
    except Exception:
        pass

print('vectors num', len(all))
with open('./data_ignore/vectors_bert.pkl','wb') as f:
    pickle.dump(all, f)

100%|██████████| 510/510 [00:31<00:00, 15.98it/s]

vectors num 510





### query by vector matrix

In [53]:
import numpy as np
from tqdm import tqdm

data = np.load('./data_ignore/vectors_bert.pkl', allow_pickle=True)

ticker_names = []
ticker_ids = []
vectors = []

for idx in tqdm(range(len(data))):
    if 'vector' in data[idx].keys():
        vectors.append(data[idx]['vector'])
        ticker_ids.append(data[idx]['ticker_id'])
        ticker_names.append(data[idx]['ticker_name'])

vectors = np.stack(vectors)

def similarity_vector_matrix(arr, brr):
    return arr.dot(brr.T) / (np.sqrt(np.sum(arr*arr)) * np.sqrt(np.sum(brr*brr, axis=1)))

def stock_search(query, topk=10):

    tokens, token_length = tokenize_input(query, tokenizer, seq_length=512)
    output = WRAPPED_MODEL(tokens.to(device)).logits
    vector = output[0,token_length[0]].detach().cpu().numpy()

    res = similarity_vector_matrix(vector, vectors)
    idxs = np.argsort(res)[::-1]

    topk_idxs = idxs[:topk]
    names = [ticker_names[idx] for idx in topk_idxs]
    print(names)

100%|██████████| 510/510 [00:00<00:00, 369510.29it/s]


In [54]:
query = '自动驾驶，新能源汽车'
stock_search(query)

['京沪高铁', '锦江酒店', '阿尔特', '安科瑞', '中材科技', '北新建材', '华发股份', '康希诺', '星宇股份', '润阳科技']


In [55]:
query = '电影，电视剧，文化艺术'
stock_search(query)

['中国石化', '鄂尔多斯', '行动教育', '锦江酒店', '福莱特', '宁波银行', '京沪高铁', '北方华创', '报喜鸟', '曲美家居']


In [56]:
query = '啤酒，烧烤，朋友聚会'
stock_search(query)

['中国石化', '锦江酒店', '行动教育', '鄂尔多斯', '福莱特', '京沪高铁', '报喜鸟', '一品红', '探路者', '斯迪克']
