# Interactive Sankey Plot Generation
Below, you can find a interactive sankey plot generator for our substitution strategies. 

In [5]:
import json
import ipywidgets as widgets
import plotly.graph_objects as go

from ipywidgets import interact, interactive, fixed, interact_manual

In [6]:
critical_relation_names = {
    'residence': ['per:countries_of_residence', "per:cities_of_residence", "per:stateorprovinces_of_residence"],
    'headquarter': ["org:country_of_headquarters", "org:city_of_headquarters", "org:stateorprovince_of_headquarters"],
    'death': ["per:city_of_death", "per:stateorprovince_of_death", "per:country_of_death"],
    'birth': ["per:city_of_birth", "per:stateorprovince_of_birth", "per:country_of_birth", "per:origin"],
    'name': ['org:alternate_names', "per:alternate_names"],
    'religion': ["per:religion", "org:political/religious_affiliation"],
    'member': ["org:member_of", "org:top_members/employees", "per:employee_of"]
}

relations = ['no_relation', 'org:alternate_names', 'org:city_of_headquarters', 'org:country_of_headquarters', 'org:dissolved', 
             'org:founded', 'org:founded_by', 'org:member_of', 'org:members', 'org:number_of_employees/members', 'org:parents', 
             'org:political/religious_affiliation', 'org:shareholders', 'org:stateorprovince_of_headquarters', 'org:subsidiaries', 
             'org:top_members/employees', 'org:website', 'per:age', 'per:alternate_names', 'per:cause_of_death', 'per:charges', 
             'per:children', 'per:cities_of_residence', 'per:city_of_birth', 'per:city_of_death', 'per:countries_of_residence', 
             'per:country_of_birth', 'per:country_of_death', 'per:date_of_birth', 'per:date_of_death', 'per:employee_of', 
             'per:origin', 'per:other_family', 'per:parents', 'per:religion', 'per:schools_attended', 'per:siblings', 'per:spouse', 
             'per:stateorprovince_of_birth', 'per:stateorprovince_of_death', 'per:stateorprovinces_of_residence', 'per:title']

color_bar = [
    "#FFD1DC",  # Pastel Pink
    "#E6E6FA",  # Pale Lavender
    "#B0E0E6",  # Soft Mint
    "#FFDAB9",  # Light Peach
    "#ADD8E6",  # Baby Blue
    "#FF6B6B",  # Muted Coral
    "#B0C4DE",  # Powder Blue°°
    "#C8A2C8",  # Lilac
    "#98FB98",  # Mint Green
    "#FFB6C1",  # Blush Pink
    "#87CEEB",  # Sky Blue
    "#FFFFE0",  # Pale Yellow
    "#A9D18E"   # Pale Green
]

name_2_sub = {
    'same-role': 'sub1',
    'same-type': 'sub2',
    'diff-type': 'sub3',
    'masked': 'sub4'
}

results_file_path = f'./results/mapping_preds_anonymous.json'
results_data = json.load(open(results_file_path))

In [7]:
def f(strategy, subj, obj, model, relation_set):

    sub_pos = []
    if subj:
        sub_pos.append('subj')
    if obj:
        sub_pos.append('obj')
    sub_pos = '+'.join(sub_pos)
    if not sub_pos:
        fig = go.Figure()
        fig.show()
        return
    
    rel_true = []
    rel_pred = []

    filename = f'controlled_tacred_test_{name_2_sub[strategy]}_{sub_pos}.json'
    print('')
    
    for s, sample in results_data.items():
        rel_true.append(sample['test_preds'][model])
        rel_pred.append(sample['adv_preds'][filename][model])
    
    d = {}
    for t, p in zip(rel_true, rel_pred):
        if t not in d:
            d[t] = {}
        if p not in d[t]:
            d[t][p] = 0
        d[t][p] += 1

    source, target, value = [], [], []

    for s in sorted(d.keys()):
        for t in sorted(d[s].keys()):
            if s not in critical_relation_names[relation_set]:
                continue
            source.append(relations.index(s))
            target.append(relations.index(t) + len(relations))
            value.append(d[s][t])

    colors = ['rgba(211, 211, 211, .7)' for _ in range(len(relations) * 2)]

    for i, idx in enumerate(set(source)):
        color_index = (critical_relation_names[relation_set].index(relations[idx]))
        colors[idx] = color_bar[color_index]
        colors[idx + len(relations)] = color_bar[color_index]

    # color flows that target critical relations, others are gray
    color_relation = []
    for i, _ in enumerate(source):
        if target[i] - len(relations) in source:
            color_relation.append(colors[source[i]])
        else:
            color_relation.append('rgba(211, 211, 211, .7)')

    fig = go.Figure(data=[go.Sankey(
        node=dict(
            pad=15,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=relations + relations,
            color=colors,
        ),
        link=dict(
            source=source,
            target=target,
            value=value,
            color=color_relation
        ))])

    # add header

    column_labels = ["standard", "adversarial"]

    for x_coordinate, column_name in enumerate(column_labels):
        fig.add_annotation(
            x=x_coordinate,
            y=1.15,

            text=column_name,
            showarrow=False,
            font=dict(
                size=12,
                color="black"
            ),
            align="center",
        )

    fig.update_layout(
        title_text=f'{strategy} {sub_pos}: {model} <br> <sup>{", ".join(critical_relation_names[relation_set])} ({sum(value)} samples)</sup> <br><br>',
        font_size=15, title_x=0.5, title_y=0.95, width=900, height=400)  # height=800, width=900,

    fig.update_layout({
'plot_bgcolor': 'rgba(0, 0, 0, 0)',
'paper_bgcolor': 'rgba(0, 0, 0, 0)',
})
    
    fig.show()
    fig.write_image("fig_test.svg")

In [None]:
interact(f, 
         strategy=['same-role', 'same-type', 'diff-type', 'masked'], 
         subj=True, 
         obj=True, 
         model=["LUKE", "SpanBERT", "SURE", "TYP_marker", "UniST", "NLI_w", "NLI_wo"], 
         relation_set=["residence", "headquarter", "death", "birth", "name", "religion", "member"])