Skip to content

Commit

Permalink
feat: update collator to set different max_len (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan committed Mar 12, 2024
1 parent 7668fc9 commit 13330dd
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 22 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ from retrievals import AutoModelForEmbedding, AutoModelForMatch

query_texts = []
passage_texts = []
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding('')
query_embeddings = model.encode(query_texts, convert_to_tensor=True)
passage_embeddings = model.encode(passage_texts, convert_to_tensor=True)
Expand All @@ -128,7 +129,15 @@ dists, indices = matcher.similarity_search(query_embeddings, passage_embeddings,

**Search by Faiss**
```python
from retrievals import AutoModelForEmbedding, AutoModelForMatch

sentences = ['A woman is reading.', 'A man is playing a guitar.']
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding(model_name_or_path)
model.build_index(sentences)

matcher = AutoModelForMatch()
results = matcher.faiss_search("He plays guitar.")
```

**Rerank**
Expand Down
79 changes: 68 additions & 11 deletions src/retrievals/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,30 @@


class PairCollator(DataCollatorWithPadding):
def __init__(self, tokenizer, max_length: Optional[int] = None) -> None:
def __init__(
self,
tokenizer,
max_length: Optional[int] = None,
query_max_length: Optional[int] = None,
passage_max_length: Optional[int] = None,
) -> None:
self.tokenizer = tokenizer
if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.max_length = max_length or tokenizer.model_max_length

self.query_max_length: int
self.passage_man_length: int
if query_max_length:
self.query_max_length = query_max_length
elif max_length:
self.query_max_length = max_length
self.passage_man_length = max_length
else:
self.query_max_length = tokenizer.model_max_length
self.passage_man_length = tokenizer.model_max_length

if passage_max_length:
self.passage_man_length = passage_max_length

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
query_texts = [feature["query"] for feature in features]
Expand All @@ -18,14 +37,14 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
query_inputs = self.tokenizer(
query_texts,
padding=True,
max_length=self.max_length,
max_length=self.query_max_length,
truncation=True,
return_tensors="pt",
)
pos_inputs = self.tokenizer(
pos_texts,
padding=True,
max_length=self.max_length,
max_length=self.passage_man_length,
truncation=True,
return_tensors="pt",
)
Expand All @@ -34,11 +53,30 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:


class TripletCollator(DataCollatorWithPadding):
def __init__(self, tokenizer, max_length: Optional[int] = None) -> None:
def __init__(
self,
tokenizer,
max_length: Optional[int] = None,
query_max_length: Optional[int] = None,
passage_max_length: Optional[int] = None,
) -> None:
self.tokenizer = tokenizer
if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.max_length = max_length or tokenizer.model_max_length

self.query_max_length: int
self.passage_man_length: int
if query_max_length:
self.query_max_length = query_max_length
elif max_length:
self.query_max_length = max_length
self.passage_man_length = max_length
else:
self.query_max_length = tokenizer.model_max_length
self.passage_man_length = tokenizer.model_max_length

if passage_max_length:
self.passage_man_length = passage_max_length

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
query_texts = [feature["query"] for feature in features]
Expand All @@ -53,21 +91,21 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
query_inputs = self.tokenizer(
query_texts,
padding=True,
max_length=self.max_length,
max_length=self.query_max_length,
truncation=True,
return_tensors="pt",
)
pos_inputs = self.tokenizer(
pos_texts,
padding=True,
max_length=self.max_length,
max_length=self.passage_man_length,
truncation=True,
return_tensors="pt",
) # ["input_ids"]
neg_inputs = self.tokenizer(
neg_texts,
padding=True,
max_length=self.max_length,
max_length=self.passage_man_length,
truncation=True,
return_tensors="pt",
) # ["input_ids"]
Expand All @@ -80,11 +118,30 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:


class RerankCollator(DataCollatorWithPadding):
def __init__(self, tokenizer, max_length: Optional[int] = None):
def __init__(
self,
tokenizer,
max_length: Optional[int] = None,
query_max_length: Optional[int] = None,
passage_max_length: Optional[int] = None,
):
self.tokenizer = tokenizer
if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.max_length = max_length or tokenizer.model_max_length

self.query_max_length: int
self.passage_man_length: int
if query_max_length:
self.query_max_length = query_max_length
elif max_length:
self.query_max_length = max_length
self.passage_man_length = max_length
else:
self.query_max_length = tokenizer.model_max_length
self.passage_man_length = tokenizer.model_max_length

if passage_max_length:
self.passage_man_length = passage_max_length

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:

Expand Down
31 changes: 23 additions & 8 deletions src/retrievals/models/embedding_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch.utils.data import DataLoader, Dataset
from tqdm.autonotebook import trange
from transformers import (
AdamW,
AutoConfig,
AutoModel,
AutoTokenizer,
Expand Down Expand Up @@ -361,8 +360,8 @@ def encode_from_text(

return all_embeddings

def build_index(self, inputs, use_gpu: bool = True):
embeddings = self.encode(inputs)
def build_index(self, inputs, batch_size: int = 64, use_gpu: bool = True):
embeddings = self.encode(inputs, batch_size=batch_size)
embeddings = np.asarray(embeddings, dtype=np.float32)
index = faiss.IndexFlatL2(len(embeddings[0]))
if use_gpu:
Expand All @@ -373,6 +372,15 @@ def build_index(self, inputs, use_gpu: bool = True):
index.add(embeddings)
return index

def add_to_index(self):
return

def search(self):
return

def similarity(self, queries: Union[str, List[str]], keys: Union[str, List[str], ndarray]):
return

def save(self):
pass

Expand Down Expand Up @@ -473,13 +481,23 @@ def forward(
return pooled_output1, pooled_output2


def unsorted_segment_mean(data: torch.Tensor, segment_ids: torch.Tensor, num_segments: int) -> torch.Tensor:
result_shape = (num_segments, data.size(1))
segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
result = data.new_full(result_shape, 0) # init empty result tensor
count = data.new_full(result_shape, 0)
result.scatter_add_(0, segment_ids, data)
count.scatter_add_(0, segment_ids, torch.ones_like(data))
return result / count.clamp(min=1)


class ListwiseModel(AutoModelForEmbedding):
"""
segment_id
"""

def __init__(self) -> None:
super().__init__()
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

def forward(
self,
Expand All @@ -490,6 +508,3 @@ def forward(
return_dict: Optional[bool] = None,
):
return

def apply_listwise_pooling(self, data: torch.Tensor, segment_ids: torch.Tensor, num_segments: int):
return
2 changes: 1 addition & 1 deletion src/retrievals/models/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,6 @@ def save(self, path):

def save_pretrained(self, path):
"""
Same function as save
Same function to save
"""
return self.save(path)
2 changes: 1 addition & 1 deletion src/retrievals/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys

__version__ = '0.0.0.dev2'
__version__ = '0.0.1'
short_version = __version__


Expand Down
23 changes: 22 additions & 1 deletion tests/test_models/test_embedding_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
from unittest import TestCase

import numpy as np
import torch
from transformers import AutoConfig

from src.retrievals.models.embedding_auto import AutoModelForEmbedding
from src.retrievals.models.embedding_auto import (
AutoModelForEmbedding,
ListwiseModel,
PairwiseModel,
unsorted_segment_mean,
)

from .test_modeling_common import (
ModelTesterMixin,
Expand Down Expand Up @@ -150,3 +156,18 @@ def setUp(self) -> None:

def test_pairwise_model(self):
pass


class ListwiseModelTest(TestCase):
def setUp(self) -> None:
pass
# self.model = ListwiseModel()

def test_unsorted_segment_mean(self):
input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
segment_ids = torch.tensor([0, 0, 1, 1])
num_segments = 2

list_pool = unsorted_segment_mean(input_tensor, segment_ids, num_segments)
print(list_pool)
# self.assertEqual()

0 comments on commit 13330dd

Please sign in to comment.