-
Notifications
You must be signed in to change notification settings - Fork 31
/
gritlm.py
218 lines (201 loc) · 10.6 KB
/
gritlm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
from typing import Dict, List, Union, cast
import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
class GritLM(torch.nn.Module):
def __init__(
self,
model_name_or_path: str = None,
mode: str = 'unified', # One of ['unified', 'embedding', 'generative']
pooling_method: str = 'mean', # One of ['cls', 'lasttoken', 'mean', 'weightedmean']
normalized: bool = True,
projection: int = None,
is_inference: bool = True,
embed_eos: str = "",
attn: str = 'bbcc',
**kwargs, # Passed to the model, e.g. `attn_implementation`, `torch_dtype` etc.
) -> None:
super().__init__()
if mode == 'embedding':
if any([x in model_name_or_path for x in ['gtr', 't5', 'instructor']]):
# Somehow AutoModel does not pick the right one by default
from transformers import T5EncoderModel
self.model = T5EncoderModel.from_pretrained(model_name_or_path, **kwargs)
else:
self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, **kwargs)
self.embedding_attr = None
else:
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, **kwargs)
self.generate = self.model.generate
if hasattr(self.model, 'model'): # LLama2 & Mistral
self.embedding_attr = 'model'
elif hasattr(self.model, 'transformer'): # GPT-Neo & GPT-J
self.embedding_attr = 'transformer'
else:
raise ValueError("Could not find attribute to use for embedding: ", self.model)
self.projection = torch.nn.Linear(
in_features=self.model.config.hidden_size,
out_features=int(projection),
dtype=self.model.dtype
) if projection is not None else None
self.normalized = normalized
self.pooling_method = pooling_method
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.num_gpus = 1
self.embed_eos = embed_eos
self.attn = attn
if (self.attn is not None) and self.attn not in ['bbcc', 'cccc', 'bb', 'cc']:
raise ValueError(f"Mixed attention no longer supported: {self.attn}. Only bbcc, cccc, bb, cc are supported")
print(f"Created GritLM: {self.model.dtype} dtype, {pooling_method} pool, {mode} mode, {attn} attn")
if is_inference:
# Padding side right is necessary for `embed_instruction` to index correctly
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side='right')
if not(self.tokenizer.pad_token) and self.tokenizer.eos_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
print('Set pad token to eos token: ' + self.tokenizer.pad_token)
if self.embed_eos:
assert self.embed_eos in self.tokenizer.vocab, f"EOS token {self.embed_eos} not in vocab"
self.model.eval()
if not("device_map" in kwargs) and not(kwargs.get("load_in_4bit", False)) and not(kwargs.get("load_in_8bit", False)):
self.model.to(self.device)
# Parallelize embedding model
if mode == 'embedding':
self.num_gpus = torch.cuda.device_count()
if self.num_gpus > 1:
print(f"----------Using {self.num_gpus} data-parallel GPUs----------")
self.model = torch.nn.DataParallel(self.model)
def encode_queries(self, queries: Union[List[str], str], **kwargs) -> np.ndarray:
"""Used for encoding the queries of retrieval or reranking tasks"""
return self.encode(queries, **kwargs)
def encode_corpus(self, corpus: Union[List[str], str, List[Dict[str, str]]], **kwargs) -> np.ndarray:
"""Used for encoding the corpus of retrieval tasks"""
if isinstance(corpus, dict):
corpus = [corpus]
if isinstance(corpus, list) and isinstance(corpus[0], dict):
corpus = [
doc["title"] + " " + doc["text"] if "title" in doc
else doc["text"] for doc in corpus
]
return self.encode(corpus, **kwargs)
@torch.no_grad()
def encode(
self,
sentences: Union[List[str], str],
batch_size: int = 256,
max_length: int = 512,
instruction: str = "",
embed_instruction: bool = False,
get_cache: bool = False,
convert_to_tensor: bool = False,
recast: bool = False,
add_special_tokens: bool = True,
**kwargs,
) -> np.ndarray:
if self.num_gpus > 1:
batch_size *= self.num_gpus
input_was_string = False
if isinstance(sentences, str):
sentences = [sentences]
input_was_string = True
all_embeddings, all_kv_caches = [], []
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences)<256):
sentences_batch = [
instruction + s + self.embed_eos for s in sentences[start_index:start_index + batch_size]
]
# This will prepend the bos token if the tokenizer has `add_bos_token=True`
inputs = self.tokenizer(
sentences_batch,
padding=True,
truncation=True,
return_tensors='pt',
max_length=max_length,
add_special_tokens=add_special_tokens,
).to(self.device)
if (self.attn is not None) and (self.attn[:2] == 'bb'):
inputs["is_causal"] = False
if get_cache:
inputs['use_cache'] = True
outputs = (
getattr(self.model, self.embedding_attr) if self.embedding_attr else self.model
)(**inputs)
last_hidden_state = outputs[0]
if get_cache:
# Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`
assert len(all_kv_caches) == 0, "Can only get cache for one batch at a time"
all_kv_caches = outputs[1]
if self.projection:
last_hidden_state = self.projection(last_hidden_state)
if (instruction) and (embed_instruction is False) and ("mean" in self.pooling_method):
# Remove instruction tokens from the embeddings by masking them
instruction_tokens = self.tokenizer(
instruction,
padding=False,
truncation=True,
max_length=max_length,
add_special_tokens=add_special_tokens,
)["input_ids"]
inputs['attention_mask'][:, :len(instruction_tokens)] = 0
embeddings = self.pooling(last_hidden_state, inputs['attention_mask'], recast=recast)
# Normalize can change the dtype (https://discuss.pytorch.org/t/tensor-in-float16-is-transformed-into-float32-after-torch-norm/110891)
if self.normalized:
in_dtype = embeddings.dtype
embeddings = torch.nn.functional.normalize(embeddings, dim=-1).to(in_dtype)
embeddings = cast(torch.Tensor, embeddings)
if convert_to_tensor:
all_embeddings.append(embeddings)
else:
# NumPy does not support bfloat16
all_embeddings.append(embeddings.cpu().to(torch.float32).numpy())
all_embeddings = (
torch.cat(all_embeddings, dim=0) if convert_to_tensor else np.concatenate(all_embeddings, axis=0)
)
if input_was_string:
all_embeddings = all_embeddings[0]
if get_cache:
# all_kv_caches = (
# torch.stack(all_kv_caches, dim=0) if convert_to_tensor else np.concatenate(all_kv_caches, axis=0)
# )
return all_embeddings, all_kv_caches
return all_embeddings
def pooling(
self, hidden_state: torch.Tensor, attention_mask: torch.Tensor = None, recast: bool = False
) -> torch.Tensor:
"""
Args:
hidden_state: [b, n, d]
attention_mask: [b, n]
"""
# In case the model is distributed across multiple devices; hidden_state may end up on diff device
hidden_state = hidden_state.to(attention_mask.device)
if self.pooling_method == 'cls':
embedding = hidden_state[:, 0]
elif self.pooling_method == 'lasttoken':
b, n, d = hidden_state.size()
# Get the last `1` in the attention mask of each item
# Often it is just `gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1`
# except when 1) There's all 1's 2) There's 0's before the 1's
reversed_mask = torch.flip(attention_mask, dims=(1,))
argmax_reverse = torch.argmax(reversed_mask, dim=1, keepdim=False)
gather_indices = attention_mask.size(1) - argmax_reverse - 1
# If there are empty sequences, where the index would become -1 it will crash so set them to 0
gather_indices = torch.clamp(gather_indices, min=0)
# Turn indices from shape [b] -> [b, 1, d]
gather_indices = gather_indices.unsqueeze(-1).repeat(1, d)
gather_indices = gather_indices.unsqueeze(1)
assert gather_indices.shape == (b, 1, d)
# Gather along the seq len: [b, n, d] -> [b, d]
# Actually no need for the attention mask as we gather the last token where attn_mask=1 but
# as some indices (which shouldn't be attended to) may be 0 due to clamp, use mask to ignore them again
input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float()
embedding = torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1)
elif self.pooling_method in ['mean', 'weightedmean']:
if self.pooling_method == 'weightedmean':
attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0]
s = torch.sum(hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
d = attention_mask.sum(dim=1, keepdim=True).float()
embedding = s / d
else: raise NotImplementedError(f"Unknown pooling method: {self.pooling_method}")
# Recasting performs slightly worse but saves 50% space
if recast: return embedding.to(hidden_state.dtype)
return embedding