# Split datasets

In [1]:
from datasets import load_dataset, concatenate_datasets
import numpy as np
from pathlib import Path

In [2]:
def add_lex_overlap(examples, field1="premise", field2="hypothesis"):
    overlap_scores = []
    ps = examples[field1]
    hs = examples[field2]
    for p, h in zip(ps, hs):
        plist = p.split()
        hlist = h.split()
        Np, Nh = len(plist), len(hlist)
        N_overlap = 0
        for pw in plist:
            if pw in hlist:
                N_overlap += 1
        overlap_scores.append(N_overlap * 2 / (Np + Nh))
    examples["overlap_score"] = overlap_scores
    return examples

In [3]:
ds = load_dataset("glue", "mnli")
ds = ds.map(
    lambda exs: add_lex_overlap(exs, "premise", "hypothesis"), 
    batched=True)
ds["concat"] = concatenate_datasets([ds["train"], ds["validation_matched"], ds["validation_mismatched"]]).sort(column="overlap_score")

Reusing dataset glue (/h/zining/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/5 [00:00<?, ?it/s]

Loading cached processed dataset at /h/zining/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-a701d1b0a20e9577.arrow
Loading cached processed dataset at /h/zining/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-153ec8ae57a036e8.arrow
Loading cached processed dataset at /h/zining/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-a3e73ec55f6b3bfa.arrow
Loading cached processed dataset at /h/zining/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-e90eb19572f598ed.arrow
Loading cached processed dataset at /h/zining/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-034d73483a5a6b17.arrow
Loading cached sorted indices for dataset at /h/zining/.cache/huggingface/datase

In [4]:
ds["concat"].to_csv("../data/mnli_processed/sort_by_overlap.csv", index=False)
# The way huggingface dataset to_csv works is: you save the csv to a directory
# Then you load from that directory. This sort_by_overlap.csv is *automatically*
# determined as "train", so you can query the split="train[:80%]" to get the 
# first 80% of the dataset.

Creating CSV from Arrow format:   0%|          | 0/42 [00:00<?, ?ba/s]

81441112

In [5]:
ds_full = load_dataset("../data/mnli_processed")
ds_full

Using custom data configuration mnli_processed-61339352ecd9ae84


Downloading and preparing dataset csv/mnli_processed to /h/zining/.cache/huggingface/datasets/csv/mnli_processed-61339352ecd9ae84/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a...


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Dataset csv downloaded and prepared to /h/zining/.cache/huggingface/datasets/csv/mnli_processed-61339352ecd9ae84/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['hypothesis', 'idx', 'label', 'overlap_score', 'premise'],
        num_rows: 412349
    })
})

In [6]:
ds_easy = load_dataset("../data/mnli_processed", split="train[:80%]")
ds_easy

Using custom data configuration mnli_processed-61339352ecd9ae84
Reusing dataset csv (/h/zining/.cache/huggingface/datasets/csv/mnli_processed-61339352ecd9ae84/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


Dataset({
    features: ['hypothesis', 'idx', 'label', 'overlap_score', 'premise'],
    num_rows: 329879
})

In [7]:
ds_hard = load_dataset("../data/mnli_processed", split="train[80%:]")
ds_hard

Using custom data configuration mnli_processed-61339352ecd9ae84
Reusing dataset csv (/h/zining/.cache/huggingface/datasets/csv/mnli_processed-61339352ecd9ae84/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


Dataset({
    features: ['hypothesis', 'idx', 'label', 'overlap_score', 'premise'],
    num_rows: 82470
})