# Training GNNs to detect patient 0

## Imports

In [62]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
import ndlib.models.epidemics as ep
import ndlib.models.ModelConfig as mc
import ndlib
import numpy as np

import plotly.graph_objects as go

## Training Data

In [63]:
# Generate 50 graphs, each with 50 nodes
NUM_NODES = 50
NUM_EDGES = 100
NUM_ITERATIONS = 10
NUM_TRAINING = 10000
BETA = 0.15
GAMMA = 0
TRAINING_SPLIT = 0.7

graphs = np.zeros((NUM_TRAINING, NUM_NODES, NUM_NODES))
nodes_statuses = []
initial_infected = []
for i in range(NUM_TRAINING):
    graph = nx.gnm_random_graph(NUM_NODES, NUM_EDGES)
    p0= np.random.randint(NUM_NODES)

    config = mc.Configuration()
    config.add_model_initial_configuration("Infected", [p0])
    config.add_model_parameter("beta", BETA)
    config.add_model_parameter("gamma", GAMMA)

    model = ep.SIRModel(graph)
    model.set_initial_status(config)

    iterations = model.iteration_bunch(NUM_ITERATIONS)
    statuses = [iteration['status'] for iteration in iterations]
    union_of_statuses = {k:0 for k in range(NUM_NODES)}
    for status in statuses:
        union_of_statuses.update(status)

    adj_matr = np.array(nx.adjacency_matrix(graph).todense())
    graphs[i,:,:] = adj_matr  # Convert sparse matrix to dense matrix
    nodes_statuses.append([v for _,v in sorted(union_of_statuses.items(), key=lambda item: item[0])])
    initial_infected.append([1 if i==p0 else 0 for i in range(NUM_NODES)])




## Graph Visualization

In [64]:
G = graph
edge_x = []
edge_y = []
positions = nx.spring_layout(G)
for edge in G.edges():
    x0, y0 = positions[edge[0]]
    x1, y1 = positions[edge[1]]
    edge_x.append(x0)
    edge_x.append(x1)
    edge_x.append(None)
    edge_y.append(y0)
    edge_y.append(y1)
    edge_y.append(None)

if True:
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines')

    node_x = []
    node_y = []
    for node in G.nodes():
        x, y = positions[node]
        node_x.append(x)
        node_y.append(y)

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            # colorscale options
            #'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
            #'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
            #'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
            colorscale='YlGnBu',
            reversescale=True,
            color=[],
            size=[],
            # size = 10,
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            ),
            line=dict(
                color='Black',
                width=2
            )))

    node_adjacencies = []
    node_text = []
    for node, adjacencies in enumerate(G.adjacency()):
        node_adjacencies.append(len(adjacencies[1]))
        node_text.append('# of connections: '+str(len(adjacencies[1])))

    avg_degree = {
        0: np.average([adj for i, adj in enumerate(node_adjacencies) if not nodes_statuses[-1][i]]),
        1: np.average([adj for i, adj in enumerate(node_adjacencies) if nodes_statuses[-1][i]]),
        2: np.average([adj for i, adj in enumerate(node_adjacencies)])
    }

    node_trace.marker.color = node_adjacencies
    # node_trace.marker.color = nodes_statuses[-1]
    node_trace.marker.size = [20 if nodes_statuses[-1][i] else 10 for i in range(NUM_NODES)]
    node_trace.text = node_text

    fig = go.Figure(data=[edge_trace, node_trace],
                layout=go.Layout(
                    title=f'Network Snapshot of Graph with {NUM_NODES} Nodes and {NUM_EDGES} Edges',
                    titlefont_size=16,
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=20,l=5,r=5,t=40),
                    annotations=[ dict(
                        text=f'Average degree of nodes is {avg_degree[2]} and average infected degree is  {avg_degree[1]:2.2f}',
                        showarrow=False,
                        xref="paper", yref="paper",
                        x=0.005, y=-0.002 ), ],
                        
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                    )
    fig.show()

In [65]:
# Define the GCN layer
class GraphConvolutionLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GraphConvolutionLayer, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = nn.Parameter(torch.FloatTensor(out_features))

        # Initialize learnable parameters
        nn.init.xavier_uniform_(self.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x, adj_matrix):
        # Perform graph convolution
        support = torch.mm(x, self.weight)
        output = torch.mm(adj_matrix, support)  # Propagate information through the graph
        output = output + self.bias
        return output


In [77]:

class GraphConvolutionalNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim):
        super(GraphConvolutionalNetwork, self).__init__()
        self.gcn1 = GraphConvolutionLayer(input_dim, hidden_dim1)
        self.gcn2 = GraphConvolutionLayer(hidden_dim1, hidden_dim2)
        self.gcn3 = GraphConvolutionLayer(hidden_dim2, output_dim)

    def forward(self, x, adj_matrix):
        x = F.relu(self.gcn1(x, adj_matrix))
        x = F.relu(self.gcn2(x, adj_matrix))
        x = F.relu(self.gcn3(x, adj_matrix))
        return x

# Define a simple training loop
def train_model(model, features, adj_matrix, labels, num_epochs, learning_rate):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(features, adj_matrix)
    # print("OUTPUTS")
    # print(outputs)
    # print("ACTUAL INDEX:")
    # print(torch.argmax(labels))
    # print("PREDICTED INDEX:")
    # print(torch.argmax(outputs))
    # print("MAX:")
    # print(torch.max(outputs))
    # print()
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    return loss.item()
    # print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

In [83]:

# Example usage
# Load your graph data, features, adjacency matrix, and labels here
# Replace the placeholders below with your data
num_nodes = NUM_NODES
num_features = 1
num_classes = NUM_NODES
model = GraphConvolutionalNetwork(num_features, 16, 16, num_classes)
overall_training_loss = []
overall_testing_loss = []
criterion = nn.CrossEntropyLoss()
for epoch in range(100):
    epoch_loss = 0
    for i in range(int(NUM_TRAINING * TRAINING_SPLIT)):
        features = torch.tensor(nodes_statuses[i]).reshape((NUM_NODES,1)).float()
        adjacency_matrix = torch.tensor(graphs[i]).float()
        labels = torch.tensor(initial_infected[i])
        epoch_loss += train_model(model, features, adjacency_matrix, labels, num_epochs=1, learning_rate=0.01)
    print(f"Average loss of epoch {epoch}: {epoch_loss/NUM_TRAINING}")


    test_loss = 0
    correct = 0
    total = 0
    model.eval()
    for i in range(int(NUM_TRAINING*TRAINING_SPLIT), NUM_TRAINING):
        features = torch.tensor(nodes_statuses[i]).reshape((NUM_NODES,1)).float()
        adjacency_matrix = torch.tensor(graphs[i]).float()
        labels = torch.tensor(initial_infected[i])
        with torch.no_grad():
            output = model(features, adjacency_matrix)
            # print("LABELS")
            # print(labels)
            test_loss += criterion(output, labels).item()
            _, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        break
    overall_testing_loss.append(test_loss)
    print(f"Avergae error of epoch {epoch}: {test_loss/(NUM_TRAINING * 0.3)}")
    print('Accuracy of the network on the test data for epoch {}: {} %'.format(epoch, 100 * correct / total))



Average loss of epoch 0: 1.0151937830645592
Avergae error of epoch 0: 7.407642900943757e-05
Accuracy of the network on the test data for epoch 0: 98.0 %
Average loss of epoch 1: 1.313190403091535
Avergae error of epoch 1: 0.0001798667907714844
Accuracy of the network on the test data for epoch 1: 84.0 %
Average loss of epoch 2: 1.7860443374704569
Avergae error of epoch 2: 2.100808670123418e-05
Accuracy of the network on the test data for epoch 2: 98.0 %
Average loss of epoch 3: 1.1452269573973493
Avergae error of epoch 3: 1.2989434103171031e-05
Accuracy of the network on the test data for epoch 3: 98.0 %
Average loss of epoch 4: 0.30278340735398235
Avergae error of epoch 4: 3.689868251482646e-05
Accuracy of the network on the test data for epoch 4: 98.0 %
Average loss of epoch 5: 0.0733756133697927
Avergae error of epoch 5: 3.6818936467170715e-05
Accuracy of the network on the test data for epoch 5: 98.0 %
Average loss of epoch 6: 0.07335322162359953
Avergae error of epoch 6: 3.6591060

In [84]:
features = torch.tensor(nodes_statuses[1]).reshape((NUM_NODES,1)).float()
adjacency_matrix = torch.tensor(graphs[1]).float()
labels = torch.tensor(initial_infected[1])

model.eval()
with torch.no_grad():
    output = model.forward(features, adjacency_matrix)
    print(output)

tensor([[21.9224, 18.9971,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [24.1911, 21.3567,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [21.9224, 18.9971,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [21.9224, 18.9971,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [24.1911, 21.3567,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [24.1911, 21.3567,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])


In [85]:
import plotly.graph_objs as go
from plotly.offline import iplot

# Create traces
trace0 = go.Scatter(
    x = list(range(100)),
    y = overall_training_loss,
    mode = 'lines',
    name = 'Training Loss'
)

trace1 = go.Scatter(
    x = list(range(100)),
    y = overall_testing_loss,
    mode = 'lines',
    name = 'Testing Loss'
)

data = [trace0, trace1]

# Edit the layout
layout = dict(title = 'Training and Testing Loss over Epochs',
              xaxis = dict(title = 'Epoch'),
              yaxis = dict(title = 'Loss'),
)

fig = go.Figure(data=data, layout=layout)
iplot(fig, filename='line-mode')
