Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/ner csv export #659

Merged
merged 4 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
hooks:
- id: pydocstyle
args: [
"--add-ignore=D100,D104,D105,D400,D415",
"--add-ignore=D100,D104,D105,D400,D415,D419",
"--add-select=D417",
"--convention=google"
]
Expand Down
33 changes: 22 additions & 11 deletions langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import csv
import importlib
from collections import defaultdict
import os
import re
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -373,11 +374,10 @@ def export_data(self, data: List[Sample], output_path: str):
output_path (str):
path to save the data to
"""
temp_id = None
otext = ""
for i in data:
text, temp_id = Formatter.process(i, output_format="conll", temp_id=temp_id)
otext += text
text = Formatter.process(i, output_format="conll")
otext += text + "\n"

with open(output_path, "wb") as fwriter:
fwriter.write(bytes(otext, encoding="utf-8"))
Expand Down Expand Up @@ -534,15 +534,26 @@ def export_data(self, data: List[Sample], output_path: str):
output_path (str):
path to save the data to
"""
temp_id = None
otext = ""
if self.task == "ner":
for i in data:
text, temp_id = Formatter.process(i, output_format="csv", temp_id=temp_id)
otext += text

with open(output_path, "wb") as fwriter:
fwriter.write(bytes(otext, encoding="utf-8"))
final_data = defaultdict(list)
for elt in data:
tokens, labels, testcase_tokens, testcase_labels = Formatter.process(
elt, output_format="csv"
)
final_data["text"].append(tokens)
final_data["ner"].append(labels)
final_data["testcase_text"].append(testcase_tokens)
final_data["testcase_labels"].append(testcase_labels)

if (
sum([len(labels) for labels in final_data["testcase_labels"]])
* sum([len(tokens) for tokens in final_data["testcase_text"]])
== 0
):
final_data.pop("testcase_text")
final_data.pop("testcase_labels")

pd.DataFrame(data=final_data).to_csv(output_path, index=False)

elif self.task == "text-classification":
rows = []
Expand Down
163 changes: 74 additions & 89 deletions langtest/datahandler/format.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from typing import Tuple

from typing import Tuple, List, Union
import re
from ..utils.custom_types import Sample


Expand Down Expand Up @@ -91,11 +91,10 @@ class SequenceClassificationOutputFormatter(BaseFormatter, ABC):

@staticmethod
def to_csv(sample: Sample) -> str:
"""
Convert a Sample object into a row for exporting.
"""Convert a Sample object into a row for exporting.

Args:
Sample :
Sample:
Sample object to convert.

Returns:
Expand All @@ -118,115 +117,101 @@ class NEROutputFormatter(BaseFormatter):

@staticmethod
def to_csv(
sample: Sample, delimiter: str = ",", temp_id: int = None
) -> Tuple[str, int]:
sample: Sample, delimiter: str = ","
) -> Tuple[List[str], List[str], List[str], List[str]]:
"""Converts a custom type to a CSV string.

Args:
sample (Sample):
The input sample containing the `NEROutput` object to convert.
delimiter (str):
The delimiter character to use in the CSV string.
temp_id (int):
A temporary ID to use for grouping entities by document.

Returns:
Tuple[str, int]:
The CSV or CoNLL string representation of the `NEROutput` object along with the document id
Tuple[List[str], List[str], List[str], List[str]]:
tuple containing the list of tokens of the original sentence, the list of
labels of the original sentence, the list of tokens for the perturbed sentence
and the labels of the perturbed sentence.
"""
text = ""
test_case = sample.test_case
original = sample.original

words = re.finditer(r"([^\s]+)", original)
tokens, labels = [], []

for word in words:
tokens.append(word.group())
match = sample.expected_results[word.group()]
labels.append(match.entity if match is not None else "O")

assert len([label for label in labels if label != "O"]) == len(
sample.expected_results
)

if test_case:
test_case_items = test_case.split()
norm_test_case_items = test_case.lower().split()
norm_original_items = original.lower().split()
temp_len = 0
for jdx, item in enumerate(norm_test_case_items):
if item in norm_original_items and jdx >= norm_original_items.index(item):
oitem_index = norm_original_items.index(item)
j = sample.expected_results.predictions[oitem_index + temp_len]
if temp_id != j.doc_id and jdx == 0:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{test_case_items[jdx]}{delimiter}{j.pos_tag}{delimiter}{j.chunk_tag}{delimiter}{j.entity}\n"
norm_original_items.pop(oitem_index)
temp_len += 1
else:
o_item = norm_original_items[jdx - temp_len]
letters_count = len(set(o_item) - set(item))
if len(norm_test_case_items) == len(
norm_original_items
) or letters_count < len(o_item):
tl = sample.expected_results.predictions[jdx]
text += f"{test_case_items[jdx]}{delimiter}{tl.pos_tag}{delimiter}{tl.chunk_tag}{delimiter}{tl.entity}\n"
else:
text += f"{test_case_items[jdx]}{delimiter}O{delimiter}O{delimiter}O\n"
text += "\n"
test_case_words = re.finditer(r"([^\s]+)", test_case)
test_case_tokens, test_case_labels = [], []

else:
for j in sample.expected_results.predictions:
if temp_id != j.doc_id:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{j.span.word}{delimiter}{j.pos_tag}{delimiter}{j.chunk_tag}{delimiter}{j.entity}\n"
text += "\n"
return text, temp_id
for word in test_case_words:
test_case_tokens.append(word.group())
match = sample.actual_results[word.group()]
test_case_labels.append(match.entity if match is not None else "O")

assert len([token for token in test_case_tokens if token != "O"]) == len(
sample.actual_results
)
return tokens, labels, test_case_tokens, test_case_labels
return tokens, labels, [], []

@staticmethod
def to_conll(sample: Sample, temp_id: int = None) -> Tuple[str, int]:
def to_conll(
sample: Sample, writing_mode: str = "ignore"
) -> Union[str, Tuple[str, str]]:
"""Converts a custom type to a CoNLL string.

Args:
sample (Sample):
The input sample containing the `NEROutput` object to convert.
temp_id (int):
A temporary ID to use for grouping entities by document.
writing_mode (str):
what to do with the expected results if present:
- ignore: simply ignores the expected_results
- append: the formatted expected_results to the original ones
- separate: returns a formatted string for the original sentence and one for
the perturbed sentence

Returns:
The CoNLL string representation of the custom type.
"""
text = ""
assert writing_mode in [
"ignore",
"append",
"separate",
], f"writing_mode: {writing_mode} not supported."

text, text_perturbed = "", ""
test_case = sample.test_case
original = sample.original
if test_case:
test_case_items = test_case.split()
norm_test_case_items = test_case.lower().split()
norm_original_items = original.lower().split()
temp_len = 0
for jdx, item in enumerate(norm_test_case_items):
try:
if item in norm_original_items and jdx >= norm_original_items.index(
item
):
oitem_index = norm_original_items.index(item)
j = sample.expected_results.predictions[oitem_index + temp_len]
if temp_id != j.doc_id and jdx == 0:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
norm_original_items.pop(oitem_index)
temp_len += 1
else:
o_item = sample.expected_results.predictions[jdx].span.word
letters_count = len(set(item) - set(o_item))
if (
len(norm_test_case_items) == len(original.lower().split())
or letters_count < 2
):
tl = sample.expected_results.predictions[jdx]
text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n"
else:
text += f"{test_case_items[jdx]} O O O\n"
except IndexError:
text += f"{test_case_items[jdx]} O O O\n"
text += "\n"

else:
for j in sample.expected_results.predictions:
if temp_id != j.doc_id:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
text += "\n"
return text, temp_id
words = re.finditer(r"([^\s]+)", original)

for word in words:
token = word.group()
match = sample.expected_results[word.group()]
label = match.entity if match is not None else "O"
text += f"{token} -X- -X- {label}\n"

if test_case and writing_mode != "ignore":
words = re.finditer(r"([^\s]+)", test_case)

for word in words:
token = word.group()
match = sample.actual_results[word.group()]
label = match.entity if match is not None else "O"
if writing_mode == "append":
text += f"{token} -X- -X- {label}\n"
elif writing_mode == "separate":
text_perturbed += f"{token} -X- -X- {label}\n"

if writing_mode == "separate":
return text, text_perturbed
return text
36 changes: 16 additions & 20 deletions langtest/utils/custom_types/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@


class SequenceClassificationOutput(BaseModel):
"""
Output model for text classification tasks.
"""
"""Output model for text classification tasks."""

predictions: List[SequenceLabel]

Expand Down Expand Up @@ -68,9 +66,7 @@ def __str__(self) -> str:


class NEROutput(BaseModel):
"""
Output model for NER tasks.
"""
"""Output model for NER tasks."""

predictions: List[NERPrediction]

Expand All @@ -84,11 +80,16 @@ def __len__(self):
return len(self.predictions)

def __getitem__(
self, item: Union[Span, int]
self, item: Union[Span, int, str]
) -> Optional[Union[List[NERPrediction], NERPrediction]]:
""""""
if isinstance(item, int):
return self.predictions[item]
elif isinstance(item, str):
for pred in self.predictions:
if pred.span.word == item:
return pred
return None
elif isinstance(item, Span):
for prediction in self.predictions:
if prediction.span == item:
Expand All @@ -98,8 +99,7 @@ def __getitem__(
return [self.predictions[i] for i in range(item.start, item.stop)]

def to_str_list(self) -> str:
"""
Converts predictions into a list of strings.
"""Converts predictions into a list of strings.

Returns:
List[str]: predictions in form of a list of strings.
Expand All @@ -122,28 +122,24 @@ def __eq__(self, other: "NEROutput"):


class TranslationOutput(BaseModel):
"""
Output model for translation tasks.
"""
"""Output model for translation tasks."""

translation_text: str # Changed from List[str] to str

def to_str_list(self) -> List[str]:
"""
Returns the translation_text as a list of strings.
"""Formatting helper

Returns:
List[str]: the translation_text as a list of strings.
"""
return [self.translation_text] # Wrap self.translation_text in a list

def __str__(self):
"""
String representation of TranslationOutput.
"""
"""String representation of TranslationOutput."""
return self.translation_text # Return translation_text directly

def __eq__(self, other):
"""
Equality comparison method.
"""
"""Equality comparison method."""
if isinstance(other, TranslationOutput):
return self.translation_text == other.translation_text
if isinstance(other, list):
Expand Down
12 changes: 6 additions & 6 deletions tests/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,16 @@ def test_fix(self):
"My -X- -X- O",
"name -X- -X- O",
"is -X- -X- O",
"Jean NN NN B-PER",
"- NN NN I-PER",
"Pierre NN NN I-PER",
"Jean -X- -X- B-PER",
"- -X- -X- I-PER",
"Pierre -X- -X- I-PER",
"and -X- -X- O",
"I -X- -X- O",
"am -X- -X- O",
"from -X- -X- O",
"New NN NN B-LOC",
"York NN NN I-LOC",
"City NN NN I-LOC",
"New -X- -X- B-LOC",
"York -X- -X- I-LOC",
"City -X- -X- I-LOC",
]
generator = TemplaticAugment(
templates=["My name is {PER} and I am from {LOC}"], task="ner"
Expand Down
Loading