In [2]:
import json
import matplotlib.pyplot as plt
import operator

In [3]:
def get_mentions_of_cluster(dataset, cluster_id):
    mentions = []
    for mention in dataset:
        if mention['coref_chain'] == cluster_id:
            mentions.append(mention)

    return mentions


def get_all_chains(mentions):
    clusters = {}
    for mention_dic in mentions:
        chain = mention_dic['coref_chain']
        clusters[chain] = [] if chain not in clusters else clusters[chain]
        clusters[chain].append(mention_dic)

    return clusters


def get_cluster_by_mention_num(clusters, num):
    clusters_names = []
    for cluster, doc_mention in clusters.items():
        num_of_mentions = len(doc_mention)
        if num_of_mentions == num:
            clusters_names.append(cluster)

    return clusters_names


def get_gold_within_doc(mentions):
    wd_cluster = {}
    for mention in mentions:
        chain = mention['coref_chain']
        doc = mention['doc_id']
        id_within_doc = chain + '_' + doc
        wd_cluster[id_within_doc] = [] if id_within_doc not in wd_cluster else wd_cluster[id_within_doc]
        wd_cluster[id_within_doc].append(mention)

    return wd_cluster



def get_metainfo(clusters):
    """
    print num of mentions per clusters
    :param clusters:
    :return:
    """
    dic = {}
    for cluster, doc_mention in clusters.items():
        num_of_mentions = len(doc_mention)
        dic[num_of_mentions] = dic.get(num_of_mentions, 0) + 1

    for length, num_of_clusters in sorted(dic.items()):
        print("There are {} clusters with {} mentions".format(num_of_clusters, length))

    number = dic.values()
    labels = dic.keys()

    #get_pie_chart(number, labels)

def extract_mention_text(cluster):
    mentions = []
    for mention in cluster:
        mention.append(mention['MENTION_TEXT'])
    return mentions


def get_pie_chart(values, labels):
    patches, texts = plt.pie(values, shadow=True, startangle=90)
    plt.legend(patches, labels, loc="best")
    plt.axis('equal')
    plt.show()


def within_to_cross(within_doc_cluster):
    cross_doc = {}
    for within in within_doc_cluster:
        name = within.split('_')[0]
        if name != 'INTRA' and name != 'Singleton':
            cross_doc[name] = [] if name not in cross_doc else cross_doc[name]
            cross_doc[name].append(within)

    return cross_doc


def find_most_popular_word(clusters, within_doc_cluster):
    words = {}
    for cluster in clusters:
        mentions = within_doc_cluster[cluster]
        vocab = set()
        for mention in mentions:
            text = mention['MENTION_TEXT']
            vocab.add(text)

        for word in vocab:
            words[word] = words.get(word, 0) + 1

    most_word = max(words.items(), key=operator.itemgetter(1))
    return most_word[0], most_word[1]/len(clusters)




def get_prob(within_doc_cluster):
    cross_doc = within_to_cross(within_doc_cluster)
    length = 0
    prob = 0
    for cluster, within in cross_doc.items():
        word, coverage = find_most_popular_word(within, within_doc_cluster)
        length += len(within)
        prob += coverage * len(within)

    return prob / length

def get_distinct_mentions(mentions):
    return list(set([m['tokens_str'] for m in mentions]))
        

# Select data to explore (ECB+ or MEANTIME)

In [4]:
data = 'ecb_data'

In [5]:
with open(data + '/all_entity_gold_mentions.json', 'r') as f:
    entity_mentions = json.load(f)

with open(data + '/all_event_gold_mentions.json', 'r') as f:
    event_mentions = json.load(f)
    
print('{} entity mentions'.format(len(entity_mentions)))
print('{} event mentions'.format(len(event_mentions)))

8289 entity mentions
6833 event mentions


In [6]:
event_cross_clusters = get_all_chains(event_mentions)
event_within_clusters = get_gold_within_doc(event_mentions)
entity_cross_clusters = get_all_chains(entity_mentions)
entity_within_clusters = get_gold_within_doc(entity_mentions)

print('Event Cross chains: {}'.format(len(event_cross_clusters)))
print('Event Within chains: {}'.format(len(event_within_clusters)))

print('Entity Cross chains: {}'.format(len(entity_cross_clusters)))
print('Entity Within chains: {}'.format(len(entity_within_clusters)))

Event Cross chains: 2741
Event Within chains: 5496
Entity Cross chains: 2221
Entity Within chains: 5850


In [7]:
entity_singleton = sum([1 for m in entity_mentions if m["coref_chain"].startswith('Singleton')])
event_singleton =  sum([1 for m in event_mentions if m["coref_chain"].startswith('Singleton')])
print('Number of entity singleton: {}'.format(entity_singleton))
print('Number of event singleton: {}'.format(event_singleton))

Number of entity singleton: 1231
Number of event singleton: 1775


# Explore dominant mention method

In [8]:
import operator

def extract_dominant_mention(cluster_id, entity=True):
    clusters = entity_cross_clusters if entity else event_cross_clusters
    mentions = {}
    different_docs = set()
    for m in clusters[cluster_id]:
        tokens = m['tokens_str']
        mentions[tokens] = set() if tokens not in mentions else mentions[tokens]
        mentions[tokens].add(m['doc_id'])
        different_docs.add(m['doc_id'])
    
    for m, docs in mentions.items():
        mentions[m] = len(docs)
    
    most_dominant = max(mentions.items(), key=operator.itemgetter(1))
    return most_dominant[0], most_dominant[1], len(different_docs), most_dominant[1]/len(different_docs)


def compute_statistics(entity=True, weight_avg=False):
    clusters = entity_cross_clusters if entity else event_cross_clusters
    numerator = 0
    denominator = 0
    exact = 0
    dominant_mentions = []
    num_of_docs = len(set(c for c, mentions in clusters.items() if len(mentions) > 1))
    for cluster_id, mentions in clusters.items():
        if len(mentions) > 1: #don't consider Singletons in this statistics
            dominant, num, total_doc, percentage = extract_dominant_mention(cluster_id, entity)
            dominant_mentions.append(dominant)
            if weight_avg:
                numerator += percentage * len(mentions)
                denominator += len(mentions)
            else:
                numerator += percentage
                denominator += 1
            
            if percentage == 1:
                exact += 1
    ambiguity = 1 - len(set(dominant_mentions)) / len(dominant_mentions)
    return numerator/denominator, exact/num_of_docs, ambiguity


In [9]:
dominant, num, total_doc, percentage = extract_dominant_mention('HUM16236184328979740')
print("Dominant mention: {}".format(dominant))
print("Appear in {} documents on {}, percentage: {}".format(num, total_doc, percentage))

Dominant mention: Tara Reid
Appear in 16 documents on 18, percentage: 0.8888888888888888


In [10]:
stat_entity, all_docs_entity, ambiguity_entity = compute_statistics(weight_avg=False)
stat_entity_weight, all_docs_entity_weigh, ambiguity_entity_weight = compute_statistics(weight_avg=True)
stat_event, all_docs_event, ambiguity_event = compute_statistics(entity=False, weight_avg=False)
stat_event_weight, all_docs_event_weight, ambiguity_event_weight = compute_statistics(entity=False, weight_avg=True)

print("Entity: \nNumber of document on average: {}, \nWeight average: {}, \nAmbiguity {}\nAll docs: {}".
      format(stat_entity, stat_entity_weight, ambiguity_entity, all_docs_entity))
print()
print("Event: \nNumber of document on average: {}, \nWeight average: {}, \nAmbiguity {}\nAll docs: {}".
      format(stat_event, stat_event_weight, ambiguity_event, all_docs_event))

Entity: 
Number of document on average: 0.7049065438768808, 
Weight average: 0.6646222766185738, 
Ambiguity 0.18354430379746833
All docs: 0.3556962025316456

Event: 
Number of document on average: 0.7175639288110573, 
Weight average: 0.6596520022553044, 
Ambiguity 0.3268698060941828
All docs: 0.36149584487534625


# How many clusters are actually across multiple documents

In [13]:
event_cross_doc_clusters = {}
entity_cross_doc_clusters = {}

for chain, mentions in entity_cross_clusters.items():
    docs = list(set([m['doc_id'] for m in mentions]))
    if len(docs) > 1:
        entity_cross_doc_clusters[chain] = {'num_of_docs': len(docs), 'num_of_mentions': len(mentions)}

for chain, mentions in event_cross_clusters.items():
    docs = list(set([m['doc_id'] for m in mentions]))
    if len(docs) > 1:
        event_cross_doc_clusters[chain] = {'num_of_docs': len(docs), 'num_of_mentions': len(mentions)}

print('Entity - Number of clusters across multiple documents: {} on {}'.
      format(len(entity_cross_doc_clusters), len(entity_cross_clusters)))

print('Event - Number of clusters across multiple documents: {} on {}'.
      format(len(event_cross_doc_clusters), len(event_cross_clusters)))

Entity - Number of clusters across multiple documents: 702 on 2221
Event - Number of clusters across multiple documents: 669 on 2741
