# Training GNNs to detect patient 0

## Imports

In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
import ndlib.models.epidemics as ep
import ndlib.models.ModelConfig as mc
import ndlib
import numpy as np

## Training Data

In [129]:
# Generate 50 graphs, each with 50 nodes
NUM_NODES = 50
NUM_EDGES = 100
NUM_ITERATIONS = 10
NUM_TRAINING = 500
BETA = 0.15
GAMMA = 0

graphs = np.zeros((NUM_TRAINING, NUM_NODES, NUM_NODES))
nodes_statuses = []
initial_infected = []
for i in range(NUM_TRAINING):
    graph = nx.gnm_random_graph(NUM_NODES, NUM_EDGES)
    p0= np.random.randint(NUM_NODES)

    config = mc.Configuration()
    config.add_model_initial_configuration("Infected", [p0])
    config.add_model_parameter("beta", BETA)
    config.add_model_parameter("gamma", GAMMA)

    model = ep.SIRModel(graph)
    model.set_initial_status(config)

    iterations = model.iteration_bunch(NUM_ITERATIONS)
    statuses = [iteration['status'] for iteration in iterations]
    union_of_statuses = {k:0 for k in range(NUM_NODES)}
    for status in statuses:
        union_of_statuses.update(status)

    adj_matr = np.array(nx.adjacency_matrix(graph).todense())
    graphs[i,:,:] = adj_matr  # Convert sparse matrix to dense matrix
    nodes_statuses.append([v for _,v in sorted(union_of_statuses.items(), key=lambda item: item[0])])
    initial_infected.append([1 if i==p0 else 0 for i in range(NUM_NODES)])


  adj_matr = np.array(nx.adjacency_matrix(graph).todense())


In [94]:
# Define the GCN layer
class GraphConvolutionLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GraphConvolutionLayer, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = nn.Parameter(torch.FloatTensor(out_features))

        # Initialize learnable parameters
        nn.init.xavier_uniform_(self.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x, adj_matrix):
        # Perform graph convolution
        support = torch.mm(x, self.weight)
        output = torch.mm(adj_matrix, support)  # Propagate information through the graph
        output = output + self.bias
        return output


In [122]:

class GraphConvolutionalNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GraphConvolutionalNetwork, self).__init__()
        self.gcn1 = GraphConvolutionLayer(input_dim, hidden_dim)
        self.gcn2 = GraphConvolutionLayer(hidden_dim, output_dim)

    def forward(self, x, adj_matrix):
        x = F.relu(self.gcn1(x, adj_matrix))
        x = self.gcn2(x, adj_matrix)
        return x

# Define a simple training loop
def train_model(model, features, adj_matrix, labels, num_epochs, learning_rate):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(features, adj_matrix)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    return loss.item()
    # print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

In [128]:

# Example usage
# Load your graph data, features, adjacency matrix, and labels here
# Replace the placeholders below with your data
num_nodes = NUM_NODES
num_features = 1
num_classes = NUM_NODES
model = GraphConvolutionalNetwork(num_features, 16, num_classes)
for epoch in range(100):
    epoch_loss = 0
    for i in range(NUM_TRAINING):
        features = torch.tensor(nodes_statuses[i]).reshape((NUM_NODES,1)).float()
        adjacency_matrix = torch.tensor(graphs[i]).float()
        labels = torch.tensor(initial_infected[i])
        epoch_loss += train_model(model, features, adjacency_matrix, labels, num_epochs=1, learning_rate=0.01)
    print(f"Avergae loss of epoch {epoch}: {epoch_loss/NUM_TRAINING}")



Avergae loss of epoch 0: 0.22527151116728783
Avergae loss of epoch 1: 0.15915870533138513
Avergae loss of epoch 2: 0.17424815061688423
Avergae loss of epoch 3: 0.1506976108700037
Avergae loss of epoch 4: 0.15587790777534247
Avergae loss of epoch 5: 0.18814101070538164
Avergae loss of epoch 6: 0.18447222628071905
Avergae loss of epoch 7: 0.1691798877827823
Avergae loss of epoch 8: 0.20190815620869398
Avergae loss of epoch 9: 0.17927497444674373
Avergae loss of epoch 10: 0.17957153210043908
Avergae loss of epoch 11: 0.16776465348526837
Avergae loss of epoch 12: 0.1644793018065393
Avergae loss of epoch 13: 0.16751183831691743
Avergae loss of epoch 14: 0.17249364118650556
Avergae loss of epoch 15: 0.1742847158908844
Avergae loss of epoch 16: 0.1806300436295569
Avergae loss of epoch 17: 0.18660266072303056
Avergae loss of epoch 18: 0.1952404471002519
Avergae loss of epoch 19: 0.20462215889245272
Avergae loss of epoch 20: 0.21529127330332995
Avergae loss of epoch 21: 0.22524761686474085
Aver