In [1]:
import torch
import statistics
from gnnco.models import GAT
import gnnco

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = 'cuda'
EPOCHS = 100

In [3]:
(train_dataset, val_dataset, train_loader, val_loader) = gnnco.graph_matching.setup_data(dataset_path="data/ER[100,8,0.02]/", batch_size=100)

In [4]:
# laplacian_layer = gnnco.models.LaplacianEmbeddings(k=32)
model = GAT(5, 32, 1024, 1024) #torch.nn.Sequential(torch.nn.Linear(32,4000), torch.nn.ReLU(), torch.nn.Linear(4000,4000), torch.nn.ReLU(), torch.nn.Linear(4000,256))
model = model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss() 

In [5]:
for epoch in range(EPOCHS):
    batch: gnnco.graph_matching.GMDatasetBatch
    model.train()
    losses = []
    accuracy = []
    edge_accuracy = []
    nonedge_accuracy = []
    for i, batch in enumerate(train_loader):
        batch = batch.to(DEVICE)
        model.zero_grad()

        #laplacian_embeddings = laplacian_layer.forward(batch.base_signals, batch.base_graphs).x()
        final_embeddings = model.forward(batch.base_signals, batch.base_graphs).x().reshape((len(batch), 100, 1024))
        prediction = torch.bmm(final_embeddings, torch.transpose(final_embeddings,1,2))
        prediction = torch.sigmoid(prediction)
        target = batch.base_graphs.to_dense().get_stacked_adj().float()
        loss = (1/13)*(-12*target.flatten()*torch.log(prediction.flatten() + 1e-7) - (1-target.flatten())*torch.log(1-prediction.flatten() + 1e-7)).mean()
        loss.backward()
        optimizer.step()

        losses.append(float(loss))
        prediction = prediction > 0.5
        target = target > 0.5
        accuracy.append(float(torch.count_nonzero(prediction == target)/torch.numel(target)))
        edge_accuracy.append(float(torch.count_nonzero(prediction[target])/ torch.count_nonzero(target)))
        prediction = torch.logical_not(prediction)
        target = torch.logical_not(target)
        nonedge_accuracy.append(float(torch.count_nonzero(prediction[target])/ torch.count_nonzero(target)))

    print(f"train {epoch}: loss: {statistics.mean(losses)}   acc:{statistics.mean(accuracy)} ({statistics.mean(edge_accuracy)}/{statistics.mean(nonedge_accuracy)})")


    
    model.eval()
    losses = []
    accuracy = []
    edge_accuracy = []
    nonedge_accuracy = []
    for i, batch in enumerate(val_loader):
        batch = batch.to(DEVICE)
        # laplacian_embeddings = laplacian_layer.forward(batch.base_signals, batch.base_graphs).x()
        final_embeddings = model.forward(batch.base_signals, batch.base_graphs).x().reshape((len(batch), 100, 1024))
        prediction = torch.bmm(final_embeddings, torch.transpose(final_embeddings,1,2))
        prediction = torch.sigmoid(prediction)
        target = batch.base_graphs.to_dense().get_stacked_adj().float()
        loss = (1/13)*(-12*target.flatten()*torch.log(prediction.flatten() + 1e-7) - (1-target.flatten())*torch.log(1-prediction.flatten() + 1e-7)).mean()
        losses.append(float(loss))
        prediction = prediction > 0.5
        target = target > 0.5
        accuracy.append(float(torch.count_nonzero(prediction == target)/torch.numel(target)))
        edge_accuracy.append(float(torch.count_nonzero(prediction[target])/ torch.count_nonzero(target)))
        prediction = torch.logical_not(prediction)
        target = torch.logical_not(target)
        nonedge_accuracy.append(float(torch.count_nonzero(prediction[target])/ torch.count_nonzero(target)))
    print(f"val {epoch}: loss: {statistics.mean(losses)}   acc:{statistics.mean(accuracy)} ({statistics.mean(edge_accuracy)}/{statistics.mean(nonedge_accuracy)})")

train 0: loss: 0.5403260678052902   acc:0.4665077455341816 (0.8125974103808403/0.436376778781414)
val 0: loss: 0.14454811215400695   acc:0.5156167984008789 (0.8426230192184448/0.4872541785240173)
train 1: loss: 0.1082688931375742   acc:0.518251147866249 (0.8903661578893661/0.4858589418232441)
val 1: loss: 0.09405350387096405   acc:0.5226575851440429 (0.9216888189315796/0.48804795145988467)
train 2: loss: 0.09120682142674923   acc:0.5211089998483658 (0.9365929141640663/0.48494444563984873)
val 2: loss: 0.08897249400615692   acc:0.5248504042625427 (0.944975733757019/0.4884117305278778)
train 3: loss: 0.08806038163602352   acc:0.5240160450339317 (0.9494688391685486/0.4869835816323757)
val 3: loss: 0.08703812956809998   acc:0.5256012082099915 (0.9538374066352844/0.4884593904018402)
train 4: loss: 0.08648472987115383   acc:0.5270135968923568 (0.9561781838536263/0.4896586000919342)
val 4: loss: 0.0857650324702263   acc:0.5303272008895874 (0.9588751792907715/0.49315961003303527)
train 5: loss