# Split datasets

In [5]:
from datasets import load_dataset, concatenate_datasets
import numpy as np
from pathlib import Path

In [2]:
def add_lex_overlap(examples, field1="premise", field2="hypothesis"):
    overlap_scores = []
    ps = examples[field1]
    hs = examples[field2]
    for p, h in zip(ps, hs):
        plist = p.split()
        hlist = h.split()
        Np, Nh = len(plist), len(hlist)
        N_overlap = 0
        for pw in plist:
            if pw in hlist:
                N_overlap += 1
        overlap_scores.append(N_overlap * 2 / (Np + Nh))
    examples["overlap_score"] = overlap_scores
    return examples

In [3]:
ds = load_dataset("glue", "mnli")
ds = ds.map(
    lambda exs: add_lex_overlap(exs, "premise", "hypothesis"), 
    batched=True)
ds["concat"] = concatenate_datasets([ds["train"], ds["validation_matched"], ds["validation_mismatched"]]).sort(column="overlap_score")

Downloading:   0%|          | 0.00/7.78k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.47k [00:00<?, ?B/s]

Downloading and preparing dataset glue/mnli (download: 298.29 MiB, generated: 78.65 MiB, post-processed: Unknown size, total: 376.95 MiB) to /home/zining/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset glue downloaded and prepared to /home/zining/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/393 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

In [4]:
pth = Path("../data/mnli_processed")
if not pth.exists():
    pth.mkdir()
ds["concat"].to_csv("../data/mnli_processed/sort_by_overlap.csv", index=False)
# The way huggingface dataset to_csv works is: you save the csv to a directory
# Then you load from that directory. This sort_by_overlap.csv is *automatically*
# determined as "train", so you can query the split="train[:80%]" to get the 
# first 80% of the dataset.

Creating CSV from Arrow format:   0%|          | 0/42 [00:00<?, ?ba/s]

81441112

In [11]:
D = {"a": 2, "b": 3}
np.sum(list(D.values()))

5

In [16]:
def print_ds_label_constituency(ds):
    all_labels = ds["label"]
    labels_count = {}
    for y in all_labels:
        if y in labels_count:
            labels_count[y] += 1 
        else:
            labels_count[y] = 1
    total = len(all_labels)
    print("Total: {}".format(total))
    for y in labels_count:
        print("    Label={}. Portion={:.4f}%".format(y, labels_count[y] / total * 100))
    return labels_count 

In [14]:
ds_full = load_dataset("../data/mnli_processed")
ds_full

Using custom data configuration mnli_processed-86c529b39c9441f5
Reusing dataset csv (/home/zining/.cache/huggingface/datasets/csv/mnli_processed-86c529b39c9441f5/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['hypothesis', 'idx', 'label', 'overlap_score', 'premise'],
        num_rows: 412349
    })
})

In [17]:
ds_easy = load_dataset("../data/mnli_processed", split="train[:80%]")
print_ds_label_constituency(ds_easy)
ds_easy

Using custom data configuration mnli_processed-86c529b39c9441f5
Reusing dataset csv (/home/zining/.cache/huggingface/datasets/csv/mnli_processed-86c529b39c9441f5/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


Total: 329879
    Label=2. Portion=34.9283%
    Label=0. Portion=28.7145%
    Label=1. Portion=36.3573%


Dataset({
    features: ['hypothesis', 'idx', 'label', 'overlap_score', 'premise'],
    num_rows: 329879
})

In [18]:
ds_hard = load_dataset("../data/mnli_processed", split="train[80%:]")
print_ds_label_constituency(ds_hard)
ds_hard

Using custom data configuration mnli_processed-86c529b39c9441f5
Reusing dataset csv (/home/zining/.cache/huggingface/datasets/csv/mnli_processed-86c529b39c9441f5/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


Total: 82470
    Label=0. Portion=52.2833%
    Label=1. Portion=20.8767%
    Label=2. Portion=26.8401%


Dataset({
    features: ['hypothesis', 'idx', 'label', 'overlap_score', 'premise'],
    num_rows: 82470
})

## Repeat on other datasets

In [19]:
ds = load_dataset("glue", "qqp")
ds = ds.map(
    lambda exs: add_lex_overlap(exs, "question1", "question2"), 
    batched=True)
ds["concat"] = concatenate_datasets([ds["train"], ds["validation"]])\
    .sort(column="overlap_score")
pth = Path("../data/qqp_processed")
if not pth.exists():
    pth.mkdir()
print_ds_label_constituency(ds["concat"])
ds["concat"].to_csv("../data/qqp_processed/sort_by_overlap.csv", index=False)

Downloading and preparing dataset glue/qqp (download: 39.76 MiB, generated: 106.55 MiB, post-processed: Unknown size, total: 146.32 MiB) to /home/zining/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


Downloading:   0%|          | 0.00/41.7M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset glue downloaded and prepared to /home/zining/.cache/huggingface/datasets/glue/qqp/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?ba/s]

  0%|          | 0/41 [00:00<?, ?ba/s]

  0%|          | 0/391 [00:00<?, ?ba/s]

Total: 404276
    Label=0. Portion=63.0789%
    Label=1. Portion=36.9211%


Creating CSV from Arrow format:   0%|          | 0/41 [00:00<?, ?ba/s]

59105841

In [21]:
ds_train = load_dataset("../data/qqp_processed", split="train[:80%]")
print_ds_label_constituency(ds_train)
ds_eval = load_dataset("../data/qqp_processed", split="train[80%:]")
print_ds_label_constituency(ds_eval)

Using custom data configuration qqp_processed-a68591ebafa4c184
Reusing dataset csv (/home/zining/.cache/huggingface/datasets/csv/qqp_processed-a68591ebafa4c184/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)
Using custom data configuration qqp_processed-a68591ebafa4c184
Reusing dataset csv (/home/zining/.cache/huggingface/datasets/csv/qqp_processed-a68591ebafa4c184/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


Total: 323421
    Label=0. Portion=67.0386%
    Label=1. Portion=32.9614%
Total: 80855
    Label=1. Portion=52.7599%
    Label=0. Portion=47.2401%


{1: 42659, 0: 38196}

In [22]:
ds = load_dataset("glue", "qnli")
ds = ds.map(
    lambda exs: add_lex_overlap(exs, "question", "sentence"), 
    batched=True)
ds["concat"] = concatenate_datasets([ds["train"], ds["validation"]])\
    .sort(column="overlap_score")
pth = Path("../data/qnli_processed")
if not pth.exists():
    pth.mkdir()
ds["concat"].to_csv("../data/qnli_processed/sort_by_overlap.csv", index=False)

Downloading and preparing dataset glue/qnli (download: 10.14 MiB, generated: 27.11 MiB, post-processed: Unknown size, total: 37.24 MiB) to /home/zining/.cache/huggingface/datasets/glue/qnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


Downloading:   0%|          | 0.00/10.6M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset glue downloaded and prepared to /home/zining/.cache/huggingface/datasets/glue/qnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?ba/s]

  0%|          | 0/6 [00:00<?, ?ba/s]

  0%|          | 0/6 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/12 [00:00<?, ?ba/s]

27872641

In [23]:
ds_train = load_dataset("../data/qnli_processed", split="train[:80%]")
print_ds_label_constituency(ds_train)
ds_eval = load_dataset("../data/qnli_processed", split="train[80%:]")
print_ds_label_constituency(ds_eval)

Using custom data configuration qnli_processed-5a840d7b1c97368a


Downloading and preparing dataset csv/qnli_processed to /home/zining/.cache/huggingface/datasets/csv/qnli_processed-5a840d7b1c97368a/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a...


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration qnli_processed-5a840d7b1c97368a


Dataset csv downloaded and prepared to /home/zining/.cache/huggingface/datasets/csv/qnli_processed-5a840d7b1c97368a/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a. Subsequent calls will reuse this data.
Total: 88165
    Label=1. Portion=56.5928%
    Label=0. Portion=43.4072%


Reusing dataset csv (/home/zining/.cache/huggingface/datasets/csv/qnli_processed-5a840d7b1c97368a/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


Total: 22041
    Label=1. Portion=23.7376%
    Label=0. Portion=76.2624%


{1: 5232, 0: 16809}

In [24]:
for ds_name in ["mrpc", "rte", "stsb", "wnli"]:
    ds = load_dataset("glue", ds_name)
    ds = ds.map(
        lambda exs: add_lex_overlap(exs, "sentence1", "sentence2"), 
        batched=True)
    ds["concat"] = concatenate_datasets([ds["train"], ds["validation"]])\
        .sort(column="overlap_score")
    pth = Path(f"../data/{ds_name}_processed")
    if not pth.exists():
        pth.mkdir()
    ds["concat"].to_csv(f"../data/{ds_name}_processed/sort_by_overlap.csv", index=False)

    ds_train = load_dataset(f"../data/{ds_name}_processed", split="train[:80%]")
    print_ds_label_constituency(ds_train)
    ds_eval = load_dataset(f"../data/{ds_name}_processed", split="train[80%:]")
    print_ds_label_constituency(ds_eval)

Downloading and preparing dataset glue/mrpc (download: 1.43 MiB, generated: 1.43 MiB, post-processed: Unknown size, total: 2.85 MiB) to /home/zining/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading: 0.00B [00:00, ?B/s]

Downloading: 0.00B [00:00, ?B/s]

Downloading: 0.00B [00:00, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset glue downloaded and prepared to /home/zining/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Using custom data configuration mrpc_processed-e7409f135deff7cb


Downloading and preparing dataset csv/mrpc_processed to /home/zining/.cache/huggingface/datasets/csv/mrpc_processed-e7409f135deff7cb/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a...


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration mrpc_processed-e7409f135deff7cb
Reusing dataset csv (/home/zining/.cache/huggingface/datasets/csv/mrpc_processed-e7409f135deff7cb/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


Dataset csv downloaded and prepared to /home/zining/.cache/huggingface/datasets/csv/mrpc_processed-e7409f135deff7cb/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a. Subsequent calls will reuse this data.
Total: 3261
    Label=0. Portion=38.4851%
    Label=1. Portion=61.5149%
Total: 815
    Label=1. Portion=91.6564%
    Label=0. Portion=8.3436%
Downloading and preparing dataset glue/rte (download: 680.81 KiB, generated: 1.83 MiB, post-processed: Unknown size, total: 2.49 MiB) to /home/zining/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


Downloading:   0%|          | 0.00/697k [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset glue downloaded and prepared to /home/zining/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Using custom data configuration rte_processed-cd550f822dd10ccc


Downloading and preparing dataset csv/rte_processed to /home/zining/.cache/huggingface/datasets/csv/rte_processed-cd550f822dd10ccc/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a...


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration rte_processed-cd550f822dd10ccc
Reusing dataset csv (/home/zining/.cache/huggingface/datasets/csv/rte_processed-cd550f822dd10ccc/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


Dataset csv downloaded and prepared to /home/zining/.cache/huggingface/datasets/csv/rte_processed-cd550f822dd10ccc/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a. Subsequent calls will reuse this data.
Total: 2214
    Label=1. Portion=51.7615%
    Label=0. Portion=48.2385%
Total: 553
    Label=1. Portion=40.8680%
    Label=0. Portion=59.1320%
Downloading and preparing dataset glue/stsb (download: 784.05 KiB, generated: 1.09 MiB, post-processed: Unknown size, total: 1.86 MiB) to /home/zining/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


Downloading:   0%|          | 0.00/803k [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset glue downloaded and prepared to /home/zining/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Using custom data configuration stsb_processed-a1ec77eecaa4ec3c


Downloading and preparing dataset csv/stsb_processed to /home/zining/.cache/huggingface/datasets/csv/stsb_processed-a1ec77eecaa4ec3c/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a...


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration stsb_processed-a1ec77eecaa4ec3c
Reusing dataset csv (/home/zining/.cache/huggingface/datasets/csv/stsb_processed-a1ec77eecaa4ec3c/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


Dataset csv downloaded and prepared to /home/zining/.cache/huggingface/datasets/csv/stsb_processed-a1ec77eecaa4ec3c/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a. Subsequent calls will reuse this data.
Total: 5799
    Label=0.2. Portion=1.7072%
    Label=0.0. Portion=8.7084%
    Label=3.2. Portion=4.3973%
    Label=1.2. Portion=2.7419%
    Label=0.6. Portion=2.6729%
    Label=2.8. Portion=3.3282%
    Label=1.0. Portion=4.1731%
    Label=1.6. Portion=2.8971%
    Label=1.8. Portion=3.0005%
    Label=0.4. Portion=3.1902%
    Label=0.8. Portion=3.1040%
    Label=3.0. Portion=5.5872%
    Label=2.6. Portion=3.2592%
    Label=1.4. Portion=3.2419%
    Label=2.3333333. Portion=0.0172%
    Label=4.2. Portion=2.4832%
    Label=4.4. Portion=2.0521%
    Label=2.2. Portion=3.2764%
    Label=3.4. Portion=3.9662%
    Label=2.0. Portion=3.6386%
    Label=3.6. Portion=3.1040%
    Label=4.8. Portion=1.6037%
    Label=0.67. Portion=0.0172%
    Label=0.75. Portion=0.3276%
    Label

Downloading:   0%|          | 0.00/29.0k [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset glue downloaded and prepared to /home/zining/.cache/huggingface/datasets/glue/wnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Using custom data configuration wnli_processed-8534c337abb3118a


Downloading and preparing dataset csv/wnli_processed to /home/zining/.cache/huggingface/datasets/csv/wnli_processed-8534c337abb3118a/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a...


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration wnli_processed-8534c337abb3118a
Reusing dataset csv (/home/zining/.cache/huggingface/datasets/csv/wnli_processed-8534c337abb3118a/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


Dataset csv downloaded and prepared to /home/zining/.cache/huggingface/datasets/csv/wnli_processed-8534c337abb3118a/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a. Subsequent calls will reuse this data.
Total: 565
    Label=0. Portion=51.8584%
    Label=1. Portion=48.1416%
Total: 141
    Label=1. Portion=50.3546%
    Label=0. Portion=49.6454%
