In [36]:
import random
random.seed(0)

import polars as pl
from datasets import load_dataset
from tqdm import tqdm

dataset = load_dataset("tommasobonomo/sem_augmented_fever_nli")#    , download_mode="force_redownload")
print(dataset['validation'])
print(dataset['test'])

Dataset({
    features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
    num_rows: 2288
})
Dataset({
    features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
    num_rows: 2287
})


In [22]:
def samples_to_skip(path):
    df = pl.read_csv(path)
    to_skip = df.select('cid').to_numpy().squeeze()
    print(f"{len(to_skip)} samples to be skipped from '{path}'.")
    return to_skip.tolist()

to_skip = samples_to_skip("fever_test.filtered.sampled.csv")
to_skip.extend(samples_to_skip("200_FEVER_dev_samples.csv"))
print(len(set(to_skip)), "total samples to skip.")

150 samples to be skipped from 'fever_test.filtered.sampled.csv'.
210 samples to be skipped from '200_FEVER_dev_samples.csv'.
360 total samples to skip.


In [37]:
to_keep = {
    'cid':[], 
    'premise':[], 
    'hypothesis': [], 
    'label': [],
    'wsd' : [],
    'srl' : []
}
for split in ['test', 'validation']:
    for sample in tqdm(dataset[split], desc=f"> processing {split} data"):
        if sample['premise'] == '' or sample['hypothesis'] == '':
            continue
        if int(sample['id']) in to_skip:
            continue
        to_keep['cid'].append(sample['id'])
        to_keep['premise'].append(sample['premise'])
        to_keep['hypothesis'].append(sample['hypothesis'])
        to_keep['label'].append(sample['label'])
        to_keep['wsd'].append(str(sample['wsd']))
        to_keep['srl'].append(str(sample['srl']))


to_keep_df = pl.from_dict(to_keep)
print(f"{to_keep_df.height}/{dataset['test'].num_rows + dataset['validation'].num_rows} filtered samples.")

> processing test data: 100%|██████████| 2287/2287 [00:01<00:00, 1219.61it/s]
> processing validation data: 100%|██████████| 2288/2288 [00:01<00:00, 1195.60it/s]

4215/4575 filtered samples.





In [38]:
sampled_e = to_keep_df.filter(pl.col('label') == 'ENTAILMENT').sample(n=35, seed=42)
sampled_n = to_keep_df.filter(pl.col('label') == 'NEUTRAL').sample(n=35, seed=42)
sampled_c = to_keep_df.filter(pl.col('label') == 'CONTRADICTION').sample(n=35, seed=42)
print(sampled_e.head())
print(sampled_n.head())
print(sampled_c.head())
concat = pl.concat([sampled_e, sampled_n, sampled_c]).select(pl.all().shuffle(seed=42))
print(concat)

shape: (5, 6)
┌────────┬───────────────────────┬──────────────────────┬────────────┬───────────────┬─────────────┐
│ cid    ┆ premise               ┆ hypothesis           ┆ label      ┆ wsd           ┆ srl         │
│ ---    ┆ ---                   ┆ ---                  ┆ ---        ┆ ---           ┆ ---         │
│ str    ┆ str                   ┆ str                  ┆ str        ┆ str           ┆ str         │
╞════════╪═══════════════════════╪══════════════════════╪════════════╪═══════════════╪═════════════╡
│ 18346  ┆ Brie Larson . As a    ┆ Brie Larson was in a ┆ ENTAILMENT ┆ {'premise':   ┆ {'premise': │
│        ┆ teenager , …          ┆ starring …           ┆            ┆ [{'index': 0, ┆ {'tokens':  │
│        ┆                       ┆                      ┆            ┆ 'tex…         ┆ [{'inde…    │
│ 228324 ┆ Island Records . It   ┆ Jamaica is the place ┆ ENTAILMENT ┆ {'premise':   ┆ {'premise': │
│        ┆ was founde…           ┆ where Isl…           ┆            ┆ [{'ind

In [39]:
concat.write_csv(input("type file name: "), separator=',')
print("csv written.")

csv written.
