In [1]:
import pandas as pd
import networkx as nx
import plotly.graph_objects as go

### Prep data

In [5]:
def load_data():
    remedy_edges = pd.read_parquet("data/remedy_edges.parquet")
    source_nodes = pd.read_parquet("data/source_nodes.parquet")
    target_nodes = pd.read_parquet("data/target_nodes.parquet")

    unique_remedies = source_nodes.id.unique().tolist()
    unique_effects = target_nodes.id.unique().tolist()

    nodes = source_nodes.append(target_nodes).drop_duplicates()

    return unique_remedies, unique_effects, nodes, remedy_edges


def filter_on_node(
    filter_node: str,
    filter_type: str,
    nodes: pd.DataFrame,
    edges: pd.DataFrame,
):
    if filter_type == "Remedy":
        filtered_edges = edges[edges["from"].eq(filter_node)]
        filtered_nodes = nodes[
            nodes["id"].eq(filter_node) | nodes["id"].isin(filtered_edges["to"])
        ]
    elif filter_type == "Effect":
        filtered_edges = edges[edges["to"].eq(filter_node)]
        filtered_nodes = nodes[
            nodes["id"].eq(filter_node) | nodes["id"].isin(filtered_edges["from"])
        ]
    else:
        raise ValueError("Invalid filter type. Try 'Remedy' or 'Effect'.")
    return filtered_nodes, filtered_edges


def filter_on_edge_weights(
    ppmi_range: tuple,
    edge_count_range: tuple,
    nodes: pd.DataFrame,
    edges: pd.DataFrame,
):
    filtered_edges = edges[edges["ppmi"].between(ppmi_range[0], ppmi_range[1])]
    filtered_edges = filtered_edges[
        filtered_edges["edge_count"].between(edge_count_range[0], edge_count_range[1])
    ]
    filtered_nodes = nodes[
        nodes["id"].isin(filtered_edges["from"])
        | nodes["id"].isin(filtered_edges["to"])
    ]
    return filtered_nodes, filtered_edges


In [6]:
unique_remedies, unique_effects, nodes, remedy_edges = load_data()

In [7]:
filtered_nodes, filtered_edges = filter_on_node(
    "oxycodone", "Remedy", nodes, remedy_edges
)


### Testing grouping with dummy edges

In [8]:
self_join = filtered_nodes.merge(filtered_nodes, on='category')

In [9]:
self_join = self_join[self_join.id_x.ne(self_join.id_y)]

In [33]:
self_join_one_only = self_join.groupby('id_x').first().reset_index()

In [34]:
dummy_edges = self_join_one_only[["id_x", "id_y", "category"]]
dummy_edges.rename(
    columns={"id_x": "from", "id_y": "to", "category": "category_source"}, inplace=True
)
dummy_edges["category_target"] = dummy_edges["category_source"]
dummy_edges["edge_count"] = 1
dummy_edges["ppmi"] = 0.5




A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a

In [35]:
combined_edgelist = filtered_edges.append(dummy_edges)

In [36]:
G = nx.from_pandas_edgelist(
    combined_edgelist,
    source='from',
    target='to',
    edge_attr=True,
)

In [37]:
for col in ['count', 'label', 'count_log']:
    print(col)
    nx.set_node_attributes(G, pd.Series(filtered_nodes[col].values, index=filtered_nodes['id']).to_dict(), name=col)

count
label
count_log


In [38]:
pos = nx.drawing.layout.spring_layout(G)
nx.set_node_attributes(G, pos, name='pos')

### Create edge trace and node trace

In [39]:
edge_traces = []
edge_midpoint_x = []
edge_midpoint_y = []
edge_midpoint_text = []

for edge in G.edges():
    if G.edges[edge]['category_source'] != G.edges[edge]['category_target']:
        x0, y0 = G.nodes[edge[0]]['pos']
        x1, y1 = G.nodes[edge[1]]['pos']
        ppmi = G.edges[edge]['ppmi']
        edge_count = G.edges[edge]['edge_count']
        trace = go.Scatter(
            x=[x0, x1, None],
            y=[y0, y1, None],
            line=dict(width = ppmi / 2, color = 'gray'),
            mode='lines',
        )
        edge_traces.append(trace)

        edge_midpoint_x.append((x0 + x1) / 2)
        edge_midpoint_y.append((y0 + y1) / 2)
        edge_midpoint_text.append(f'# connections = {edge_count}<br>ppmi = {round(ppmi, 2)}')

edge_midpoint_trace = go.Scatter(
    x=edge_midpoint_x, 
    y=edge_midpoint_y,
    mode='markers',
    text=edge_midpoint_text,
    hoverinfo='text',
    marker=dict(
        color='grey',
        opacity=0,
        size=50
    )
)


In [40]:
node_x = []
node_y = []
node_color = []
node_size = []
node_text = []

for node in G.nodes():
    x, y = G.nodes[node]['pos']
    node_x.append(x)
    node_y.append(y)
    # if G.nodes[node]['label'] == 'EFFECT':
    #     node_color.append('coral')
    # elif G.nodes[node]['remedy_type'] == 'Remedy':
    #     node_color.append('cornflowerblue')
    # else:
    #     node_color.append('lightgreen')
    node_size.append(G.nodes[node]['count_log'] * 5)
    node_text.append(f"{node}<br>count = {G.nodes[node]['count']}")

node_trace = go.Scatter(
    x=node_x, 
    y=node_y,
    mode='markers',
    text=node_text,
    hoverinfo='text',
    marker=dict(
        color=node_color,
        size=node_size,
        line_width=2,
        opacity=1
    )
)

In [41]:
fig = go.Figure(
   layout=go.Layout(
      titlefont_size=16,
      showlegend=False,
      hovermode='closest',
      margin=dict(b=20,l=20,r=20,t=20),
      xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
      yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
   )
)

for trace in edge_traces:
   fig.add_trace(trace)
fig.add_trace(node_trace)
fig.add_trace(edge_midpoint_trace)

fig.show()