Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GSK-1191 Added support of scan Transformation Function to the Hub #1791

Merged
merged 7 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions docs/knowledge/catalogs/transformation-function-catalog/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@ Transformation functions
- **Text transformation function**

- :func:`~.giskard.functions.transformation.keyboard_typo_transformation`
- :func:`~.giskard.functions.transformation.uppercase_transformation`
- :func:`~.giskard.functions.transformation.lowercase_transformation`
- :func:`~.giskard.functions.transformation.strip_punctuation`
- :func:`~.giskard.functions.transformation.text_uppercase`
- :func:`~.giskard.functions.transformation.text_lowercase`
- :func:`~.giskard.functions.transformation.text_title_case`
- :func:`~.giskard.functions.transformation.text_typo`
- :func:`~.giskard.functions.transformation.text_typo_from_ocr`
- :func:`~.giskard.functions.transformation.text_punctuation_removal`
- :func:`~.giskard.functions.transformation.text_accent_removal`
- :func:`~.giskard.functions.transformation.text_gender_switch`
- :func:`~.giskard.functions.transformation.text_number_to_word`
- :func:`~.giskard.functions.transformation.text_religion_switch`
- :func:`~.giskard.functions.transformation.text_nationality_switch`
- :func:`~.giskard.functions.transformation.text_typo_from_speech`
16 changes: 12 additions & 4 deletions docs/reference/transformation-functions/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,18 @@ Textual transformation functions
--------------------------------

.. autofunction:: giskard.functions.transformation.keyboard_typo_transformation
.. autofunction:: giskard.functions.transformation.uppercase_transformation
.. autofunction:: giskard.functions.transformation.lowercase_transformation
.. autofunction:: giskard.functions.transformation.strip_punctuation
.. autofunction:: giskard.functions.transformation.change_writing_style
.. autofunction:: giskard.functions.transformation.text_uppercase
.. autofunction:: giskard.functions.transformation.text_lowercase
.. autofunction:: giskard.functions.transformation.text_title_case
.. autofunction:: giskard.functions.transformation.text_typo
.. autofunction:: giskard.functions.transformation.text_typo_from_ocr
.. autofunction:: giskard.functions.transformation.text_punctuation_removal
.. autofunction:: giskard.functions.transformation.text_accent_removal
.. autofunction:: giskard.functions.transformation.text_gender_switch
.. autofunction:: giskard.functions.transformation.text_number_to_word
.. autofunction:: giskard.functions.transformation.text_religion_switch
.. autofunction:: giskard.functions.transformation.text_nationality_switch
.. autofunction:: giskard.functions.transformation.text_typo_from_speech


Special transformations used by the scan
Expand Down
107 changes: 78 additions & 29 deletions giskard/functions/transformation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import random
import re
import string

import numpy as np
import pandas as pd
from scipy.stats import median_abs_deviation

from giskard.registry.transformation_function import transformation_function

from ..datasets import Dataset
from ..llm import get_default_client
from ..registry.transformation_function import transformation_function

nearbykeys = {
"a": ["q", "w", "s", "x", "z"],
Expand Down Expand Up @@ -71,39 +69,90 @@ def keyboard_typo_transformation(text: str, rate: float = 0.1) -> str:
return " ".join(words)


@transformation_function(name="To uppercase", tags=["text"], cell_level=True)
def uppercase_transformation(text: str) -> str:
"""
Transform the text to uppercase
"""
return np.nan if pd.isnull(text) else text.upper()
@transformation_function(name="Transform to uppercase", row_level=False)
def text_uppercase(data: pd.DataFrame, column: str):
from ..scanner.robustness.text_transformations import TextUppercase

return TextUppercase(column).execute(data)

@transformation_function(name="To lowercase", tags=["text"], cell_level=True)
def lowercase_transformation(text: str) -> str:
"""
Transform the text of the column 'column_name' to lowercase
"""
return np.nan if pd.isnull(text) else text.lower()

@transformation_function(name="Transform to lowercase", row_level=False)
def text_lowercase(data: pd.DataFrame, column: str):
from ..scanner.robustness.text_transformations import TextLowercase

@transformation_function(name="Strip punctuation", tags=["text"], cell_level=True)
def strip_punctuation(text: str) -> str:
"""
Remove all punctuation symbols (e.g., ., !, ?) from the text of the column 'column_name'
"""
if pd.isnull(text):
return text
return TextLowercase(column).execute(data)


@transformation_function(name="Transform to title case", row_level=False)
def text_title_case(data: pd.DataFrame, column: str):
from ..scanner.robustness.text_transformations import TextTitleCase

return TextTitleCase(column).execute(data)


@transformation_function(name="Add typos", row_level=False)
def text_typo(data: pd.DataFrame, column: str, rate: float = 0.05, min_length: int = 10, rng_seed: int = 1729):
from ..scanner.robustness.text_transformations import TextTypoTransformation

return TextTypoTransformation(column, rate, min_length, rng_seed).execute(data)


@transformation_function(name="Add typos from OCR", row_level=False)
def text_typo_from_ocr(data: pd.DataFrame, column: str, rate: float = 0.05, min_length: int = 10, rng_seed: int = 1729):
from ..scanner.robustness.text_transformations import TextFromOCRTypoTransformation

return TextFromOCRTypoTransformation(column, rate, min_length, rng_seed).execute(data)


@transformation_function(name="Punctuation Removal", row_level=False)
def text_punctuation_removal(data: pd.DataFrame, column: str):
from ..scanner.robustness.text_transformations import TextPunctuationRemovalTransformation

return TextPunctuationRemovalTransformation(column).execute(data)


@transformation_function(name="Accent Removal", row_level=False)
def text_accent_removal(data: pd.DataFrame, column: str, rate: float = 1.0, rng_seed: int = 1729):
from ..scanner.robustness.text_transformations import TextAccentRemovalTransformation

return TextAccentRemovalTransformation(column, rate, rng_seed).execute(data)


@transformation_function(name="Switch Gender", row_level=False, needs_dataset=True)
def text_gender_switch(dataset: Dataset, column: str):
from ..scanner.robustness.text_transformations import TextGenderTransformation

return TextGenderTransformation(column).execute(dataset)


@transformation_function(name="Transform numbers to words", row_level=False, needs_dataset=True)
def text_number_to_word(dataset: Dataset, column: str):
from ..scanner.robustness.text_transformations import TextNumberToWordTransformation

return TextNumberToWordTransformation(column).execute(dataset)


@transformation_function(name="Switch Religion", row_level=False, needs_dataset=True)
def text_religion_switch(dataset: Dataset, column: str):
from ..scanner.robustness.text_transformations import TextReligionTransformation

return TextReligionTransformation(column).execute(dataset)


@transformation_function(
name="Switch countries from high- to low-income and vice versa", row_level=False, needs_dataset=True
)
def text_nationality_switch(dataset: Dataset, column: str):
from ..scanner.robustness.text_transformations import TextNationalityTransformation

split_urls_from_text = gruber.split(text)
return TextNationalityTransformation(column).execute(dataset)

# The non-URLs are always even-numbered entries in the list and the URLs are odd-numbered.
for i in range(0, len(split_urls_from_text), 2):
split_urls_from_text[i] = split_urls_from_text[i].translate(str.maketrans("", "", string.punctuation))

stripped_text = "".join(split_urls_from_text)
@transformation_function(name="Add text from speech typos", row_level=False, needs_dataset=True)
def text_typo_from_speech(dataset: Dataset, column: str, rng_seed: int = 1729, min_length: int = 10):
from ..scanner.robustness.text_transformations import TextFromSpeechTypoTransformation

return stripped_text
return TextFromSpeechTypoTransformation(column, rng_seed, min_length).execute(dataset)


@transformation_function(name="Change writing style", row_level=False, tags=["text"])
Expand Down
20 changes: 15 additions & 5 deletions giskard/registry/transformation_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ class TransformationFunction(RegistryArtifact[DatasetProcessFunctionMeta]):
def _get_name(cls) -> str:
return "transformations"

def __init__(self, func: Optional[TransformationFunctionType], row_level=True, cell_level=False):
def __init__(
self, func: Optional[TransformationFunctionType], row_level=True, cell_level=False, needs_dataset=False
):
self.func = func
self.row_level = row_level
self.cell_level = cell_level
self.needs_dataset = needs_dataset

test_uuid = get_object_uuid(func)
meta = tests_registry.get_test(test_uuid)
Expand Down Expand Up @@ -99,6 +102,7 @@ def transformation_function(
cell_level=False,
name=None,
tags: Optional[List[str]] = None,
needs_dataset=False,
):
"""
Decorator that registers a function as a transformation function and returns a TransformationFunction instance.
Expand Down Expand Up @@ -130,19 +134,25 @@ def inner(func: Union[TransformationFunctionType, Type[TransformationFunction]])

if inspect.isclass(func) and issubclass(func, TransformationFunction):
return func
return _wrap_transformation_function(func, row_level, cell_level)()
return _wrap_transformation_function(func, row_level, cell_level, needs_dataset)()

if callable(_fn):
return functools.wraps(_fn)(inner(_fn))
else:
return inner


def _wrap_transformation_function(original: Callable, row_level: bool, cell_level: bool):
transformation_fn = functools.wraps(original)(TransformationFunction(original, row_level, cell_level))
def _wrap_transformation_function(original: Callable, row_level: bool, cell_level: bool, needs_dataset: bool):
transformation_fn = functools.wraps(original)(
TransformationFunction(original, row_level, cell_level, needs_dataset)
)

if not cell_level:
validate_arg_type(transformation_fn, 0, pd.Series if row_level else pd.DataFrame)
from ..datasets import Dataset

validate_arg_type(
transformation_fn, 0, pd.Series if row_level else (Dataset if needs_dataset else pd.DataFrame)
)

drop_arg(transformation_fn, 0)

Expand Down
8 changes: 3 additions & 5 deletions giskard/scanner/robustness/text_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
class TextTransformation(TransformationFunction):
name: str

def __init__(self, column):
super().__init__(None, row_level=False, cell_level=False)
def __init__(self, column, needs_dataset=False):
super().__init__(None, row_level=False, cell_level=False, needs_dataset=needs_dataset)
self.column = column
self.meta = DatasetProcessFunctionMeta(type="TRANSFORMATION")
self.meta.uuid = get_object_uuid(self)
Expand Down Expand Up @@ -212,10 +212,8 @@ def make_perturbation(self, text):


class TextLanguageBasedTransformation(TextTransformation):
needs_dataset = True

def __init__(self, column, rng_seed=1729):
super().__init__(column)
super().__init__(column, needs_dataset=True)
self._lang_dictionary = dict()
self._load_dictionaries()
self.rng = np.random.default_rng(seed=rng_seed)
Expand Down
32 changes: 32 additions & 0 deletions tests/functions/test_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pandas as pd

from giskard import Dataset


def _dataset_from_dict(data):
return Dataset(pd.DataFrame(data), target=None)


def test_gender_transformation():
dataset = _dataset_from_dict(
{
"text": [
"We just got this and my daughter loves it. She has played it several times.",
"It did not work.",
"“They pushed the feature just 1 minute before the user test”",
"He hates doing user tests! for his company",
"Il déteste faire les tests en langue française",
]
}
)

from giskard.functions.transformation import text_gender_switch

transformed = dataset.transform(text_gender_switch(column="text"))

transformed_text = transformed.df.text.str.lower().values
assert transformed_text[0] == "We just got this and my son loves it. He has played it several times.".lower()
assert transformed_text[1] == "It did not work.".lower()
assert transformed_text[2] == "“They pushed the feature just 1 minute before the user test”".lower()
assert transformed_text[3] == "She hates doing user tests! for her company".lower()
assert transformed_text[4] == "Elle déteste faire les tests en langue français".lower()
Loading