In [1]:
import os
import torch
import numpy as np
import networkx as nx
import torch.nn as nn
import torch.functional as f

from sklearn.model_selection import train_test_split
from karateclub import BoostNE
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

import warnings
warnings.filterwarnings("ignore")

In [2]:
class GraphDataset(Dataset):
    def __init__(self, 
                 graph_dir, 
                 graph_files, 
                 emb_model):
        
        self.graph_dir = graph_dir
        self.emb_model = emb_model
        self.graph_files = graph_files        
    
    def __len__(self):
        return len(self.graph_files)

    def __getitem__(self, idx):
        
        # single graph
        file = self.graph_files[idx]
        
        # graph paths
        graph_gt_path = os.path.join(self.graph_dir, 'original', file)
        graph_del_path = os.path.join(self.graph_dir, 'deletion', file)
        graph_ins_path = os.path.join(self.graph_dir, 'insertion', file)
        
        # ground truth adj
        graph_gt = torch.from_numpy(nx.to_numpy_array(nx.read_gpickle(graph_gt_path))).float()
        
        # deletion graph embedding
        graph_del = nx.read_gpickle(graph_del_path)
        graph_ins = nx.read_gpickle(graph_ins_path)
        graph_comb = nx.compose(graph_ins, graph_del)
        self.emb_model.fit(graph_comb)
        graph_emb = torch.from_numpy(self.emb_model.get_embedding())
        
        return graph_gt, graph_emb

In [3]:
def collate_fn(data):
    """
       data: is a list of tuples with (example, label, length)
             where 'example' is a tensor of arbitrary shape
             and label/length are scalars
    """
    graph_gt, graph_emb = zip(*data)
    index = np.cumsum([g.size(0) for g in graph_gt])
    index = np.insert(index,0,0)
    graph_emb = torch.concat(graph_emb)
    
    return graph_gt, graph_emb.float(), index

In [4]:
class graph_conflict(nn.Module):
    
    """
    Resolve graph conflicts
    
    """
    
    def __init__(self):
        super().__init__()

        # layers
        self.lr = nn.Sequential(
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU()
        )    
        
    def forward(self, graph_emb):
        
        ## projection
        x = self.lr(graph_emb)
        A = torch.matmul(x, x.T)
        
        return A

In [5]:
# parameters
lr = 1e-4
epochs = 20
batch_size = 1
emb_model = BoostNE(dimensions=16, iterations=15)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
graph_model=graph_conflict()
graph_model.to(device)
optimizer = torch.optim.Adam(graph_model.parameters(), lr=lr)
criterion= nn.BCEWithLogitsLoss()

# load files
train_files = os.listdir('../graph-data/seattle-graphs/original/')
test_files = os.listdir('../graph-data/west-seattle-graphs/original/')
train_files, val_files = train_test_split(train_files, test_size=0.1, random_state=42)

# make datasets
train_data = GraphDataset('../graph-data/seattle-graphs/', train_files, emb_model)
val_data = GraphDataset('../graph-data/seattle-graphs/', val_files, emb_model)
test_data = GraphDataset('../graph-data/west-seattle-graphs/', test_files, emb_model)

# data loader
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [6]:
for i, batch in enumerate(train_dataloader): 
    # load data
    graph_gt, graph_emb, index = batch
    labels = torch.concat([g.flatten() for g in graph_gt]).float().to(device)
    graph_emb = graph_emb.to(device)
    
    # prediction
    optimizer.zero_grad()
    A = graph_model(graph_emb)
    preds = A.flatten()
    loss = criterion(preds, labels)
    loss.backward()    
    optimizer.step()
    
    if i % 50 == 0: print(loss.item())

0.7186110019683838
0.6966872811317444
0.693310022354126
0.6932337284088135
0.6931611895561218
0.6931530833244324
0.6931631565093994
0.6931513547897339
0.6931484937667847
0.6931500434875488
0.6931482553482056
0.6931481957435608
0.6931479573249817
0.6931477189064026
0.6931474804878235
0.6931474804878235
0.6931474804878235
0.6931474804878235


KeyboardInterrupt: 

In [11]:
acc = []
sig = nn.Sigmoid()
for i, batch in enumerate(val_dataloader):
    
    # load data
    graph_gt, graph_emb, index = batch
    labels = torch.concat([g.flatten() for g in graph_gt]).float().to(device)
    graph_emb = graph_emb.to(device)
    
    # prediction
    with torch.no_grad():
        A = graph_model(graph_emb)
    acc.append(torch.mean(1.0*(A.detach().cpu() == graph_gt[0])).item())

In [12]:
np.mean(acc)

0.819476006015151