In [None]:
import torch
import random
import torch.nn.functional as F
from torch.nn import Linear, Module
from torch_geometric.datasets import TUDataset
from torch_geometric.data import HeteroData
from torch_geometric.loader import DataLoader as HeteroDataLoader
from torch_geometric.utils import to_undirected
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, global_mean_pool
from sklearn.metrics import accuracy_score
from torch.nn.functional import cross_entropy

# Load MUTAG dataset
dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')

def make_hetero_with_triangle(data, add_triangle=True):
    hetero = HeteroData()
    # Add molecular graph data
    hetero['mol_node'].x = data.x
    hetero['mol_node', 'to', 'mol_node'].edge_index = to_undirected(data.edge_index)

    # Initialize empty triangle nodes and edges
    hetero['tri_node'].x = torch.zeros((0, data.x.size(1)))
    hetero['tri_node', 'connects', 'tri_node'].edge_index = torch.zeros((2, 0), dtype=torch.long)
    hetero['mol_node', 'linked_to', 'tri_node'].edge_index = torch.zeros((2, 0), dtype=torch.long)
    hetero['tri_node', 'linked_to_by', 'mol_node'].edge_index = torch.zeros((2, 0), dtype=torch.long)

    if add_triangle:
        # Create triangle nodes with all-one features
        hetero['tri_node'].x = torch.ones((3, data.x.size(1)))

        # Create triangle edges (undirected)
        tri_edges = torch.tensor([[0, 1], [1, 2], [2, 0]], dtype=torch.long).t().contiguous()
        hetero['tri_node', 'connects', 'tri_node'].edge_index = to_undirected(tri_edges)

        # Connect random molecule node to random triangle node
        mol_node_idx = random.randint(0, data.num_nodes - 1)
        tri_node_idx = random.randint(0, 2)

        # Add bidirectional connections
        hetero['mol_node', 'linked_to', 'tri_node'].edge_index = torch.tensor([[mol_node_idx], [tri_node_idx]], dtype=torch.long)
        hetero['tri_node', 'linked_to_by', 'mol_node'].edge_index = torch.tensor([[tri_node_idx], [mol_node_idx]], dtype=torch.long)

        hetero['graph'].y = torch.tensor([1], dtype=torch.long)
    else:
        hetero['graph'].y = torch.tensor([0], dtype=torch.long)

    return hetero

# Create heterogeneous dataset
hetero_dataset = [make_hetero_with_triangle(data, add_triangle=(i % 2 == 0)) for i, data in enumerate(dataset)]

# Split dataset
random.shuffle(hetero_dataset)
split_idx = int(0.8 * len(hetero_dataset))
train_dataset = hetero_dataset[:split_idx]
test_dataset = hetero_dataset[split_idx:]

train_loader = HeteroDataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = HeteroDataLoader(test_dataset, batch_size=8)

class HeteroGNN(Module):
    def __init__(self, hidden_dim=64, out_dim=2):
        super().__init__()
        self.conv1 = HeteroConv({
            # Same-type connections using GCNConv
            ('mol_node', 'to', 'mol_node'): GCNConv(-1, hidden_dim, add_self_loops=True),
            ('tri_node', 'connects', 'tri_node'): GCNConv(-1, hidden_dim, add_self_loops=True),

            # Cross-type connections using SAGEConv
            ('mol_node', 'linked_to', 'tri_node'): SAGEConv((-1, -1), hidden_dim),
            ('tri_node', 'linked_to_by', 'mol_node'): SAGEConv((-1, -1), hidden_dim),
        }, aggr='sum')

        self.lin = Linear(hidden_dim, out_dim)

    def forward(self, x_dict, edge_index_dict, batch_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {k: F.relu(v) for k, v in x_dict.items()}

        # Use molecular nodes for graph classification
        mol_x = x_dict['mol_node']
        mol_batch = batch_dict['mol_node']

        # Global mean pooling with batch information
        graph_rep = global_mean_pool(mol_x, mol_batch)
        return self.lin(graph_rep)

# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HeteroGNN(hidden_dim=64, out_dim=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x_dict, batch.edge_index_dict, batch.batch_dict)
        loss = cross_entropy(out, batch['graph'].y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def test(loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch.x_dict, batch.edge_index_dict, batch.batch_dict)
            preds = out.argmax(dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(batch['graph'].y.cpu())
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    return accuracy_score(all_labels, all_preds)

# Training loop
for epoch in range(1, 21):
    loss = train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch {epoch:02d} | Loss: {loss:.4f} | Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}')

Epoch 01 | Loss: 0.6593 | Train Acc: 0.5267 | Test Acc: 0.3947
Epoch 02 | Loss: 0.4816 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 03 | Loss: 0.2393 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 04 | Loss: 0.0878 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 05 | Loss: 0.0384 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 06 | Loss: 0.0210 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 07 | Loss: 0.0135 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 08 | Loss: 0.0097 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 09 | Loss: 0.0075 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 10 | Loss: 0.0060 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 11 | Loss: 0.0049 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 12 | Loss: 0.0041 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 13 | Loss: 0.0034 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 14 | Loss: 0.0029 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 15 | Loss: 0.0026 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 16 | Loss: 0.0023 | Train Acc: 1.0000 | Test Acc:

: 