Simplest Env

In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, DataLoader

# Example graph data preparation
# Number of tanks (nodes) and connections (edges)
num_tanks = 5
num_edges = 7

# Node features: e.g., [volume, concentration, temperature]
node_features = torch.randn((num_tanks, 3))

# Edge indices (connections between tanks)
edge_index = torch.tensor([[0, 1, 2, 3, 4, 0, 2], 
                           [1, 2, 3, 4, 0, 2, 4]], dtype=torch.long)

# Edge features: e.g., [flow capacity, resistance]
edge_features = torch.randn((num_edges, 2))

# True flow values for training (example values)
flow_values = torch.randn((num_edges, 1))

# Create a Data object
data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features, y=flow_values)


In [3]:

# Define the GNN model
class TankFlowGNN(torch.nn.Module):
    def __init__(self):
        super(TankFlowGNN, self).__init__()
        self.conv1 = GCNConv(3, 16)
        self.conv2 = GCNConv(16, 16)
        self.fc = torch.nn.Linear(16, 1)
    
    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        
        # Aggregate edge features
        edge_x = torch.cat([x[edge_index[0]], x[edge_index[1]], edge_attr], dim=1)
        
        # Predict flows on edges
        flow_predictions = self.fc(edge_x)
        return flow_predictions

# Instantiate the model, define loss and optimizer
model = TankFlowGNN()


In [4]:
data

Data(x=[5, 3], edge_index=[2, 7], edge_attr=[7, 2], y=[7, 1])

In [None]:

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()

# Training loop
def train(data, epochs=100):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')

# Train the model
train(data)


In [5]:
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')

Downloading https://www.chrsmrrs.com/graphkerneldatasets/ENZYMES.zip
Processing...
Done!


In [8]:
dataset[1]

Data(edge_index=[2, 102], x=[23, 3], y=[1])