In [4]:
import spacy
from spacy import displacy
from contratospr.contracts.models import Document, Contract

from asgiref.sync import sync_to_async

nlp = spacy.load("es_core_news_md")

def get_sample_data():
    sample_data = []
    contracts = Contract.objects.filter(document__isnull=False, document__pages__isnull=False, parent=None).order_by("?")[:5]
    
    for contract in contracts:
        pages_text_str = ""
        if contract.document:
            pages_text = [page["text"] for page in contract.document.pages]
            pages_text_str = "\n".join(pages_text)
        sample_data.append({
            "contract": contract,
            "document": contract.document,
            "document_text": pages_text_str,
            "contractors": list(contract.contractors.all())
        })
    return sample_data

def top_entities(doc, n):
    label_count = {}
    for ent in doc.ents:
        if ent.text in label_count:
            label_count[ent.text] += 1
        else:
            label_count[ent.text] = 1
    
    result = sorted({k: v for (k, v) in label_count.items() if v > 1}.items(),
                    key=lambda kv: kv[1], reverse=True)[:n]

    return result


sample_data = await sync_to_async(get_sample_data)()

In [32]:
# from pyvis.network import Network
# net = Network(notebook=True, height='800px', width='1000px')

# for data in sample_data:
#     net.add_node(str(data["contract"]))
#     for contractor in data["contractors"]:
#         net.add_node(str(contractor))
#         net.add_edge(str(data["contract"]), str(contractor))
#     nlp.max_length = len(data["document_text"])
#     doc = nlp(data["document_text"])
#     results = top_entities(doc, 30)
#     for result, _ in results:
#         net.add_node(result)
#         net.add_edge(str(data["contract"]), result)
#         for contractor in data["contractors"]:
#             net.add_edge(result, str(contractor))
    
# net.show("mygraph.html")

import networkx as nx

G = nx.Graph()

for data in sample_data:
    contract = data["contract"]
    contractors = data["contractors"]
    G.add_node(f"contract:{contract.pk}", label=str(contract))
    
    for contractor in contractors:
        G.add_node(f"contractor:{contractor.pk}", label=str(contractor))
        G.add_edge(f"contract:{contract.pk}", f"contractor:{contractor.pk}")

    nlp.max_length = len(data["document_text"])
    doc = nlp(data["document_text"])
    results = top_entities(doc, 15)
    
    for result, _ in results:
        G.add_node(f"entity:{result}", label=result)
        G.add_edge(f"contract:{contract.pk}", f"entity:{result}")
        for contractor in contractors:
            G.add_edge(f"entity:{result}", f"contractor:{contractor.pk}")

import plotly.graph_objects as go

# pos = nx.layout.shell_layout(G)
pos = nx.layout.random_layout(G)
# pos = nx.layout.spectral_layout(G)

for node in G.nodes:
    G.nodes[node]["pos"] = list(pos[node])

edge_x = []
edge_y = []
for edge in G.edges():
    x0, y0 = G.nodes[edge[0]]["pos"]
    x1, y1 = G.nodes[edge[1]]["pos"]
    edge_x.append(x0)
    edge_x.append(x1)
    edge_x.append(None)
    edge_y.append(y0)
    edge_y.append(y1)
    edge_y.append(None)

edge_trace = go.Scatter(
    x=edge_x,
    y=edge_y,
    line=dict(width=0.5, color="#888"),
    hoverinfo="none",
    mode="lines",
)

node_x = []
node_y = []
for node in G.nodes():
    x, y = G.nodes[node]["pos"]
    node_x.append(x)
    node_y.append(y)

node_trace = go.Scatter(
    x=node_x,
    y=node_y,
    mode="markers",
    hoverinfo="text",
    marker=dict(
        showscale=True,
        # colorscale options
        #'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
        #'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
        #'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
        colorscale="YlGnBu",
        reversescale=True,
        color=[],
        size=10,
        colorbar=dict(
            thickness=15, title="Node Connections", xanchor="left", titleside="right"
        ),
        line_width=2,
    ),
)

node_adjacencies = []
node_text = []
node_hovertext = []

for node, adjacencies in G.adjacency():
    node_label = G.nodes[node]['label']
    adjacencies_count = len(adjacencies.keys())
    node_adjacencies.append(adjacencies_count)
    node_hovertext.append(f"{node_label} ({adjacencies_count})")
#     node_text.append("# of connections: " + str(adjacencies_count))

node_trace.marker.color = node_adjacencies
node_trace.text = node_text
node_trace.hovertext = node_hovertext

fig = go.Figure(
    data=[edge_trace, node_trace],
    layout=go.Layout(
        title="",
        titlefont_size=16,
        showlegend=False,
        hovermode="closest",
        margin=dict(b=20, l=5, r=5, t=40),
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    ),
)
fig.show()