In [1]:
import os
import torch
import numpy as np
import networkx as nx
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split
from karateclub import BoostNE
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

import warnings
warnings.filterwarnings("ignore")

In [2]:
class GraphDataset(Dataset):
    def __init__(self, 
                 graph_dir, 
                 graph_files, 
                 emb_model):
        
        self.graph_dir = graph_dir
        self.emb_model = emb_model
        self.graph_files = graph_files        
    
    def __len__(self):
        return len(self.graph_files)

    def __getitem__(self, idx):
        
        # single graph
        file = self.graph_files[idx]
        
        # graph paths
        graph_gt_path = os.path.join(self.graph_dir, 'original', file)
        graph_del_path = os.path.join(self.graph_dir, 'deletion', file)
        graph_ins_path = os.path.join(self.graph_dir, 'insertion', file)
        
        # ground truth adj
        graph_gt = torch.from_numpy(nx.to_numpy_array(nx.read_gpickle(graph_gt_path))).float()
        
        # deletion graph embedding
        graph_del = nx.read_gpickle(graph_del_path)
        self.emb_model.fit(graph_del)
        graph_del_emb = torch.from_numpy(self.emb_model.get_embedding())
        
        # insertion graph embedding
        graph_ins = nx.read_gpickle(graph_ins_path)
        self.emb_model.fit(graph_ins)
        graph_ins_emb = torch.from_numpy(self.emb_model.get_embedding())
        
        return graph_gt, graph_del_emb, graph_ins_emb

In [3]:
def collate_fn(data):
    """
       data: is a list of tuples with (example, label, length)
             where 'example' is a tensor of arbitrary shape
             and label/length are scalars
    """
    graph_gt, graph_del_emb, graph_ins_emb = zip(*data)
    
    index = np.cumsum([g.size(0) for g in graph_gt])
    index = np.insert(index,0,0)
    graph_del_emb = torch.concat(graph_del_emb)
    graph_ins_emb = torch.concat(graph_ins_emb)    
    
    return graph_gt, graph_del_emb.float(), graph_ins_emb.float(), index

In [4]:
class graph_conflict(nn.Module):
    
    """
    Resolve graph conflicts
    
    """
    
    def __init__(self):
        super().__init__()

        # layers
        self.lr_del = nn.Sequential(
            nn.Linear(256, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 64),
            nn.LeakyReLU()
        )
        
        self.lr_ins = nn.Sequential(
            nn.Linear(256, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 64),
            nn.LeakyReLU()
        )
        
    def forward(self, graph_del, graph_ins):
        
        ## projection
        x_del = self.lr_del(graph_del)
        x_ins = self.lr_ins(graph_ins)
        A = torch.matmul(x_del, x_ins.T)
        
        return A

In [5]:
# parameters
lr = 1e-3
epochs = 20
batch_size = 4
emb_model = BoostNE(dimensions=16, iterations=15)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
graph_model=graph_conflict()
graph_model.to(device)
optimizer = torch.optim.Adam(graph_model.parameters(), lr=lr)
criterion= nn.BCEWithLogitsLoss()

# load files
train_files = os.listdir('../graph-data/seattle-graphs/original/')
test_files = os.listdir('../graph-data/west-seattle-graphs/original/')
train_files, val_files = train_test_split(train_files, test_size=0.1, random_state=42)

# make datasets
train_data = GraphDataset('../graph-data/seattle-graphs/', train_files, emb_model)
val_data = GraphDataset('../graph-data/seattle-graphs/', val_files, emb_model)
test_data = GraphDataset('../graph-data/west-seattle-graphs/', test_files, emb_model)

# data loader
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [7]:
losses = []
for i, batch in enumerate(train_dataloader): 
    # load data
    graph_gt, graph_del, graph_ins, index = batch
    labels = torch.concat([g.flatten() for g in graph_gt]).float().to(device)
    graph_del = graph_del.to(device)
    graph_ins = graph_ins.to(device)
    
    # prediction
    optimizer.zero_grad()
    A = graph_model(graph_del, graph_ins)
    preds = torch.concat([A[index[i]:index[i+1], index[i]:index[i+1]].flatten() for i in range(len(index)-1)])
    loss = criterion(preds, labels)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    
    if i % 50 == 0: print(loss.item())

0.7032991051673889
0.6931471824645996
0.6931471824645996
0.6931471824645996
0.6931471228599548
0.6931471824645996
0.6931471824645996
0.6931472420692444
0.6931471824645996
0.6931471824645996
0.6931471824645996
0.6931472420692444


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "C:\Users\binha\anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3457, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\binha\AppData\Local\Temp\ipykernel_1216\1816005009.py", line 2, in <module>
    for i, batch in enumerate(train_dataloader):
  File "C:\Users\binha\anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 630, in __next__
    data = self._next_data()
  File "C:\Users\binha\anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 674, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "C:\Users\binha\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\Users\binha\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
 

TypeError: object of type 'NoneType' has no len()

In [9]:
acc = []
sig = nn.Sigmoid()
for i, batch in enumerate(val_dataloader):
    
    # load data
    graph_gt, graph_del, graph_ins, index = batch
    graph_del = graph_del.to(device)
    graph_ins = graph_ins.to(device)
    
    # prediction
    with torch.no_grad():
        A = graph_model(graph_del, graph_ins)
    preds = [ sig(A[index[i]:index[i+1], index[i]:index[i+1]].detach().cpu())>0.5 for i in range(len(index)-1)]
    for p in range(len(preds)):
        acc.append(torch.mean(1.0*(preds[p] == graph_gt[p])).item())

In [11]:
np.mean(acc)

0.8975347871714091