# Training GNNs to detect patient 0

## Imports

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
import ndlib.models.epidemics as ep
import ndlib.models.ModelConfig as mc
import ndlib
import numpy as np

import plotly.graph_objects as go

## Training Data

In [117]:
# Generate 50 graphs, each with 50 nodes
NUM_NODES = 50
NUM_EDGES = 100
NUM_ITERATIONS = 3
NUM_TRAINING = 10000
BETA = 0.15
GAMMA = 0
TRAINING_SPLIT = 0.7
USE_BA_GRAPH = True

graphs = np.zeros((NUM_TRAINING, NUM_NODES, NUM_NODES))
nodes_statuses = []
initial_infected = []
for i in range(NUM_TRAINING):
    # if USE_BA_GRAPH:
    #     graph = nx.barabasi_albert_graph(NUM_NODES, NUM_EDGES)
    graph = nx.gnm_random_graph(NUM_NODES, NUM_EDGES)
    p0= np.random.randint(NUM_NODES)

    config = mc.Configuration()
    config.add_model_initial_configuration("Infected", [p0])
    config.add_model_parameter("beta", BETA)
    config.add_model_parameter("gamma", GAMMA)

    model = ep.SIRModel(graph)
    model.set_initial_status(config)

    iterations = model.iteration_bunch(NUM_ITERATIONS)
    statuses = [iteration['status'] for iteration in iterations]
    union_of_statuses = {k:0 for k in range(NUM_NODES)}
    for status in statuses:
        union_of_statuses.update(status)

    adj_matr = np.array(nx.adjacency_matrix(graph).todense())
    graphs[i,:,:] = adj_matr  # Convert sparse matrix to dense matrix
    nodes_statuses.append([v for _,v in sorted(union_of_statuses.items(), key=lambda item: item[0])])
    initial_infected.append([1 if i==p0 else 0 for i in range(NUM_NODES)])


adjacency_matrix will return a scipy.sparse array instead of a matrix in Networkx 3.0.



## Graph Visualization

In [118]:
G = graph
edge_x = []
edge_y = []
positions = nx.spring_layout(G)
for edge in G.edges():
    x0, y0 = positions[edge[0]]
    x1, y1 = positions[edge[1]]
    edge_x.append(x0)
    edge_x.append(x1)
    edge_x.append(None)
    edge_y.append(y0)
    edge_y.append(y1)
    edge_y.append(None)

if True:
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines')

    node_x = []
    node_y = []
    for node in G.nodes():
        x, y = positions[node]
        node_x.append(x)
        node_y.append(y)

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            # colorscale options
            #'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
            #'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
            #'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
            colorscale='YlGnBu',
            reversescale=True,
            color=[],
            size=[],
            # size = 10,
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            ),
            line=dict(
                color='Black',
                width=2
            )))

    node_adjacencies = []
    node_text = []
    for node, adjacencies in enumerate(G.adjacency()):
        node_adjacencies.append(len(adjacencies[1]))
        node_text.append('# of connections: '+str(len(adjacencies[1])))

    avg_degree = {
        0: np.average([adj for i, adj in enumerate(node_adjacencies) if not nodes_statuses[-1][i]]),
        1: np.average([adj for i, adj in enumerate(node_adjacencies) if nodes_statuses[-1][i]]),
        2: np.average([adj for i, adj in enumerate(node_adjacencies)])
    }

    # node_trace.marker.color = node_adjacencies
    node_trace.marker.color = nodes_statuses[-1]
    node_trace.marker.size = [20 if nodes_statuses[-1][i] else 10 for i in range(NUM_NODES)]
    node_trace.text = node_text

    fig = go.Figure(data=[edge_trace, node_trace],
                layout=go.Layout(
                    title=f'Network Snapshot of Graph with {NUM_NODES} Nodes and {NUM_EDGES} Edges',
                    titlefont_size=16,
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=20,l=5,r=5,t=40),
                    annotations=[ dict(
                        text=f'Average degree of nodes is {avg_degree[2]} and average infected degree is  {avg_degree[1]:2.2f}',
                        showarrow=False,
                        xref="paper", yref="paper",
                        x=0.005, y=-0.002 ), ],
                        
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                    )
    fig.show()

In [101]:
class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features, use_bias=True):
        super(GCNLayer, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(torch.zeros(size=(in_features, out_features))))
        if use_bias:
            self.bias = nn.Parameter(torch.FloatTensor(torch.zeros(size=(out_features,))))
        else:
            self.register_parameter('bias', None)

        self.initialize_weights()

    def initialize_weights(self):
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(self, x, adj):
        x = x @ self.weight
        if self.bias is not None:
            x += self.bias

        return torch.sparse.mm(adj, x)


In [105]:

class GraphConvolutionalNetwork(nn.Module):
    def __init__(self, node_features, hidden_dim, num_classes, dropout, use_bias=True):
        super(GraphConvolutionalNetwork, self).__init__()
        self.gcn_1 = GCNLayer(node_features, hidden_dim, use_bias)
        self.gcn_2 = GCNLayer(hidden_dim, hidden_dim, use_bias)
        self.gcn_3 = GCNLayer(hidden_dim, num_classes, use_bias)
        self.dropout = nn.Dropout(p=dropout)

    def initialize_weights(self):
        self.gcn_1.initialize_weights()
        self.gcn_2.initialize_weights()
        self.gcn_3.initialize_weights()

    def forward(self, x, adj):
        x = F.relu(self.gcn_1(x, adj))
        x = self.dropout(x)
        x = F.relu(self.gcn_2(x, adj))
        x = self.dropout(x)
        x = self.gcn_3(x, adj)
        return x

# Define a simple training loop
def train_model(model, features, adj_matrix, labels, learning_rate):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    model.train()
    optimizer.zero_grad()
    outputs = model(features, adj_matrix)
    loss = criterion(outputs.flatten(), labels.float())
    loss.backward()
    optimizer.step()
    return loss.item()

In [120]:
num_nodes = NUM_NODES
num_features = 1
num_classes = NUM_NODES
model = GraphConvolutionalNetwork(num_features, 16, 1, dropout=0)
overall_training_loss = []
overall_testing_loss = []
criterion = nn.CrossEntropyLoss()
for epoch in range(100):
    epoch_loss = 0
    for i in range(int(NUM_TRAINING * TRAINING_SPLIT)):
        features = torch.tensor(nodes_statuses[i]).reshape((NUM_NODES,1)).float()
        print(features.shape)
        # features = torch.tensor(np.ones((NUM_NODES,1))).float()
        adjacency_matrix = torch.tensor(graphs[i]).float()
        labels = torch.tensor(initial_infected[i])
        epoch_loss += train_model(model, features, adjacency_matrix, labels, learning_rate=0.001)
    # print(f"Average training loss of epoch {epoch}: {epoch_loss/(NUM_TRAINING * TRAINING_SPLIT)}")


    test_loss = 0
    correct = 0
    total = 0
    model.eval()
    avg_infected_nodes = 0
    for i in range(int(NUM_TRAINING*TRAINING_SPLIT), NUM_TRAINING):
        features = torch.tensor(nodes_statuses[i]).reshape((NUM_NODES,1)).float()
        adjacency_matrix = torch.tensor(graphs[i]).float()
        labels = torch.tensor(initial_infected[i])
        with torch.no_grad():
            output = model(features, adjacency_matrix)
            test_loss += criterion(output.flatten(), labels.float()).item()
            predicted = torch.argmax(output.flatten())
            total += 1
            correct += predicted == torch.argmax(labels)
            avg_infected_nodes += torch.sum(features)
    overall_testing_loss.append(test_loss)
    print(f"Avergae error of epoch {epoch}: {test_loss/total}")
    print(f'Accuracy of the network on the test data for epoch {epoch}: {100 * correct / total} %')



torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])
torch.Size([50, 1])


KeyboardInterrupt: 

In [84]:
features = torch.tensor(nodes_statuses[1]).reshape((NUM_NODES,1)).float()
adjacency_matrix = torch.tensor(graphs[1]).float()
labels = torch.tensor(initial_infected[1])

model.eval()
with torch.no_grad():
    output = model.forward(features, adjacency_matrix)
    print(output)

tensor([[21.9224, 18.9971,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [24.1911, 21.3567,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [21.9224, 18.9971,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [21.9224, 18.9971,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [24.1911, 21.3567,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [24.1911, 21.3567,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])


In [85]:
import plotly.graph_objs as go
from plotly.offline import iplot

# Create traces
trace0 = go.Scatter(
    x = list(range(100)),
    y = overall_training_loss,
    mode = 'lines',
    name = 'Training Loss'
)

trace1 = go.Scatter(
    x = list(range(100)),
    y = overall_testing_loss,
    mode = 'lines',
    name = 'Testing Loss'
)

data = [trace0, trace1]

# Edit the layout
layout = dict(title = 'Training and Testing Loss over Epochs',
              xaxis = dict(title = 'Epoch'),
              yaxis = dict(title = 'Loss'),
)

fig = go.Figure(data=data, layout=layout)
iplot(fig, filename='line-mode')
