In [None]:
import os.path as osp
from build.config import STORAGE_FOLDER

In [None]:
def human_format(num):
    num = float('{:.3g}'.format(num))
    magnitude = 0
    if num < 10000:
        return str(int(num))
    while abs(num) >= 1000:
        magnitude += 1
        num /= 1000.0
    return '{}{}'.format('{:f}'.format(num).rstrip('0').rstrip('.'), ['', 'K', 'M', 'B', 'T'][magnitude])

# Wikidata Dump Index Distribution

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import json

with open(osp.join(STORAGE_FOLDER, 'wikidata_dumps_index.json')) as f:
    idx = json.load(f)
import pandas as pd

df = [(k[:4] + '-' + k[4:6] + '-' + k[-2:],'Wikidata Repository' if v.startswith('https://dumps') else 'Internet Archive') for k,v in idx.items()]
df = pd.DataFrame(df, columns=['Date', 'Source'])
df['Date'] = df['Date'].astype('datetime64[ns]')

plt.figure(dpi=400, figsize=(8,2.5))

sns.histplot(df, x='Date', hue='Source', bins=50)
plt.savefig('wikidata_dump_index.png', bbox_inches='tight')
plt.close()

# Type distribution Wikidata

In [None]:
from build.utils.wd import get_info_wikidata
import json
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd

def get_stats(version : str, top_k = 20, post : bool = False):
    post_str = '' if not post else 'post_'
    if not post:
        j_dict = json.load(open(osp.join(STORAGE_FOLDER, 'resources/script_stats/process_json_dump_wikidata_%s_json.json' % version)))
        n_entities = j_dict['n_entities']
        type_count = list(j_dict['type_count'].items())
    else:
        j_dict = json.load(open(osp.join(STORAGE_FOLDER,'resources/script_stats/preprocess_dump_%s.json' % version)))
        n_entities = j_dict['post_n_wiki_entities']
        type_count = list(j_dict['%stype_count' % post_str].items())
    
    # Temporary solution
    type_count = [(k,v) for k,v in type_count if k != "none"]
    top_k_type_count = sorted(type_count, key=lambda x : x[1], reverse=True)

    others_count = sum([x[1] for x in top_k_type_count[top_k:]])
    top_k_type_count = top_k_type_count[:top_k]
    top_k_type_count.append(('others', others_count))
    type2label = {k:v.get('name', k) for k,v in get_info_wikidata([x[0] for x in top_k_type_count], version=version).items()}
    stats = [(type2label[x] if x != 'others' else 'others', y) for x,y in top_k_type_count]
    # stats = {k : v/len(types) for k,v in Counter(types).items()}
    stats = pd.DataFrame(data=stats, columns=['Type', 'Count'])
    return stats, n_entities

top_k = 6
new_stats, pre_n_entities_new = get_stats('new', top_k=top_k)
old_stats, pre_n_entities_old = get_stats('old', top_k=top_k)
post_new_stats, post_n_entities_new = get_stats('new', post=True, top_k=top_k)
post_old_stats, post_n_entities_old = get_stats('old', post=True, top_k=top_k)

fig = make_subplots(rows=2, cols=2, specs=[[{"type": "pie"}, {"type": "pie"}], [{"type": "pie"}, {"type": "pie"}]])


fig.add_trace(
    go.Pie(values=old_stats['Count'], labels=old_stats['Type'], title='Old Wikidata before preprocessing<br>(%s entities)' % human_format(pre_n_entities_old), 
           domain=dict(x=[0, 0.5], y=[0, 0.5]), name='haha'),
    row=1, col=1
)
fig.add_trace(
    go.Pie(values=post_old_stats['Count'], labels=post_old_stats['Type'], title='Old Wikidata after preprocessing<br>(%s entities)' % human_format(post_n_entities_old),
           domain=dict(x=[0.5, 1], y=[0, 0.5])),
    row=1, col=2
    
)
fig.add_trace(
    go.Pie(values=new_stats['Count'], labels=new_stats['Type'], title='New Wikidata before preprocessing<br>(%s entities)' % human_format(pre_n_entities_new), 
           domain=dict(x=[0, 0.5], y=[0.5, 1])),
    row=2, col=1,
)
fig.add_trace(
    go.Pie(values=post_new_stats['Count'], labels=post_new_stats['Type'], title='New Wikidata after preprocessing<br>(%s entities)' % human_format(post_n_entities_new), 
           domain=dict(x=[0.5, 1], y=[0.5, 1])),
    row=2, col=2,
)
fig.update_layout(
    margin=dict(l=5, r=5, t=5, b=5),
)
fig.show()
fig.write_image('./wikidata_type_distribution.png', scale=4)

In [None]:
pd.concat([new_stats, post_new_stats], axis=1)

# Most popular relations

In [None]:
from build.utils.wd import get_info_wikidata
import json
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd

def get_stats(version : str, top_k = 20, post : bool = False):
    post_str = '' if not post else 'post_'
    if not post:
        j_dict = json.load(open(osp.join(STORAGE_FOLDER, 'resources/script_stats/process_json_dump_wikidata_%s_json.json' % version)))
        n_relations = len(j_dict['relation_count'])
        relation_count = list(j_dict['relation_count'].items())
    else:
        j_dict = json.load(open(osp.join(STORAGE_FOLDER, 'resources/script_stats/preprocess_dump_%s.json' % version)))
        n_relations = len(j_dict['post_relation_count'])
        relation_count = list(j_dict['%srelation_count' % post_str].items())
    
    # Temporary solution
    relation_count = [(k,v) for k,v in relation_count]
    top_k_relation_count = sorted(relation_count, key=lambda x : x[1], reverse=True)

    others_count = sum([x[1] for x in top_k_relation_count[top_k:]])
    top_k_relation_count = top_k_relation_count[:top_k]
    # top_k_relation_count.append(('others', others_count))
    type2label = {k:v.get('name', k) for k,v in get_info_wikidata([x[0] for x in top_k_relation_count], version=version).items()}
    stats = [(type2label[x] if x != 'others' else 'others', y) for x,y in top_k_relation_count]
    # stats = {k : v/len(types) for k,v in Counter(types).items()}
    stats = pd.DataFrame(data=stats, columns=['Relation', 'Count'])
    stats['Relation'] = stats['Relation'].apply(lambda x : x[:20] + ('...' if len(x) > 20 else ''))
    return stats, n_relations

top_k = 15
new_stats, pre_n_relations_new = get_stats('new', top_k=top_k)
old_stats, pre_n_relations_old = get_stats('old', top_k=top_k)
post_new_stats, post_n_relations_new = get_stats('new', post=True, top_k=top_k)
post_old_stats, post_n_relations_old = get_stats('old', post=True, top_k=top_k)

fig = make_subplots(rows=2, cols=2,
                    subplot_titles=(
                        'Old Wikidata before preprocessing<br>(%s relations)' % human_format(pre_n_relations_old),
                        'Old Wikidata after preprocessing<br>(%s relations)' % human_format(post_n_relations_old),
                        'New Wikidata before preprocessing<br>(%s relations)' % human_format(pre_n_relations_new), 
                        'New Wikidata after preprocessing<br>(%s relations)' % human_format(post_n_relations_new), 
                        )
                        )


fig.add_trace(
    go.Bar(y=old_stats['Count'], x=old_stats['Relation'], marker_color='#636EFA'),
    row=1, col=1
)
fig.add_trace(
    go.Bar(y=post_old_stats['Count'], x=post_old_stats['Relation'], marker_color='#636EFA'),
    row=1, col=2
    
)
fig.add_trace(
    go.Bar(y=new_stats['Count'], x=new_stats['Relation'], marker_color='#00CC96'),
    row=2, col=1,
)
fig.add_trace(
    go.Bar(y=post_new_stats['Count'], x=post_new_stats['Relation'], marker_color='#00CC96'),
    row=2, col=2,
)
fig.update_layout(
    margin=dict(l=5, r=5, t=50, b=5),
)
fig.update_layout(showlegend=False)
fig.show()
fig.write_image('./wikidata_relation_distribution.png', scale=4, width=1500, height=630)

# Filtering process visualization

In [None]:
#### IMPORTANT : install graphviz for this cell to work using this command "sudo apt install graphviz"

import json
from build.utils.wd import db
from tree_plot import tree_plot

def generate_filtering_process_image(version : str):
    # Loading stats
    stats = json.load(open(osp.join(STORAGE_FOLDER,'resources/script_stats/preprocess_dump_%s.json' % version)))
    stats0 = json.load(open(osp.join(STORAGE_FOLDER,'resources/script_stats/process_json_dump_wikidata_%s_json.json' % version)))

    n_entities0 = stats0['n_entities']
    n_groups0 = stats0['n_groups']
    n_triples0 = stats0['n_triples']
    n_relations0 = len(stats0['relation_count'])
    n_triples_banned_types = stats0['n_triples_banned_types']
    n_triples_rank_deprecated = stats0['n_triples_deprecated']
    n_triples = stats['n_triples']
    n_irrelevant_entities = stats['n_irrelevant_entities']
    n_irrelevant_entities_triples = stats['n_irrelevant_entities_triples']
    n_entities = db['wikidata_new_json'].estimated_document_count()
    n_novalue_triples = stats['n_novalue']
    n_somevalue_triples = stats['n_somevalue']
    n_corrections_group = stats['n_corrections']
    n_corrections_triples = stats['n_corrections']
    n_empty_entities = stats['n_empty_entities']
    n_temp_rel_pit_deprecation_triples = stats['n_temp_rel_pit_deprecation_all']
    n_temp_rel_pit_deprecation_groups = stats['n_temp_rel_pit_deprecation']
    n_triples_filtered_because_of_meta = stats['n_triples_filtered_because_of_meta']
    n_groups_filtered_because_of_meta = sum(stats['meta_relations_count'].values())
    n_restricted_triples = stats['n_restricted_triples']
    n_wiki_entities = stats['n_wiki_entities']
    n_relations = len(stats['relation_count'])
    n_groups = sum(stats['relation_count'].values())
    n_relevant_entities = stats['n_relevant_entities']
    n_relevant_triples = stats['n_relevant_entities_triples']
    n_relevant_groups = stats['n_relevant_entities_groups']

    n_group_post_deprecated_banned = stats0['n_groups_post']

    # Defining each node of the tree

    state0 = "#B#Entities = %s\nRelations = %s\nGroups = %s\nTriples = %s" % (human_format(n_entities0), human_format(n_relations0), human_format(n_groups0), human_format(n_triples0))

    state_banned_deprecated = '#B#Entities = %s\nGroups = %s\nTriples = %s' % (human_format(n_entities0), human_format(n_group_post_deprecated_banned), human_format(n_triples0-n_triples_banned_types-n_triples_rank_deprecated))
    how_many_banned_deprecated = '<R<<<u>Remove banned types<br></br>and rank deprecated</u><br></br>Groups = %s (-%.1f%%)<br></br>Triples = %s (-%.1f%%)>' % (human_format(n_groups0-n_group_post_deprecated_banned), (n_groups0-n_group_post_deprecated_banned)/n_groups0*100, human_format(n_triples_banned_types+n_triples_rank_deprecated), (n_triples_banned_types+n_triples_rank_deprecated)/n_triples0*100)

    state_relevant_wikipedia = '#B#Entities = %s\nGroups = %s\nTriples = %s' % (human_format(n_wiki_entities-n_irrelevant_entities), human_format(n_relevant_groups), human_format(n_triples-n_irrelevant_entities_triples))
    how_many_irrelevant_wikipedia = '<R<<<u>Keep only relevant<br></br>Wikipedia entities</u><br></br>Entities = %s (-%.1f%%)<br></br>Groups = %s (-%.1f%%)<br></br>Triples = %s (-%.1f%%)>' % (human_format(n_entities0-(n_wiki_entities-n_irrelevant_entities)), (n_entities0-(n_wiki_entities-n_irrelevant_entities))/n_entities0*100,human_format(n_group_post_deprecated_banned-n_relevant_groups), (n_group_post_deprecated_banned-n_relevant_groups)/ n_group_post_deprecated_banned*100, human_format(n_triples0-n_triples_banned_types-n_triples_rank_deprecated-(n_triples-n_irrelevant_entities_triples)), (n_triples0-n_triples_banned_types-n_triples_rank_deprecated-(n_triples-n_irrelevant_entities_triples))/(n_triples0-n_triples_banned_types-n_triples_rank_deprecated)*100)

    state_after_meta_removal = '#B#Entities = %s\nGroups = %s\nTriples = %s' % (human_format(n_relevant_entities), human_format(n_relevant_groups-n_groups_filtered_because_of_meta), human_format(n_relevant_triples-n_triples_filtered_because_of_meta))
    how_many_meta = '<R<<<u>Remove meta triples</u><br></br>Groups = %s (-%.1f%%)<br></br>Triples = %s (-%.1f%%)>' % (human_format(n_groups_filtered_because_of_meta), n_groups_filtered_because_of_meta/n_relevant_groups*100, human_format(n_groups_filtered_because_of_meta), n_triples_filtered_because_of_meta/n_relevant_triples*100)

    n_triples_after_meta = n_relevant_triples-n_triples_filtered_because_of_meta
    n_groups_after_meta = n_relevant_groups-n_groups_filtered_because_of_meta
    n_novalue_somevalue_group = stats['n_novalue_somevalue_group']

    state_after_special_value = '#B#Entities = %s\nGroups = %s\nTriples = %s' % (human_format(n_relevant_entities), human_format(n_groups_after_meta-n_novalue_somevalue_group), human_format(n_triples_after_meta- (n_novalue_triples + n_somevalue_triples)))
    how_many_special = '<R<<<u>Remove special values</u><br></br>Groups = %s (-%.1f%%)<br></br>Triples = %s (-%.1f%%)>' % (human_format(n_novalue_somevalue_group), n_novalue_somevalue_group/n_groups_after_meta*100, human_format(n_novalue_triples + n_somevalue_triples), (n_novalue_triples + n_somevalue_triples)/n_triples_after_meta*100)

    n_triples_after_special = n_triples_after_meta - (n_novalue_triples + n_somevalue_triples)
    n_groups_after_special = n_groups_after_meta-n_novalue_somevalue_group
    n_restricted_group_removed = stats['n_restricted_group_removed']

    state_after_restricted_removal = '#B#Entities = %s\nGroups = %s\nTriples = %s' % (human_format(n_relevant_entities), human_format(n_groups_after_special-n_restricted_group_removed), human_format(n_triples_after_special- n_restricted_triples))
    how_many_restricted = '<R<<<u>Remove restricted triples</u><br></br>Groups = %s (-%.1f%%)<br></br>Triples = %s (-%.1f%%)>' % (human_format(n_restricted_group_removed), (n_restricted_group_removed)/n_groups_after_special*100, human_format(n_restricted_triples), n_restricted_triples/n_triples_after_special*100)

    n_triples_after_restricted = n_triples_after_special - n_restricted_triples
    n_groups_after_restricted = n_groups_after_special-n_restricted_group_removed

    state_after_pit = '#B#Entities = %s\nGroups = %s\nTriples = %s' % (human_format(n_relevant_entities), human_format(n_groups_after_restricted), human_format(n_triples_after_restricted-n_temp_rel_pit_deprecation_triples))
    how_many_pit = '<R<<<u>Keep most recent value<br></br>(temporal functional relations)</u><br></br>Triples = %s (-%.1f%%)>' % (human_format(n_temp_rel_pit_deprecation_triples), n_restricted_triples/n_triples_after_restricted*100)

    n_triples_after_pit = n_triples_after_restricted - n_temp_rel_pit_deprecation_triples
    n_groups_after_pit = n_groups_after_restricted

    state_after_corrections = '<B<<Entities = %s<br></br>Groups = %s<br></br>Triples = %s>' % (human_format(n_relevant_entities), human_format(n_groups_after_pit), human_format(n_triples_after_pit-n_corrections_triples))
    how_many_corrections = '<R<<<u>Correct errors in<br></br>temporal functional relation</u><br></br>Triples = %s (-%.1f%%)>' % (human_format(n_corrections_triples), n_corrections_triples/n_triples_after_pit*100)

    post_n_entities = stats['post_n_wiki_entities'] 
    post_n_triples = stats['n_triples_postproc']
    post_n_relations = len(stats['post_relation_count'])
    post_n_groups = sum(stats['post_relation_count'].values())
    n_empty_entities = stats['n_empty_entities']

    state_after_empty_removal = '#B#Entities = %s\nRelations = %s\nGroups = %s\nTriples = %s' % (human_format(post_n_entities), human_format(post_n_relations), human_format(post_n_groups), human_format(post_n_triples))
    how_many_empty = '<R<<<u>Remove empty relations</u><br></br>Entities = %s (-%.1f%%)>' % (human_format(n_empty_entities), n_empty_entities/n_relevant_entities*100)

    # Defining the structure of the tree

    tree = {
        state0: [
            {
                state_banned_deprecated : [
                    {state_relevant_wikipedia : [
                        {state_after_meta_removal : [
                            {state_after_special_value : [
                                {state_after_restricted_removal : [
                                    {state_after_pit : [{state_after_corrections : [state_after_empty_removal, how_many_empty]}, 
                                                        how_many_corrections]},
                                    how_many_pit
                                ]},
                                how_many_restricted
                            ]},
                            how_many_special
                        ]},
                        how_many_meta
                    ]},
                    how_many_irrelevant_wikipedia
                ]},
            how_many_banned_deprecated
        ]
    }

    # Plotting the tree
    def transform_node(node):
        name = node.get_label()
        color = name[1]
        name = name[3:]
        node.set_label(name)
        if color == 'R':
            node.set_color('#e74c3c')
            node.set_fillcolor("#f9c4be")

    weight_order = [5,5,5,5,5,5,5,5] + [0]*50
    graph = tree_plot(tree, transform_node, weight_order)

    graph.write('filtering_process_%s.png' % version, format='png')

generate_filtering_process_image('old')
generate_filtering_process_image('new')

# Proportion of new/both/old in entities/groups/triples

In [None]:
from build.utils.wd import get_info_wikidata
import json
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px

j_dict = json.load(open(osp.join(STORAGE_FOLDER, 'resources/script_stats/compute_diff.json')))

y = j_dict['n_new_entities'], j_dict['n_old_entities'], j_dict['n_both_entities'], j_dict['n_new_group'], j_dict['n_old_group'], j_dict['n_both_group'], j_dict['n_new_triples'], j_dict['n_old_triples'], j_dict['n_both_triples']
color = ['New', 'Old', 'Both']
x = ['Entities', 'Groups', 'Triples']
fig = make_subplots(rows=1, cols=3,
                    subplot_titles=['Entities', 'Groups', 'Triples']
                        )

for i in range(0,len(y),3):
    fig.add_trace(go.Bar(x=color,y=y[i:i+3],name=x[i//3]), row=1, col=i//3+1)

fig.data[0].marker.color = px.colors.qualitative.Plotly[:3]
fig.data[1].marker.color = px.colors.qualitative.Plotly[:3]
fig.data[2].marker.color = px.colors.qualitative.Plotly[:3]

fig.update_layout(showlegend=False)
fig.update_layout(
    margin=dict(l=5, r=5, t=30, b=5),
)

fig.show()
fig.write_image('./wikidata_diff_novelty.png', scale=4, height=150, width=600)

# Proportion of physically new entities in new entities

In [None]:
import json
import plotly.graph_objects as go
import plotly.express as px

j_dict = json.load(open(osp.join(STORAGE_FOLDER,'resources/script_stats/compute_diff.json')))
x, y = zip(*j_dict['physicality_count'].items())
rename = {
    'new' : 'Novel',
    'old' : 'Not novel',
    'unk' : 'Unknown'
}
x = [rename[i] for i in x]
fig = go.Figure()
fig.add_trace(go.Bar(x=x,y=y))
fig.update_layout(showlegend=False)
fig.update_layout(
    margin=dict(l=5, r=5, t=5, b=5),
)
fig.show()
fig.write_image('./physicality_stats.png', scale=4, height=100, width=300)

# Datatypes before and after preprocessing

In [None]:
from build.utils.wd import get_info_wikidata
import json
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd

def get_stats(version : str, post : bool = False, top_k=5):
    post_str = '' if not post else 'post_'
    if not post:
        j_dict = json.load(open(osp.join(STORAGE_FOLDER, 'resources/script_stats/process_json_dump_wikidata_%s_json.json' % version)))
        n_triples = j_dict['n_triples']
        datatype_count = list((x[11:],y) for x,y in j_dict.items() if x.startswith('datatype_'))
    else:
        j_dict = json.load(open(osp.join(STORAGE_FOLDER, 'resources/script_stats/preprocess_dump_%s.json' % version)))
        n_triples = j_dict['n_triples_postproc']
        datatype_count = list(j_dict['%sdatatype_count' % post_str].items())
    
    top_k_datatype_count = sorted(datatype_count, key=lambda x : x[1], reverse=True)
    others_count = sum(x[1] for x in top_k_datatype_count[top_k:])
    top_k_datatype_count = top_k_datatype_count[:top_k]
    top_k_datatype_count.append(('others', others_count))
    stats = [(x, y) for x,y in top_k_datatype_count]
    # stats = {k : v/len(types) for k,v in Counter(types).items()}
    stats = pd.DataFrame(data=stats, columns=['Data type', 'Count'])
    return stats, n_triples

top_k = 6
new_stats, pre_n_triples_new = get_stats('new')
old_stats, pre_n_triples_old = get_stats('old')
post_new_stats, post_n_triples_new = get_stats('new', post=True)
post_old_stats, post_n_triples_old = get_stats('old', post=True)

fig = make_subplots(rows=2, cols=2, specs=[[{"type": "pie"}, {"type": "pie"}], [{"type": "pie"}, {"type": "pie"}]])


fig.add_trace(
    go.Pie(values=old_stats['Count'], labels=old_stats['Data type'], title='Old Wikidata before preprocessing<br>(%s triples)' % human_format(pre_n_triples_old), 
           domain=dict(x=[0, 0.5], y=[0, 0.5]), name='haha'),
    row=1, col=1
)
fig.add_trace(
    go.Pie(values=post_old_stats['Count'], labels=post_old_stats['Data type'], title='Old Wikidata after preprocessing<br>(%s triples)' % human_format(post_n_triples_old),
           domain=dict(x=[0.5, 1], y=[0, 0.5])),
    row=1, col=2
    
)
fig.add_trace(
    go.Pie(values=new_stats['Count'], labels=new_stats['Data type'], title='New Wikidata before preprocessing<br>(%s triples)' % human_format(pre_n_triples_new), 
           domain=dict(x=[0, 0.5], y=[0.5, 1])),
    row=2, col=1,
)
fig.add_trace(
    go.Pie(values=post_new_stats['Count'], labels=post_new_stats['Data type'], title='New Wikidata after preprocessing<br>(%s triples)' % human_format(post_n_triples_new), 
           domain=dict(x=[0.5, 1], y=[0.5, 1])),
    row=2, col=2,
)
fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
)
fig.show()
fig.write_image('./wikidata_datatype_distribution.png', scale=4)

# WFD Data analysis

In [None]:
from build.verbalize_wikifactdiff.verbalize_wikifactdiff import SAVE_PATH
import json
import plotly.express as px
import plotly.graph_objects as go
from collections import Counter
import os

os.makedirs('./wfd_analysis', exist_ok=True)

In [None]:
dataset = [json.loads(x) for x in open(osp.join(STORAGE_FOLDER, "wikifactdiff.jsonl"))]

In [None]:
# General relation distribution
relations = [x['relation']['label'] for x in dataset]
c = Counter(relations)
top_relations = [x[0] for x in c.most_common(20)]
relations = [x if x in top_relations else 'other' for x in relations]
relations.sort(key=lambda x : c[x], reverse=True)

fig = go.Figure(data=[go.Histogram(x=relations)])
fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
)
fig.show()
fig.write_image('./wfd_analysis/general_relation_distribution.png', scale=4)

In [None]:
# Physically new entities relation distribution
relations = [x['relation']['label'] for x in dataset if x['subject_is_ph_new']]
c = Counter(relations)
top_relations = [x[0] for x in c.most_common(20)]
relations = [x if x in top_relations else 'other' for x in relations]
relations.sort(key=lambda x : c[x], reverse=True)

fig = go.Figure(data=[go.Histogram(x=relations)])
fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
)
fig.show()
fig.write_image('./wfd_analysis/ph_new_relation_distribution.png', scale=4)

In [None]:
# Physically old entities relation distribution
relations = [x['relation']['label'] for x in dataset if not x['subject_is_ph_new']]
c = Counter(relations)
top_relations = [x[0] for x in c.most_common(20)]
relations = [x if x in top_relations else 'other' for x in relations]
relations.sort(key=lambda x : c[x], reverse=True)

fig = go.Figure(data=[go.Histogram(x=relations)])
fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
)
fig.show()
fig.write_image('./wfd_analysis/ph_old_relation_distribution.png', scale=4)

In [None]:
# Replacement update relation distribution
def is_replace(x):
    objects = x['objects']
    count = Counter(x['decision'] for x in objects)
    return len(count) == 2 and count.get('learn') == 1 and count.get('forget') == 1
relations = [x['relation']['label'] for x in dataset if is_replace(x)]
c = Counter(relations)
top_relations = [x[0] for x in c.most_common(20)]
relations = [x if x in top_relations else 'other' for x in relations]
relations.sort(key=lambda x : c[x], reverse=True)

fig = go.Figure(data=[go.Histogram(x=relations)])
fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
)
fig.show()
fig.write_image('./wfd_analysis/replacement_relation_distribution.png', scale=4)

In [None]:
# Update scenario distribution
from collections import defaultdict
import pandas as pd

def format_scenario(counts: dict):
    s = ""
    for k, v in sorted(counts.items(), key=lambda x : x[0]):
        s += '_'.join([k for _ in range(v)]) + "_"
    s = s.rstrip('_')
    return s
            

count_rels = defaultdict(lambda : defaultdict(lambda : 0))

scenarios = [frozenset(Counter([y['decision'] for y in x['objects']]).items()) for x in dataset if not x['subject_is_ph_new']]
c = Counter(scenarios)
top_scenarios = [x[0] for x in c.most_common(20)]
scenarios = ['_'.join(sorted([y[0] for y in x])) if x in top_scenarios else 'other' for x in scenarios]
def gen():
    for x in dataset:
        if x['subject_is_ph_new']:
            continue
        yield x

for i,x in enumerate(gen()) :
    count_rels[scenarios[i]][x['relation']['label']] += 1
for k,v in count_rels.copy().items():
    mp = [x[0] for x in sorted(v.items(), key=lambda y : y[1], reverse=True)][:4]
    count_rels[k] = mp
color = []
for i,x in enumerate(gen()):
    s = scenarios[i]
    v = count_rels.get(s)
    r = x['relation']['label']
    if r in v:
        color.append(r)
    else:
        color.append('other')

scenarios.sort(key=lambda x : c[x], reverse=True)

fig = px.histogram(pd.DataFrame(data = {
    'Scenario' : scenarios,
    'Relation' : color
}), x = 'Scenario', color='Relation')
fig.show()
fig.write_image('./wfd_analysis/scenario_distribution.png', scale=4)

In [None]:
c

In [None]:
# General datatype distribution
dt2cat = {
    'Date' : 'Date',
    'Entity' : 'Entity',
    'Year' : 'Date',
    'Month' : 'Date',
    'String' : 'String',
    'Quantity' : 'Quantity'
}
objects = [x['objects'][0] for x in dataset]
datatypes = [dt2cat[x['description']] if 'id' not in x else 'Entity' for x in objects]
c = Counter(datatypes)
top_datatypes = [x[0] for x in c.most_common(20)]
datatypes = [x if x in top_datatypes else 'other' for x in datatypes]
datatypes.sort(key=lambda x : c[x], reverse=True)

fig = go.Figure(data=[go.Histogram(x=datatypes)])
fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
)
fig.show()
fig.write_image('./wfd_analysis/general_datatype_distribution.png', scale=4)

In [None]:
# Physically new entities datatype distribution
dt2cat = {
    'Date' : 'Date',
    'Entity' : 'Entity',
    'Year' : 'Date',
    'Month' : 'Date',
    'String' : 'String',
    'Quantity' : 'Quantity'
}
objects = [x['objects'][0] for x in dataset if x['subject_is_ph_new']]
datatypes = [dt2cat[x['description']] if 'id' not in x else 'Entity' for x in objects]
c = Counter(datatypes)
top_datatypes = [x[0] for x in c.most_common(20)]
datatypes = [x if x in top_datatypes else 'other' for x in datatypes]
datatypes.sort(key=lambda x : c[x], reverse=True)

fig = go.Figure(data=[go.Histogram(x=datatypes)])
fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
)
fig.show()
fig.write_image('./wfd_analysis/ph_new_datatype_distribution.png', scale=4)

In [None]:
# Physically old entities datatype distribution
dt2cat = {
    'Date' : 'Date',
    'Entity' : 'Entity',
    'Year' : 'Date',
    'Month' : 'Date',
    'String' : 'String',
    'Quantity' : 'Quantity'
}
objects = [x['objects'][0] for x in dataset if not x['subject_is_ph_new']]
datatypes = [dt2cat[x['description']] if 'id' not in x else 'Entity' for x in objects]
c = Counter(datatypes)
top_datatypes = [x[0] for x in c.most_common(20)]
datatypes = [x if x in top_datatypes else 'other' for x in datatypes]
datatypes.sort(key=lambda x : c[x], reverse=True)

fig = go.Figure(data=[go.Histogram(x=datatypes)])
fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
)
fig.show()
fig.write_image('./wfd_analysis/ph_old_datatype_distribution.png', scale=4)

In [None]:
# Replacement datatype distribution
dt2cat = {
    'Date' : 'Date',
    'Entity' : 'Entity',
    'Year' : 'Date',
    'Month' : 'Date',
    'String' : 'String',
    'Quantity' : 'Quantity'
}
def is_replace(x):
    objects = x['objects']
    count = Counter(x['decision'] for x in objects)
    return len(count) == 2 and count.get('learn') == 1 and count.get('forget') == 1
objects = [x['objects'][0] for x in dataset if is_replace(x)]
datatypes = [dt2cat[x['description']] if 'id' not in x else 'Entity' for x in objects]
c = Counter(datatypes)
top_datatypes = [x[0] for x in c.most_common(20)]
datatypes = [x if x in top_datatypes else 'other' for x in datatypes]
datatypes.sort(key=lambda x : c[x], reverse=True)

fig = go.Figure(data=[go.Histogram(x=datatypes)])
fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
)
fig.show()
fig.write_image('./wfd_analysis/replacement_datatype_distribution.png', scale=4)

In [None]:
# Replacement datatype distribution with 2k 'population'
import random
dt2cat = {
    'Date' : 'Date',
    'Entity' : 'Entity',
    'Year' : 'Date',
    'Month' : 'Date',
    'String' : 'String',
    'Quantity' : 'Quantity'
}
def is_replace(x):
    objects = x['objects']
    count = Counter(x['decision'] for x in objects)
    return len(count) == 2 and count.get('learn') == 1 and count.get('forget') == 1
subset = [x for x in dataset if is_replace(x)]
# Use random to undersample the relation population by a factor of 14 randomly
objects = [x['objects'][0] for x in subset if x['relation']['label'] != 'population' or random.random() < 1/14]
datatypes = [dt2cat[x['description']] if 'id' not in x else 'Entity' for x in objects]
c = Counter(datatypes)
top_datatypes = [x[0] for x in c.most_common(20)]
datatypes = [x if x in top_datatypes else 'other' for x in datatypes]
datatypes.sort(key=lambda x : c[x], reverse=True)

fig = go.Figure(data=[go.Histogram(x=datatypes)])
fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
)
fig.show()
fig.write_image('./wfd_analysis/replacement_undersampled_datatype_distribution.png', scale=4)

In [None]:
# Replacement update relation distribution
def is_replace(x):
    objects = x['objects']
    count = Counter(x['decision'] for x in objects)
    return len(count) == 2 and count.get('learn') == 1 and count.get('forget') == 1
relations = [x['relation']['label'] for x in subset if dataset]
relations = [x for x in relations if random.random() < 1/14 or x != 'population']

c = Counter(relations)
top_relations = [x[0] for x in c.most_common(20)]
relations = [x if x in top_relations else 'other' for x in relations]
relations.sort(key=lambda x : c[x], reverse=True)

fig = go.Figure(data=[go.Histogram(x=relations)])
fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
)
fig.show()
fig.write_image('./wfd_analysis/undersampled_replacement_relation_distribution.png', scale=4)