Skip to content

Commit

Permalink
fixed type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
KennethEnevoldsen committed Jan 26, 2024
1 parent 141ed73 commit 0a5fd27
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 61 deletions.
26 changes: 9 additions & 17 deletions src/seb/interfaces/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
from dataclasses import dataclass
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, Optional, Protocol,
runtime_checkable)
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, runtime_checkable

from numpy.typing import ArrayLike
from pydantic import BaseModel
Expand Down Expand Up @@ -131,29 +130,22 @@ def encode(
"""
return self.model.encode(sentences, batch_size=batch_size, task=task, **kwargs)

def encode_queries(self, queries: list[str], batch_size: int, **kwargs):
def encode_queries(self, queries: list[str], batch_size: int, **kwargs): # noqa
try:
return self.model.encode_queries(queries, batch_size=batch_size, **kwargs)
return self.model.encode_queries(queries, batch_size=batch_size, **kwargs) # type: ignore
except AttributeError:
return self.encode(queries, task=None, batch_size=batch_size, **kwargs)

def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs):
def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs): # noqa
try:
return self.model.encode_corpus(corpus, batch_size=batch_size, **kwargs)
return self.model.encode_corpus(corpus, batch_size=batch_size, **kwargs) # type: ignore
except AttributeError:
sep = " "
if type(corpus) is dict:
if isinstance(corpus, dict):
sentences = [
(corpus["title"][i] + sep + corpus["text"][i]).strip()
if "title" in corpus
else corpus["text"][i].strip()
for i in range(len(corpus["text"]))
(corpus["title"][i] + sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() # type: ignore
for i in range(len(corpus["text"])) # type: ignore
]
else:
sentences = [
(doc["title"] + sep + doc["text"]).strip()
if "title" in doc
else doc["text"].strip()
for doc in corpus
]
sentences = [(doc["title"] + sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
return self.encode(sentences, task=None, batch_size=batch_size, **kwargs)
12 changes: 0 additions & 12 deletions src/seb/interfaces/mteb_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,6 @@
from .task import DescriptiveDatasetStats, Task


<<<<<<< HEAD
class MTEBTaskModel(Encoder):
def __init__(self, mteb_model: Encoder, task: Task) -> None:
self.mteb_model = mteb_model
self.task = task

def encode(self, texts: list[str], **kwargs: Any) -> np.ndarray: # type: ignore
return self.mteb_model.encode(texts, task=self.task, **kwargs) # type: ignore


=======
>>>>>>> f427a6da949fb2f7eb14f1014e9ae7a43df3a109
class MTEBTask(Task):
def __init__(self, mteb_task: AbsTask) -> None:
self.mteb_task = mteb_task
Expand Down
23 changes: 8 additions & 15 deletions src/seb/registered_models/cohere_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,32 +44,25 @@ def encode(
task: Optional[Task] = None,
**kwargs: Any, # noqa: ARG002
) -> torch.Tensor:
if task.task_type == "Classification":
if task and task.task_type == "Classification":
input_type = "classification"
elif task.task_type == "Clustering":
elif task and task.task_type == "Clustering":
input_type = "clustering"
else:
input_type = "search_document"
return self._embed(sentences, input_type=input_type)

def encode_queries(self, queries: list[str], batch_size: int, **kwargs):
def encode_queries(self, queries: list[str], batch_size: int, **kwargs): # noqa
return self._embed(queries, input_type="search_query")

def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs):
if type(corpus) is dict:
def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs): # noqa
if isinstance(corpus, dict):
sentences = [
(corpus["title"][i] + self.sep + corpus["text"][i]).strip()
if "title" in corpus
else corpus["text"][i].strip()
for i in range(len(corpus["text"]))
(corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() # type: ignore
for i in range(len(corpus["text"])) # type: ignore
]
else:
sentences = [
(doc["title"] + self.sep + doc["text"]).strip()
if "title" in doc
else doc["text"].strip()
for doc in corpus
]
sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
return self._embed(sentences, input_type="search_document")


Expand Down
23 changes: 8 additions & 15 deletions src/seb/registered_models/e5_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,28 @@ def __init__(self, model_name: str, sep: str = " "):
self.mdl = SentenceTransformer(model_name)
self.sep = sep

def encode(
def encode( # type: ignore
self,
sentences: list[str],
*,
task: Task,
task: Task, # noqa: ARG002
batch_size: int = 32,
**kwargs: Any,
) -> ArrayLike:
return self.encode_queries(sentences, batch_size=batch_size, **kwargs)

def encode_queries(self, queries: list[str], batch_size: int, **kwargs):
def encode_queries(self, queries: list[str], batch_size: int, **kwargs): # noqa
sentences = ["query: " + sentence for sentence in queries]
return self.mdl.encode(sentences, batch_size=batch_size, **kwargs)

def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs):
if type(corpus) is dict:
def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs): # noqa
if isinstance(corpus, dict):
sentences = [
(corpus["title"][i] + self.sep + corpus["text"][i]).strip()
if "title" in corpus
else corpus["text"][i].strip()
for i in range(len(corpus["text"]))
(corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() # type: ignore
for i in range(len(corpus["text"])) # type: ignore
]
else:
sentences = [
(doc["title"] + self.sep + doc["text"]).strip()
if "title" in doc
else doc["text"].strip()
for doc in corpus
]
sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
sentences = ["passage: " + sentence for sentence in sentences]
return self.mdl.encode(sentences, batch_size=batch_size, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/seb/registered_models/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def embed(self, sentences: Sequence[str]) -> torch.Tensor:
vectors = [embedding.embedding for embedding in data]
return torch.tensor(vectors)

def encode(
def encode( # type: ignore
self,
sentences: Sequence[str],
*,
Expand Down
2 changes: 1 addition & 1 deletion src/seb/registered_models/translate_e5_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def encode(
**kwargs: Any,
) -> torch.Tensor:
try:
src_lang = task.languages[0]
src_lang = task.languages[0] # type: ignore
except IndexError:
# Danish is the default fallback if no language is specified for the task.
src_lang = "da"
Expand Down

0 comments on commit 0a5fd27

Please sign in to comment.