In [1]:
import os
import torch
import numpy as np
import networkx as nx
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from utils_simulated import *
from model import *

In [2]:
# ---------------------
# parameters
# ---------------------
lr = 1e-3
epochs = 100
batch_size = 1
pos_weights = 5
path = 'D:graph-conflation-data/'
model = 'gcn'
simulation = 'mixed'

# ---------------------
# load data
# ---------------------
print('Load Datasets...')
files = os.listdir(path+f'simulated-graphs/{simulation}/gt/')
train_files, test_files = train_test_split(files, test_size=0.2, random_state=42)
train_files, val_files = train_test_split(train_files, test_size=0.2, random_state=42)

# make datasets
train_data = GraphDataset(path, simulation, train_files)
val_data = GraphDataset(path, simulation, val_files)
test_data = GraphDataset(path, simulation, test_files)

# data loader
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)  

# ---------------------
#  models
# ---------------------
print('Load Model...')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model == 'gcn':
    model_set1 = GCN().to(device)
    model_set2 = GCN().to(device)
elif model == 'gat':
    model_set1 = GAT().to(device)
    model_set2 = GAT().to(device)
elif model == 'graphsage':
    model_set1 = GraphSAGE().to(device)
    model_set2 = GraphSAGE().to(device)
else:
    model_set1 = GraphUNet(2,64,128,3).to(device)
    model_set2 = GraphUNet(2,64,128,3).to(device)
    
if device == 'cpu':
    model_set1.load_state_dict(torch.load(f'model_states/{model}_set1_{pos_weights}_{simulation}', map_location=device))
    model_set2.load_state_dict(torch.load(f'model_states/{model}_set2_{pos_weights}_{simulation}', map_location=device))
else:
    model_set1.load_state_dict(torch.load(f'model_states/{model}_set1_{pos_weights}_{simulation}'))
    model_set2.load_state_dict(torch.load(f'model_states/{model}_set2_{pos_weights}_{simulation}'))
model_set1.eval()
model_set2.eval()
criterion= nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weights))

Load Datasets...
Load Model...


### Graph Level Accuracy

In [3]:
graph_acc = []
check = []
torch.use_deterministic_algorithms(False)
for i, batch in enumerate(tqdm(test_dataloader)):
    
    if i == 1591: continue
    
    # load data
    graph_gt, graph_set1, graph_set2, gt_x, set1_x, set2_x = batch
    graph_gt = graph_gt.squeeze_(0).to(device)
    labels = graph_gt.flatten()
    graph_set1 = graph_set1.squeeze_(0).long().to(device)
    graph_set2 = graph_set2.squeeze_(0).long().to(device)
    gt_x = gt_x.squeeze_(0).to(device)
    set1_x = set1_x.squeeze_(0).to(device)
    set2_x = set2_x.squeeze_(0).to(device)
    
    # make prediction
    with torch.no_grad():
        out_set1 = model_set1(set1_x, graph_set1)
        out_set2 = model_set2(set2_x, graph_set2)
    A = torch.sigmoid(torch.matmul(out_set1, out_set2.T)).flatten()
    check.append(torch.mean(1.0*(A>0.5)).item())
    graph_acc.append(torch.mean(1.0*( 1.0*(A.detach()>0.5) == labels)).item())

100%|██████████████████████████████████████████████████████████████████████████████| 1621/1621 [00:17<00:00, 92.64it/s]


In [4]:
np.mean(graph_acc), np.mean(check)

(0.7959427811849265, 0.10473229003334303)

### Conflicted Edges Accuracy

In [5]:
edge_acc = []
for i, batch in enumerate(tqdm(test_dataloader)):
    
    if i == 1591: continue
        
    # load data
    graph_gt, graph_set1, graph_set2, gt_x, set1_x, set2_x = batch
    graph_gt = graph_gt.squeeze_(0).to(device)
    labels = graph_gt.flatten()
    graph_set1 = graph_set1.squeeze_(0).long().to(device)
    graph_set2 = graph_set2.squeeze_(0).long().to(device)
    gt_x = gt_x.squeeze_(0).to(device)
    set1_x = set1_x.squeeze_(0).to(device)
    set2_x = set2_x.squeeze_(0).to(device)
    
    # extract conflicted edges
    conflicted_edges = torch.concatenate(
        [graph_set1.T[~(graph_set1.T[:, None] == graph_set2.T).all(-1).any(-1)],
         graph_set2.T[~(graph_set2.T[:, None] == graph_set1.T).all(-1).any(-1)]],
        axis=0
    ).cpu().detach().numpy()
    conflicted_edges = np.repeat(conflicted_edges, 2,axis=0)
    conflicted_edges[::2,[0,1]] = conflicted_edges[::2,[1,0]]
    conflicted_edges = list(zip(*conflicted_edges.T))     
    
    # make prediction
    with torch.no_grad():
        out_set1 = model_set1(set1_x, graph_set1)
        out_set2 = model_set2(set2_x, graph_set2)
    A = torch.sigmoid(torch.matmul(out_set1, out_set2.T))
    A_pred = 1.0*(A.detach().cpu()>0.5)
    
    if len(conflicted_edges) == 0:
        edge_acc.append(torch.mean(1.0*(A_pred==graph_gt)).item())
    else:
        edge_acc.append(np.mean([1.0*(A_pred[c]==graph_gt[c]).item() for c in conflicted_edges]))

100%|█████████████████████████████████████████████████████████████████████████████| 1621/1621 [00:11<00:00, 140.66it/s]


In [6]:
np.mean(edge_acc)

0.4492415510427837