diff --git a/src/seb/interfaces/model.py b/src/seb/interfaces/model.py index 40005149..e68799b2 100644 --- a/src/seb/interfaces/model.py +++ b/src/seb/interfaces/model.py @@ -1,8 +1,7 @@ import json from dataclasses import dataclass from pathlib import Path -from typing import (TYPE_CHECKING, Any, Callable, Optional, Protocol, - runtime_checkable) +from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, runtime_checkable from numpy.typing import ArrayLike from pydantic import BaseModel @@ -131,29 +130,22 @@ def encode( """ return self.model.encode(sentences, batch_size=batch_size, task=task, **kwargs) - def encode_queries(self, queries: list[str], batch_size: int, **kwargs): + def encode_queries(self, queries: list[str], batch_size: int, **kwargs): # noqa try: - return self.model.encode_queries(queries, batch_size=batch_size, **kwargs) + return self.model.encode_queries(queries, batch_size=batch_size, **kwargs) # type: ignore except AttributeError: return self.encode(queries, task=None, batch_size=batch_size, **kwargs) - def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs): + def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs): # noqa try: - return self.model.encode_corpus(corpus, batch_size=batch_size, **kwargs) + return self.model.encode_corpus(corpus, batch_size=batch_size, **kwargs) # type: ignore except AttributeError: sep = " " - if type(corpus) is dict: + if isinstance(corpus, dict): sentences = [ - (corpus["title"][i] + sep + corpus["text"][i]).strip() - if "title" in corpus - else corpus["text"][i].strip() - for i in range(len(corpus["text"])) + (corpus["title"][i] + sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() # type: ignore + for i in range(len(corpus["text"])) # type: ignore ] else: - sentences = [ - (doc["title"] + sep + doc["text"]).strip() - if "title" in doc - else doc["text"].strip() - for doc in corpus - ] + sentences = [(doc["title"] + sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus] return self.encode(sentences, task=None, batch_size=batch_size, **kwargs) diff --git a/src/seb/interfaces/mteb_task.py b/src/seb/interfaces/mteb_task.py index 743b5b7b..b18a1512 100644 --- a/src/seb/interfaces/mteb_task.py +++ b/src/seb/interfaces/mteb_task.py @@ -12,18 +12,6 @@ from .task import DescriptiveDatasetStats, Task -<<<<<<< HEAD -class MTEBTaskModel(Encoder): - def __init__(self, mteb_model: Encoder, task: Task) -> None: - self.mteb_model = mteb_model - self.task = task - - def encode(self, texts: list[str], **kwargs: Any) -> np.ndarray: # type: ignore - return self.mteb_model.encode(texts, task=self.task, **kwargs) # type: ignore - - -======= ->>>>>>> f427a6da949fb2f7eb14f1014e9ae7a43df3a109 class MTEBTask(Task): def __init__(self, mteb_task: AbsTask) -> None: self.mteb_task = mteb_task diff --git a/src/seb/registered_models/cohere_models.py b/src/seb/registered_models/cohere_models.py index d11830db..32d2b5cd 100644 --- a/src/seb/registered_models/cohere_models.py +++ b/src/seb/registered_models/cohere_models.py @@ -44,32 +44,25 @@ def encode( task: Optional[Task] = None, **kwargs: Any, # noqa: ARG002 ) -> torch.Tensor: - if task.task_type == "Classification": + if task and task.task_type == "Classification": input_type = "classification" - elif task.task_type == "Clustering": + elif task and task.task_type == "Clustering": input_type = "clustering" else: input_type = "search_document" return self._embed(sentences, input_type=input_type) - def encode_queries(self, queries: list[str], batch_size: int, **kwargs): + def encode_queries(self, queries: list[str], batch_size: int, **kwargs): # noqa return self._embed(queries, input_type="search_query") - def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs): - if type(corpus) is dict: + def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs): # noqa + if isinstance(corpus, dict): sentences = [ - (corpus["title"][i] + self.sep + corpus["text"][i]).strip() - if "title" in corpus - else corpus["text"][i].strip() - for i in range(len(corpus["text"])) + (corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() # type: ignore + for i in range(len(corpus["text"])) # type: ignore ] else: - sentences = [ - (doc["title"] + self.sep + doc["text"]).strip() - if "title" in doc - else doc["text"].strip() - for doc in corpus - ] + sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus] return self._embed(sentences, input_type="search_document") diff --git a/src/seb/registered_models/e5_models.py b/src/seb/registered_models/e5_models.py index 278a1e36..76bc8e7a 100644 --- a/src/seb/registered_models/e5_models.py +++ b/src/seb/registered_models/e5_models.py @@ -16,35 +16,28 @@ def __init__(self, model_name: str, sep: str = " "): self.mdl = SentenceTransformer(model_name) self.sep = sep - def encode( + def encode( # type: ignore self, sentences: list[str], *, - task: Task, + task: Task, # noqa: ARG002 batch_size: int = 32, **kwargs: Any, ) -> ArrayLike: return self.encode_queries(sentences, batch_size=batch_size, **kwargs) - def encode_queries(self, queries: list[str], batch_size: int, **kwargs): + def encode_queries(self, queries: list[str], batch_size: int, **kwargs): # noqa sentences = ["query: " + sentence for sentence in queries] return self.mdl.encode(sentences, batch_size=batch_size, **kwargs) - def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs): - if type(corpus) is dict: + def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs): # noqa + if isinstance(corpus, dict): sentences = [ - (corpus["title"][i] + self.sep + corpus["text"][i]).strip() - if "title" in corpus - else corpus["text"][i].strip() - for i in range(len(corpus["text"])) + (corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() # type: ignore + for i in range(len(corpus["text"])) # type: ignore ] else: - sentences = [ - (doc["title"] + self.sep + doc["text"]).strip() - if "title" in doc - else doc["text"].strip() - for doc in corpus - ] + sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus] sentences = ["passage: " + sentence for sentence in sentences] return self.mdl.encode(sentences, batch_size=batch_size, **kwargs) diff --git a/src/seb/registered_models/openai_models.py b/src/seb/registered_models/openai_models.py index c1eede25..7b74b9bf 100644 --- a/src/seb/registered_models/openai_models.py +++ b/src/seb/registered_models/openai_models.py @@ -70,7 +70,7 @@ def embed(self, sentences: Sequence[str]) -> torch.Tensor: vectors = [embedding.embedding for embedding in data] return torch.tensor(vectors) - def encode( + def encode( # type: ignore self, sentences: Sequence[str], *, diff --git a/src/seb/registered_models/translate_e5_models.py b/src/seb/registered_models/translate_e5_models.py index 39835733..217bac06 100644 --- a/src/seb/registered_models/translate_e5_models.py +++ b/src/seb/registered_models/translate_e5_models.py @@ -35,7 +35,7 @@ def encode( **kwargs: Any, ) -> torch.Tensor: try: - src_lang = task.languages[0] + src_lang = task.languages[0] # type: ignore except IndexError: # Danish is the default fallback if no language is specified for the task. src_lang = "da"