# Download selected text summarization datasets from Huggingface


- wikihow/all: [repo](https://github.com/mahnazkoupaee/WikiHow-Dataset), paper: [WikiHow: A Large Scale Text Summarization Dataset](https://arxiv.org/abs/1810.09305), HF: [wikihow](https://huggingface.co/datasets/wikihow), requires manual download
- xsum: HF: [xsum](https://huggingface.co/datasets/xsum)
- cnn-dailymail: repo: [abisee/cnn-dailymail](https://github.com/abisee/cnn-dailymail), HF: [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail)
- samsum: paper: [SAMSum Corpus: A Human-annotated Dialogue Dataset for Abstractive Summarization](https://arxiv.org/abs/1911.12237), HF: [samsum](https://huggingface.co/datasets/samsum)
- scitldr: repo: [allenai/scitldr](https://github.com/allenai/scitldr), paper: [TLDR: Extreme Summarization of Scientific Documents](https://arxiv.org/abs/2004.15011), HF: [scitldr](https://huggingface.co/datasets/scitldr)
- billsum: HF: [billsum](https://huggingface.co/datasets/billsum)


Huggingface datastes for summarization: https://huggingface.co/datasets?task_categories=task_categories:summarization 

In [18]:
# load hf datasets csv file
from pathlib import Path
import pandas as pd

# from https://github.com/jordiclive/Summarization/blob/704c079892faa7f902710d8b68e781b856adfa5c/processing/hf_datasets.csv
datasets_info = pd.read_csv("hf_datasets/hf_datasets.csv")
datasets_info

Unnamed: 0,hf_dataset_key,source_key,target_key,flan_prompt
0,wikihow/all,text,headline,Produce an article summary including outlines ...
1,xsum/1.2.1,document,summary,"Given the following news article, summarize th..."
2,cnn_dailymail/3.0.0,article,highlights,Produce an article summary of the following ne...
3,samsum,dialogue,summary,Briefly summarize in third person the followin...
4,scitldr/AIC,source,target,"Given the following scientific article, provid..."
5,billsum,text,summary,Summarize the following proposed legislation (...


In [20]:
import datasets
import json
from typing import List
from sklearn.model_selection import train_test_split
import numpy as np
import random

cache_dir = str(Path('../hf_data_cache/').resolve())

additional_args = {
    'wikihow/all': {
        'data_dir': str(Path('../WikiHow-Dataset/').resolve())
    }
}

provenance_id_colums = {
    'wikihow/all': ['title'],
    'xsum/1.2.1': ['id'],
    'cnn_dailymail/3.0.0': ['id'],
    'samsum': ['id'],
    'scitldr/AIC': ['paper_id'],
    'billsum': ['title'],
}


def convert_dataset(data, text_key, summary_key, output_dir: Path, output_prefix: str, id_colums: List[str], compression:str="snappy"):
    fn = f'{output_prefix}.{compression}.parquet'
    fn = output_dir / fn
    
    text, summary, provenance = [], [], []
    provenance = []
    for idx, row in data.iterrows():
        t = row[text_key]
        if isinstance(t, np.ndarray):    # special list handling for scitldr
            if t.size == 1:
                t = t.item()
            else:
                t = ' '.join(t)
            assert type(t) == str
        text.append(t)

        s = row[summary_key]
        if isinstance(s, np.ndarray):   # special array handling for scitldr
            if s.size == 1:
                s = s.item()
            else:
                s = random.choice(s)

            assert type(s) == str
        summary.append(s)
        p = { 'src': output_prefix }
        for col in id_colums:
            p[col] = row[col]
        provenance.append(json.dumps(p))
    
    text_, summary_, provenance_ = map(lambda x: pd.array(x, dtype="string"), (text, summary, provenance))
    df = pd.DataFrame({"text": text_, "summary": summary_, "provenance": provenance_})
    print(f'writing: {fn}')
    #print(df.head())
    df.to_parquet(
        fn, 
        engine="pyarrow",
        compression=compression
    )


# load datasets
for index, row in datasets_info.iterrows():
    dataset_name = row['hf_dataset_key']
    #if dataset_name != 'scitldr/AIC':       # for single dataset debugging
    #    continue

    extra_args = {}
    if dataset_name in additional_args:
        extra_args = additional_args[dataset_name]

    print(f'loading {dataset_name}')
    name = dataset_name.split("/")
    if len(name) > 1:
        data = datasets.load_dataset(name[0], name=name[1], cache_dir=cache_dir, **extra_args)
    else:
        data = datasets.load_dataset(name[0], cache_dir=cache_dir, **extra_args)

    # make sure every dataset has a validation set, sample one if missing
    splits = {}
    split_names = ['train', 'validation', 'test']
    min_num_val = 100   # min size of valiadion set
    
    if 'validation' not in data.keys():        
        train = data['train'].to_pandas()
        val_size = max(int(len(train)//30), min_num_val)
        print(f'Warning: Validation set missing for {dataset_name}, sampling synthetic validation (size: {val_size}).')
        train, val = train_test_split(train, test_size=val_size)
        splits['train'] = train
        splits['validation'] = val
    else:
        splits['train'] = data['train'].to_pandas()
        splits['validation'] = data['validation'].to_pandas()
    splits['test'] = data['test'].to_pandas()

    print('columns:', data['train'].column_names)

    for split_name, split_df in splits.items():
        print(f'dataset: {dataset_name}; split: {split_name}; rows: {len(split_df)/1000:.1f}k;')

        prefix = dataset_name.replace('/', '-')
        prefix = prefix + '_' + split_name

        output_dir = Path(f'./data/{name[0]}')
        if not output_dir.exists():
            print(f'creating directory: {output_dir}')
            output_dir.mkdir(exist_ok=True)
        id_colums = provenance_id_colums[dataset_name]
        convert_dataset(split_df, row['source_key'], row['target_key'], output_dir, prefix, id_colums) 


loading wikihow/all


Using custom data configuration all-3003e4082f016f00
Found cached dataset wikihow (/media/koepf/data2/laion/hf_data_cache/wikihow/all-3003e4082f016f00/1.2.0/5343fc81d685acaa086c9cc19eb8706206cd1f8b315792b04c1d7b92091c305e)


  0%|          | 0/3 [00:00<?, ?it/s]

columns: ['text', 'headline', 'title']
dataset: wikihow/all; split: train; rows: 157.3k;
writing: data/wikihow/wikihow-all_train.snappy.parquet
dataset: wikihow/all; split: validation; rows: 5.6k;
writing: data/wikihow/wikihow-all_validation.snappy.parquet
dataset: wikihow/all; split: test; rows: 5.6k;
writing: data/wikihow/wikihow-all_test.snappy.parquet
loading xsum/1.2.1


Using custom data configuration 1.2.1
Found cached dataset xsum (/media/koepf/data2/laion/hf_data_cache/xsum/1.2.1/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934)


  0%|          | 0/3 [00:00<?, ?it/s]

columns: ['document', 'summary', 'id']
dataset: xsum/1.2.1; split: train; rows: 204.0k;
writing: data/xsum/xsum-1.2.1_train.snappy.parquet
dataset: xsum/1.2.1; split: validation; rows: 11.3k;
writing: data/xsum/xsum-1.2.1_validation.snappy.parquet
dataset: xsum/1.2.1; split: test; rows: 11.3k;
writing: data/xsum/xsum-1.2.1_test.snappy.parquet
loading cnn_dailymail/3.0.0


Found cached dataset cnn_dailymail (/media/koepf/data2/laion/hf_data_cache/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)


  0%|          | 0/3 [00:00<?, ?it/s]

columns: ['article', 'highlights', 'id']
dataset: cnn_dailymail/3.0.0; split: train; rows: 287.1k;
writing: data/cnn_dailymail/cnn_dailymail-3.0.0_train.snappy.parquet
dataset: cnn_dailymail/3.0.0; split: validation; rows: 13.4k;
writing: data/cnn_dailymail/cnn_dailymail-3.0.0_validation.snappy.parquet
dataset: cnn_dailymail/3.0.0; split: test; rows: 11.5k;
writing: data/cnn_dailymail/cnn_dailymail-3.0.0_test.snappy.parquet
loading samsum


Found cached dataset samsum (/media/koepf/data2/laion/hf_data_cache/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)


  0%|          | 0/3 [00:00<?, ?it/s]

columns: ['id', 'dialogue', 'summary']
dataset: samsum; split: train; rows: 14.7k;
writing: data/samsum/samsum_train.snappy.parquet
dataset: samsum; split: validation; rows: 0.8k;
writing: data/samsum/samsum_validation.snappy.parquet
dataset: samsum; split: test; rows: 0.8k;
writing: data/samsum/samsum_test.snappy.parquet
loading scitldr/AIC


Found cached dataset scitldr (/media/koepf/data2/laion/hf_data_cache/scitldr/AIC/0.0.0/79e0fa75961392034484808cfcc8f37deb15ceda153b798c92d9f621d1042fef)


  0%|          | 0/3 [00:00<?, ?it/s]

columns: ['source', 'source_labels', 'rouge_scores', 'paper_id', 'ic', 'target']
dataset: scitldr/AIC; split: train; rows: 2.0k;
writing: data/scitldr/scitldr-AIC_train.snappy.parquet
dataset: scitldr/AIC; split: validation; rows: 0.6k;
writing: data/scitldr/scitldr-AIC_validation.snappy.parquet
dataset: scitldr/AIC; split: test; rows: 0.6k;
writing: data/scitldr/scitldr-AIC_test.snappy.parquet
loading billsum


Found cached dataset billsum (/media/koepf/data2/laion/hf_data_cache/billsum/default/3.0.0/75cf1719d38d6553aa0e0714c393c74579b083ae6e164b2543684e3e92e0c4cc)


  0%|          | 0/3 [00:00<?, ?it/s]

columns: ['text', 'summary', 'title']
dataset: billsum; split: train; rows: 18.3k;
writing: data/billsum/billsum_train.snappy.parquet
dataset: billsum; split: validation; rows: 0.6k;
writing: data/billsum/billsum_validation.snappy.parquet
dataset: billsum; split: test; rows: 3.3k;
writing: data/billsum/billsum_test.snappy.parquet
