In [1]:
from datasets import (
    import_main_class,
    list_datasets,
    load_dataset,
    load_dataset_builder,
    prepare_module,
)
    
ds_list = list_datasets(with_details=True)
ds_list[0].id, ds_list[0]

('acronym_identification',
 datasets.ObjectInfo(
 	id='acronym_identification',
 	description='Acronym identification training and development sets for the acronym identification task at SDU@AAAI-21.',
 	files=None
 ))

In [2]:
# list of fiels we want to keep around from the DatasetInfo object
keep_info_fields = [
    "features",
    "config_name",
    "version",
    "splits",
]

def get_config_infos(name):
    module_path, *_ = prepare_module(name, dataset=True)
    builder_cls = import_main_class(module_path, dataset=True)
    configs = [c.name for c in builder_cls.BUILDER_CONFIGS] or [None]
    if len(configs) == 1:
        info_dict = load_dataset_builder(name).info.__dict__
        return [
            {k: info_dict[k] for k in keep_info_fields}
        ]
    else:
        config_list = []
        for config_name in configs:
            info_dict = load_dataset_builder(name, config_name).info.__dict__
            config_list += [
                {k: info_dict[k] for k in keep_info_fields}
            ]
        return config_list

In [3]:
squad_configs = get_config_infos('squad')
glue_configs = get_config_infos('glue')
squad_configs, glue_configs

([{'features': {'id': Value(dtype='string', id=None),
    'title': Value(dtype='string', id=None),
    'context': Value(dtype='string', id=None),
    'question': Value(dtype='string', id=None),
    'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)},
   'config_name': 'plain_text',
   'version': 1.0.0,
   'splits': {'train': SplitInfo(name='train', num_bytes=79317110, num_examples=87599, dataset_name='squad'),
    'validation': SplitInfo(name='validation', num_bytes=10472653, num_examples=10570, dataset_name='squad')}}],
 [{'features': {'sentence': Value(dtype='string', id=None),
    'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),
    'idx': Value(dtype='int32', id=None)},
   'config_name': 'cola',
   'version': 1.0.0,
   'splits': {'test': SplitInfo(name='test', num_bytes=61049, num_examples=1063, dataset_name='glue'),
    'train': SplitInfo(name='t

In [4]:
def get_typed_features(features, ftype='string', parents=None):
    if parents is None:
        parents = []
    typed_features = []
    for name, feat in features.items():
        if hasattr(feat, 'dtype') and feat.dtype == ftype:
            typed_features += [parents + [name]]
        elif hasattr(feat, 'feature'):
            if hasattr(feat.feature, 'dtype') and feat.feature.dtype == ftype:
                typed_features += [parents + [name]]
            elif isinstance(feat.feature, dict):
                typed_features += get_typed_features(feat.feature, ftype, parents + [name])
    return typed_features

def get_label_features(features, parents=None):
    if parents is None:
        parents = []
    text_features = []
    for name, feat in features.items():
        if hasattr(feat, 'num_classes'):
            text_features += [(parents + [name], feat.names)]
        elif hasattr(feat, 'feature'):
            text_features += get_label_features(feat.feature, parents + [name])
    return text_features

In [5]:
for cfg in glue_configs:
    print("-----", cfg["config_name"])
    print("TEXT:   ", get_typed_features(cfg['features'], 'string'))
    print("LABELS: ", get_label_features(cfg['features']))

----- cola
TEXT:    [['sentence']]
LABELS:  [(['label'], ['unacceptable', 'acceptable'])]
----- sst2
TEXT:    [['sentence']]
LABELS:  [(['label'], ['negative', 'positive'])]
----- mrpc
TEXT:    [['sentence1'], ['sentence2']]
LABELS:  [(['label'], ['not_equivalent', 'equivalent'])]
----- qqp
TEXT:    [['question1'], ['question2']]
LABELS:  [(['label'], ['not_duplicate', 'duplicate'])]
----- stsb
TEXT:    [['sentence1'], ['sentence2']]
LABELS:  []
----- mnli
TEXT:    [['premise'], ['hypothesis']]
LABELS:  [(['label'], ['entailment', 'neutral', 'contradiction'])]
----- mnli_mismatched
TEXT:    [['premise'], ['hypothesis']]
LABELS:  [(['label'], ['entailment', 'neutral', 'contradiction'])]
----- mnli_matched
TEXT:    [['premise'], ['hypothesis']]
LABELS:  [(['label'], ['entailment', 'neutral', 'contradiction'])]
----- qnli
TEXT:    [['question'], ['sentence']]
LABELS:  [(['label'], ['entailment', 'not_entailment'])]
----- rte
TEXT:    [['sentence1'], ['sentence2']]
LABELS:  [(['label'], ['

In [19]:
name_to_configs = torch.load("name_to_configs_100.th")
list(name_to_configs.keys())

['acronym_identification',
 'ade_corpus_v2',
 'adversarial_qa',
 'aeslc',
 'afrikaans_ner_corpus',
 'ag_news',
 'ai2_arc',
 'air_dialogue',
 'ajgt_twitter_ar',
 'allegro_reviews',
 'allocine',
 'alt',
 'amazon_polarity',
 'amazon_reviews_multi',
 'amazon_us_reviews',
 'ambig_qa',
 'amttl',
 'anli',
 'app_reviews',
 'aqua_rat',
 'aquamuse',
 'ar_cov19',
 'ar_res_reviews',
 'ar_sarcasm',
 'arabic_billion_words',
 'arabic_pos_dialect',
 'arabic_speech_corpus',
 'arcd',
 'arsentd_lev',
 'art',
 'arxiv_dataset',
 'ascent_kb',
 'aslg_pc12',
 'asnq',
 'asset',
 'assin',
 'assin2',
 'atomic',
 'autshumato',
 'babi_qa',
 'banking77',
 'bbaw_egyptian',
 'bbc_hindi_nli',
 'bc2gm_corpus',
 'best2009',
 'bianet',
 'bible_para',
 'big_patent',
 'billsum',
 'bing_coronavirus_query_set',
 'biomrc',
 'blended_skill_talk',
 'blimp',
 'blog_authorship_corpus',
 'bn_hate_speech',
 'bookcorpus',
 'bookcorpusopen',
 'boolq',
 'bprec',
 'break_data',
 'brwac',
 'bsd_ja_en',
 'bswac',
 'c3',
 'c4',
 'cail2018

In [7]:
def get_text_to_analyze(
    name, text_path,
    config_name=None, split=None,
    max_items=20000, streaming=False
):
    ### default arguments
    if config_name is None:
        config_name = name_to_configs[name]["configs"][0]["config_name"]
        print(f"using default config: {config_name}")
    # TODO - fix name_to_configs to have a dict instead of a list to avoid the following
    config_id = [cfg["config_name"] for cfg in name_to_configs[name]["configs"]].index(config_name)
    config = name_to_configs[name]["configs"][config_id]
    if split is None:
        split = 'train' if 'train' in config["splits"] else list(config["splits"])[0]
        print(f"using default split: {split}")        
    ### get text from dataset
    print(f"running -- load_dataset({name}, {config_name}, streaming={streaming})")
    dataset = load_dataset(name, config_name, streaming=streaming)
    text_list = []
    example_ct = 0
    for example in dataset[split]:
        example_ct += 1
        # robustly handle fields that contain lists of text
        item_list = [example]
        for field_name in text_path:
            item_list = [
                next_item
                for item in item_list
                for next_item in (item[field_name] if isinstance(item[field_name], list) else [item[field_name]])
            ]
        text_list += [
            text
            for item in item_list
            for text in (item if isinstance(item, list) else [item])
        ]
        if example_ct >= max_items:
            break
    return text_list

In [8]:
cfg = name_to_configs["ag_news"]["configs"][0]
get_typed_features(cfg['features'], 'string')

[['text']]

In [9]:
get_text_to_analyze(
    "ag_news", ['text'],
    max_items=20, streaming=True
)

using default config: default
using default split: train
running -- load_dataset(ag_news, default, streaming=True)


Using custom data configuration default


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




["Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.",
 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.',
 "Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums.",
 'Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\\flows from the main pipeline in southern Iraq after\\intelligence showed a rebel militia could strike\\infrastructure, an oil official said on Saturday.',
 'Oil prices soar to all-time record, posing new menace to US economy (A

In [10]:
name_to_configs["ambig_qa"]["configs"][0]

{'features': {'id': Value(dtype='string', id=None),
  'question': Value(dtype='string', id=None),
  'annotations': Sequence(feature={'type': Value(dtype='string', id=None), 'answer': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'qaPairs': Sequence(feature={'question': Value(dtype='string', id=None), 'answer': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)}, length=-1, id=None)}, length=-1, id=None)},
 'config_name': 'light',
 'version': 1.0.0,
 'splits': {'train': SplitInfo(name='train', num_bytes=2739732, num_examples=10036, dataset_name='ambig_qa'),
  'validation': SplitInfo(name='validation', num_bytes=805808, num_examples=2002, dataset_name='ambig_qa')}}

In [11]:
cfg = name_to_configs["ambig_qa"]["configs"][0]
get_typed_features(cfg['features'], 'string')

[['id'],
 ['question'],
 ['annotations', 'type'],
 ['annotations', 'answer'],
 ['annotations', 'qaPairs', 'question'],
 ['annotations', 'qaPairs', 'answer']]

In [12]:
get_text_to_analyze(
    "ambig_qa", ['question'],
    max_items=20, streaming=False
)

using default config: light
using default split: train
running -- load_dataset(ambig_qa, light, streaming=False)


Reusing dataset ambig_qa (/home/yjernite/.cache/huggingface/datasets/ambig_qa/light/1.0.0/6425acf3572d4caf508123e5443753ab7ff415564753ae326ae801a0a1aa155e)


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




['When did the simpsons first air on television?',
 'Who played george washington in the john adams series?',
 'What is the legal age of marriage in usa?',
 'Who starred in barefoot in the park on broadway?',
 'When did the manhattan project began and end?',
 'When did the frozen ride open at epcot?',
 'Name the landforms that form the boundaries of the peninsular plateau?',
 'When was the last time uga won a national championship?',
 'Who sing play that funky music white boy?',
 'When was the first airplane used in war?',
 "What color is a negative benedict's test?",
 'Consubstantial with the father in the creed means what?',
 "What was elvis presley's first uk number 1?",
 'Voice of the snake in the jungle book?',
 'Where are the majority of cases heard in the united states?',
 'When is magnus chase book 3 coming out?',
 'Who has more super bowl wins afc or nfc?',
 'Which is the most recent state to have joined the united states of america?',
 'When was the song believer by imagine d

In [13]:
get_text_to_analyze(
    "ambig_qa", ['annotations', 'answer'],
    max_items=20, streaming=False
)

using default config: light
using default split: train
running -- load_dataset(ambig_qa, light, streaming=False)


Reusing dataset ambig_qa (/home/yjernite/.cache/huggingface/datasets/ambig_qa/light/1.0.0/6425acf3572d4caf508123e5443753ab7ff415564753ae326ae801a0a1aa155e)


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




['David Morse',
 'June 21, 2016',
 'Aravali Range, Satpura Range, Vindhyan Range',
 'Blériot XI',
 'Nieuport IV',
 'clear blue',
 'blue',
 'deep-blue',
 'All Shook Up',
 'State courts',
 'State court',
 'October 3, 2017',
 'tied',
 'The Islands of Aloha',
 'The Aloha State',
 'Hawaii',
 'Paradise of the Pacific',
 'February 1, 2017']

In [14]:
get_text_to_analyze(
    "ambig_qa", ['annotations', 'qaPairs', 'answer'],
    max_items=20, streaming=True
)

using default config: light
using default split: train
running -- load_dataset(ambig_qa, light, streaming=True)


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




['April 19, 1987',
 'December 17, 1989',
 '18 years of age',
 '18',
 '19',
 '21',
 '0',
 'Elizabeth Ashley',
 'Kurt Kasznar',
 'Mildred Natwick',
 'Robert Redford',
 'Herbert Edelman',
 'Joseph Keating',
 'Began 1939, end 1946',
 'Began 1942, end 1946',
 '1980',
 '2009',
 '1990',
 '2005',
 '2016',
 '2018',
 '2019',
 '2014',
 'Rob Parissi',
 'Vanilla Ice',
 'Roxanne',
 'common humanity which is shared by all human persons.',
 'of the same being',
 'nature of God in Christianity',
 'Scarlett Johansson',
 'Sterling Holloway',
 'Joseph J Terry',
 'May 10, 2018',
 'May 25, 2018']