-
Notifications
You must be signed in to change notification settings - Fork 396
/
flag_dres_model.py
97 lines (80 loc) · 3.62 KB
/
flag_dres_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from typing import cast, List, Dict, Union
import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, is_torch_npu_available
class FlagDRESModel:
def __init__(
self,
model_name_or_path: str = None,
pooling_method: str = 'cls',
normalize_embeddings: bool = True,
query_instruction_for_retrieval: str = None,
batch_size: int = 256,
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModel.from_pretrained(model_name_or_path)
self.query_instruction_for_retrieval = query_instruction_for_retrieval
self.normalize_embeddings = normalize_embeddings
self.pooling_method = pooling_method
self.batch_size = batch_size
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif is_torch_npu_available():
self.device = torch.device("npu")
else:
self.device = torch.device("cpu")
self.model = self.model.to(self.device)
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
self.model = torch.nn.DataParallel(self.model)
self.batch_size = self.batch_size * num_gpus
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
'''
This function will be used for retrieval task
if there is a instruction for queries, we will add it to the query text
'''
if self.query_instruction_for_retrieval is not None:
input_texts = ['{}{}'.format(self.query_instruction_for_retrieval, q) for q in queries]
else:
input_texts = queries
return self.encode(input_texts)
def encode_corpus(self, corpus: List[Union[Dict[str, str], str]], **kwargs) -> np.ndarray:
'''
This function will be used for retrieval task
encode corpus for retrieval task
'''
if isinstance(corpus[0], dict):
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
else:
input_texts = corpus
return self.encode(input_texts)
@torch.no_grad()
def encode(self, sentences: List[str], **kwargs) -> np.ndarray:
self.model.eval()
all_embeddings = []
for start_index in tqdm(range(0, len(sentences), self.batch_size), desc="Batches", disable=len(sentences)<256):
sentences_batch = sentences[start_index:start_index + self.batch_size]
inputs = self.tokenizer(
sentences_batch,
padding=True,
truncation=True,
return_tensors='pt',
max_length=512,
).to(self.device)
last_hidden_state = self.model(**inputs, return_dict=True).last_hidden_state
embeddings = self.pooling(last_hidden_state, inputs['attention_mask'])
if self.normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
embeddings = cast(torch.Tensor, embeddings)
all_embeddings.append(embeddings.cpu().numpy())
return np.concatenate(all_embeddings, axis=0)
def pooling(self,
last_hidden_state: torch.Tensor,
attention_mask: torch.Tensor=None):
if self.pooling_method == 'cls':
return last_hidden_state[:, 0]
elif self.pooling_method == 'mean':
s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
d = attention_mask.sum(dim=1, keepdim=True).float()
return s / d