In [6]:
########################################
# 1) Imports
########################################
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
import math


device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [15]:
# GCN File from https://github.com/seongjunyun/Graph_Transformer_Networks/blob/master/gcn.py + additional functions from other files

from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_geometric.nn.conv.message_passing import MessagePassing
from torch_geometric.utils import add_self_loops

def glorot(tensor):
    stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
    if tensor is not None:
        tensor.data.uniform_(-stdv, stdv)

def zeros(tensor):
    if tensor is not None:
        tensor.data.fill_(0)

class GCNConv(MessagePassing):
    r"""The graph convolutional operator from the `"Semi-supervised
    Classfication with Graph Convolutional Networks"
    <https://arxiv.org/abs/1609.02907>`_ paper

    .. math::
        \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
        \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},

    where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the
    adjacency matrix with inserted self-loops and
    :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix.

    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        improved (bool, optional): If set to :obj:`True`, the layer computes
            :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`.
            (default: :obj:`False`)
        cached (bool, optional): If set to :obj:`True`, the layer will cache
            the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2}
            \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}`.
            (default: :obj:`False`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 improved=False,
                 cached=False,
                 bias=True,
                 args=None):
        super(GCNConv, self).__init__('add', flow='target_to_source')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached
        self.cached_result = None

        self.weight = Parameter(torch.Tensor(in_channels, out_channels))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.args = args
        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)
        self.cached_result = None


    @staticmethod
    def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None, args=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)
        edge_weight = edge_weight.view(-1)
        assert edge_weight.size(0) == edge_index.size(1)

        edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
        
        loop_weight = torch.full((num_nodes, ),
                                1, # if not args.remove_self_loops else 0,
                                dtype=edge_weight.dtype,
                                device=edge_weight.device)
        edge_weight = torch.cat([edge_weight, loop_weight], dim=0)

        row, col = edge_index
        
        # deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-1)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        # return edge_index, (deg_inv_sqrt[col] ** 0.5) * edge_weight * (deg_inv_sqrt[row] ** 0.5)
        return edge_index, deg_inv_sqrt[row] * edge_weight


    def forward(self, x, edge_index, edge_weight=None):
        """"""
        x = torch.matmul(x, self.weight)

        if not self.cached or self.cached_result is None:
            edge_index, norm = self.norm(edge_index, x.size(0), edge_weight,
                                         self.improved, x.dtype, args=self.args)
            self.cached_result = edge_index, norm
        edge_index, norm = self.cached_result

        return self.propagate(edge_index, x=x, norm=norm)


    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)

In [9]:
########################################
# 2) Load the ZINC dataset from PyG
########################################
# By default, subset=True uses ~12k molecules, each is a separate graph.
# The data object has .x (node features), .edge_index, .edge_attr, .y

root = '../data/ZINC'
train_dataset = ZINC(root, split='train', subset=True)
val_dataset   = ZINC(root, split='val',   subset=True)
test_dataset  = ZINC(root, split='test',  subset=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=32, shuffle=False)
test_loader  = DataLoader(test_dataset,  batch_size=32, shuffle=False)

print(f"Train size: {len(train_dataset)}  Val size: {len(val_dataset)}  Test size: {len(test_dataset)}")

Train size: 10000  Val size: 1000  Test size: 1000


In [10]:
########################################
# 3) Define a simple GCN-based Model
########################################
# We'll adapt a standard "Graph-level" GCN model:
#   2 or 3 layers of GCNConv from gcn.py
#   global mean pool
#   final MLP for the property (ZINC is regression, so out_dim=1)
########################################

class GCNNet(nn.Module):
    def __init__(self, hidden_dim, num_layers, in_dim, out_dim=1):
        super().__init__()
        self.convs = nn.ModuleList()
        
        # First GCNConv: from in_dim -> hidden_dim
        self.convs.append(GCNConv(in_channels=in_dim, out_channels=hidden_dim))
        
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(in_channels=hidden_dim, out_channels=hidden_dim))
        
        self.lin = nn.Linear(hidden_dim, out_dim)
        
    def forward(self, x, edge_index, batch, edge_attr=None):
        # x: [N, in_dim]
        # edge_index: [2, E]
        # batch: [N] (which graph each node belongs to)
        # edge_attr: can pass as edge_weight if 1D
        # or we can do modifications inside GCNConv if needed.
        
        for conv in self.convs:
            # The original GCNConv from your code expects edge_weight:
            # So if ZINC's edge_attr is not scalar, you might do some aggregation or pass None.
            
            out = conv(x, edge_index, edge_weight=None)
            out = F.relu(out)
            x = out
        
        # Then pool to get graph-level embedding
        out_pool = global_mean_pool(x, batch)  # shape [num_graphs, hidden_dim]
        
        # Final linear for regression
        y = self.lin(out_pool)  # shape [num_graphs, 1]
        return y.squeeze(-1)


In [13]:
########################################
# 4) Training / Evaluation
########################################

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for batch_data in loader:
        batch_data = batch_data.to(device)
        # batch_data.x: [N, node_features]
        # batch_data.edge_index: [2, E]
        # batch_data.edge_attr: [E, bond_dim] (ZINC)
        # batch_data.y: [batch_size_of_graphs, 1]
        # batch_data.batch: [N] => which graph each node belongs to
        
        y_pred = model(batch_data.x.float(), batch_data.edge_index, batch_data.batch, edge_attr=batch_data.edge_attr)
        y_true = batch_data.y.view(-1).float()
        
        loss = criterion(y_pred, y_true)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * batch_data.num_graphs
    return total_loss / len(loader.dataset)

@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    for batch_data in loader:
        batch_data = batch_data.to(device)
        y_pred = model(batch_data.x.float(), batch_data.edge_index, batch_data.batch, edge_attr=batch_data.edge_attr)
        y_true = batch_data.y.view(-1).float()
        
        loss = criterion(y_pred, y_true)
        total_loss += loss.item() * batch_data.num_graphs
    return total_loss / len(loader.dataset)


In [None]:
########################################
# 5) Putting It All Together
########################################

model = GCNNet(
    hidden_dim=64, 
    num_layers=3,  # e.g. 3-layer GCN
    in_dim=train_dataset.num_node_features,  # Usually 28
    out_dim=1
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = nn.MSELoss()

patience = 10
best_val_loss = float('inf')
epochs_no_improve = 0
best_model_state = None

epochs = 2
for epoch in range(1, epochs+1):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss   = evaluate(model, val_loader, criterion)
    # if epoch % 2 == 0:
    print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train MSE: {train_loss:.4f}, Val MSE: {val_loss:.4f}, Epochs no improvement: {epochs_no_improve}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = model.state_dict()
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print(f"Early stopping at epoch {epoch}")
        break

test_mse = evaluate(model, test_loader, criterion)
print(f"Final Test MSE: {test_mse:.4f}")


Epoch 1, Train Loss: 3.1930, Val Loss: 2.9130, Train MSE: 3.1930, Val MSE: 2.9130, Epochs no improvement: 0
