In [None]:
!pip install "numpy<2"

In [1]:
!pip install pandas


Collecting pandas
  Downloading pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.1 MB)
[K     |████████████████████████████████| 13.1 MB 12.1 MB/s eta 0:00:01
Collecting pytz>=2020.1
  Downloading pytz-2024.2-py2.py3-none-any.whl (508 kB)
[K     |████████████████████████████████| 508 kB 11.2 MB/s eta 0:00:01
Collecting tzdata>=2022.7
  Downloading tzdata-2024.2-py2.py3-none-any.whl (346 kB)
[K     |████████████████████████████████| 346 kB 10.8 MB/s eta 0:00:01
Installing collected packages: tzdata, pytz, pandas
Successfully installed pandas-2.2.3 pytz-2024.2 tzdata-2024.2
You should consider upgrading via the '/home/abazouzi/Documents/Code/DropHyper/drophyper/bin/python3 -m pip install --upgrade pip' command.[0m


In [None]:
!pip freeze

In [8]:
import time
from copy import deepcopy

import torch
import torch.optim as optim
import torch.nn.functional as F

from dhg import Hypergraph
from dhg.data import Cooking200
from dhg.models import HGNN
from dhg.random import set_seed
from dhg.metrics import HypergraphVertexClassificationEvaluator as Evaluator

In [9]:
def train(net, X, A, lbls, train_idx, optimizer, epoch):
    net.train()

    st = time.time()
    optimizer.zero_grad()
    outs = net(X, A)
    outs, lbls = outs[train_idx], lbls[train_idx]
    loss = F.cross_entropy(outs, lbls)
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch}, Time: {time.time()-st:.5f}s, Loss: {loss.item():.5f}")
    return loss.item()


@torch.no_grad()
def infer(net, X, A, lbls, idx, test=False):
    net.eval()
    outs = net(X, A)
    outs, lbls = outs[idx], lbls[idx]
    if not test:
        res = evaluator.validate(lbls, outs)
    else:
        res = evaluator.test(lbls, outs)
    return res

In [10]:
set_seed(2021)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
evaluator = Evaluator(["accuracy", "f1_score", {"f1_score": {"average": "micro"}}])
data = Cooking200()

X, lbl = torch.eye(data["num_vertices"]), data["labels"]
G = Hypergraph(data["num_vertices"], data["edge_list"])
train_mask = data["train_mask"]
val_mask = data["val_mask"]
test_mask = data["test_mask"]

net = HGNN(X.shape[1], 32, data["num_classes"], use_bn=True)
optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)

X, lbl = X.to(device), lbl.to(device)
G = G.to(device)
net = net.to(device)

best_state = None
best_epoch, best_val = 0, 0
for epoch in range(200):
    # train
    train(net, X, G, lbl, train_mask, optimizer, epoch)
    # validation
    if epoch % 1 == 0:
        with torch.no_grad():
            val_res = infer(net, X, G, lbl, val_mask)
        if val_res > best_val:
            print(f"update best: {val_res:.5f}")
            best_epoch = epoch
            best_val = val_res
            best_state = deepcopy(net.state_dict())
print("\ntrain finished!")
print(f"best val: {best_val:.5f}")
# test
print("test...")
net.load_state_dict(best_state)
res = infer(net, X, G, lbl, test_mask, test=True)
print(f"final result: epoch: {best_epoch}")
print(res)

Epoch: 0, Time: 0.40277s, Loss: 2.99680
update best: 0.05000
Epoch: 1, Time: 0.30828s, Loss: 2.71560
Epoch: 2, Time: 0.30597s, Loss: 2.34183
Epoch: 3, Time: 0.30820s, Loss: 2.17803
Epoch: 4, Time: 0.30484s, Loss: 2.04616
Epoch: 5, Time: 0.30546s, Loss: 1.90518
Epoch: 6, Time: 0.29929s, Loss: 1.78512
Epoch: 7, Time: 0.30219s, Loss: 1.66366
Epoch: 8, Time: 0.30903s, Loss: 1.53951
Epoch: 9, Time: 0.31299s, Loss: 1.43321
Epoch: 10, Time: 0.29602s, Loss: 1.34124
Epoch: 11, Time: 0.31282s, Loss: 1.22620
Epoch: 12, Time: 0.30538s, Loss: 1.11851
Epoch: 13, Time: 0.31003s, Loss: 1.01421
Epoch: 14, Time: 0.30352s, Loss: 0.93399
Epoch: 15, Time: 0.31174s, Loss: 0.83967
Epoch: 16, Time: 0.32792s, Loss: 0.76134
Epoch: 17, Time: 0.32425s, Loss: 0.68312
update best: 0.05500
Epoch: 18, Time: 0.33249s, Loss: 0.61719
update best: 0.07000
Epoch: 19, Time: 0.33323s, Loss: 0.56950
update best: 0.08500
Epoch: 20, Time: 0.32288s, Loss: 0.50835
update best: 0.09500
Epoch: 21, Time: 0.31636s, Loss: 0.44322
upd