# Sample script to create data for the NLI Data Sanity Check
https://github.com/Helsinki-NLP/nli-data-sanity-check


In [1]:
!pip install datasets



In [2]:
import datasets

In [3]:
nli_data = datasets.load_dataset('multi_nli')

Using custom data configuration default
Reusing dataset multi_nli (/root/.cache/huggingface/datasets/multi_nli/default/0.0.0/3248359997b13e6ccd296f42420b31c107ba6859b742ed6af1dce0f1544c9ec1)


In [4]:
train_data = nli_data['train']
dev = nli_data['validation_matched']
test = nli_data['validation_matched']

In [5]:
train_hypothesis = train_data['hypothesis']
train_premise = train_data['premise']
train_label = train_data['label']

In [6]:
test_hypothesis = test['hypothesis']
test_premise = test['premise']
test_label = test['label']

In [7]:
lines = [[prem, hypo, label] for prem, hypo, label in zip(test_premise, test_hypothesis, test_label)]

In [8]:
import nltk
nltk.download(['universal_tagset', 'punkt','averaged_perceptron_tagger'])

[nltk_data] Downloading package universal_tagset to /root/nltk_data...
[nltk_data]   Package universal_tagset is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

###Define the POS and whether to corrupt the test data or train data.

In [9]:
POS = 'NOUN'
dataset = 'test'

In [10]:
if dataset == 'test':
  premise = test['premise']
  hypothesis = test['hypothesis']
else:
  premise = train['premise']
  hypothesis = train_data['hypothesis']

In [11]:

from tqdm import tqdm
tokenized_prem = []
prem_labels = []
for sentence in tqdm(premise):
  text = nltk.word_tokenize(sentence)
  tagged = nltk.pos_tag(text, tagset='universal')
  sent = []
  lab = []
  for pair in tagged:
    sent.append(pair[0])
    lab.append(pair[1])

  tokenized_prem.append(sent)
  prem_labels.append(lab)

100%|██████████| 9815/9815 [00:17<00:00, 557.15it/s]


In [12]:
tokenized_hypo = []
hypo_labels = []
for sentence in tqdm(hypothesis):
  text = nltk.word_tokenize(sentence)
  tagged = nltk.pos_tag(text, tagset='universal')
  sent = []
  lab = []
  for pair in tagged:
    sent.append(pair[0])
    lab.append(pair[1])

  tokenized_hypo.append(sent)
  hypo_labels.append(lab)

100%|██████████| 9815/9815 [00:09<00:00, 992.76it/s] 


In [13]:
adj_prem = []
adj_hypo = []

count_prem = 0
count_hypo = 0

for h, l in zip(tokenized_hypo, hypo_labels):
  sent = []
  for word, label in zip(h, l):
    if label not in [POS]:
      sent.append(word)
    else:
      count_hypo = count_hypo + 1
  adj_hypo.append(sent)

for h, l in zip(tokenized_prem, prem_labels):
  sent = []
  for word, label in zip(h, l):
    if label not in [POS]:
      sent.append(word)
    else:
      count_prem = count_prem + 1
  adj_prem.append(sent)

print(f'TOKENS REMOVED ({POS}):')
print('prem: ' + str(count_prem))
print('hypo: ' + str(count_hypo))

TOKENS REMOVED (NOUN):
prem: 54700
hypo: 27182


In [14]:
filename = 'MNLI-'+ POS + '.tsv'
with open(filename, 'w') as adjfile:
  i=0
  adjfile.write('index\tsentence1\tsentence2\tgold_label\n')
  for pre, hyp, lab in zip(adj_prem, adj_hypo, train_label):
    if str(lab) == '0':
      lab = 'entailment'
    elif str(lab) == '1':
      lab = 'neutral'
    else:
      lab = 'contradiction'
    if len(hyp) != 0 and len(pre) != 0:
      adjfile.write(str(i) + "\t" + ' '.join(pre) + "\t" + ' '.join(hyp) + '\t' + lab + '\n')
      i = i+1