In [None]:
# Defining the model
import torch
from torch_geometric.nn import GCNConv, global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self, num_features, hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = global_mean_pool(x, batch)
        return self.lin(x)


In [None]:
# Training
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd

# Load data
dataset = 

train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(num_features=1, hidden_channels=64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

for epoch in range(50):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x.float(), batch.edge_index, batch.batch)
        loss = loss_fn(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")


In [None]:
# Evaluation
model.eval()
preds, trues = [], []
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        out = model(batch.x.float(), batch.edge_index, batch.batch)
        preds.append(out.cpu())
        trues.append(batch.y.cpu())

import torch
preds = torch.cat(preds)
trues = torch.cat(trues)

from sklearn.metrics import mean_absolute_error, r2_score
print("MAE:", mean_absolute_error(trues, preds))
print("RÂ²:", r2_score(trues, preds))
