In [2]:
import torch
from torch_geometric.data import Data
from torch.optim import Adam
from torch.nn import MSELoss

# Assuming you've already defined your GNN-based SIR model as SIRModel
# from your_gnn_module import SIRModel

class SIRModel(torch.nn.Module):
    def __init__(self, initial_beta=0.5):
        super(SIRModel, self).__init__()
        self.beta = torch.nn.Parameter(torch.tensor(initial_beta))

    def forward(self, x, edge_index):
        return x + self.beta * torch.sum(x[edge_index[0]] - x[edge_index[1]])

# Simulated observed data
observed_data = torch.rand(100) * 10

# Define a simple graph for this example
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[1], [1], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)

def loss_fn(output, observed_data):
    criterion = MSELoss()
    return criterion(output, observed_data)

def train(model, data, observed_data, epochs=500):
    optimizer = Adam(model.parameters(), lr=0.01)
    model.train()

    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = loss_fn(out, observed_data)
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch}, Loss: {loss.item()}")

    return model.beta  # Now we return the beta from the model

model = SIRModel(initial_beta=0.5)
beta = torch.nn.Parameter(torch.tensor(0.5))  # Initial beta

trained_beta = train(model, data, observed_data)
print(f"Inferred Beta: {trained_beta.item()}")


TypeError: SIRModel.forward() takes 3 positional arguments but 4 were given