Skip to content

Commit

Permalink
Update datasets (#309)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabegma committed Dec 20, 2022
1 parent 8c1007e commit 7afbd5c
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 33 deletions.
4 changes: 2 additions & 2 deletions azimuth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from pathlib import Path

from datasets import set_progress_bar_enabled
from datasets import disable_progress_bar
from datasets.utils.logging import set_verbosity_error

from azimuth.utils.logs import set_logger_config

PROJECT_ROOT = str(Path(__file__).parent)

set_verbosity_error()
set_progress_bar_enabled(False)
disable_progress_bar()
set_logger_config()
3 changes: 1 addition & 2 deletions azimuth/routers/v1/utterances.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,8 @@ def get_similar(
item_scores = dict(
source_ds.select([index])[f"neighbors_{neighbors_dataset_split_name}"][0][:limit]
)
# NOTE: idx may be float in the HF Dataset.
items: Dict[int, Dict] = {
idx: neighbors_ds_with_class_names[int(idx)] for idx in item_scores.keys() # type: ignore
idx: neighbors_ds_with_class_names[int(idx)] for idx in item_scores.keys()
}
# Build utterances from `items`
similar_utterances = [
Expand Down
44 changes: 33 additions & 11 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ spacy = "3.1.5"
en_core_web_sm = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.1.0/en_core_web_sm-3.1.0.tar.gz#egg=en_core_web_sm"}
fr_core_news_md = {url = "https://github.com/explosion/spacy-models/releases/download/fr_core_news_md-3.1.0/fr_core_news_md-3.1.0.tar.gz#egg=fr_core_news_md"}
nlpaug = "1.1.10"
datasets = "1.16.1"
datasets = "2.1.0"

# Misc
filelock = "^3.0.12"
Expand Down
5 changes: 1 addition & 4 deletions tests/test_dataset_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,8 @@ def test_custom_persistent_id(simple_text_config, a_text_dataset):
dataset_split=a_text_dataset,
)

# TODO Bug that will be fixed when updating datasets.
# 872 should be replaced by len(a_text_dataset).
dataset_length = 872
a_text_dataset_new_col = a_text_dataset.add_column(
name="yukongold", column=[1] * dataset_length
name="yukongold", column=[1] * len(a_text_dataset)
)
with pytest.raises(ValueError, match="need to be unique"):
DatasetSplitManager(
Expand Down
10 changes: 7 additions & 3 deletions tests/test_loading_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# in the root directory of this source tree.
import string
from dataclasses import dataclass
from typing import Callable, List
from typing import Callable, Dict, List

import numpy as np
import torch
Expand Down Expand Up @@ -73,9 +73,13 @@ def load_intent_data(train_path, test_path, python_loader) -> Dataset:
return load_dataset(python_loader, data_files={"train": train_path, "test": test_path})


def load_file_dataset(*args, azimuth_config, **kwargs):
def load_file_dataset(data_files: Dict[str, str], azimuth_config):
# Load a file dataset and cast the label column as a ClassLabel.
ds_dict = load_dataset(*args, **kwargs)
# Train and test need to be loaded separately because they don't always share the same columns.
# Train sometimes don't have predictions. HF will complain if we load both together.
ds_dict = load_dataset("csv", data_files={"train": data_files["train"]})
ds_dict_test = load_dataset("csv", data_files={"test": data_files["test"]})
ds_dict.update(ds_dict_test)
features: Features = [v.features for v in ds_dict.values()][0]
if not isinstance(features[azimuth_config.columns.label], ClassLabel):
# Get all classes from both set and apply the same mapping to every dataset.
Expand Down
24 changes: 15 additions & 9 deletions tests/test_modules/test_dataset_analysis/test_dataset_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import pytest
from datasets import ClassLabel

from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.modules.dataset_analysis.dataset_warnings import DatasetWarningsModule
Expand Down Expand Up @@ -86,17 +87,22 @@ def test_compute_on_dataset_split(

def add_rejection_class(mod, monkeypatch):
# Adding a rejection class
eval_dm: DatasetSplitManager = mod.get_dataset_split_manager(DatasetSplitName.eval)
train_dm: DatasetSplitManager = mod.get_dataset_split_manager(DatasetSplitName.train)
eval_dm._base_dataset_split.features["label"].names.append("NO_INTENT") # Should be 2
train_dm._base_dataset_split.features["label"].names.append("NO_INTENT") # Should be 2
eval_dm._base_dataset_split = eval_dm._base_dataset_split.map(
lambda u, i: {"label": 2 if i % 10 == 0 else u["label"]}, with_indices=True
)
dms = {
DatasetSplitName.eval: eval_dm,
DatasetSplitName.train: train_dm,
DatasetSplitName.eval: mod.get_dataset_split_manager(DatasetSplitName.eval),
DatasetSplitName.train: mod.get_dataset_split_manager(DatasetSplitName.train),
}

existing_classes = dms[DatasetSplitName.eval].get_class_names(labels_only=True)
class_label = ClassLabel(num_classes=3, names=existing_classes + ["NO_INTENT"])
for dm in dms.values():
dm._base_dataset_split.features["label"] = class_label

dms[DatasetSplitName.eval]._base_dataset_split = dms[
DatasetSplitName.eval
]._base_dataset_split.map(
lambda u, i: {"label": 2 if i % 10 == 0 else u["label"]}, with_indices=True
)

# Modifying the config reset the ArtifactManager, which we do not want.
monkeypatch.setattr(mod, "get_dataset_split_manager", lambda s: dms[s])
mod.config.rejection_class = "NO_INTENT"
Expand Down
1 change: 0 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def file_based_ds_from_paths(
):
return {
"class_name": "tests.test_loading_resources.load_file_dataset",
"args": ["csv"],
"kwargs": {
"data_files": {
"train": train,
Expand Down

0 comments on commit 7afbd5c

Please sign in to comment.