In [None]:
import pandas as pd
import joblib
import dash
from dash.dependencies import Input, Output, State
from dash import dcc, html
import plotly.express as px
import networkx as nx
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import numpy as np

model = joblib.load('random_forest_model.pkl')
transactions = pd.read_csv('combined.csv')  
edges = pd.read_csv('elliptic_txs_edgelist.csv')
transactions_cleaned= transactions.drop(columns=['timestep', 'txId'])

df = transactions[['txId', 'timestep']]
edges = edges.rename(columns={'txId': 'txId1', 'class': 'txId2'})

#Get risk scores
risk_scores = model.predict_proba(transactions_cleaned)[:, 1]
df['risk_score'] = risk_scores

#Predict label
predicted_labels = model.predict(transactions_cleaned)
df['type'] = ['illicit' if label == 1 else 'licit' for label in predicted_labels]
#top_transactions = df.sort_values(by='risk_score', ascending=False).head(5)
#print(df)

In [17]:
# Get top 5 transactions and their risk scores
top_txIds = df['txId'].values
top_risk_scores = df['risk_score'].values
top_types = df['type'].values

filtered_edges = edges[edges['txId1'].isin(top_txIds) | edges['txId2'].isin(top_txIds)]

# Create a smaller graph using the filtered edges
G = nx.from_pandas_edgelist(filtered_edges, 'txId1', 'txId2')

pos = nx.spring_layout(G)  

node_x = []
node_y = []
node_color = []
node_size = []
for node in G.nodes():
    node_x.append(pos[node][0])
    node_y.append(pos[node][1])
    
    # Color and size based on the type and risk score of top transactions
    if node in top_txIds:
        idx = list(top_txIds).index(node)
        if top_types[idx] == 'illicit':
            color = 'red' 
        elif top_types[idx] == 'licit':
            color = 'green'
        else:
            print(f"Unexpected type for node {node}: {type}")
            color = 'blue'     
        size = 20 + top_risk_scores[idx] * 30  
    else:
        color = 'blue'
        size = 10
        
    node_color.append(color)
    node_size.append(size)

# Create edge trace
edge_x = []
edge_y = []
for edge in G.edges():
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_x.append(x0)
    edge_x.append(x1)
    edge_x.append(None)
    edge_y.append(y0)
    edge_y.append(y1)
    edge_y.append(None)

edge_trace = go.Scatter(
    x=edge_x, y=edge_y,
    line=dict(width=1, color='gray'),
    hoverinfo='none',
    mode='lines')

# Create node trace
node_trace = go.Scatter(
    x=node_x, y=node_y,
    mode='markers',
    hoverinfo='text',
    marker=dict(
        showscale=False,
        color=node_color,
        size=node_size,
        line_width=2))

# Add hover information
node_text = []
for node in G.nodes():
    if node in top_txIds:
        idx = list(top_txIds).index(node)
        text = f"txId: {node}, Risk Score: {top_risk_scores[idx]:.2f}, Type: {top_types[idx]}"
    else:
        text = f"txId: {node}"
    node_text.append(text)

node_trace.text = node_text

# Plot the figure
fig = go.Figure(data=[edge_trace, node_trace],
                layout=go.Layout(
                    title='Transaction Graph',
                    titlefont_size=16,
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=0, l=0, r=0, t=40),
                    xaxis=dict(showgrid=False, zeroline=False),
                    yaxis=dict(showgrid=False, zeroline=False)))

fig.show()


In [6]:
top_txIds = df['txId'].values
top_risk_scores = df['risk_score'].values
top_types = df['type'].values

filtered_edges = edges[edges['txId1'].isin(top_txIds) | edges['txId2'].isin(top_txIds)]

G = nx.from_pandas_edgelist(filtered_edges, 'txId1', 'txId2')

for node in G.nodes():
    G.nodes[node]['type'] = 'illicit' if node in top_txIds[:5] else 'licit'  # Example logic

pos = nx.spring_layout(G)

# Assigning example node attributes
for node in G.nodes():
    if node in top_txIds:
        idx = list(top_txIds).index(node)
        G.nodes[node]['type'] = top_types[idx]  
    else:
        G.nodes[node]['type'] = 'unknown'

# Start Dash app
app = dash.Dash(__name__)

# Function to build the graph visualization
def create_graph(selected_node=None):
    edge_x = []
    edge_y = []

    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x += [x0, x1, None]
        edge_y += [y0, y1, None]

    edge_trace = go.Scatter(
        x=edge_x,
        y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines'
    )

    node_x = []
    node_y = []
    node_color = []
    node_size = []
    node_text = []

    for node in G.nodes():
        node_x.append(pos[node][0])
        node_y.append(pos[node][1])

        # Determine node color and size based on the type and risk score of top transactions
        if node in top_txIds:
            idx = list(top_txIds).index(node)
            if top_types[idx] == 'illicit':
                color = 'red'
            elif top_types[idx] == 'licit':
                color = 'green'
            else:
                print(f"Unexpected type for node {node}: {top_types[idx]}")
                color = 'blue'  

            risk_score = top_risk_scores[idx]
            size = (20 + risk_score * 50) 
        else:
            color = 'blue'
            size = 10  

        node_color.append(color)
        node_size.append(size)
        node_text.append(f"Node {node}, Type: {G.nodes[node].get('type', 'unknown')}")


    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        text=node_text,
        mode='markers',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            color=node_color,
            size=node_size,
            line_width=2
        )
    )

    return [edge_trace, node_trace]

app.layout = html.Div([
    dcc.Graph(
        id='graph',
        config={'displayModeBar': True},  
        figure={
            'data': create_graph(),
            'layout': go.Layout(
                showlegend=False,
                hovermode='closest',
                margin=dict(b=0, l=0, r=0, t=0),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                dragmode='zoom'  # Enable zooming
            )
        }
    )
])


@app.callback(
    Output('graph', 'figure'),
    [Input('graph', 'clickData')],
    [State('graph', 'figure')]
)
def update_graph_on_click(clickData, figure):
    # Capture the current zoom level (axis ranges) before updating the graph
    xaxis_range = figure['layout']['xaxis']['range'] if figure else None
    yaxis_range = figure['layout']['yaxis']['range'] if figure else None

    if clickData:
        try:
            clicked_node_text = clickData['points'][0]['text']
            clicked_node = int(clicked_node_text.split()[1])  
            
            if G.nodes[clicked_node]['type'] == 'illicit':
                return {
                    'data': create_graph(clicked_node),  
                    'layout': go.Layout(
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=0, l=0, r=0, t=0),
                        xaxis=dict(showgrid=False, zeroline=False, range=xaxis_range),  
                        yaxis=dict(showgrid=False, zeroline=False, range=yaxis_range)   
                    )
                }
        except (IndexError, KeyError, ValueError) as e:
            print(f"Error processing clicked node: {e}")

    # Default behavior if no node is clicked or an error occurs
    return {
        'data': create_graph(),
        'layout': go.Layout(
            showlegend=False,
            hovermode='closest',
            margin=dict(b=0, l=0, r=0, t=0),
            xaxis=dict(showgrid=False, zeroline=False, range=xaxis_range),  
            yaxis=dict(showgrid=False, zeroline=False, range=yaxis_range)   
        )
    }


if __name__ == '__main__':
    app.run_server(debug=True)

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[6], line 119, in update_graph_on_click(
    clickData=None,
    figure={'data': [{'hoverinfo': 'none', 'line': {'color': '#888', 'width': 0.5}, 'mode': 'lines', 'type': 'scatter', 'x': [0.16681092977523804, 0.1662999540567398, None, 0.1662999540567398, 0.19614048302173615, None, 0.1662999540567398, 0.1820075958967209, None, 0.1662999540567398, 0.18764635920524597, None, 0.1662999540567398, 0.17140796780586243, None, 0.1662999540567398, 0.2091006636619568, None, 0.1662999540567398, 0.20811131596565247, ...], 'y': [-0.3567316234111786, -0.30921322107315063, None, -0.30921322107315063, -0.3310110867023468, None, -0.30921322107315063, -0.33840319514274597, None, -0.30921322107315063, -0.3179001212120056, None, -0.30921322107315063, -0.3365485966205597, None, -0.30921322107315063, -0.31739622354507446, None, -0.30921322107315063, -0.