In [24]:
import time
from copy import deepcopy

import torch
import torch.optim as optim
import torch.nn.functional as F

from dhg import Hypergraph
from dhg.data import Cooking200, CocitationCora, CoauthorshipCora, DBLP4k, Tencent2k
from dhg.models import HGNN, HGNNP, HyperGCN, UniSAGE
from dhg.random import set_seed
from dhg.metrics import HypergraphVertexClassificationEvaluator as Evaluator

In [2]:
def train(net, X, A, lbls, train_idx, optimizer, epoch):
    net.train()

    st = time.time()
    optimizer.zero_grad()
    outs = net(X, A)
    outs, lbls = outs[train_idx], lbls[train_idx]
    loss = F.cross_entropy(outs, lbls)
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch}, Time: {time.time()-st:.5f}s, Loss: {loss.item():.5f}")
    return loss.item()


@torch.no_grad()
def infer(net, X, A, lbls, idx, test=False):
    net.eval()
    outs = net(X, A)
    outs, lbls = outs[idx], lbls[idx]
    if not test:
        res = evaluator.validate(lbls, outs)
    else:
        res = evaluator.test(lbls, outs)
    return res

In [29]:
set_seed(0)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
evaluator = Evaluator(["accuracy", "f1_score", {"f1_score": {"average": "micro"}}])
data = Tencent2k()

X, lbl = torch.eye(data["num_vertices"]), data["labels"]
G = Hypergraph(data["num_vertices"], data["edge_list"])
train_mask = data["train_mask"]
val_mask = data["val_mask"]
test_mask = data["test_mask"]

net = UniSAGE(X.shape[1], 32, data["num_classes"], use_bn=True)
optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)

X, lbl = X.to(device), lbl.to(device)
G = G.to(device)
net = net.to(device)

best_state = None
best_epoch, best_val = 0, 0
for epoch in range(50):
    # train
    train(net, X, G, lbl, train_mask, optimizer, epoch)
    # validation
    if epoch % 1 == 0:
        with torch.no_grad():
            val_res = infer(net, X, G, lbl, val_mask)
        if val_res > best_val:
            print(f"update best: {val_res:.5f}")
            best_epoch = epoch
            best_val = val_res
            best_state = deepcopy(net.state_dict())
print("\ntrain finished!")
print(f"best val: {best_val:.5f}")
# test
print("test...")
net.load_state_dict(best_state)
res = infer(net, X, G, lbl, test_mask, test=True)
print(f"final result: epoch: {best_epoch}")
print(res)

Epoch: 0, Time: 0.03744s, Loss: 1.47682
update best: 0.62500
Epoch: 1, Time: 0.01640s, Loss: 2.64121
Epoch: 2, Time: 0.01576s, Loss: 0.46155
Epoch: 3, Time: 0.01656s, Loss: 0.48771
Epoch: 4, Time: 0.01648s, Loss: 0.34893
Epoch: 5, Time: 0.01587s, Loss: 0.18031
Epoch: 6, Time: 0.01717s, Loss: 0.18295
Epoch: 7, Time: 0.01643s, Loss: 0.14636
Epoch: 8, Time: 0.01571s, Loss: 0.13556
Epoch: 9, Time: 0.01619s, Loss: 0.09177
Epoch: 10, Time: 0.01752s, Loss: 0.11499
Epoch: 11, Time: 0.01610s, Loss: 0.12797
Epoch: 12, Time: 0.01608s, Loss: 0.11654
Epoch: 13, Time: 0.01591s, Loss: 0.05929
Epoch: 14, Time: 0.01639s, Loss: 0.08277
Epoch: 15, Time: 0.01693s, Loss: 0.05057
Epoch: 16, Time: 0.01585s, Loss: 0.06059
Epoch: 17, Time: 0.01719s, Loss: 0.06320
update best: 0.67500
Epoch: 18, Time: 0.01746s, Loss: 0.04078
update best: 0.75000
Epoch: 19, Time: 0.01689s, Loss: 0.06295
update best: 0.77500
Epoch: 20, Time: 0.01656s, Loss: 0.01748
update best: 0.85000
Epoch: 21, Time: 0.01724s, Loss: 0.01381
Epo