In [1]:
import os
import torch
import numpy as np
import networkx as nx
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from utils import *
from model import *

In [2]:
# ---------------------
# parameters
# ---------------------
lr = 1e-3
epochs = 100
batch_size = 1
pos_weights=10
path = 'D:graph-conflation-data/'

# ---------------------
# load data
# ---------------------
print('Load Datasets...')
files = os.listdir(path+'/graphs/osm/')
train_files, test_files = train_test_split(files, test_size=0.2, random_state=42)
train_files, val_files = train_test_split(train_files, test_size=0.2, random_state=42)

# make datasets
train_data = GraphDataset(path, train_files)
val_data = GraphDataset(path, val_files)
test_data = GraphDataset(path, test_files)

# data loader
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# ---------------------
#  models
# ---------------------
print('Load Model...')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_osm = GCN().to(device)
model_sdot = GCN().to(device)
model_osm.load_state_dict(torch.load(f'model_states/model_osm_5'))
model_sdot.load_state_dict(torch.load(f'model_states/model_sdot_5'))
model_osm.eval()
model_sdot.eval()
criterion= nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weights))

Load Datasets...
Load Model...


### Graph Level Accuracy

In [3]:
graph_acc = []
check = []
for i, batch in enumerate(tqdm(test_dataloader)):
    
    # load data
    graph_osw, graph_osm, graph_sdot, osw_x, osm_x, sdot_x = batch
    graph_osw = graph_osw.squeeze_(0).to(device)
    labels = graph_osw.flatten()
    graph_osm = graph_osm.squeeze_(0).to(device)
    graph_sdot = graph_sdot.squeeze_(0).to(device)    
    osw_x = osw_x.squeeze_(0).to(device)
    osm_x = osm_x.squeeze_(0).to(device)
    sdot_x = sdot_x.squeeze_(0).to(device)
    
    # make prediction
    with torch.no_grad():
        out_osm = model_osm(graph_osm, osm_x)
        out_sdot = model_sdot(graph_sdot, sdot_x)
    A = F.sigmoid(torch.matmul(out_osm, out_sdot.T)).flatten()
    check.append(torch.mean(1.0*(A>0.5)).item())
    graph_acc.append(torch.mean(1.0*( 1.0*(A.detach()>0.5) == labels)).item())

100%|█████████████████████████████████████████████████████████████████████████████| 1773/1773 [00:14<00:00, 118.41it/s]


In [4]:
np.mean(graph_acc)

0.8951421059663862

### Conflicted Edges Accuracy

In [5]:
edge_acc = []
for i, batch in enumerate(tqdm(test_dataloader)):
    
    # load data
    graph_osw, graph_osm, graph_sdot, osw_x, osm_x, sdot_x = batch
    graph_osw = graph_osw.squeeze_(0)
    labels = graph_osw.flatten()
    graph_osm = graph_osm.squeeze_(0).to(device)
    graph_sdot = graph_sdot.squeeze_(0).to(device)    
    osw_x = osw_x.squeeze_(0).to(device)
    osm_x = osm_x.squeeze_(0).to(device)
    sdot_x = sdot_x.squeeze_(0).to(device)
    
    # extract conflicted edges
    conflicted_edges = torch.cat(
        [graph_osm.T[~(graph_osm.T[:, None] == graph_sdot.T).all(-1).any(-1)],
         graph_sdot.T[~(graph_sdot.T[:, None] == graph_osm.T).all(-1).any(-1)]],
        axis=0
    ).detach().cpu().numpy()
    conflicted_edges = np.repeat(conflicted_edges, 2,axis=0)
    conflicted_edges[::2,[0,1]] = conflicted_edges[::2,[1,0]]
    conflicted_edges = list(zip(*conflicted_edges.T))
    
    # make prediction
    with torch.no_grad():
        out_osm = model_osm(graph_osm, osm_x)
        out_sdot = model_sdot(graph_sdot, sdot_x)
    A = F.sigmoid(torch.matmul(out_osm, out_sdot.T))
    A_pred = 1.0*(A.detach().cpu()>0.5)
    edge_acc.append(np.mean([1.0*(A_pred[c]==graph_osw[c]).item() for c in conflicted_edges]))

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
100%|█████████████████████████████████████████████████████████████████████████████| 1773/1773 [00:13<00:00, 133.73it/s]


In [6]:
## 5
np.mean(np.array(edge_acc)[~np.isnan(edge_acc)])

0.8341113078412254