Import dependencies:

In [2]:
from typing import Tuple, List, Iterable
from pydot import Dot, graph_from_dot_data, Edge
from graphviz.graphs import BaseGraph
from graphviz import Source
import amrlib
from amrlib.graph_processing.amr_plot import AMRPlot
import numpy as np
import pandas as pd
import csv, pickle
from tqdm.notebook import tqdm

Extract nodes and edges from AMR graphs.

In [47]:
#cite: https://stackoverflow.com/questions/47426249/finding-list-of-edges-in-graphviz-in-python 
def get_graph_dot_obj(graph_spec) -> List[Dot]:
    """Get a dot (graphs) object list from a variety 
    of possible sources (postelizing inputs here)"""
    _original_graph_spec = graph_spec
    if isinstance(graph_spec, (BaseGraph, Source)):
        # get the source (str) from a graph object
        graph_spec = graph_spec.source
    if isinstance(graph_spec, str):
        # get a dot-graph from dot string data
        graph_spec = graph_from_dot_data(graph_spec)
    # make sure we have a list of Dot objects now
    assert isinstance(graph_spec, list) and all(
        isinstance(x, Dot) for x in graph_spec
    ), (
        f"Couldn't get a proper dot object list from: {_original_graph_spec}. "
        f"At this point, we should have a list of Dot objects, but was: {graph_spec}"
    )
    return graph_spec

def get_edges(graph_spec, label = False):
    """Get a list of edges for a given graph (or list of lists thereof).
    If ``postprocess_edges`` is ``None`` the function will return ``pydot.Edge`` objects from
    which you can extract any information you want.
    By default though, it is set to extract the node pairs for the edges, and you can
    replace with any function that takes ``pydot.Edge`` as an input.
    """
    graphs = get_graph_dot_obj(graph_spec)
    n_graphs = len(graphs)
    if n_graphs > 1:
        return [get_edges(graph) for graph in graphs]
    elif n_graphs == 0:
        raise ValueError(f"Your input had no graphs")
    else:
        graph = graphs[0]
        edges = graph.get_edges()
        edges_list = []
        if not label:
            for edge in edges:
                r1, r2 = graph.get_node(edge.get_source())[0].get_label().strip('\"').strip('\\').strip('\"'), graph.get_node(edge.get_destination())[0].get_label().strip('\"').strip('\\').strip('\"')
                if '/' in r1:
                    r1 = r1.split('/')[1]
                elif '\\' in r1:
                    r1 = r1.split('\\')[0]
                
                if '/' in r2:
                    r2 = r2.split('/')[1]
                elif '\\' in r1:
                    r2 = r2.split('\\')[0]

                edges_list.append([r1,r2])
        else:
            for edge in edges:
                r1, r2, r3 = graph.get_node(edge.get_source())[0].get_label().strip('\"').strip('\\').strip('\"'), graph.get_node(edge.get_destination())[0].get_label().strip('\"').strip('\\').strip('\"'), edge.get_label().strip('\"')[1:]
                if '/' in r1:
                    r1 = r1.split('/')[1]
                elif '\\' in r1:
                    r1 = r1.split('\\')[0]
                
                if '/' in r2:
                    r2 = r2.split('/')[1]
                elif '\\' in r1:
                    print("called")
                    r2 = r2.split('\\')[0]

                edges_list.append([r1,r2,r3])
        
        return edges_list

Save large intermediate results (Only used for the first run). 

In [57]:
with open('train_AMR.csv', 'r') as csv_file:
    csv_reader = csv.reader(csv_file, delimiter=',')
    next(csv_reader, None)
    g_train = []
    for row in tqdm(csv_reader, total=1600):
        AP = AMRPlot()
        AP.build_from_graph(entry = row)
        edges = get_edges(AP.graph, label=True)
        g_train.append(edges)

    np.save('g_train',g_train)
    print(g_train[1])

len(g_train)

  0%|          | 0/1600 [00:00<?, ?it/s]

ignoring epigraph data for duplicate triple: ('c2', ':ARG0', 'p')
ignoring epigraph data for duplicate triple: ('h', ':ARG0', 'y')
ignoring epigraph data for duplicate triple: ('c3', ':ARG1', 'w')
ignoring epigraph data for duplicate triple: ('c', ':name', 'n')
ignoring epigraph data for duplicate triple: ('c', ':ARG0', 'p2')


[['multi-sentence', 'acquire-01', 'snt1'], ['acquire-01', 'country', 'ARG0'], ['country', 'name', 'name'], ['acquire-01', 'stake', 'ARG1'], ['stake', 'company', 'mod'], ['company', 'name', 'name'], ['stake', 'percentage-entity', 'quant'], ['stake', 'mean-01', 'ARG1-of'], ['mean-01', 'approximately', 'ARG2'], ['approximately', 'share-01', 'op1'], ['multi-sentence', 'and', 'snt2'], ['and', 'price-01', 'op1'], ['price-01', 'share-01', 'ARG1'], ['price-01', 'monetary-quantity', 'ARG2'], ['monetary-quantity', 'low-04', 'ARG1-of'], ['low-04', 'very', 'degree'], ['and', 'lose-02', 'op2'], ['lose-02', 'stock', 'ARG0'], ['lose-02', 'percentage-entity', 'ARG1'], ['percentage-entity', 'include-91', 'ARG3-of'], ['include-91', 'value', 'ARG2'], ['lose-02', 'since', 'time'], ['since', 'begin-01', 'op1'], ['begin-01', 'year', 'ARG1'], ['lose-02', 'since', 'time'], ['since', 'struggle-02', 'op1'], ['struggle-02', 'industry', 'ARG0'], ['industry', 'cruise-01', 'mod'], ['struggle-02', 'company', 'ARG1']

  arr = np.asanyarray(arr)


(1600, 0)

In [61]:
with open('test_AMR.csv', 'r') as csv_file:
    csv_reader = csv.reader(csv_file, delimiter=',')
    next(csv_reader, None)
    g_test = []
    for row in tqdm(csv_reader, total=400):
        AP = AMRPlot()
        AP.build_from_graph(entry = row)
        edges = get_edges(AP.graph, label=True)
        g_test.append(edges)

    np.save('g_test',g_test)
    print(g_test[1])

len(g_test)

  0%|          | 0/400 [00:00<?, ?it/s]

ignoring epigraph data for duplicate triple: ('b', ':mod', 'p3')
ignoring epigraph data for duplicate triple: ('c', ':ARG0', 'h')
ignoring epigraph data for duplicate triple: ('w2', ':ARG0', 'p')


[['create-01', 'company', 'ARG0'], ['company', 'name', 'name'], ['create-01', 'job', 'ARG1'], ['job', 'new-01', 'ARG1-of'], ['job', 'multiple', 'quant'], ['create-01', 'footprint', 'location'], ['footprint', 'company', 'poss'], ['footprint', 'country', 'location'], ['country', 'name', 'name'], ['create-01', 'increase-01', 'purpose'], ['increase-01', 'company', 'ARG0'], ['increase-01', 'effort-01', 'ARG1'], ['effort-01', 'company', 'ARG0'], ['effort-01', 'and', 'ARG1'], ['and', 'logistics', 'op1'], ['and', 'distribute-01', 'op2'], ['increase-01', 'surge-01', 'time'], ['surge-01', 'demand-01', 'ARG1'], ['surge-01', 'crisis', 'prep-amid'], ['crisis', 'coronavirus', 'mod'], ['name', 'Aldi', 'op1'], ['multiple', '1000', 'op1'], ['name', 'UK', 'op1']]


  arr = np.asanyarray(arr)


400

In [3]:
gtrs = np.load('g_train.npy',allow_pickle=True)
gtes = np.load('g_test.npy',allow_pickle=True)
gall = np.concatenate((gtrs, gtes), axis=0)
gtrs.shape, gtes.shape, gall.shape, type(gtrs)

((1600,), (400,), (2000,), numpy.ndarray)

Generate inventories for words and edges.

In [4]:
word_set = list({ts[i] for g in gall for ts in g for i in range(2)})
edge_set = list({ts[2] for g in gall for ts in g})
word_set.sort()
edge_set.sort()
word_to_id = dict(zip(word_set,[i for i in range(len(word_set))]))
edge_to_id = dict(zip(edge_set,[i for i in range(len(edge_set))]))
Vsize, Esize = len(word_to_id), len(edge_to_id)
Vsize, Esize

(5138, 109)

In [11]:
# for a single tweet amr
edges = gtrs[0]
print(edges,"\n")
nodes = list({edge[i] for edge in edges for i in range(2)})
nodes_to_id = dict(zip(nodes,[i for i in range(len(nodes))]))
print(nodes_to_id,"\n")
edge_index = [[nodes_to_id[edge[0]] for edge in edges], [nodes_to_id[edge[1]] for edge in edges]]
x, edge_attr = [], []
for node in nodes_to_id.keys():
    vector = np.zeros(Vsize)
    # one-hot vector
    vector[word_to_id[node]] = 1.0
    x.append(vector)

for edge in edges:
    vector = np.zeros(Esize)
    # one-hot vector
    vector[edge_to_id[edge[2]]] = 1.0
    edge_attr.append(vector)

print(edge_index, "\n")
print(np.array(edge_index).shape)
print(np.array(x).shape)
print(np.array(edge_attr).shape)

[['possible-01', 'work-01', 'ARG1'], ['work-01', 'you', 'ARG0'], ['work-01', 'hard-02', 'ARG1-of'], ['hard-02', 'have-degree-91', 'ARG2-of'], ['have-degree-91', 'work-01', 'ARG1'], ['have-degree-91', 'too', 'ARG3'], ['possible-01', 'find-01', 'condition'], ['find-01', 'you', 'ARG0'], ['find-01', 'dream-01', 'ARG1'], ['dream-01', 'you', 'ARG0'], ['dream-01', 'and', 'ARG1'], ['and', 'product', 'op1'], ['product', 'clean-01', 'purpose'], ['and', 'index', 'op2'], ['index', 'market', 'mod'], ['market', 'stock', 'mod']] 

{'find-01': 0, 'you': 1, 'possible-01': 2, 'have-degree-91': 3, 'product': 4, 'work-01': 5, 'clean-01': 6, 'index': 7, 'and': 8, 'market': 9, 'stock': 10, 'dream-01': 11, 'too': 12, 'hard-02': 13} 

[[2, 5, 5, 13, 3, 3, 2, 0, 0, 11, 11, 8, 4, 8, 7, 9], [5, 1, 13, 3, 5, 12, 0, 1, 11, 1, 8, 4, 6, 7, 9, 10]] 

(2, 16)
(14, 5138)
(16, 109)
