In [71]:
import dill
import networkx as nx
import numpy as np
import pandas as pd
from torch_geometric.utils import to_networkx
from tqdm.autonotebook import tqdm

In [66]:
def get_distance(graph_data):
    nx_graph = to_networkx(graph_data).to_undirected()
    n1_indices = np.nonzero(graph_data.n1_mask).flatten().tolist()
    n2_indices = np.nonzero(graph_data.n2_mask).flatten().tolist()

    min_distance = 1e10
    for n1_index in n1_indices:
        for n2_index in n2_indices:
            shortest_path = nx.shortest_path(nx_graph, n1_index, n2_index)
            path_len = len(shortest_path)
            if path_len < min_distance:
                min_distance = path_len
    return min_distance

In [73]:
datasets = ["risec", "japflow", "chemu", "mscorpus"]

dep_distances = []
amr_distances = []


for dataset in tqdm(datasets):
    with open(f"/projects/flow_graphs/data/{dataset}/data_amr.dill", "rb") as f:
        data = dill.load(f)


    for split in ["train", "dev", "test"]:
        instances = data[split]["rels"]
        
        for instance in tqdm(instances):
            dep_distance = get_distance(instance["dep_data"])
            dep_distances.append({
                "dataset": dataset,
                "split": split,
                "distance": dep_distance
            })

            amr_distance = get_distance(instance["amr_data"])
            if instance["amr_data"].n1_mask.sum() > 0 and instance["amr_data"].n2_mask.sum() > 0:
                amr_distances.append({
                    "dataset": dataset,
                    "split": split,
                    "distance": amr_distance
                })
    

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/3689 [00:01<?, ?it/s]

  0%|          | 0/1689 [00:00<?, ?it/s]

  0%|          | 0/2213 [00:00<?, ?it/s]

  0%|          | 0/13958 [00:00<?, ?it/s]

  0%|          | 0/1745 [00:00<?, ?it/s]

  0%|          | 0/1745 [00:00<?, ?it/s]

  0%|          | 0/11411 [00:00<?, ?it/s]

  0%|          | 0/2885 [00:00<?, ?it/s]

  0%|          | 0/3332 [00:00<?, ?it/s]

  0%|          | 0/12330 [00:00<?, ?it/s]

  0%|          | 0/2287 [00:00<?, ?it/s]

  0%|          | 0/3782 [00:00<?, ?it/s]

In [75]:
pd.DataFrame(dep_distances).groupby(["dataset", "split"]).aggregate(["mean", "std"])

Unnamed: 0_level_0,Unnamed: 1_level_0,distance,distance
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std
dataset,split,Unnamed: 2_level_2,Unnamed: 3_level_2
chemu,dev,3.092548,1.332706
chemu,test,3.056122,1.385964
chemu,train,3.060643,1.362643
japflow,dev,3.20745,1.86021
japflow,test,3.273352,1.952799
japflow,train,3.236137,1.890225
mscorpus,dev,2.804985,1.352394
mscorpus,test,2.812797,1.42534
mscorpus,train,2.824574,1.444335
risec,dev,2.746004,1.229099


In [76]:
pd.DataFrame(amr_distances).groupby(["dataset", "split"]).aggregate(["mean", "std"])

Unnamed: 0_level_0,Unnamed: 1_level_0,distance,distance
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std
dataset,split,Unnamed: 2_level_2,Unnamed: 3_level_2
chemu,dev,3.396344,2.350579
chemu,test,3.577506,2.268459
chemu,train,3.402605,2.250643
japflow,dev,1.366268,0.990138
japflow,test,1.37406,1.015147
japflow,train,1.367777,0.999671
mscorpus,dev,3.028523,2.083688
mscorpus,test,2.845283,2.047945
mscorpus,train,2.912596,2.084232
risec,dev,2.750192,0.829185
