In [None]:
import sys
sys.path.append("../") 

In [2]:
import torch
torch.manual_seed(42)
import torch.optim as optim
from torch import nn

import dhg
from dhg import Hypergraph

import hgp
from hgp.models import HGNNP
from hgp.utils import from_pickle_to_hypergraph,from_file_to_hypergraph
from hgp.function import StraightThroughEstimator

DEVICE = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")
DEVICE

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda', index=1)

In [3]:
from hgp.models import ParameterDict

# fmt: off
h_hyper_prmts = ParameterDict()
l_hyper_prmts = ParameterDict()

partitions = 3

h_hyper_prmts["convlayers1"] = {"in_channels": 3824, "out_channels": 2048, "use_bn": False, "drop_rate": 0.3}
h_hyper_prmts["convlayers12"] = {"in_channels": 2048, "out_channels": 2048, "use_bn": False, "drop_rate": 0.3}
h_hyper_prmts["convlayers13"] = {"in_channels": 2048, "out_channels": 1024, "use_bn": False, "drop_rate": 0.25}
#h_hyper_prmts["convlayers14"] = {"in_channels": 1536, "out_channels": 1536, "use_bn": False, "drop_rate": 0.3}
#h_hyper_prmts["convlayers15"] = {"in_channels": 1536, "out_channels": 1536, "use_bn": False, "drop_rate": 0.3}
#h_hyper_prmts["convlayers16"] = {"in_channels": 2048, "out_channels": 1536, "use_bn": False, "drop_rate": 0.2}
h_hyper_prmts["convlayers3"] = {"in_channels": 1024, "out_channels": 1024, "use_bn": False, "drop_rate": 0.2}
h_hyper_prmts["convlayers4"] = {"in_channels": 1024, "out_channels": 1024, "use_bn": False, "drop_rate": 0.2}
h_hyper_prmts["convlayers5"] = {"in_channels": 1024, "out_channels": 512, "use_bn": False, "drop_rate": 0.1}
#h_hyper_prmts["convlayers52"] = {"in_channels": 512, "out_channels": 512, "use_bn": False, "drop_rate": 0.1}


l_hyper_prmts["linerlayer1"] = {"in_channels":list(h_hyper_prmts.values())[-1]["out_channels"], "out_channels":512, "use_bn":True, "drop_rate":0.05}
l_hyper_prmts["linerlayer2"] = {"in_channels":512, "out_channels":256, "use_bn":True, "drop_rate":0.05}
l_hyper_prmts["linerlayer3"] = {"in_channels":256, "out_channels":128, "use_bn":True, "drop_rate":0.05}
l_hyper_prmts["linerlayer32"] = {"in_channels":128, "out_channels":64, "use_bn":True, "drop_rate":0.05}
l_hyper_prmts["linerlayer33"] = {"in_channels":64, "out_channels":32, "use_bn":False, "drop_rate":0.05}
l_hyper_prmts["linerlayer34"] = {"in_channels":32, "out_channels":16, "use_bn":False, "drop_rate":0.05}
l_hyper_prmts["linerlayer4"] = {"in_channels":16, "out_channels":3, "use_bn":False, "drop_rate":0.05}


hyper = {
    "h_hyper_prmts": h_hyper_prmts,
    "l_hyper_prmts":l_hyper_prmts,
    "init_features_dim":list(h_hyper_prmts.values())[0]["in_channels"],
    "partitions":partitions
}

# fmt: on

In [4]:
def loss_bs_matrix(outs, H, H_degree,device):
    outs = outs.to(device)
    nn = torch.matmul(outs, (1 - torch.transpose(outs, 0, 1)))
    ne_k = torch.matmul(nn, H)
    ne_k = ne_k.mul(H)

    H_degree = torch.sum(H, dim=0)

    H_1 = ne_k / H_degree
    a2 = 1 - H_1
    a3 = torch.prod(a2, dim=0)
    a3 = a3.sum()
    loss_1 = -1 * a3

    # pun = torch.mul(ne_k, H)

    # loss_1 = pun.sum()
    loss_2 = torch.var(torch.sum(outs, dim=0)).to(device)

    loss = 100 * loss_1 + loss_2
    return loss, loss_1, loss_2



In [None]:
class Trainer(nn.Module):

    def __init__(self, net, X, hg, optimizer):
        super().__init__()
        self.X: torch.Tensor = X.to(DEVICE)
        self.hg = hg.to(DEVICE)
        self.de = self.hg.H.to_dense().sum(dim=0).to("cpu").to(DEVICE)
        self.optimizer: torch.optim.Optimizer = optimizer
        self.layers = nn.ModuleList()
        self.layers.append(net.to(DEVICE))
        self.H = self.hg.H.to_dense().to(DEVICE)
        self.H_degree = torch.sum(self.H, dim=0)

    def forward(self, X):
        X = self.layers[0](X, self.hg)
        for layer in self.layers[1:]:
            X = layer(X)
        return X

    def run(self, epoch):
        self.train() 
        self.optimizer.zero_grad()
        outs = self.forward(self.X)
        loss, loss_1, loss_2 = loss_bs_matrix(
            outs, self.H, self.H_degree, device=DEVICE
        )
        loss.backward()
        self.optimizer.step()

        return loss.item(), loss_1.item(), loss_2.item()

In [6]:
G = dhg.data.CocitationPubmed()
G = dhg.Hypergraph(G["num_vertices"],G["edge_list"])

In [7]:
X = torch.randn(size=(G.num_v, hyper["init_features_dim"]))
#X = torch.eye(n=G.num_v)
net = HGNNP(hyper["h_hyper_prmts"]).to(DEVICE)
hgnn_trainer = Trainer(net=net, X=X, hg=G, optimizer=None).to(DEVICE)

for (k,v) in hyper["l_hyper_prmts"].items():
    hgnn_trainer.layers.append(nn.BatchNorm1d(num_features=v["in_channels"]).to(DEVICE)) if v["use_bn"] else None
    hgnn_trainer.layers.append(nn.ReLU().to(DEVICE))
    hgnn_trainer.layers.append(nn.Dropout(v["drop_rate"]))
    hgnn_trainer.layers.append(nn.Linear(in_features=v["in_channels"],out_features=v["out_channels"],device=DEVICE))
hgnn_trainer.layers.append(nn.Softmax(dim=1))

optim = optim.Adam(hgnn_trainer.parameters(), lr=8e-4, weight_decay=5e-8)
hgnn_trainer.optimizer = optim

In [9]:
temp_loss_total,temp_loss1,temp_loss2 = torch.zeros(1, requires_grad=False),torch.zeros(1, requires_grad=False),torch.zeros(1, requires_grad=False)
for epoch in range(15000):
    loss,loss_1,loss_2 = hgnn_trainer.run(epoch=epoch)
    # train
    temp_loss_total += loss
    temp_loss1 += loss_1
    temp_loss2 += loss_2
    # validation
    if epoch % 10 == 0:
        print(f"in {epoch} epoch, average loss: {temp_loss_total.item() / 10}")
        print(f"                , loss1: {temp_loss1.item() / 10}")
        print(f"                , loss2: {temp_loss2.item() / 10}")
        print(f"=================================")
        sys.stdout.flush()
        temp_loss_total,temp_loss1,temp_loss2 = torch.zeros(1, requires_grad=False),torch.zeros(1, requires_grad=False),torch.zeros(1, requires_grad=False)

in 0 epoch, average loss: -1183.530078125
                , loss1: -42.131036376953126
                , loss2: 3029.5734375
in 10 epoch, average loss: -40400.478125
                , loss1: -424.30126953125
                , loss2: 2029.6490234375
in 20 epoch, average loss: -42645.965625
                , loss1: -427.01123046875
                , loss2: 55.1591796875
in 30 epoch, average loss: -43523.7125
                , loss1: -435.898046875
                , loss2: 66.0960693359375
in 40 epoch, average loss: -45418.190625
                , loss1: -456.101220703125
                , loss2: 191.93011474609375
in 50 epoch, average loss: -53382.775
                , loss1: -537.44453125
                , loss2: 361.668115234375
in 60 epoch, average loss: -97494.6625
                , loss1: -985.40732421875
                , loss2: 1046.076953125
in 70 epoch, average loss: -176489.5
                , loss1: -1815.8767578125
                , loss2: 5098.183203125
in 80 epoch, average 

KeyboardInterrupt: 

In [10]:
hgnn_trainer.eval()
outs = hgnn_trainer.forward(hgnn_trainer.X)
outs_straight = StraightThroughEstimator.apply(outs)
G_clone = G.clone()
edges, _  = G_clone.e
cut = 0
for vertices in edges:
    if torch.prod(outs_straight[list(vertices)], dim=0).sum() == 0:
        cut += 1
    else:
        G_clone.remove_hyperedges(vertices)
assert cut == G_clone.num_e
cut

0

In [11]:
bs = torch.sum(outs_straight, dim = 0)
bs = torch.sum(outs, dim = 0)
bs

tensor([6601.8999, 6596.1392, 6518.9629], device='cuda:1',
       grad_fn=<SumBackward1>)