In [1]:
import os
import torch
import argparse
import numpy as np
import networkx as nx
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch_geometric.nn import GraphUNet
from utils_simulated import *
from model import *

# ---------------------------
# seeding for reproducibility
# ---------------------------
seed = 100
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=True
torch.use_deterministic_algorithms(False)
# set CUBLAS_WORKSPACE_CONFIG=:16:8

In [2]:
# ---------------------
# parameters
# ---------------------
lr = 2e-3
epochs = 200
batch_size = 1
input_dim = 2
hidden_dim = 32
output_dim = 64
pos_weights = 6.667
path = '../../../data/2023-graph-conflation/'
simulation = 'union'
model ='gcn'

# ---------------------
# load data
# ---------------------
print('Load Datasets...')
files = os.listdir(path+f'simulated-graphs/{simulation}/gt/')
train_files, test_files = train_test_split(files, test_size=0.2, random_state=42)
train_files, val_files = train_test_split(train_files, test_size=0.2, random_state=42)

# make datasets
train_data = GraphDataset(path, simulation, train_files)
val_data = GraphDataset(path, simulation, val_files)
test_data = GraphDataset(path, simulation, test_files)

# data loader
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False) 

Load Datasets...


In [3]:
# ---------------------
#  models
# ---------------------
print('Load Model...')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
graphconflator = GraphConflator(input_dim, hidden_dim, output_dim, model)
optimizer = torch.optim.Adam(graphconflator.parameters(), lr=lr)
criterion= nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weights))
es = EarlyStopping(tolerance=10)

Load Model...


In [4]:
for i, batch in enumerate(tqdm(train_dataloader)):
    
    # load data
    graph_gt, graph_set1, graph_set2, gt_x, x_set1, x_set2 = batch
    graph_gt = graph_gt.squeeze_(0).to(device)
    labels = graph_gt.flatten()
    graph_set1 = graph_set1.squeeze_(0).long().to(device)
    graph_set2 = graph_set2.squeeze_(0).long().to(device)
    gt_x = gt_x.squeeze_(0).to(device)
    x_set1 = x_set1.squeeze_(0).to(device)
    x_set2 = x_set2.squeeze_(0).to(device)
    break

  0%|                                                  | 0/5187 [00:00<?, ?it/s]


In [5]:
logits = graphconflator(graph_set1, graph_set2, x_set1, x_set2)  

In [6]:
logits

tensor([-0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031,
        -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031,
        -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031,
        -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031,
        -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031,
        -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031,
        -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031,
        -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031,
        -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031,
        -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031,
        -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031,
        -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031, -0.0031,
        -0.0031, -0.0031, -0.0031, -0.00