Skip to content

Commit

Permalink
prepare for rebase on main
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <dafnashein@yahoo.com>
  • Loading branch information
dafnapension committed May 8, 2024
1 parent 50e44f5 commit b2203d4
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
25 changes: 24 additions & 1 deletion src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import tempfile
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, List, Mapping, Optional, Sequence, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import pandas as pd
from datasets import load_dataset as hf_load_dataset
Expand Down Expand Up @@ -474,3 +474,26 @@ def process(self):
return FixedFusion(
origins=self.sources, max_instances_per_origin=self.get_limit()
).process()


class LoadFromDictionary(Loader):
"""Allows loading data from dictionary of constants.
The loader can be used, for example, when debugging or working with small datasets.
Attributes:
data (Dict[str, List[Dict[str, Any]]]): a dictionary of constants from which the data will be loaded
Examples:
data = {
"train": {"input": "SomeInput1", "output": "SomeResult1"},
"test": {"input": "SomeInput2", "output": "SomeResult2"},
}
loader = LoadFromDictionary(data=data)
multi_stream = loader.process()
"""

data: Dict[str, List[Dict[str, Any]]]

def process(self) -> MultiStream:
return MultiStream.from_iterables(self.data)
25 changes: 24 additions & 1 deletion tests/library/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

import ibm_boto3
import pandas as pd
from unitxt.loaders import LoadCSV, LoadFromIBMCloud, LoadHF, MultipleSourceLoader
from unitxt.loaders import (
LoadCSV,
LoadFromDictionary,
LoadFromIBMCloud,
LoadHF,
MultipleSourceLoader,
)
from unitxt.logging_utils import get_logger

from tests.utils import UnitxtTestCase
Expand Down Expand Up @@ -228,3 +234,20 @@ def test_multiple_source_loader(self):
ms = loader()
assert len(dfs["test"]) == len(list(ms["test"]))
assert len(dfs["train"]) == len(list(ms["train"]))

def test_load_from_dictionary(self):
data = {
"train": [
{"input": "Input1", "output": "Result1"},
{"input": "Input2", "output": "Result2"},
],
"test": [
{"input": "Input3", "output": "Result3"},
],
}
loader = LoadFromDictionary(data=data)
streams = loader.process()

for split, instances in data.items():
for original_instance, stream_instance in zip(instances, streams[split]):
self.assertEqual(original_instance, stream_instance)

0 comments on commit b2203d4

Please sign in to comment.