In [1]:
from collections import defaultdict

import datasets
import pyarrow.lib as pylib

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt


MOCHA_DIR_PATH = "../../datasets/mocha"
!ls {MOCHA_DIR_PATH}

  from .autonotebook import tqdm as notebook_tqdm


dev.json	    minimal_pairs.json.sha1   train.json
dev.json.sha1	    test_no_labels.json       train.json.sha1
minimal_pairs.json  test_no_labels.json.sha1


```python


{
    'candidate': 'I want to help Luke feed.',
    'context': "There is one area I want to work on . Breast - feeding . Right now , Luke's addicted to the bottle . We were so eager to wean him off his nose tube that when he started taking a bottle , we made it our only goal to re - enforce that .",
    'metadata': {
        'scores': [1, 1, 1],
        'source': 'gpt2',
    },
    'question': 'What may be your reason for wanting to work on Breast - feeding ?',
    'reference': 'It could help my son .',
    'score': 1,
}

```

In [11]:
import json
import spacy
from collections import defaultdict
from statistics import mean
from tqdm import tqdm

nlp = spacy.load("en_core_web_sm", disable=['tagger', 'parser', 'ner'])


### Count statistics
def num_passages(data: dict) -> int:
    """Count distinct passages in the provided ``data``.
    
    
    We expect ``data`` to be organized as follows:
    data = {
        'uuid1': {
            'candidate': "He's a child and it's a very rare thing.",
            'context': 'Somewhere in me I knew it all along , there are all those moments when he stares into my eyes and his start to sparkle while this gorgeous grin spreads across his face . When he first started to do it I would ask him " what ? What s funny ? " he would always say nothing and attempt to divert his attention elsewhere .',
            'metadata': {'scores': [1], 'source': 'gpt2'},
            'question': "What's a possible reason the guy stares into the writer's eyes ?",
            'reference': 'Because he likes her a lot .',
            'score': 1,
        },
        ...,
        'uuidn': {
            'candidate': 'The kitten would have been killed.',
            'context': 'Her dog and another kitten kept trying to escape the house while all this was going on . It was awkward and sad . I tried to be comforting because I could tell she was truly distraught but I was honestly mad at her for letting her animals out in the first place . She told me that she gave three of the kittens away to a home with dogs and two of them had been killed by the dogs already .',
            'metadata': {'scores': [1], 'source': 'gpt2'},
            'question': "What might be different if the friend didn't give away kittens to homes with dogs ?",
            'reference': "Two of the kittens wouldn't have been killed",
            'score': 1,
        }
    }
    """
    seen_passages = set()
    num = 0

    for instance in data.values():
        if instance['context'] not in seen_passages:
            num += 1
            seen_passages.add(instance['context'])

    return num


def num_ques_ref_pairs(data: dict) -> int:
    """Count distinct <context, question, ref> pairs in the provided ``data``.
    
    We expect data to be organized as indicated in ``num_passages``.
    """
    seen_ques_ref_pairs = set()
    num = 0

    for instance in data.values():
        ques_ref = instance['context'] + instance['question'] + instance['reference']
        if ques_ref not in seen_ques_ref_pairs:
            num += 1
            seen_ques_ref_pairs.add(ques_ref)

    return num


def num_instances(data) -> int:
    """Count the number of examples in the data."""
    return len(data)


def pct_ref_context_overlap(data) -> int:
    counts = [1 if instance["reference"] in instance["context"] else 0 for instance in data.values()]
    return round(mean(counts), 1)

def pct_ref_question_overlap(data) -> int:
    counts = [1 if instance["reference"] in instance["question"] else 0 for instance in data.values()]
    return round(mean(counts), 1)

    
### Average length statistics
def avg_passage_len(data) -> float:
    """Computer avg number of words in the context (includes punctuation)"""

    lengths = [len(nlp(instance['context'])) for instance in data.values()]
    return round(mean(lengths), 1)


def avg_question_len(data) -> float:
    """Computer avg number of words in the question (includes punctuation)"""
    lengths = [len(nlp(instance['question'])) for instance in data.values()]
    return round(mean(lengths), 1)


def avg_reference_len(data) -> float:
    """Computer avg number of words in the reference (includes punctuation)"""
    lengths = [len(nlp(instance['reference'])) for instance in data.values()]
    return round(mean(lengths), 1)


def avg_candidate_len(data) -> float:
    """Computer avg number of words in the candidate (includes punctuation)"""
    lengths = [len(nlp(instance['candidate'])) for instance in data.values()]
    return round(mean(lengths), 1)


def avg_candidate_agreement(data) -> float:
    """Compute the avg agreement in the candidate."""
    scores = [mean(instance['metadata']["scores"]) for instance in data.values()]
    return round(mean(scores), 1)

    
def avg_candidate_score(data) -> float:
    """Computer avg score candidate (includes punctuation)"""
    scores = [instance['score'] for instance in data.values()]
    return round(mean(scores), 1)

def avg_num_annots_ques_ref(data: dict) -> int:
    """Compute avg number of annotations per <context, question, ref> pair
    in the provided ``data``.
    
    We expect data to be organized as indicated in ``num_passages``.
    """
    seen_ques_ref_pairs = {}

    for instance in data.values():
        ques_ref = instance['context'] + instance['question'] + instance['reference']
        if ques_ref not in seen_ques_ref_pairs:
            seen_ques_ref_pairs[ques_ref] = 1
        else:
            seen_ques_ref_pairs[ques_ref] += 1

    seen_ques_ref_pairs = list(seen_ques_ref_pairs.values())
    return round(mean(seen_ques_ref_pairs), 1)


def get_statistics_for_split(file_path, compute_average_lengths=False, agreement_score=None):
    data = json.load(open(file_path))
    statistics = defaultdict(lambda: defaultdict(int))

    # Compute statistics per constituent dataset
    for dataset in tqdm(data):
        data_data = data[dataset]
        # Filter if agreement score
        if agreement_score:
            data_data = {k: v for k, v in data_data.items() if v["score"] >= agreement_score}
  
        
        # Compute count statistics
        statistics[dataset]['num_passages'] = num_passages(data_data)
        statistics[dataset]['num_ques_ref_pairs'] = num_ques_ref_pairs(data_data)
        statistics[dataset]['num_instances'] = num_instances(data_data)

        # Average num_annots
        statistics[dataset]['avg_annots_per_ques_ref_pair'] = avg_num_annots_ques_ref(data_data)
        statistics[dataset]['pct_ref_cont_overlap'] = pct_ref_context_overlap(data_data)
        statistics[dataset]['pct_ref_ques_overlap'] = pct_ref_question_overlap(data_data)

        # Add count statistics to a total field
        statistics['total']['num_passages'] += \
            statistics[dataset]['num_passages']
        statistics['total']['num_ques_ref_pairs'] += \
            statistics[dataset]['num_ques_ref_pairs']
        statistics['total']['num_instances'] += \
            statistics[dataset]['num_instances']

        # Compute average length statistics
        if compute_average_lengths:
            statistics[dataset]['avg_passage_len'] = \
                avg_passage_len(data_data)
            statistics[dataset]['avg_question_len'] = \
                avg_question_len(data_data)
            statistics[dataset]['avg_reference_len'] = \
                avg_reference_len(data_data)
            statistics[dataset]['avg_candidate_len'] = \
                avg_candidate_len(data_data)
            statistics[dataset]['avg_candidate_scores'] = \
                avg_candidate_score(data_data)
            statistics[dataset]['avg_candidate_agreement'] = \
                avg_candidate_agreement(data_data)

    return statistics

In [14]:
import pandas as pd 

df_train = pd.DataFrame(get_statistics_for_split(f'{MOCHA_DIR_PATH}/train.json', compute_average_lengths=True)).T
df_train.sort_index().to_clipboard()
df_train.sort_index()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [04:16<00:00, 42.72s/it]


Unnamed: 0,num_passages,num_ques_ref_pairs,num_instances,avg_annots_per_ques_ref_pair,pct_ref_cont_overlap,pct_ref_ques_overlap,avg_passage_len,avg_question_len,avg_reference_len,avg_candidate_len,avg_candidate_scores,avg_candidate_agreement
cosmosqa,1064.0,1139.0,5033.0,4.4,0.0,0.0,72.8,10.8,7.5,8.8,2.2,2.2
drop,80.0,542.0,687.0,1.3,0.6,0.2,218.9,11.6,3.9,5.2,2.2,2.2
mcscript,462.0,2940.0,7210.0,2.5,0.1,0.0,197.1,7.8,4.3,4.1,2.6,2.6
narrativeqa,85.0,2249.0,7471.0,3.3,0.2,0.0,333.1,9.6,5.8,5.9,2.7,2.7
quoref,184.0,1098.0,3259.0,3.0,0.9,0.0,324.2,15.8,2.3,8.2,1.9,1.9
socialiqa,3075.0,3075.0,7409.0,2.4,0.0,0.0,15.7,7.2,3.9,3.9,2.4,2.4
total,4950.0,11043.0,31069.0,,,,,,,,,


In [15]:
df_dev = pd.DataFrame(get_statistics_for_split(f'{MOCHA_DIR_PATH}/dev.json', compute_average_lengths=True)).T
df_dev.sort_index().to_clipboard() 
df_dev.sort_index()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:31<00:00,  5.21s/it]


Unnamed: 0,num_passages,num_ques_ref_pairs,num_instances,avg_annots_per_ques_ref_pair,pct_ref_cont_overlap,pct_ref_ques_overlap,avg_passage_len,avg_question_len,avg_reference_len,avg_candidate_len,avg_candidate_scores,avg_candidate_agreement
cosmosqa,142.0,156.0,683.0,4.4,0.0,0.0,77.7,10.8,7.3,8.7,2.2,2.2
drop,10.0,76.0,97.0,1.3,0.7,0.3,197.7,12.1,3.6,6.7,2.0,2.0
mcscript,61.0,390.0,978.0,2.5,0.0,0.0,197.6,8.1,4.1,4.0,2.4,2.4
narrativeqa,11.0,277.0,890.0,3.2,0.2,0.0,348.5,9.5,4.9,5.7,2.6,2.6
quoref,24.0,123.0,344.0,2.8,0.9,0.0,336.2,14.4,2.3,8.0,1.9,1.9
socialiqa,414.0,414.0,1017.0,2.5,0.0,0.0,15.5,7.2,3.9,3.9,2.4,2.4
total,662.0,1436.0,4009.0,,,,,,,,,


In [16]:
df_test = pd.DataFrame(get_statistics_for_split(f'{MOCHA_DIR_PATH}/test_no_labels.json')).T
df_test.sort_index().to_clipboard()
df_test.sort_index()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 306.55it/s]


Unnamed: 0,num_passages,num_ques_ref_pairs,num_instances,avg_annots_per_ques_ref_pair,pct_ref_cont_overlap,pct_ref_ques_overlap
cosmosqa,212.0,226.0,1017.0,4.5,0.0,0.0
drop,17.0,117.0,152.0,1.3,0.5,0.3
mcscript,93.0,583.0,1409.0,2.4,0.1,0.0
narrativeqa,18.0,500.0,1707.0,3.4,0.2,0.0
quoref,38.0,180.0,509.0,2.8,0.8,0.0
socialiqa,611.0,611.0,1527.0,2.5,0.0,0.0
total,989.0,2217.0,6321.0,,,


## Analysis `score>=3`

An agreement score above 3, implies the answers is either equivalent or more correct than the reference. Let us get a perspective of how the stats change with this.

In [12]:
import pandas as pd 

df_train = pd.DataFrame(get_statistics_for_split(f'{MOCHA_DIR_PATH}/train.json', 
                                                 compute_average_lengths=True, 
                                                 agreement_score=3)).T
df_train.sort_index().to_clipboard()
df_train.sort_index()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [01:48<00:00, 18.10s/it]


Unnamed: 0,num_passages,num_ques_ref_pairs,num_instances,avg_annots_per_ques_ref_pair,pct_ref_cont_overlap,pct_ref_ques_overlap,avg_passage_len,avg_question_len,avg_reference_len,avg_candidate_len,avg_candidate_scores,avg_candidate_agreement
cosmosqa,917.0,966.0,1752.0,1.8,0.0,0.0,73.1,10.8,7.4,7.5,4.3,4.3
drop,76.0,231.0,271.0,1.2,0.5,0.1,222.0,11.4,3.8,3.3,3.8,3.8
mcscript,455.0,2050.0,3340.0,1.6,0.1,0.0,197.6,7.6,4.1,3.8,4.2,4.2
narrativeqa,85.0,1811.0,3759.0,2.1,0.2,0.0,333.4,9.6,5.6,5.9,4.3,4.3
quoref,172.0,641.0,1072.0,1.7,0.8,0.0,329.1,16.4,2.5,5.6,3.5,3.5
socialiqa,2306.0,2306.0,3001.0,1.3,0.0,0.0,15.6,7.1,3.8,3.7,4.2,4.2
total,4011.0,8005.0,13195.0,,,,,,,,,


In [13]:
df_dev = pd.DataFrame(get_statistics_for_split(f'{MOCHA_DIR_PATH}/dev.json', compute_average_lengths=True, agreement_score=3)).T
df_dev.sort_index().to_clipboard()
df_dev.sort_index()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:11<00:00,  1.92s/it]


Unnamed: 0,num_passages,num_ques_ref_pairs,num_instances,avg_annots_per_ques_ref_pair,pct_ref_cont_overlap,pct_ref_ques_overlap,avg_passage_len,avg_question_len,avg_reference_len,avg_candidate_len,avg_candidate_scores,avg_candidate_agreement
cosmosqa,115.0,121.0,214.0,1.8,0.0,0.0,79.5,10.7,7.5,7.2,4.4,4.4
drop,10.0,24.0,26.0,1.1,0.6,0.4,210.7,12.8,3.1,4.0,4.3,4.3
mcscript,60.0,239.0,363.0,1.5,0.0,0.0,197.3,7.8,3.7,3.5,4.3,4.3
narrativeqa,11.0,205.0,379.0,1.8,0.2,0.0,351.6,9.5,4.7,5.3,4.2,4.2
quoref,21.0,56.0,93.0,1.7,0.8,0.0,341.7,14.8,2.6,4.6,3.5,3.5
socialiqa,303.0,303.0,363.0,1.2,0.0,0.0,15.7,7.1,3.8,3.7,4.2,4.2
total,520.0,948.0,1438.0,,,,,,,,,
