<a href="https://colab.research.google.com/github/Helsinki-NLP/nli-data-sanity-check/blob/main/corrupt_nli_datasets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Script to create data for the NLI Data Sanity Check
https://github.com/Helsinki-NLP/nli-data-sanity-check



In [None]:
from tqdm import tqdm
import json
from nltk.tokenize.treebank import TreebankWordDetokenizer
import nltk
nltk.download(['universal_tagset', 'punkt','averaged_perceptron_tagger'])

[nltk_data] Downloading package universal_tagset to /root/nltk_data...
[nltk_data]   Unzipping taggers/universal_tagset.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.


True

In [None]:
def dump_jsonl(data, output_path, append=False):
    """
    Write list of objects to a JSON lines file.
    """
    mode = 'a+' if append else 'w'
    with open(output_path, mode, encoding='utf-8') as f:
        for line in data:
            json_record = json.dumps(line, ensure_ascii=False)
            f.write(json_record + '\n')

def load_jsonl(input_path) -> list:
    """
    Read list of objects from a JSON lines file.
    """
    data = []
    with open(input_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.rstrip('\n|\r')))
    print('Loaded {} records from {}'.format(len(data), input_path))
    return data


In [None]:
def corrupt(POS, data):
  count_prem = 0
  count_hypo = 0
  premises = []
  hypos = []

  for line in data:
    premise = nltk.word_tokenize(line['premise'])
    hypothesis = nltk.word_tokenize(line['hypothesis'])
    tagged_prem = nltk.pos_tag(premise, tagset='universal')
    tagged_hypo = nltk.pos_tag(hypothesis, tagset='universal')
    prems = []
    hypot = []
    
    for prem_pair in tagged_prem:
      if prem_pair[1] not in [POS]:
        prems.append(prem_pair[0])
      else:
        count_prem = count_prem + 1
    
    for hypo_pair in tagged_hypo:
      if hypo_pair[1] not in [POS]:
        hypot.append(hypo_pair[0])
      else:
        count_hypo = count_hypo + 1

    premises.append((TreebankWordDetokenizer().detokenize(prems)))
    hypos.append((TreebankWordDetokenizer().detokenize(hypot)))

  return count_prem, count_hypo, premises, hypos

In [None]:
data = load_jsonl('test.jsonl')

Loaded 1200 records from test.jsonl


In [None]:
print('Removing tokens...')
print('No of tokens removed: \n')
print('POS,premises,hypotheses,total')
for pos in ['NOUN', 'VERB', 'PRON', 'ADJ', 'ADV', 'NUM', 'CONJ', 'DET']:
  count_prem, count_hypo, premises_1, hypos_1 = corrupts(pos, data)
  for prem, line in zip(premises_1, data):
    line['premise'] = prem
  for hypos, line in zip(hypos_1, data):
    line['hypothesis'] = hypos
  dump_jsonl(data, 'test-'+pos+'.jsonl')

  print(f'{pos},{count_prem},{count_hypo},{count_prem+count_hypo} ')

NOUN,23086,4033,27119 
VERB,11281,2258,13539 
PRON,4152,446,4598 
ADJ,3525,625,4150 
ADV,2898,470,3368 
NUM,1737,286,2023 
CONJ,2073,142,2215 
DET,7167,1406,8573 
