In [1]:
import networkx as nx
import matplotlib.pyplot as plt
import json
import spacy
from spacy.tokens import Doc
from spanUtil import get_ent_spans

parser = spacy.load('en_core_web_trf')

plt.rcParams['figure.dpi'] = 300


def build_g(tokens, ents):
    g = nx.Graph()
    start = len(g.nodes())
    idx = 0
    dic = {}
    for token in tokens:
        g.add_node(
            start+idx, text=token.text, tag=token.tag_,
            pos=token.pos_, dep=token.dep_)
        dic[token] = start + idx
        idx += 1
    
    for token in tokens:
        if token != token.head:
            g.add_edge(dic[token], dic[token.head], dep=token.dep_)

    for ent_s, ent_e in ents:
        if ent_e - ent_s <= 1:
            continue
        
        text = ' '.join([t.text for t in tokens[ent_s: ent_e]])
        g.add_node(
            start+idx, text=text, tag=None, pos=None, dep=None)
        dic[(ent_s, ent_e)] = start + idx

        for e_idx in range(ent_s, ent_e):
            g.add_edge(idx, e_idx, dep='ent_node')
        
        idx += 1

    return g


def draw_graph(G, words_cnt):
    pos = nx.kamada_kawai_layout(G)
    node_colors = ['gray'] * words_cnt + ['green'] * (len(G.nodes)-words_cnt)
    # edge_color = ['gray'] * len(G.edges)
    edge_color = ['gray' if max(e_1, e_2) < words_cnt else 'green' for e_1, e_2 in G.edges]
    node_labels = nx.get_node_attributes(G, 'text')
    nx.draw_networkx(
        G, pos, node_size=40, labels=node_labels, font_size=4,
        node_color=node_colors, font_color='black', edge_color=edge_color)
    plt.show()



In [2]:
dataset = json.load(open('../data/datasets/conll04/conll04_test_clean.json'))

data = dataset[255]
words = data['tokens']

print(' '.join(words))

doc = Doc(parser.vocab, words=words)
tokens = parser(doc)
spans = sorted(get_ent_spans(tokens), key=lambda x: x[0])

print('extracted spans: ', spans)
print([' '.join(words[span_s: span_e]) for span_s, span_e in spans])

g = build_g(tokens, spans)

A = nx.adjacency_matrix(g)

print(A.shape)

print(A.todense())

# draw_graph(g, len(tokens))


COLUMBIA , S.C. _ Jesse Jackson on Sunday was touting his big win in his native South Carolina , while officials struggled to count the final ballots after a record turnout at Democratic caucuses.




extracted spans:  [(0, 1), (0, 3), (2, 3), (4, 6), (4, 5), (5, 6), (7, 8), (16, 17), (16, 18), (17, 18), (32, 34), (32, 33)]
['COLUMBIA', 'COLUMBIA , S.C.', 'S.C.', 'Jesse Jackson', 'Jesse', 'Jackson', 'Sunday', 'South', 'South Carolina', 'Carolina', 'Democratic caucuses.', 'Democratic']
(38, 38)
  (0, 1)	1
  (0, 2)	1
  (0, 3)	1
  (0, 9)	1
  (0, 34)	1
  (1, 0)	1
  (1, 34)	1
  (2, 0)	1
  (2, 34)	1
  (3, 0)	1
  (4, 5)	1
  (4, 35)	1
  (5, 4)	1
  (5, 9)	1
  (5, 35)	1
  (6, 7)	1
  (6, 9)	1
  (7, 6)	1
  (8, 9)	1
  (9, 0)	1
  (9, 5)	1
  (9, 6)	1
  (9, 8)	1
  (9, 12)	1
  (9, 18)	1
  :	:
  (26, 25)	1
  (27, 23)	1
  (27, 30)	1
  (28, 30)	1
  (29, 30)	1
  (30, 27)	1
  (30, 28)	1
  (30, 29)	1
  (30, 31)	1
  (31, 30)	1
  (31, 33)	1
  (32, 33)	1
  (32, 37)	1
  (33, 31)	1
  (33, 32)	1
  (33, 37)	1
  (34, 0)	1
  (34, 1)	1
  (34, 2)	1
  (35, 4)	1
  (35, 5)	1
  (36, 16)	1
  (36, 17)	1
  (37, 32)	1
  (37, 33)	1


In [45]:
print(g.edges)

[(0, 1), (1, 2), (1, 3), (1, 5), (1, 11), (1, 18), (1, 31), (3, 4), (5, 10), (6, 10), (7, 10), (8, 9), (8, 35), (8, 38), (9, 10), (9, 32), (9, 35), (9, 38), (10, 32), (10, 35), (11, 13), (12, 13), (13, 14), (14, 17), (15, 17), (15, 40), (16, 17), (16, 37), (16, 40), (17, 37), (17, 40), (18, 19), (19, 20), (19, 21), (19, 22), (19, 23), (19, 30), (19, 33), (19, 36), (20, 33), (20, 36), (21, 36), (22, 36), (24, 25), (24, 34), (25, 30), (25, 34), (26, 30), (27, 28), (27, 39), (28, 30), (28, 39), (29, 30)]
