In [1]:
import os
import torch
import numpy as np
import networkx as nx
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from utils import *
from model import *

In [2]:
# ---------------------
# parameters
# ---------------------
lr = 1e-4
epochs = 200
batch_size = 1
pos_weights = 5

# ---------------------
# load data
# ---------------------
print('Load Datasets...')
train_files = os.listdir('../graph-data/seattle-graphs/original/')
test_files = os.listdir('../graph-data/west-seattle-graphs/original/')
train_files, val_files = train_test_split(train_files, test_size=0.1, random_state=42)

train_data = DuoGraphDataset('../graph-data/seattle-graphs/', train_files)
val_data = DuoGraphDataset('../graph-data/seattle-graphs/', val_files)
test_data = DuoGraphDataset('../graph-data/west-seattle-graphs/', test_files)
    
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)    

# ---------------------
#  models
# ---------------------
print('Load Model...')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_del = GCN().to(device)
model_ins = GCN().to(device)
model_del.load_state_dict(torch.load(f'model_states/DuoGraph_del_0.01'))
model_ins.load_state_dict(torch.load(f'model_states/DuoGraph_ins_0.01'))
model_del.eval()
model_ins.eval()
criterion= nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weights))

Load Datasets...
Load Model...


### Graph Level Accuracy

In [3]:
graph_acc = []
for i, batch in enumerate(tqdm(val_dataloader)):
    
    # load data
    graph_gt, graph_del_edge_index, graph_ins_edge_index, x = batch
    graph_gt = graph_gt.squeeze_(0)
    labels = graph_gt.flatten()
    graph_del_edge_index = graph_del_edge_index.squeeze_(0).to(device)
    graph_ins_edge_index = graph_ins_edge_index.squeeze_(0).to(device)    
    x = x.squeeze_(0).to(device)
    
    # make prediction
    with torch.no_grad():
        out_del = model_del(graph_del_edge_index, x)
        out_ins = model_ins(graph_ins_edge_index, x) 
    A = F.sigmoid(torch.matmul(out_del, out_ins.T)).flatten()
    graph_acc.append(torch.mean(1.0*( 1.0*(A.detach().cpu()>0.5) == labels)).item())

100%|███████████████████████████████████████████████████████████████████████████████| 863/863 [00:06<00:00, 139.42it/s]


In [4]:
np.mean(graph_acc)

0.848387392343541

### Conflicted Edges Accuracy

In [5]:
edge_acc = []
for i, batch in enumerate(tqdm(val_dataloader)):
    
    # load data
    graph_gt, graph_del_edge_index, graph_ins_edge_index, x = batch
    graph_gt = graph_gt.squeeze_(0)
    graph_del_edge_index = graph_del_edge_index.squeeze_(0).to(device)
    graph_ins_edge_index = graph_ins_edge_index.squeeze_(0).to(device)    
    x = x.squeeze_(0).to(device)
    
    # extract conflicted edges
    conflicted_edges = graph_ins_edge_index.T[~(graph_ins_edge_index.T[:, None] == graph_del_edge_index.T).all(-1).any(-1)]
    conflicted_edges = conflicted_edges.repeat(2,1).detach().cpu()
    conflicted_edges[int(len(conflicted_edges)/2):,[0,1]] = conflicted_edges[int(len(conflicted_edges)/2):,[1,0]]
    conflicted_edges = list(zip(*conflicted_edges.T.numpy()))    
    
    # make prediction
    with torch.no_grad():
        out_del = model_del(graph_del_edge_index, x)
        out_ins = model_ins(graph_ins_edge_index, x) 
    A = F.sigmoid(torch.matmul(out_del, out_ins.T))
    A_pred = 1.0*(A.detach().cpu()>0.5)
    edge_acc.append(np.mean([1.0*(A_pred[c]==graph_gt[c]).item() for c in conflicted_edges]))

100%|███████████████████████████████████████████████████████████████████████████████| 863/863 [00:04<00:00, 177.17it/s]


In [6]:
np.mean(edge_acc)

0.5035410090117053