In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import networkx as nx
import numpy as np
from torch_geometric.nn import GCNConv

# Create a grid graph
def create_grid_graph(n):
    G = nx.grid_2d_graph(n, n)
    return nx.convert_node_labels_to_integers(G)

# Define a simple GNN model
class GNN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(in_channels, 16)
        self.conv2 = GCNConv(16, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x

# Create a grid graph
n = 5
G = create_grid_graph(n)
edge_index = torch.tensor(list(G.edges)).t().contiguous()

# Example grid values (graph signals)
x = torch.randn((n * n, 1))  # Signal on the source grid
y = torch.randn((n * n, 1))  # Signal on the target grid

# Initialize the model, loss, and optimizer
model = GNN(in_channels=1, out_channels=1)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    out = model(x, edge_index)
    loss = criterion(out, y)
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')
