diff --git a/nlptest/nlptest.py b/nlptest/nlptest.py index 466981cea..8643fb27e 100644 --- a/nlptest/nlptest.py +++ b/nlptest/nlptest.py @@ -112,6 +112,26 @@ def __init__( data.get('subset', None) ) if data is not None else None + elif type(data) is dict and hub == "johnsnowlabs" and task == "text-classification": + self.data = HuggingFaceDataset(data['name']).load_data( + data.get('feature_column', 'text'), + data.get('target_column', 'label'), + data.get('split', 'test'), + data.get('subset', None) + ) if data is not None else None + + elif type(data) is dict and hub == "spacy" and task == "text-classification": + self.data = HuggingFaceDataset(data['name']).load_data( + data.get('feature_column', 'text'), + data.get('target_column', 'label'), + data.get('split', 'test'), + data.get('subset', None) + ) if data is not None else None + if model == 'textcat_imdb': + model = resource_filename("nlptest", "data/textcat_imdb") + else: + raise ValueError(f"Unsupported model '{model}'! Only 'textcat_imdb' is supported.") + elif data is None and (task, model, hub) not in self.DEFAULTS_DATASET.keys(): raise ValueError("You haven't specified any value for the parameter 'data' and the configuration you " "passed is not among the default ones. You need to either specify the parameter 'data' "