In [1]:
import plotly.graph_objects as go
import networkx as nx
import pandas as pd

In [30]:
# Example Edge List and Node Features
edge_list = pd.read_csv('../02_data/train_test/edge_list.csv')
node_features = pd.read_csv('../02_data/train_test/node_features.csv')
node_features = node_features.drop('article_embedding', axis=1)
node_features = node_features.drop('label', axis=1)

In [31]:
# Ensure domain_index and chat_id are integers
edge_list['chat_id'] = edge_list['chat_id'].astype(int)
edge_list['domain_index'] = edge_list['domain_index'].astype(int)

node_features['virality'] = node_features['virality'].round(2)
node_features['pc1'] = node_features['pc1'].round(2)
node_features['avalanches'] = node_features['avalanches'].round(2)
node_features['messages'] = node_features['messages'].round(2)
node_features['chats'] = node_features['chats'].round(2)

edge_list_sample = edge_list.sample(n=1000, random_state=42)


In [32]:
node_features

Unnamed: 0,domain_index,domain,virality,avalanches,messages,chats,year,pc1
0,4719,100giornidaleoni.it,0.95,1.32,1.44,1.21,2021,0.26
1,2064,100milefreepress.net,0.71,2.00,2.00,2.00,2022,0.60
2,5960,100percentfedup.com,0.70,6.08,9.92,4.97,2022,0.22
3,1389,1011now.com,0.83,2.32,2.85,2.01,2021,0.72
4,1405,10news.com,0.87,3.22,4.68,3.07,2021,0.71
...,...,...,...,...,...,...,...,...
6107,1450,zmescience.com,0.84,2.34,2.90,1.87,2021,0.89
6108,2895,zombie.news,1.00,1.00,1.00,1.00,2023,0.24
6109,4183,zonazealots.com,0.71,2.00,4.00,1.00,2021,0.65
6110,1848,zuerst.de,0.99,1.06,1.22,1.07,2022,0.68


In [None]:
# Create NetworkX graph
G = nx.Graph()

# Add edges
for _, row in edge_list.iterrows():
    G.add_edge(row['chat_id'], row['domain_index'])

# Add positions for nodes
pos = nx.spring_layout(G)

# Add edges to the plot
edge_x = []
edge_y = []
for edge in G.edges():
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_x.append(x0)
    edge_x.append(x1)
    edge_x.append(None)
    edge_y.append(y0)
    edge_y.append(y1)
    edge_y.append(None)

edge_trace = go.Scatter(
    x=edge_x, y=edge_y,
    line=dict(width=2, color='#888'),
    hoverinfo='none',
    mode='lines')

# Add nodes to the plot
node_x = []
node_y = []
node_text = []

for node in G.nodes():
    x, y = pos[node]
    node_x.append(x)
    node_y.append(y)

    # Get node features
    features = node_features[node_features['domain_index'] == node]
    if not features.empty:
        feature = features.iloc[0]
        tooltip = (
            f"Domain Index: {node}<br>"
            f"Domain: {feature['domain']}<br>"
            f"PC1: {feature['pc1']}<br>"
            f"Virality: {feature['virality']}<br>"
            f"Avalanches: {feature['avalanches']}<br>"
            f"Messages: {feature['messages']}<br>"
            f"Chats: {feature['chats']}<br>"
            f"Year: {int(feature['year'])}"
        )
        node_text.append(tooltip)
    else:
        node_text.append(f"Domain Index: {node}")

node_trace = go.Scatter(
    x=node_x, y=node_y,
    mode='markers',
    hoverinfo='text',
    text=node_text,  # Add tooltips
    marker=dict(
        showscale=True,
        colorscale='YlGnBu',
        color=[node_features[node_features['domain_index'] == n]['pc1'].values[0]
               if not node_features[node_features['domain_index'] == n].empty else 0
               for n in G.nodes()],
        size=10))

# Combine traces and plot
fig = go.Figure(data=[edge_trace, node_trace],
                layout=go.Layout(
                    title="Domains shared in Telegram-Chats",
                    showlegend=True,
                    hovermode='closest',
                    width=1200,
                    height=800))
fig.write_html('../03_plots/graph_plotly.html')
fig.show()