In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import RGCNConv

# A simple GPS layer that fuses local (R-GCN) and global (Transformer) processing.
class GPSLayer(nn.Module):
    def __init__(self, hidden_dim, num_relations, nhead=8):
        """
        Args:
            hidden_dim (int): Dimension of node features.
            num_relations (int): Number of relation types in the graph.
            nhead (int): Number of attention heads in the transformer layer.
        """
        super(GPSLayer, self).__init__()
        # Local message passing using R-GCN (Relational GCN)
        self.local = RGCNConv(hidden_dim, hidden_dim, num_relations)
        # Global transformer: using PyTorch's TransformerEncoderLayer.
        # Note: Transformer expects input shape (S, N, E) where S=sequence length (num_nodes) and N=batch size.
        self.global_transformer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead)
        # Two layer norms (one after each branch) for stability.
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)

    def forward(self, x, edge_index, edge_type):
        # LOCAL BRANCH: apply the R-GCN module.
        local_out = self.local(x, edge_index, edge_type)
        x = x + local_out  # residual connection
        x = self.norm1(x)

        # GLOBAL BRANCH: apply the transformer module.
        # Transformer expects input shape (num_nodes, batch_size, hidden_dim).
        # We assume a single-graph (batch size = 1).
        x_transformed = self.global_transformer(x.unsqueeze(1)).squeeze(1)
        x = x + x_transformed  # residual connection
        x = self.norm2(x)
        return x

# The overall model that stacks multiple GPS layers for node prediction.
class GraphGPSNodeClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, num_relations, nhead=8):
        """
        Args:
            in_channels (int): Input node feature dimension.
            hidden_channels (int): Hidden (and transformer) dimension.
            out_channels (int): Number of prediction classes (or regression output dim).
            num_layers (int): Number of alternating GPS layers.
            num_relations (int): Number of relation types in the graph.
            nhead (int): Number of transformer heads.
        """
        super(GraphGPSNodeClassifier, self).__init__()
        # Initial linear embedding of input features.
        self.embedding = nn.Linear(in_channels, hidden_channels)
        # Stack alternating GPS layers.
        self.layers = nn.ModuleList([
            GPSLayer(hidden_channels, num_relations, nhead=nhead)
            for _ in range(num_layers)
        ])
        # Final classifier head (applied on each node)
        self.classifier = nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, edge_type):
        """
        Args:
            x (torch.Tensor): Node features of shape (num_nodes, in_channels).
            edge_index (torch.Tensor): Edge indices (2, num_edges).
            edge_type (torch.Tensor): Edge type labels (num_edges,).
        Returns:
            torch.Tensor: Predictions for each node (num_nodes, out_channels).
        """
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, edge_index, edge_type)
        out = self.classifier(x)
        return out

# ---------------------------
# Example usage:
# Assume:
#   - each node has 'in_channels' features,
#   - the graph has 'num_relations' (e.g. from your KG or other source),
#   - we wish to classify nodes into 'out_channels' classes.
# Replace these numbers with the ones for your application.
in_channels = 128
hidden_channels = 256
out_channels = 10   # for 10-class node classification
num_layers = 3
num_relations = 4   # change this based on your relation vocab

model = GraphGPSNodeClassifier(in_channels, hidden_channels, out_channels, num_layers, num_relations)

# Dummy data for illustration:
num_nodes = 50
num_edges = 200
x = torch.randn(num_nodes, in_channels)
edge_index = torch.randint(0, num_nodes, (2, num_edges))
edge_type = torch.randint(0, num_relations, (num_edges,))

# Forward pass: node predictions
predictions = model(x, edge_index, edge_type)
print("Node predictions shape:", predictions.shape)  # Expected: (num_nodes, out_channels)

  from .autonotebook import tqdm as notebook_tqdm


Node predictions shape: torch.Size([50, 10])


In [2]:
from KG_trainer_w_transformer import get_KG_trainer

Loading cpnet...
Done
Loaded 3947 allowed concept IDs from train/val data.


In [3]:
DATA_PATH = "data/eg"
source_path=f"{DATA_PATH}/train.source"
target_path=f"{DATA_PATH}/train.target"
model_name = "facebook/bart-base"
output_dir = "KG_finetuned_out_transformer"
max_len = 128
epochs = 1
train_batch_size = 20

In [4]:
trainer = get_KG_trainer(
    source_path=source_path,
    target_path=target_path,
    model_name=model_name,
    output_dir=output_dir,
    max_len=max_len,
    epochs=epochs,
    train_batch_size=train_batch_size
)

Preprocessing dataset...


Preprocessing data: 100%|██████████| 500/500 [00:00<00:00, 872.19 examples/s]


Saving preprocessed dataset to disk...


Saving the dataset (1/1 shards): 100%|██████████| 500/500 [00:00<00:00, 100002.48 examples/s]
Some weights of BartGraphAwareForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['graph_encoder.embedding.weight', 'graph_encoder.gps_layers.0.global_transformer.linear1.bias', 'graph_encoder.gps_layers.0.global_transformer.linear1.weight', 'graph_encoder.gps_layers.0.global_transformer.linear2.bias', 'graph_encoder.gps_layers.0.global_transformer.linear2.weight', 'graph_encoder.gps_layers.0.global_transformer.norm1.bias', 'graph_encoder.gps_layers.0.global_transformer.norm1.weight', 'graph_encoder.gps_layers.0.global_transformer.norm2.bias', 'graph_encoder.gps_layers.0.global_transformer.norm2.weight', 'graph_encoder.gps_layers.0.global_transformer.self_attn.in_proj_bias', 'graph_encoder.gps_layers.0.global_transformer.self_attn.in_proj_weight', 'graph_encoder.gps_layers.0.global_transformer.self_attn.out_proj.bias', 'grap

In [None]:
trainer.train()

Step,Training Loss
