Skip to content

Commit

Permalink
fix: embed
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan committed May 30, 2024
1 parent 5198cc5 commit 3f8b350
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/retrievals/models/embedding_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,14 +464,14 @@ def from_pretrained(
**kwargs,
)

def save_pretrained(self, path: str):
def save_pretrained(self, path: str, safe_serialization: bool = True):
"""
Saves all model and tokenizer to path
"""
logger.info("Save model to {}".format(path))
state_dict = self.model.state_dict()
state_dict = type(state_dict)({k: v.clone().cpu() for k, v in state_dict.items()})
self.model.save_pretrained(path, state_dict=state_dict)
self.model.save_pretrained(path, state_dict=state_dict, safe_serialization=safe_serialization)
self.tokenizer.save_pretrained(path)

def push_to_hub(self, hub_model_id: str, private: bool = True, **kwargs):
Expand Down
17 changes: 5 additions & 12 deletions src/retrievals/models/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,23 +304,16 @@ def from_pretrained(
)
return reranker

def save(self, path: str):
def save_pretrained(self, path: str, safe_serialization: bool = True):
"""
Saves all model and tokenizer to path
"""
if path is None:
return

logger.info("Save model to {}".format(path))
self.model.save_pretrained(path)
state_dict = self.model.state_dict()
state_dict = type(state_dict)({k: v.clone().cpu() for k, v in state_dict.items()})
self.model.save_pretrained(path, state_dict=state_dict, safe_serialization=safe_serialization)
self.tokenizer.save_pretrained(path)

def save_pretrained(self, path: str):
"""
Same function to save
"""
return self.save(path)


class ColBERT(RerankModel):
def __init__(
Expand Down Expand Up @@ -446,7 +439,7 @@ def create_documents(self, query, documents, tokenizer):
res_merge_inputs_pids.append(pid)
return res_merge_inputs, res_merge_inputs_pids

def _merge_inputs(self, chunk1_raw, chunk2, sep_id):
def _merge_inputs(self, chunk1_raw, chunk2, sep_id: int):
chunk1 = deepcopy(chunk1_raw)

chunk1['input_ids'].append(sep_id)
Expand Down

0 comments on commit 3f8b350

Please sign in to comment.