# Sample script to create data for the NLI Data Sanity Check
https://github.com/Helsinki-NLP/nli-data-sanity-check


In [1]:
!pip install datasets



In [2]:
import datasets

In [3]:
nli_data = datasets.load_dataset('multi_nli')

Using custom data configuration default
Reusing dataset multi_nli (/root/.cache/huggingface/datasets/multi_nli/default/0.0.0/3248359997b13e6ccd296f42420b31c107ba6859b742ed6af1dce0f1544c9ec1)


In [4]:
train_data = nli_data['train']
dev = nli_data['validation_matched']
test = nli_data['validation_matched']

In [5]:
train_hypothesis = train_data['hypothesis']
train_premise = train_data['premise']
train_label = train_data['label']

In [6]:
dev_hypothesis = dev['hypothesis']
dev_premise = dev['premise']
dev_label = dev['label']

In [7]:
import nltk
nltk.download(['universal_tagset', 'punkt','averaged_perceptron_tagger'])

[nltk_data] Downloading package universal_tagset to /root/nltk_data...
[nltk_data]   Package universal_tagset is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

###Define whether to corrupt the test data or train data.

In [8]:
dataset = 'train'

In [9]:
if dataset == 'dev':
  premise = test['premise']
  hypothesis = test['hypothesis']
  goldlabels = test['label']
else:
  premise = train_data['premise']
  hypothesis = train_data['hypothesis']
  goldlabels = train_data['label']

In [10]:

from tqdm import tqdm
tokenized_prem = []
prem_labels = []
for sentence in tqdm(premise):
  text = nltk.word_tokenize(sentence)
  tagged = nltk.pos_tag(text, tagset='universal')
  sent = []
  lab = []
  for pair in tagged:
    sent.append(pair[0])
    lab.append(pair[1])

  tokenized_prem.append(sent)
  prem_labels.append(lab)

100%|██████████| 392702/392702 [12:11<00:00, 536.77it/s]


In [11]:
tokenized_hypo = []
hypo_labels = []
for sentence in tqdm(hypothesis):
  text = nltk.word_tokenize(sentence)
  tagged = nltk.pos_tag(text, tagset='universal')
  sent = []
  lab = []
  for pair in tagged:
    sent.append(pair[0])
    lab.append(pair[1])

  tokenized_hypo.append(sent)
  hypo_labels.append(lab)

100%|██████████| 392702/392702 [06:39<00:00, 983.44it/s] 


In [12]:
def corrupt(POS, tokenized_hypo, hypo_labels, tokenized_prem, prem_labels):
  prem_list = []
  hypo_list = []

  count_prem = 0
  count_hypo = 0

  for h, l in zip(tokenized_hypo, hypo_labels):
    sent = []
    for word, label in zip(h, l):
      if label not in [POS]:
        sent.append(word)
      else:
        count_hypo = count_hypo + 1
    hypo_list.append(sent)

  for h, l in zip(tokenized_prem, prem_labels):
    sent = []
    for word, label in zip(h, l):
      if label not in [POS]:
        sent.append(word)
      else:
        count_prem = count_prem + 1
    prem_list.append(sent)

  print(f'TOKENS REMOVED ({POS}):')
  print('prem: ' + str(count_prem))
  print('hypo: ' + str(count_hypo))
  return prem_list, hypo_list

In [13]:
def write_file(POS, prem_list, hypo_list, goldlabels):
  filename = 'MNLI-'+ POS + '.tsv'
  with open(filename, 'w') as adjfile:
    i=0
    adjfile.write('index\tsentence1\tsentence2\tgold_label\n')
    for pre, hyp, lab in zip(prem_list, hypo_list, goldlabels):
      if str(lab) == '0':
        lab = 'entailment'
      elif str(lab) == '1':
        lab = 'neutral'
      else:
        lab = 'contradiction'
      if len(hyp) != 0 and len(pre) != 0:
        adjfile.write(str(i) + "\t" + ' '.join(pre) + "\t" + ' '.join(hyp) + '\t' + lab + '\n')
        i = i+1

In [14]:
for pos in ['NOUN', 'VERB', 'PRON', 'ADJ', 'ADV', 'CONJ', 'NUM', 'DET']:
  prem_list, hypo_list = corrupt(pos, tokenized_hypo, hypo_labels, tokenized_prem, prem_labels)
  write_file(pos, prem_list, hypo_list, goldlabels)

TOKENS REMOVED (NOUN):
prem: 2228780
hypo: 1090814
TOKENS REMOVED (VERB):
prem: 1474454
hypo: 886597
TOKENS REMOVED (PRON):
prem: 543968
hypo: 301293
TOKENS REMOVED (ADJ):
prem: 677095
hypo: 302652
TOKENS REMOVED (ADV):
prem: 492895
hypo: 237250
TOKENS REMOVED (CONJ):
prem: 320210
hypo: 76466
TOKENS REMOVED (NUM):
prem: 119587
hypo: 44289
TOKENS REMOVED (DET):
prem: 886966
hypo: 483238
