Analyzing Data Obtained

In [None]:
first_edge = G[0]

for target_node, attributes in first_edge.items():
    print("Target Node:", target_node)
    print("Attributes:", attributes)

In [None]:
data.edge_index
author_to_index
data.y
data.x

Training The Model


In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split

class GNNModel(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_classes):
        super(GNNModel, self).__init__()
        self.conv = GNNConv(num_features, hidden_channels)
        self.fc = torch.nn.Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = F.relu(x)
        x = self.fc(x)
        return x

class GNNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GNNConv, self).__init__(aggr="add")  # "add" aggregation
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return x_j

    def update(self, aggr_out):
        return aggr_out

torch.manual_seed(1234)

num_features = data.num_node_features
hidden_channels = 64
num_classes = 1

model = GNNModel(num_features, hidden_channels, num_classes)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

data.train_mask = data.test_mask = None
data.train_mask, data.test_mask = train_test_split(range(data.num_nodes), test_size=0.2)
print(data.y)
data.x = data.x.float()
print(data.x)

model.train()
for epoch in range(100):
    total_loss = 0
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
    print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}")

print("Training finished!")

torch.save(model.state_dict(), 'model.pt')
