Skip to content

Commit

Permalink
Merge pull request #97 from KennethEnevoldsen/add-danfever
Browse files Browse the repository at this point in the history
Add danFEVER
  • Loading branch information
KennethEnevoldsen committed Jan 26, 2024
2 parents ec412da + a572962 commit 801753f
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name":"DanFEVER","task_description":"A Danish dataset intended for misinformation research. It follows the same format as the English FEVER dataset.","task_version":"1.1.1","time_of_run":"2024-01-25T16:46:10.510144","scores":{"da":{"ndcg_at_1":0.28982,"ndcg_at_3":0.36722,"ndcg_at_5":0.37753,"ndcg_at_10":0.38335,"ndcg_at_100":0.38781,"ndcg_at_1000":0.388,"map_at_1":0.28974,"map_at_3":0.34882,"map_at_5":0.35458,"map_at_10":0.35702,"map_at_100":0.35811,"map_at_1000":0.35812,"recall_at_1":0.28974,"recall_at_3":0.42013,"recall_at_5":0.445,"recall_at_10":0.46273,"recall_at_100":0.48188,"recall_at_1000":0.48329,"precision_at_1":0.28982,"precision_at_3":0.14007,"precision_at_5":0.08903,"precision_at_10":0.0463,"precision_at_100":0.00482,"precision_at_1000":0.00048,"mrr_at_1":0.28982,"mrr_at_3":0.34889,"mrr_at_5":0.35463,"mrr_at_10":0.35709,"mrr_at_100":0.35815,"mrr_at_1000":0.35816}},"main_score":"ndcg_at_10"}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name":"DanFEVER","task_description":"A Danish dataset intended for misinformation research. It follows the same format as the English FEVER dataset.","task_version":"1.1.1","time_of_run":"2024-01-25T16:46:56.185726","scores":{"da":{"ndcg_at_1":0.21732,"ndcg_at_3":0.26397,"ndcg_at_5":0.27362,"ndcg_at_10":0.28184,"ndcg_at_100":0.29779,"ndcg_at_1000":0.30425,"map_at_1":0.21732,"map_at_3":0.25292,"map_at_5":0.25828,"map_at_10":0.26166,"map_at_100":0.26466,"map_at_1000":0.26488,"recall_at_1":0.21732,"recall_at_3":0.29578,"recall_at_5":0.31916,"recall_at_10":0.34466,"recall_at_100":0.42225,"recall_at_1000":0.47434,"precision_at_1":0.21732,"precision_at_3":0.09859,"precision_at_5":0.06383,"precision_at_10":0.03447,"precision_at_100":0.00423,"precision_at_1000":0.00047,"mrr_at_1":0.21732,"mrr_at_3":0.25297,"mrr_at_5":0.2583,"mrr_at_10":0.26168,"mrr_at_100":0.26468,"mrr_at_1000":0.2649}},"main_score":"ndcg_at_10"}
2 changes: 1 addition & 1 deletion src/seb/interfaces/mteb_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def evaluate(self, model: Encoder) -> TaskResult:
task_description=self.description,
task_version=self.version,
time_of_run=time_of_run,
scores=scores,
scores=scores, # type: ignore
main_score=self.main_score,
)

Expand Down
11 changes: 6 additions & 5 deletions src/seb/registered_models/fasttext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Sequence
from functools import partial
from typing import Any

import numpy as np
import torch
Expand All @@ -10,8 +11,8 @@

class FastTextModel(seb.Encoder):
def __init__(self, model_name: str, lang: str) -> None:
import fasttext
import fasttext.util
import fasttext # type: ignore
import fasttext.util # type: ignore

fasttext.util.download_model(self.lang, if_exists="ignore")
self.model = fasttext.load_model(self.model_name)
Expand All @@ -22,16 +23,16 @@ def get_embedding_dim(self) -> int:
v = self.encode(["get emb dim"])
return v.shape[1]

def encode(
def encode( # type: ignore
self,
sentences: Sequence[str],
**kwargs: dict, # noqa: ARG002
**kwargs: Any, # noqa: ARG002
) -> torch.Tensor:
embeddings = []
for sentence in sentences:
# This is to appease FastText as they made the function err
# if there's a \n in the sentence.
sentence = " ".join(sentence.split())
sentence = " ".join(sentence.split()) # noqa
sentence_embedding = self.model.get_sentence_vector(sentence)
embeddings.append(sentence_embedding)
return torch.tensor(np.stack(embeddings))
Expand Down
10 changes: 10 additions & 0 deletions src/seb/registered_tasks/danish.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,13 @@ def create_da_political_comments() -> Task:
task.domain = ["social"]
task.reference = "https://huggingface.co/datasets/danish_political_comments" # TODO: Make a PR for MTEB to add this reference
return task


@tasks.register("DanFEVER")
def create_dan_fever() -> Task:
from .mteb_retrieval import DanFever

task = MTEBTask(DanFever())
task.name = "DanFEVER"
task.domain = ["wiki", "non-fiction"]
return task
81 changes: 81 additions & 0 deletions src/seb/registered_tasks/mteb_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Any

import datasets
from mteb.abstasks import AbsTaskRetrieval


class DanFever(AbsTaskRetrieval):
@property
def description(self) -> dict[str, Any]:
return {
"name": "DanFEVER",
"hf_hub_name": "strombergnlp/danfever",
"description": "A Danish dataset intended for misinformation research. It follows the same format as the English FEVER dataset.",
"reference": "https://aclanthology.org/2021.nodalida-main.47/",
"type": "Retrieval",
"category": "p2p",
"eval_splits": ["train"],
"eval_langs": ["da"],
"main_score": "ndcg_at_10",
"revision": "5d01e3f6a661d48e127ab5d7e3aaa0dc8331438a",
}

def load_data(self, **kwargs: dict): # noqa: ARG002
"""
Load dataset from HuggingFace hub
"""
if self.data_loaded:
return

self.dataset: datasets.DatasetDict = datasets.load_dataset(
self.description["hf_hub_name"],
revision=self.description.get("revision"),
) # type: ignore

self.dataset_transform()
self.data_loaded = True

def dataset_transform(self) -> None:
"""
and transform to a retrieval datset, which have the following attributes
self.corpus = Dict[doc_id, Dict[str, str]] #id => dict with document data like title and text
self.queries = Dict[query_id, str] #id => query
self.relevant_docs = Dict[query_id, Dict[[doc_id, score]]
"""
self.corpus = {}
self.relevant_docs = {}
self.queries = {}
text2id = {}

for split in self.dataset:
self.corpus[split] = {}
self.relevant_docs[split] = {}
self.queries[split] = {}

ds = self.dataset[split]
claims = ds["claim"]
evidences = ds["evidence_extract"]
labels = ds["label"]
class_labels = ds.features["label"].names

for claim, evidence, label_id in zip(claims, evidences, labels):
claim_is_supported = class_labels[label_id] == "Supported"

sim = 1 if claim_is_supported else 0 # negative for refutes claims - is that what we want?

if claim not in text2id:
text2id[claim] = str(len(text2id))
if evidence not in text2id:
text2id[evidence] = len(text2id)

claim_id = str(text2id[claim])
evidence_id = str(text2id[evidence])

self.queries[split][claim_id] = claim
self.corpus[split][evidence_id] = {"title": "", "text": evidence}

if claim_id not in self.relevant_docs[split]:
self.relevant_docs[split][claim_id] = {}

self.relevant_docs[split][claim_id][evidence_id] = sim
2 changes: 1 addition & 1 deletion src/seb/result_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_main_score(self, lang: Optional[Iterable[str]] = None) -> float:
lang = self.scores.keys()

for l in lang:
main_scores.append(self.scores[l][self.main_score])
main_scores.append(self.scores[l][self.main_score]) # type: ignore

return sum(main_scores) / len(main_scores)

Expand Down
1 change: 0 additions & 1 deletion tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def to_command(self, output_path: Path) -> list[str]:
BenchmarkCliTestInput("sentence-transformers/all-MiniLM-L6-v2", 0.550, tasks=["DKHate"]),
BenchmarkCliTestInput("sentence-transformers/all-MiniLM-L6-v2", 0.525, tasks=["DKHate", "ScaLA"]),
BenchmarkCliTestInput("sentence-transformers/all-MiniLM-L6-v2", 0.50, tasks=["DKHate", "ScaLA"], languages=["sv", "nn", "nb"]),
BenchmarkCliTestInput("sentence-transformers/all-MiniLM-L6-v2", 0.423, languages=["da"]),
BenchmarkCliTestInput(
"test_model", np.nan, code_path=(test_dir / "benchmark_cli_code_inject.py"), tasks=["test-encode-task"], ignore_cache=True
),
Expand Down

0 comments on commit 801753f

Please sign in to comment.