In [1]:
#!/usr/bin/env python

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GraphUNet


class FeatureExtractorGNN(nn.Module):
    """
    GraphUNet-based feature extractor for each node with attention.
    """
    def __init__(self, in_channels=9, hidden_channels=64, out_channels=32,
                 depth=3, pool_ratios=0.5, heads=4, concat=True, dropout=0.6):
        super(FeatureExtractorGNN, self).__init__()
        self.unet = GraphUNet(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            out_channels=out_channels,
            depth=depth,
            pool_ratios=pool_ratios,
            act=F.relu
        )
        self.attention1 = GATConv(out_channels, out_channels, heads=heads,
                                  concat=concat, dropout=dropout)
        self.attention2 = GATConv(out_channels * heads if concat else out_channels,
                                  out_channels, heads=1, concat=False, dropout=dropout)
        self.residual = nn.Linear(out_channels, out_channels)

    def forward(self, x, edge_index):
        residual = self.unet(x, edge_index)
        x = F.elu(self.attention1(residual, edge_index))
        x = self.attention2(x, edge_index)
        x += residual
        return x


class DerivativeGNN(nn.Module):
    """
    GNN that approximates Fdot = dF/dt using Graph Attention.
    """
    def __init__(self, in_channels, out_channels=64, heads=4, concat=True, dropout=0.6):
        super(DerivativeGNN, self).__init__()
        self.gat1 = GATConv(in_channels, 64, heads=heads, concat=concat, dropout=dropout)
        self.gat2 = GATConv(64 * heads if concat else 64, out_channels, heads=1,
                            concat=False, dropout=dropout)

    def forward(self, x, edge_index):
        x = F.elu(self.gat1(x, edge_index))
        x = self.gat2(x, edge_index)
        return x


class IntegralGNN(nn.Module):
    """
    GNN that learns the integral operator: ΔF = ∫ Fdot dt using Graph Attention.
    """
    def __init__(self, in_channels=64, out_channels=4, heads=4, concat=True, dropout=0.6):
        super(IntegralGNN, self).__init__()
        self.gat1 = GATConv(in_channels, 64, heads=heads, concat=concat, dropout=dropout)
        self.gat2 = GATConv(64 * heads if concat else 64, out_channels, heads=1,
                            concat=False, dropout=dropout)

    def forward(self, x, edge_index):
        x = F.elu(self.gat1(x, edge_index))
        x = self.gat2(x, edge_index)
        return x


class GPARCRecurrent(nn.Module):
    """
    Recurrent GPARC that processes a sequence of Data objects.
    The feature_extractor is now part of the model and trained together.
    """
    def __init__(self, feature_extractor, derivative_solver, integral_solver,
                 num_static_feats=9, num_dynamic_feats=4):
        super().__init__()
        self.feature_extractor = feature_extractor
        self.derivative_solver = derivative_solver
        self.integral_solver = integral_solver
        self.num_static_feats = num_static_feats
        self.num_dynamic_feats = num_dynamic_feats


# Model configuration from the script
num_static_feats = 9
num_dynamic_feats = 4

# Build submodules with exact config from script
feature_extractor = FeatureExtractorGNN(
    in_channels=num_static_feats,
    hidden_channels=64,
    out_channels=128,
    depth=2,
    pool_ratios=0.1,
    heads=4,
    concat=True,
    dropout=0.2
)

derivative_solver = DerivativeGNN(
    in_channels=128 + num_dynamic_feats,  # (128 extracted) + dynamic feats
    out_channels=4,
    heads=4,
    concat=True,
    dropout=0.2
)

integral_solver = IntegralGNN(
    in_channels=4,
    out_channels=num_dynamic_feats,
    heads=4,
    concat=True,
    dropout=0.2
)

# Integrated model
model = GPARCRecurrent(
    feature_extractor=feature_extractor,
    derivative_solver=derivative_solver,
    integral_solver=integral_solver,
    num_static_feats=num_static_feats,
    num_dynamic_feats=num_dynamic_feats
)

# Count parameters
def count_parameters(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

total_params = count_parameters(model)
feature_params = count_parameters(model.feature_extractor)
derivative_params = count_parameters(model.derivative_solver)
integral_params = count_parameters(model.integral_solver)

print("=" * 60)
print("GPARC Recurrent Model Parameter Count")
print("=" * 60)
print(f"\nFeature Extractor:    {feature_params:>12,} parameters")
print(f"Derivative Solver:    {derivative_params:>12,} parameters")
print(f"Integral Solver:      {integral_params:>12,} parameters")
print(f"{'-' * 60}")
print(f"TOTAL:                {total_params:>12,} parameters")
print("=" * 60)

# Breakdown of feature extractor
print("\nFeature Extractor Breakdown:")
print(f"  GraphUNet:          {count_parameters(model.feature_extractor.unet):>12,}")
print(f"  GAT Layer 1:        {count_parameters(model.feature_extractor.attention1):>12,}")
print(f"  GAT Layer 2:        {count_parameters(model.feature_extractor.attention2):>12,}")
print(f"  Residual Linear:    {count_parameters(model.feature_extractor.residual):>12,}")

print("\nDerivative Solver Breakdown:")
print(f"  GAT Layer 1:        {count_parameters(model.derivative_solver.gat1):>12,}")
print(f"  GAT Layer 2:        {count_parameters(model.derivative_solver.gat2):>12,}")

print("\nIntegral Solver Breakdown:")
print(f"  GAT Layer 1:        {count_parameters(model.integral_solver.gat1):>12,}")
print(f"  GAT Layer 2:        {count_parameters(model.integral_solver.gat2):>12,}")

GPARC Recurrent Model Parameter Count

Feature Extractor:         171,072 parameters
Derivative Solver:          35,596 parameters
Integral Solver:             2,828 parameters
------------------------------------------------------------
TOTAL:                     209,496 parameters

Feature Extractor Breakdown:
  GraphUNet:                21,568
  GAT Layer 1:              67,072
  GAT Layer 2:              65,920
  Residual Linear:          16,512

Derivative Solver Breakdown:
  GAT Layer 1:              34,560
  GAT Layer 2:               1,036

Integral Solver Breakdown:
  GAT Layer 1:               1,792
  GAT Layer 2:               1,036
