Skip to content

Commit

Permalink
修复Text2Vec的bug(#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tongjilibo committed Apr 3, 2024
1 parent aa4c81a commit d104c44
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 16 deletions.
15 changes: 10 additions & 5 deletions bert4torch/pipelines/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union, Dict
from typing import List, Union, Dict, Literal
import numpy as np
import os
import torch
Expand All @@ -11,21 +11,26 @@
class PipeLineBase:
'''基类
'''
def __init__(self, checkpoint_path:str, device:str=None, **kwargs) -> None:
def __init__(self, checkpoint_path:str, device:str=None, tokenizer_type:Literal['b4t', 'hf']='b4t', **kwargs) -> None:
self.checkpoint_path = checkpoint_path
self.config_path = kwargs.get('config_path') or checkpoint_path
if device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = device

if (tokenizer_type == 'b4t') and os.path.exists(os.path.join(self.checkpoint_path, 'vocab.txt')):
self.tokenizer_type = 'b4t'
else:
self.tokenizer_type = 'hf'
self.tokenizer = self.build_tokenizer()
self.model = self.build_model(kwargs)
self.config = self.model.config

def build_tokenizer(self):
vocab_path = os.path.join(self.checkpoint_path, 'vocab.txt')
if os.path.exists(vocab_path):
tokenizer = Tokenizer(vocab_path, do_lower_case=True)
# TODO: 默认优先使用默认的Tokenizer,如果没有vocab文件,则使用AutoTokenizer,后续可能修改
if self.tokenizer_type == 'b4t':
tokenizer = Tokenizer(os.path.join(self.checkpoint_path, 'vocab.txt'), do_lower_case=True)
else:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.checkpoint_path)
Expand Down
18 changes: 11 additions & 7 deletions bert4torch/pipelines/text2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,22 @@ def encode(
all_embeddings = []
for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
sentences_batch = sentences_sorted[start_index: start_index + batch_size]
batch = self.tokenizer(sentences_batch, maxlen=max_seq_length)
batch_input = [torch.tensor(sequence_padding(item), dtype=torch.long, device=self.device) for item in batch]
output = self.model(batch_input)
if self.tokenizer_type == 'b4t':
batch_input = self.tokenizer(sentences_batch, max_length=max_seq_length, return_tensors='pt').to(self.device)
output = self.model(batch_input)
else:
batch_input = self.tokenizer(sentences_batch, max_length=max_seq_length, return_tensors='pt').to(self.device)
output = self.model(**batch_input)

last_hidden_state = output.get('last_hidden_state')
pooler = output.get('pooled_output')
if isinstance(batch_input, list):
attention_mask = (batch_input[0] != self.tokenizer._token_pad_id).long()
attention_mask = (batch_input[0] != self.tokenizer.pad_token_id).long()
elif isinstance(batch_input, torch.Tensor):
attention_mask = (batch_input != self.tokenizer._token_pad_id).long()
else:
raise TypeError('Args `batch_input` only support list(tensor)/tensor format')
attention_mask = (batch_input != self.tokenizer.pad_token_id).long()
else: # 类似字典格式的
attention_mask = batch_input.get('attention_mask', (batch_input['input_ids'] != self.tokenizer.pad_token_id).long())

pool_strategy = pool_strategy or self.pool_strategy
embs = get_pool_emb(last_hidden_state, pooler, attention_mask, pool_strategy, custom_layer)
if normalize_embeddings:
Expand Down
10 changes: 7 additions & 3 deletions bert4torch/snippets/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
from pathlib import Path
from typing import Union, Optional, Dict
from torch4keras.snippets import log_error_once, log_info_once, log_error, is_safetensors_available
from torch4keras.snippets import log_error_once, log_info_once, log_error, is_safetensors_available, check_file_modified


if os.environ.get('SAFETENSORS_FIRST', False):
Expand Down Expand Up @@ -163,7 +163,9 @@ def snapshot_download(
)
if resolved_file.endswith('config.json'):
storage_folder = os.path.dirname(resolved_file)
log_info_once(f'Download {repo_id} to {storage_folder}')
if check_file_modified(resolved_file, duration=2):
# 如果文件在2s内下载的,则不打印
log_info_once(f'Download {repo_id} to {storage_folder}')
if os.path.exists(resolved_file + ".lock"):
os.remove(resolved_file + ".lock")
return storage_folder
Expand Down Expand Up @@ -194,7 +196,9 @@ def snapshot_download(
user_agent = user_agent,
endpoint = HF_ENDPOINT
)
log_info_once(f'Download {repo_id} to {resolved_file}')
if check_file_modified(resolved_file, duration=2):
# 如果文件在2s内下载的,则不打印
log_info_once(f'Download {repo_id} to {resolved_file}')
except EntryNotFoundError:
log_error(
f"{repo_id} does not appear to have a file named {filename}. Checkout "
Expand Down
2 changes: 2 additions & 0 deletions docs/History.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
## 更新历史

- **20240403**:修改Text2Vec的bug
- **20240331**: 修复chatglm3的bug, 修复save_pretrained时多文件的bug,增加CausalLMLoss, 修改deepspeed的传参逻辑
- **20240317**: 修复config_path的bug,允许num_key_value_heads参数
- **20240316**: 增加get_weight_decay_optim_groups函数, attention中允许is_causal,修改repetition_penalty的bug,把baichuan从llama中剥离,[torch4keras-v0.2.1](https://github.com/Tongjilibo/torch4keras/releases/tag/v0.2.1)更新特性
- **20240216**: fastapi发布服务允许闲时offload到cpu, `build_transformer_model`允许从hf下载, 添加`FillMask`的pipeline, 添加`SequenceClassificationTrainer`
Expand Down
5 changes: 4 additions & 1 deletion examples/basic/embedding/basic_bge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"
# root_model_path = 'E:\pretrain_ckpt\embedding\BAAI@bge-large-en-v1.5'
root_model_path = 'E:\pretrain_ckpt\embedding\BAAI@bge-large-zh-v1.5'
root_model_path = 'BAAI/bge-large-zh-v1.5'
# root_model_path = '/data/pretrain_ckpt/embedding/BAAI--bge-large-zh-v1.5'

sentences_1 = ["样例数据-1", "样例数据-2"]
sentences_2 = ["样例数据-3", "样例数据-4"]
Expand Down

0 comments on commit d104c44

Please sign in to comment.