In [1]:
import os
import torch
import numpy as np
import networkx as nx
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from utils_simulated import *
from model import *

# ---------------------------
# seeding for reproducibility
# ---------------------------
seed = 100
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=True
torch.use_deterministic_algorithms(True)
# set CUBLAS_WORKSPACE_CONFIG=:16:8

In [12]:
# ---------------------
# parameters
# ---------------------
lr = 2e-3
epochs = 100
batch_size = 1
pos_weights = 1
path = 'D:/graph-conflation-data/'
simulation = 'mixed'

# ---------------------
# load data
# ---------------------
print('Load Datasets...')
files = os.listdir(path+f'/simulated-graphs/{simulation}/gt/')
train_files, test_files = train_test_split(files, test_size=0.2, random_state=42)
train_files, val_files = train_test_split(train_files, test_size=0.2, random_state=42)

# make datasets
train_data = GraphDataset(path, simulation, train_files)
val_data = GraphDataset(path, simulation, val_files)
test_data = GraphDataset(path, simulation, test_files)

# data loader
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False) 

Load Datasets...


In [13]:
graph_acc_union=[]
graph_acc_intersection=[]
graph_mean = []
for file in tqdm(test_files):
    
    if file == 'graph_5103': continue
    
    # paths
    gt_path = os.path.join(path, f'simulated-graphs/{simulation}/gt', file)
    set1_path = os.path.join(path, f'simulated-graphs/{simulation}/set1', file)
    set2_path = os.path.join(path, f'simulated-graphs/{simulation}/set2', file)
    
    # graphs & edges
    graph_gt = nx.to_numpy_array(nx.read_graph6(gt_path))
    graph_set1 = nx.to_numpy_array(nx.read_graph6(set1_path))
    graph_set2 = nx.to_numpy_array(nx.read_graph6(set2_path))
    
    union = 1*np.logical_or(graph_set1, graph_set2)
    intersection = 1*np.logical_and(graph_set1, graph_set2)
    graph_acc_union.append(np.mean(1.0*( union == graph_gt)))  
    graph_acc_intersection.append(np.mean(1.0*( intersection == graph_gt)))  

100%|████████████████████████████████████████████████████████████████████████████| 1621/1621 [00:01<00:00, 1236.49it/s]


In [14]:
np.mean(graph_acc_union), np.mean(graph_acc_intersection)

(0.9554835857842671, 0.9555676643707968)

In [15]:
edge_acc_union = []
edge_acc_intersection = []
for file in tqdm(test_files):
    
    if file == 'graph_5103': continue
        
    # paths
    gt_path = os.path.join(path, f'simulated-graphs/{simulation}/gt', file)
    set1_path = os.path.join(path, f'simulated-graphs/{simulation}/set1', file)
    set2_path = os.path.join(path, f'simulated-graphs/{simulation}/set2', file)
    
    # graphs & edges
    graph_gt = nx.to_numpy_array(nx.read_graph6(gt_path))
    graph_set1 = nx.to_numpy_array(nx.read_graph6(set1_path))
    graph_set2 = nx.to_numpy_array(nx.read_graph6(set2_path))
    graph_set1_edges = np.array(nx.read_graph6(set1_path).edges()).T
    graph_set2_edges = np.array(nx.read_graph6(set2_path).edges()).T
    
    # extract conflicted edges
    conflicted_edges = np.concatenate(
        [graph_set1_edges.T[~(graph_set1_edges.T[:, None] == graph_set2_edges.T).all(-1).any(-1)],
         graph_set2_edges.T[~(graph_set2_edges.T[:, None] == graph_set1_edges.T).all(-1).any(-1)]],
        axis=0
    )
    conflicted_edges = np.repeat(conflicted_edges, 2,axis=0)
    conflicted_edges[::2,[0,1]] = conflicted_edges[::2,[1,0]]
    conflicted_edges = list(zip(*conflicted_edges.T))  
    
    # union & intersection
    union = 1*np.logical_or(graph_set1, graph_set2)
    intersection = 1*np.logical_and(graph_set1,graph_set2)
    
    if len(conflicted_edges) == 0:
        edge_acc_intersection.append(np.mean([1.0*(intersection==graph_gt)]))
        edge_acc_union.append(np.mean([1.0*(union==graph_gt)]))    
    else:
        edge_acc_intersection.append(np.mean([1.0*(intersection[c]==graph_gt[c]) for c in conflicted_edges]))
        edge_acc_union.append(np.mean([1.0*(union[c]==graph_gt[c]) for c in conflicted_edges]))

100%|█████████████████████████████████████████████████████████████████████████████| 1621/1621 [00:02<00:00, 710.96it/s]


In [17]:
np.mean(edge_acc_intersection), np.mean(edge_acc_union)

(0.5005053504124852, 0.4994946495875148)