In [1]:
!pip install optuna
!pip install torch_geometric
!pip install gudhi

Collecting optuna
  Downloading optuna-4.3.0-py3-none-any.whl.metadata (17 kB)
Collecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.15.2-py3-none-any.whl.metadata (7.3 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.9.0-py3-none-any.whl.metadata (10 kB)
Collecting Mako (from alembic>=1.5.0->optuna)
  Downloading mako-1.3.10-py3-none-any.whl.metadata (2.9 kB)
Downloading optuna-4.3.0-py3-none-any.whl (386 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m386.6/386.6 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading alembic-1.15.2-py3-none-any.whl (231 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m231.9/231.9 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading colorlog-6.9.0-py3-none-any.whl (11 kB)
Downloading mako-1.3.10-py3-none-any.whl (78 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.5/78.5 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: M

Main model. For different types change flags USE_PH or USE_SPECTRAL

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import optuna

from torch_geometric.datasets import WebKB, WikipediaNetwork, Planetoid, Actor
from torch_geometric.transforms import NormalizeFeatures

import gudhi
from gudhi.representations import PersistenceImage
USE_PH = True
USE_SPECTRAL = True

def compute_local_PH(x, edge_index, k=5, maxdim=1):
    row, col = edge_index
    N = x.size(0)
    x_np = x.detach().cpu().numpy()
    neighs = [[] for _ in range(N)]
    for i,j in zip(row.tolist(), col.tolist()):
        neighs[i].append(j)

    diagrams = []
    for i in range(N):
        patch_pts = [x_np[i]]
        for j in neighs[i][:k]:
            patch_pts.append(x_np[j])
        patch = np.stack(patch_pts, axis=0)
        rips = gudhi.RipsComplex(points=patch)
        st = rips.create_simplex_tree(max_dimension=maxdim)
        st.persistence()
        diag = st.persistence_intervals_in_dimension(1)
        if len(diag)==0:
            diag = np.array([[0.0,0.0]])
        diagrams.append(diag)
    return diagrams

def get_persistence_vectors(diagrams, resolution=(5,5)):
    PI = PersistenceImage(bandwidth=1.0,
                          weight=lambda pt: pt[1]-pt[0],
                          resolution=resolution)
    vecs = PI.fit_transform(diagrams)
    vecs = np.nan_to_num(vecs, nan=0.0, posinf=1e3, neginf=-1e3)
    return torch.tensor(vecs, dtype=torch.float)

def laplacian_regularization(edge_index, num_nodes):
    row, col = edge_index
    deg = torch.bincount(row, minlength=num_nodes).float()
    D_inv_sqrt = torch.diag(torch.pow(deg.clamp(min=1), -0.5))
    A = torch.zeros(num_nodes, num_nodes, device=edge_index.device)
    A[row, col] = 1
    return torch.eye(num_nodes, device=edge_index.device) - D_inv_sqrt @ A @ D_inv_sqrt

class DeepSheafLayer(nn.Module):
    def __init__(self, d, f_dim):
        super().__init__()
        self.d, self.f = d, f_dim
        self.lin = nn.Linear(d * f_dim, d * f_dim)
        self.ln  = nn.LayerNorm([d, f_dim])
    def forward(self, x, edge_index):
        N = x.size(0)
        flat = x.view(N, -1)
        outf = self.lin(flat).view(N, self.d, self.f)
        row, col = edge_index
        agg = torch.zeros_like(outf)
        agg.index_add_(0, row, outf[col])
        deg = torch.bincount(row, minlength=N).clamp(min=1).float().view(-1,1,1)
        agg = agg / deg
        return self.ln(F.relu(agg))

class DeepSheafNet(nn.Module):
    def __init__(self, in_dim, d, f_dim, out_dim, depth=4, spectral_weight=1e-5, ph_dim=25):
        super().__init__()
        self.encoder = nn.Linear(in_dim, d)
        self.spectral_weight = spectral_weight
        node_dim = d + (ph_dim if USE_PH else 0)
        self.layers = nn.ModuleList([DeepSheafLayer(d, f_dim) for _ in range(depth)])
        self.decoder = nn.Sequential(nn.Flatten(), nn.Linear(d*f_dim, out_dim))
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        nf = self.encoder(x)
        nf = torch.nan_to_num(nf, nan=0.0, posinf=1e3, neginf=-1e3)
        xs = nf.unsqueeze(2).repeat(1,1,self.layers[0].f)

        if USE_PH:
            phd = compute_local_PH(nf, edge_index)
            phv = get_persistence_vectors(phd).to(x.device)
            phv = torch.nan_to_num(phv, nan=0.0, posinf=1e3, neginf=-1e3)
            node_feat = torch.cat([nf, phv], dim=1)
        else:
            node_feat = nf

        L = laplacian_regularization(edge_index, data.num_nodes)
        for layer in self.layers:
            xs = layer(xs, edge_index) + xs

        out = self.decoder(xs)
        if USE_SPECTRAL:
            flat = xs.view(xs.size(0), -1)
            spec = torch.trace(flat.T @ L @ flat)
            spec = torch.nan_to_num(spec, nan=0.0, posinf=1e6, neginf=1e6) / data.num_nodes
        else:
            spec = torch.tensor(0., device=x.device)
        return F.log_softmax(out, dim=1), spec

def create_masks(data, tr=0.6, va=0.2):
    n = data.num_nodes
    perm = torch.randperm(n)
    ntr = int(tr*n); nva = int(va*n)
    tri = perm[:ntr]; vai = perm[ntr:ntr+nva]; tei = perm[ntr+nva:]
    m0 = torch.zeros(n, dtype=torch.bool, device=data.x.device)
    data.train_mask = m0.clone().scatter_(0, tri, True)
    data.val_mask   = m0.clone().scatter_(0, vai, True)
    data.test_mask  = m0.clone().scatter_(0, tei, True)
    return data

def load_dataset(name):
    if name in ['Cora','Citeseer','Pubmed']:
        return Planetoid(f'/tmp/{name}',name,transform=NormalizeFeatures())
    if name in ['Texas','Wisconsin','Cornell']:
        return WebKB(f'/tmp/{name}',name,transform=NormalizeFeatures())
    if name in ['Chameleon','Squirrel']:
        return WikipediaNetwork(f'/tmp/{name}',name,transform=NormalizeFeatures())
    if name=='Film':
        return Actor(f'/tmp/Film',transform=NormalizeFeatures())
    raise ValueError(name)

def objective_deepsheaf(trial, ds_name, device):
    d     = trial.suggest_int("d", 8, 64, step=8)
    f     = trial.suggest_int("f_dim", 2, 16)
    depth = trial.suggest_int("depth", 2, 8)
    lr    = trial.suggest_float("lr",1e-4,5e-1,log=True)
    wd    = trial.suggest_float("weight_decay",1e-7,1e-1,log=True)
    sw    = trial.suggest_float("spectral_weight",1e-7,1e-2,log=True)
    epochs= trial.suggest_int("epochs",50,200)

    ds   = load_dataset(ds_name)
    data = ds[0]
    if data.y.dim()>1: data.y = data.y.argmax(dim=1)
    if data.x is None: data.x = torch.ones((data.num_nodes,1),device=data.y.device)
    data = create_masks(data).to(device)

    model = DeepSheafNet(ds.num_node_features, d, f, int(data.y.max())+1,
                         spectral_weight=sw, depth=depth).to(device)
    opt   = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    crit  = nn.NLLLoss()

    best_val = 0.
    for epoch in range(1, epochs+1):
        model.train(); opt.zero_grad()
        out, reg = model(data)
        loss = crit(out[data.train_mask], data.y[data.train_mask]) + sw*reg
        loss = torch.nan_to_num(loss, nan=0.0, posinf=1e6, neginf=1e6)
        loss.backward()
        gn = torch.sqrt(sum((p.grad.norm(2)**2 for p in model.parameters() if p.grad is not None)))
        opt.step()

        model.eval()
        with torch.no_grad():
            pred = model(data)[0].argmax(dim=1)
            tr_acc = (pred[data.train_mask]==data.y[data.train_mask]).float().mean().item()
            va_acc = (pred[data.val_mask]  ==data.y[data.val_mask]).float().mean().item()
        if va_acc>best_val: best_val=va_acc

        if epoch==1 or epoch%20==0:
            print(f"[{ds_name}] ep {epoch:03d} | loss {loss:.4f} | grad {gn:.4e} | train {tr_acc:.3f} | val {va_acc:.3f}")

        trial.report(best_val, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
    return best_val

torch.manual_seed(42); np.random.seed(42); random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

datasets = ["Texas","Wisconsin","Cornell",
             "Chameleon","Squirrel","Film",
             "Cora","Citeseer","Pubmed"]
results = {}

for name in datasets:
    print(f"\n{name} (DeepSheafNet+PH)")
    study = optuna.create_study(direction="maximize")
    study.optimize(lambda t: objective_deepsheaf(t, name, device), n_trials=10)
    bv, bp = study.best_value, study.best_params
    print(f"\u2192 Best val acc: {bv:.4f} | params: {bp}")
    results[name] = (bv, bp)

print("\nFinal Results")
for nm,(bv,bp) in results.items():
    print(f"{nm:12s} | Val Acc: {bv:.4f} | Params: {bp}")

example of start model with parametrs after optuna end cycle.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random

from torch_geometric.datasets import WikipediaNetwork, WebKB
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import to_undirected

import gudhi
from gudhi.representations import PersistenceImage

def laplacian_regularization(edge_index, num_nodes):
    row, col = edge_index
    deg = torch.bincount(row, minlength=num_nodes).float()
    D_inv_sqrt = torch.diag(torch.pow(deg, -0.5))
    A = torch.zeros(num_nodes, num_nodes, device=edge_index.device)
    A[row, col] = 1
    return torch.eye(num_nodes, device=edge_index.device) - D_inv_sqrt @ A @ D_inv_sqrt

def compute_local_PH(x, edge_index, k=5, maxdim=1):
    row, col = edge_index
    N = x.size(0)
    x_np = x.detach().cpu().numpy()
    neighs = [[] for _ in range(N)]
    for i,j in zip(row.tolist(), col.tolist()):
        neighs[i].append(j)
    diagrams = []
    for i in range(N):
        patch_pts = [x_np[i]] + [x_np[j] for j in neighs[i][:k]]
        patch = np.stack(patch_pts, axis=0)
        rips = gudhi.RipsComplex(points=patch)
        st = rips.create_simplex_tree(max_dimension=maxdim)
        st.compute_persistence()
        diag = st.persistence_intervals_in_dimension(1)
        if len(diag) == 0:
            diag = np.array([[0.0, 0.0]])
        diagrams.append(diag)
    return diagrams

def get_persistence_vectors(diagrams, resolution=(5,5)):
    PI = PersistenceImage(bandwidth=1.0, weight=lambda pt: pt[1]-pt[0], resolution=resolution)
    vecs = PI.fit_transform(diagrams)
    vecs = np.nan_to_num(vecs, nan=0.0, posinf=1e3, neginf=-1e3)
    return torch.tensor(vecs, dtype=torch.float)

def create_masks(data, tr=0.6, va=0.2):
    n = data.num_nodes
    perm = torch.randperm(n)
    ntr = int(tr*n); nva = int(va*n)
    tri = perm[:ntr]; vai = perm[ntr:ntr+nva]; tei = perm[ntr+nva:]
    m0 = torch.zeros(n, dtype=torch.bool, device=data.x.device)
    data.train_mask = m0.clone().scatter_(0, tri, True)
    data.val_mask   = m0.clone().scatter_(0, vai, True)
    data.test_mask  = m0.clone().scatter_(0, tei, True)
    return data

class DeepSheafLayer(nn.Module):
    def __init__(self, d, f):
        super().__init__()
        self.d, self.f = d, f
        self.lin = nn.Linear(d * f, d * f)
        self.ln = nn.LayerNorm([d, f])
    def forward(self, x, edge_index):
        N = x.size(0)
        flat = x.view(N, -1)
        out = self.lin(flat).view(N, self.d, self.f)
        row, col = edge_index
        agg = torch.zeros_like(out)
        agg.index_add_(0, row, out[col])
        deg = torch.bincount(row, minlength=N).clamp(min=1).float().view(-1,1,1)
        return self.ln(F.relu(agg / deg))

class DeepSheafNet(nn.Module):
    def __init__(self, in_dim, d, f, out_dim, depth=3, spectral_weight=1e-5, ph_dim=25):
        super().__init__()
        self.encoder = nn.Linear(in_dim, d)
        self.spectral_weight = spectral_weight
        node_dim = d + ph_dim
        self.layers = nn.ModuleList([DeepSheafLayer(d, f) for _ in range(depth)])
        self.decoder = nn.Sequential(nn.Flatten(), nn.Linear(d * f, out_dim))
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        nf = self.encoder(x)
        nf = torch.nan_to_num(nf, nan=0.0, posinf=1e3, neginf=-1e3)
        x_sheaf = nf.unsqueeze(2).repeat(1, 1, self.layers[0].f)

        ph_diags = compute_local_PH(nf, edge_index)
        ph_vecs = get_persistence_vectors(ph_diags).to(x.device)
        ph_vecs = torch.nan_to_num(ph_vecs, nan=0.0, posinf=1e3, neginf=-1e3)
        node_feat = torch.cat([nf, ph_vecs], dim=1)

        L = laplacian_regularization(edge_index, data.num_nodes)
        for layer in self.layers:
            x_sheaf = layer(x_sheaf, edge_index) + x_sheaf

        out = self.decoder(x_sheaf)
        if self.spectral_weight > 0:
            flat = x_sheaf.view(x_sheaf.size(0), -1)
            spectral = torch.trace(flat.T @ L @ flat) / data.num_nodes
        else:
            spectral = torch.tensor(0., device=x.device)
        return F.log_softmax(out, dim=1), spectral

torch.manual_seed(42); np.random.seed(42); random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

params = {
    'd': 8, 'f_dim': 16, 'depth': 2, 'lr': 0.07876328739451417, 'weight_decay': 0.00035039182415240504, 'spectral_weight': 3.6910574695954574e-07, 'epochs': 100}

data = WebKB(root="/tmp/Texas", name="Texas", transform=NormalizeFeatures())[0]
if data.y.dim() > 1:
    data.y = data.y.argmax(dim=1)
if data.x is None:
    data.x = torch.ones((data.num_nodes, 1))
data.edge_index = to_undirected(data.edge_index)
data = create_masks(data).to(device)

model = DeepSheafNet(
    in_dim=data.num_node_features,
    d=params['d'],
    f=params['f_dim'],
    out_dim=int(data.y.max()) + 1,
    depth=params['depth'],
    spectral_weight=params['spectral_weight']
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])
criterion = nn.NLLLoss()

for epoch in range(1, params['epochs'] + 1):
    model.train()
    optimizer.zero_grad()
    out, reg = model(data)
    loss = criterion(out[data.train_mask], data.y[data.train_mask]) + params['spectral_weight'] * reg
    loss.backward()
    optimizer.step()

    model.eval()
    with torch.no_grad():
        out, _ = model(data)
        pred = out.argmax(dim=1)
        tr = (pred[data.train_mask] == data.y[data.train_mask]).float().mean().item()
        va = (pred[data.val_mask]   == data.y[data.val_mask]).float().mean().item()
        te = (pred[data.test_mask]  == data.y[data.test_mask]).float().mean().item()


    print(f"[Texas] ep {epoch:03d} | loss {loss:.4f} | train {tr:.3f} | val {va:.3f} | test {te:.3f}")

[Texas] ep 001 | loss 1.3549 | train 0.541 | val 0.722 | test 0.421
[Texas] ep 002 | loss 1.6624 | train 0.202 | val 0.083 | test 0.132
[Texas] ep 003 | loss 6.8751 | train 0.541 | val 0.722 | test 0.421
[Texas] ep 004 | loss 6.4649 | train 0.541 | val 0.722 | test 0.421
[Texas] ep 005 | loss 2.7531 | train 0.156 | val 0.139 | test 0.289
[Texas] ep 006 | loss 8.3454 | train 0.202 | val 0.083 | test 0.132
[Texas] ep 007 | loss 6.6320 | train 0.541 | val 0.722 | test 0.421
[Texas] ep 008 | loss 3.5822 | train 0.541 | val 0.722 | test 0.421
[Texas] ep 009 | loss 4.6833 | train 0.541 | val 0.722 | test 0.421
[Texas] ep 010 | loss 3.8001 | train 0.092 | val 0.056 | test 0.158
[Texas] ep 011 | loss 4.0056 | train 0.541 | val 0.722 | test 0.421
[Texas] ep 012 | loss 1.7195 | train 0.202 | val 0.083 | test 0.132
[Texas] ep 013 | loss 2.9026 | train 0.541 | val 0.722 | test 0.421
[Texas] ep 014 | loss 1.7829 | train 0.541 | val 0.722 | test 0.421
[Texas] ep 015 | loss 2.1022 | train 0.312 | val

Implementation of MLP

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import optuna
import numpy as np
import random

from torch_geometric.datasets import Planetoid, WebKB, WikipediaNetwork, Actor
from torch_geometric.transforms import NormalizeFeatures


class MLPNodeClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers, num_classes):
        super().__init__()
        layers = [nn.Linear(in_channels, hidden_channels), nn.ReLU()]
        for _ in range(num_layers - 1):
            layers += [nn.Linear(hidden_channels, hidden_channels), nn.ReLU()]
        self.mlp = nn.Sequential(*layers)
        self.classifier = nn.Linear(hidden_channels, num_classes)

    def forward(self, x):
        return self.classifier(self.mlp(x))

def load_dataset(name):
    if name in ['Cora','Citeseer','Pubmed']:
        return Planetoid(f'/tmp/{name}', name, transform=NormalizeFeatures())
    if name in ['Texas','Wisconsin','Cornell']:
        return WebKB(f'/tmp/{name}', name, transform=NormalizeFeatures())
    if name in ['Chameleon','Squirrel']:
        return WikipediaNetwork(f'/tmp/{name}', name, transform=NormalizeFeatures())
    if name == 'Film':
        return Actor(f'/tmp/Film', transform=NormalizeFeatures())
    raise ValueError(name)


def create_masks(data, train_frac=0.6, val_frac=0.2):
    num_nodes = data.num_nodes
    perm = torch.randperm(num_nodes)
    n_train = int(train_frac * num_nodes)
    n_val   = int(val_frac   * num_nodes)
    train_idx = perm[:n_train]
    val_idx   = perm[n_train:n_train+n_val]
    test_idx  = perm[n_train+n_val:]
    data.train_mask = torch.zeros(num_nodes, dtype=torch.bool, device=data.x.device)
    data.val_mask   = data.train_mask.clone()
    data.test_mask  = data.train_mask.clone()
    data.train_mask[train_idx] = True
    data.val_mask[  val_idx]   = True
    data.test_mask[ test_idx]  = True
    return data


def train(model, data, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    out = model(data.x)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

def evaluate(model, data):
    model.eval()
    with torch.no_grad():
        out = model(data.x)
        pred = out.argmax(dim=1)
        accs = []
        for mask in (data.train_mask, data.val_mask, data.test_mask):
            acc = (pred[mask] == data.y[mask]).sum().item() / mask.sum().item()
            accs.append(acc)
    return accs


def objective_mlp(trial, dataset_name, device):
    hidden = trial.suggest_int("hidden_channels", 32, 256, step=32)
    layers = trial.suggest_int("num_layers", 1, 4)
    lr     = trial.suggest_float("lr", 1e-3, 1e-1, log=True)
    wd     = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
    epochs = trial.suggest_int("epochs", 50, 200)

    dataset = load_dataset(dataset_name)
    data = dataset[0]

    if data.y.dim() > 1:
        data.y = data.y.argmax(dim=1)
    if data.x is None:
        data.x = torch.ones((data.num_nodes, 1), dtype=torch.float)

    data = create_masks(data)
    data = data.to(device)

    in_ch = dataset.num_node_features or data.x.size(1)
    out_ch = int(torch.unique(data.y).numel())

    model = MLPNodeClassifier(in_ch, hidden, layers, out_ch).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    criterion = nn.CrossEntropyLoss()

    best_val = 0.
    best_test = 0.
    for _ in range(epochs):
        train(model, data, optimizer, criterion)
        train_acc, val_acc, test_acc = evaluate(model, data)
        if val_acc > best_val:
            best_val = val_acc
            best_test = test_acc
        trial.report(best_val, _)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
    return best_val


torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

datasets = ["Texas", "Wisconsin", "Cornell",
                "Chameleon", "Squirrel", "Film",
                "Cora", "Citeseer", "Pubmed"]

results = {}
for name in datasets:
    print(f"\nDataset: {name} (MLP)")
    study = optuna.create_study(direction="maximize")
    study.optimize(lambda t: objective_mlp(t, name, device), n_trials=20)

    best_val = study.best_value
    best_params = study.best_params
    print(f"Best val acc: {best_val:.4f}")
    print("Best params:", best_params)
    results[name] = (best_val, best_params)

print("\nFinal MLP Results")
for name, (val, params) in results.items():
    print(f"{name:12s} | Val Acc: {val:.4f} | Params: {params}")

Implementation of GAT

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import optuna
import numpy as np
import random

from torch_geometric.datasets import Planetoid, WebKB, WikipediaNetwork, Actor
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.nn import GATConv


class GATNodeClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers, heads, num_classes):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GATConv(in_channels, hidden_channels, heads=heads))
        for _ in range(num_layers - 1):
            self.convs.append(GATConv(hidden_channels * heads, hidden_channels, heads=heads))
        self.final = nn.Linear(hidden_channels * heads, num_classes)

    def forward(self, x, edge_index):
        for conv in self.convs:
            x = F.elu(conv(x, edge_index))
        return self.final(x)

def create_masks(data, train_frac=0.6, val_frac=0.2):
    n = data.num_nodes
    perm = torch.randperm(n)
    n_train = int(train_frac * n)
    n_val   = int(val_frac   * n)
    idx_train = perm[:n_train]
    idx_val   = perm[n_train:n_train+n_val]
    idx_test  = perm[n_train+n_val:]
    mask = lambda idx: torch.zeros(n, dtype=torch.bool, device=data.x.device).scatter_(0, idx, True)
    data.train_mask = mask(idx_train)
    data.val_mask   = mask(idx_val)
    data.test_mask  = mask(idx_test)
    return data

def load_dataset(name):
    if name in ['Cora','Citeseer','Pubmed']:
        return Planetoid(f'/tmp/{name}', name, transform=NormalizeFeatures())
    if name in ['Texas','Wisconsin','Cornell']:
        return WebKB(f'/tmp/{name}', name, transform=NormalizeFeatures())
    if name in ['Chameleon','Squirrel']:
        return WikipediaNetwork(f'/tmp/{name}', name, transform=NormalizeFeatures())
    if name == 'Film':
        return Actor(f'/tmp/Film', transform=NormalizeFeatures())
    raise ValueError(name)

def objective_gat(trial, dataset_name, device):
    hidden = trial.suggest_int("hidden_channels", 32, 256, step=32)
    layers = trial.suggest_int("num_layers", 1, 3)
    heads  = trial.suggest_int("heads", 1, 8)
    lr     = trial.suggest_float("lr", 1e-3, 1e-1, log=True)
    wd     = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
    epochs = trial.suggest_int("epochs", 50, 200)

    dataset = load_dataset(dataset_name)
    data = dataset[0]

    if data.y.dim() > 1:
        data.y = data.y.argmax(dim=1)
    if data.x is None:
        data.x = torch.ones((data.num_nodes, 1), dtype=torch.float)

    data = create_masks(data)
    data = data.to(device)

    in_ch = dataset.num_node_features or data.x.size(1)
    out_ch = int(torch.unique(data.y).numel())

    model = GATNodeClassifier(in_ch, hidden, layers, heads, out_ch).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    criterion = nn.CrossEntropyLoss()

    best_val, best_test = 0., 0.
    for _ in range(epochs):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            pred = model(data.x, data.edge_index).argmax(dim=1)
            val_acc  = (pred[data.val_mask]  == data.y[data.val_mask]).float().mean().item()
            test_acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean().item()

        if val_acc > best_val:
            best_val, best_test = val_acc, test_acc

        trial.report(best_val, _)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return best_val

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

datasets = ["Texas", "Wisconsin", "Cornell",
                "Chameleon", "Squirrel", "Film",
                "Cora", "Citeseer", "Pubmed"]

results = {}
for name in datasets:
    print(f"\n Dataset: {name} (GAT)")
    study = optuna.create_study(direction="maximize")
    study.optimize(lambda t: objective_gat(t, name, device), n_trials=20)

    best_val = study.best_value
    best_params = study.best_params
    print(f"Best val acc: {best_val:.4f}")
    print("Best params:", best_params)
    results[name] = (best_val, best_params)

print("\nFinal GAT Results")
for name, (val, params) in results.items():
    print(f"{name:12s} | Val Acc: {val:.4f} | Params: {params}")

Implementation of GCN

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import optuna
import random
import numpy as np

from torch_geometric.datasets import Planetoid, WebKB, WikipediaNetwork, Actor
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.nn import GCNConv


class GCNNodeClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers, num_classes):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.lin = nn.Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index):
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        return self.lin(x)


def create_masks(data, train_frac=0.6, val_frac=0.2):
    n = data.num_nodes
    perm = torch.randperm(n)
    n_train = int(train_frac * n)
    n_val   = int(val_frac * n)
    idx_train = perm[:n_train]
    idx_val   = perm[n_train:n_train+n_val]
    idx_test  = perm[n_train+n_val:]
    mask = lambda idx: torch.zeros(n, dtype=torch.bool, device=data.x.device).scatter_(0, idx, True)
    data.train_mask = mask(idx_train)
    data.val_mask   = mask(idx_val)
    data.test_mask  = mask(idx_test)
    return data


def load_dataset(name):
    if name in ['Cora','Citeseer','Pubmed']:
        return Planetoid(f'/tmp/{name}', name, transform=NormalizeFeatures())
    if name in ['Texas','Wisconsin','Cornell']:
        return WebKB(f'/tmp/{name}', name, transform=NormalizeFeatures())
    if name in ['Chameleon','Squirrel']:
        return WikipediaNetwork(f'/tmp/{name}', name, transform=NormalizeFeatures())
    if name == 'Film':
        return Actor(f'/tmp/Film', transform=NormalizeFeatures())
    raise ValueError(name)

-
def objective_gcn(trial, ds_name, device):
    hidden = trial.suggest_int("hidden_channels", 32, 256, step=32)
    layers = trial.suggest_int("num_layers", 1, 4)
    lr     = trial.suggest_float("lr", 1e-3, 1e-1, log=True)
    wd     = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
    epochs = trial.suggest_int("epochs", 50, 200)

    ds = load_dataset(ds_name)
    data = ds[0]

    if data.y.dim() > 1:
        data.y = data.y.argmax(dim=1)
    if data.x is None:
        data.x = torch.ones((data.num_nodes, 1), dtype=torch.float)

    data = create_masks(data)
    data = data.to(device)

    in_ch = ds.num_node_features or data.x.size(1)
    num_cls = int(torch.unique(data.y).numel())

    model = GCNNodeClassifier(in_ch, hidden, layers, num_cls).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    crit = nn.CrossEntropyLoss()

    best_val, best_test = 0., 0.
    for epoch in range(epochs):
        model.train()
        opt.zero_grad()
        out = model(data.x, data.edge_index)
        loss = crit(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        opt.step()

        model.eval()
        with torch.no_grad():
            pred = model(data.x, data.edge_index).argmax(dim=1)
            val_acc  = (pred[data.val_mask]  == data.y[data.val_mask]).float().mean().item()
            test_acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean().item()

        if val_acc > best_val:
            best_val, best_test = val_acc, test_acc

        trial.report(best_val, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return best_val


torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

datasets = ["Texas", "Wisconsin", "Cornell",
                "Chameleon", "Squirrel", "Film",
                "Cora", "Citeseer", "Pubmed"]

results = {}
for name in datasets:
    print(f"\nDataset: {name} (GCN)")
    study = optuna.create_study(direction="maximize")
    study.optimize(lambda t: objective_gcn(t, name, device), n_trials=20)

    best_val = study.best_value
    best_params = study.best_params
    print(f"Best val acc: {best_val:.4f}")
    print("Best params:", best_params)
    results[name] = (best_val, best_params)

print("\nFinal GCN Results")
for name, (val, params) in results.items():
    print(f"{name:12s} | Val Acc: {val:.4f} | Params: {params}")