In [None]:
import torch
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import pickle as pkl
import random
from collections import Counter
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.nn import functional as F
from layers.layers import SinusoidalPosEmb, NodeModel, BitModel
from eval_utils.evaluation.graph_structure_evaluation import Descriptor

In [3]:
def unpack_deg_matrix(degs):
    res = []
    for deg in degs:
        deg = deg.long().tolist()
        r = []
        # print(degs)
        for d in deg:
            if (sum(r)==0) or (d > 0):
                r.append(d)
        res.append(r)
    return res
    
def dec2bin(x, bits):
    mask = 2 ** torch.arange(bits - 1, -1, -1).to(x.device, x.dtype)
    return x.unsqueeze(-1).bitwise_and(mask).ne(0).float()
    
def bin2dec(b, bits):
    mask = 2 ** torch.arange(bits - 1, -1, -1).to(b.device, b.dtype)
    return torch.sum(mask * b, -1)

class Degree(Descriptor):
    def __init__(self, *args, **kwargs):
        self.name = 'degree'
        self.sigmas = [1.0]
        self.distance_scaling = 1.0
        super().__init__(*args, **kwargs)

    def extract_features(self, res):
        res = [s1 / np.sum(s1) for s1 in res]
        return res

    def degree_worker(self, G):
        return np.array(nx.degree_histogram(G))

In [4]:
dataset_name = 'community'

In [5]:
assert dataset_name in [ 'community', 'Ego']

num_node_classes = None
num_edge_classes = 2
num_node_feat = None
nx_graphs = pkl.load(open(f"graphs/{dataset_name}.pkl", 'rb'))
random.shuffle(nx_graphs)
l = len(nx_graphs)
train_nx_graphs = nx_graphs[:int(0.8*l)]
eval_nx_graphs = nx_graphs[:int(0.2*l)]
test_nx_graphs = nx_graphs[int(0.8*l):]

MAX_NUM_NODES =  max([g.number_of_nodes() for g in nx_graphs]) 
MIN_NUM_NODES = min([g.number_of_nodes() for g in nx_graphs])
max_degree = max(sum([[d[1] for d in nx.degree(g)] for g in nx_graphs],[]))
train_deg = []
eval_deg = []
test_deg = []

for nx_graph in train_nx_graphs:
    degs = [d[1] for d in nx.degree(nx_graph)]

    d = Counter(degs)
    vector = torch.zeros(max_degree)
    for k,v in d.items():
        vector[k-1] = v
    vector = vector.long()
    train_deg.append(vector)

for nx_graph in eval_nx_graphs:
    degs = [d[1] for d in nx.degree(nx_graph)]

    d = Counter(degs)
    vector = torch.zeros(max_degree)
    for k,v in d.items():
        vector[k-1] = v
    vector = vector.long()
    eval_deg.append(vector)

for nx_graph in test_nx_graphs:
    degs = [d[1] for d in nx.degree(nx_graph)]

    d = Counter(degs)
    vector = torch.zeros(max_degree)
    for k,v in d.items():
        vector[k-1] = v
    vector = vector.long()
    test_deg.append(vector)

train_deg = torch.stack(train_deg)
eval_deg = torch.stack(eval_deg)
test_deg = torch.stack(test_deg)

n_vocab = max(max(train_deg.max(), eval_deg.max()), test_deg.max())+1
SEQ_LENS = max_degree
NUM_BITS = bin(n_vocab)[2:].__len__()


In [6]:
DegreeMMD = Degree()
deg_mmd = DegreeMMD.evaluate([nx.degree_histogram(g) for g in train_nx_graphs], [nx.degree_histogram(g) for g in test_nx_graphs])
print('average distance between train and test:', deg_mmd)

average distance between train and test: 0.007126073743663752


In [7]:
def sample(modelNode, modelBit, num_samples):
    x = torch.zeros(num_samples, 1, NUM_BITS)
    g = torch.randint(low=MIN_NUM_NODES,high=MAX_NUM_NODES,size=(num_samples,1))
    r = g
    x = x.cuda()
    g = g.cuda()
    r = r.cuda()
    modelNode.eval()
    modelBit.eval()
    with torch.no_grad():
        for i in range(SEQ_LENS):      
            node_hidden = modelNode(x, g, r)[:,-1,:]
            y = (torch.ones(num_samples, 1).long().cuda()*2).long()

            for j in range(NUM_BITS):
                prediction = modelBit(y.view(-1, j+1), node_hidden.view(-1, node_hidden.shape[-1]), r[:,-1][:,None])[:,-1,:]
                prediction = F.sigmoid(prediction)
                index = prediction.bernoulli().long()
                y = torch.cat([y, index],dim=-1)
            y = y[:, 1:]
            n_j = bin2dec(y, NUM_BITS)-1
            r = torch.cat([r, (r[:, -1]-n_j)[:,None]],dim=-1)
            x = torch.cat([x, y[:,None,:]], dim=1)
        
        x = (bin2dec(x, NUM_BITS)-1).clamp(0)[:,1:]
        return unpack_deg_matrix(x)

In [8]:
n_epochs = 400
batch_size = 64
modelN = NodeModel(NUM_BITS, MAX_NUM_NODES, SEQ_LENS)
modelB = BitModel(NUM_BITS, SEQ_LENS)

G_train = train_deg.sum(-1,keepdim=True)
R_train = train_deg.flip(-1).cumsum(-1).flip(-1)
X_train = torch.cat([torch.zeros(train_deg.shape[0], 1).long(), train_deg[:, :-1]+1], dim=1)
y_train = train_deg[:, :] + 1

G_eval = eval_deg.sum(-1,keepdim=True)
R_eval = eval_deg.flip(-1).cumsum(-1).flip(-1)
X_eval = torch.cat([torch.zeros(eval_deg.shape[0], 1).long(), eval_deg[:, :-1]+1], dim=1)
y_eval = eval_deg[:, :] + 1

X_train_b = dec2bin(X_train, NUM_BITS)
y_train_b = dec2bin(y_train, NUM_BITS)
X_eval_b = dec2bin(X_eval, NUM_BITS)
y_eval_b = dec2bin(y_eval, NUM_BITS)

weights = torch.tensor([1]).cuda()

modelN.to('cuda')
modelB.to('cuda')
optimizer = optim.Adam(list(modelN.parameters())+list(modelB.parameters()), lr=0.001)
loss_fn = nn.BCEWithLogitsLoss(reduction='none')
loader_train = data.DataLoader(data.TensorDataset(X_train_b, G_train, R_train, y_train_b), shuffle=True, batch_size=batch_size)
loader_eval = data.DataLoader(data.TensorDataset(X_eval_b, G_eval, R_eval, y_eval_b), shuffle=True, batch_size=batch_size)

best_model = None
best_loss = np.inf
for epoch in range(n_epochs):
    modelN.train()
    modelB.train()
    for X_batch, g_batch, r_batch, y_batch in loader_train:
        X_batch = X_batch.cuda()
        y_batch = y_batch.cuda()
        g_batch = g_batch.cuda()
        r_batch = r_batch.cuda()
        node_hidden = modelN(X_batch, g_batch, r_batch)
        # print(node_hidden.shape)
        y_batch_viewed = y_batch.view(-1, NUM_BITS)
        bit_batch = torch.cat([torch.ones(y_batch_viewed.shape[0], 1).long().cuda()*2, y_batch_viewed[:, :-1]], dim=1).long()
        # print(bit_batch.shape, bit_batch.max(),bit_batch.dtype)
        y_pred = modelB(bit_batch, node_hidden.view(-1, node_hidden.shape[-1]), r_batch.view(-1, 1)).view(y_batch.shape[0], -1, NUM_BITS) 
        loss = (loss_fn(y_pred, y_batch)*weights).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # Validation
    modelN.eval()
    modelB.eval()
    loss = 0
    with torch.no_grad():
        for X_batch, g_batch, r_batch, y_batch in loader_eval:
            X_batch = X_batch.cuda()
            y_batch = y_batch.cuda()
            g_batch = g_batch.cuda()
            r_batch = r_batch.cuda()
            node_hidden = modelN(X_batch, g_batch, r_batch)
            y_batch_viewed = y_batch.view(-1, NUM_BITS)
            bit_batch = torch.cat([torch.ones(y_batch_viewed.shape[0], 1).long().cuda()*2, y_batch_viewed[:, :-1]], dim=1).long()
            y_pred = modelB(bit_batch, node_hidden.view(-1, node_hidden.shape[-1]), r_batch.view(-1, 1)).view(y_batch.shape[0], -1, NUM_BITS) 

            loss += (loss_fn(y_pred, y_batch)*weights).mean()
        
        if loss < best_loss:
            best_loss = loss
            best_modelB = modelB.state_dict()
            best_modelN = modelN.state_dict()
        print("Epoch %d: loss: %.4f" % (epoch, loss))

Epoch 0: loss: 0.9779
Epoch 1: loss: 0.8436
Epoch 2: loss: 0.7372
Epoch 3: loss: 0.6887
Epoch 4: loss: 0.6834
Epoch 5: loss: 0.6361
Epoch 6: loss: 0.7145
Epoch 7: loss: 0.5687
Epoch 8: loss: 0.5416
Epoch 9: loss: 0.5095
Epoch 10: loss: 0.4604
Epoch 11: loss: 0.4671
Epoch 12: loss: 0.4567
Epoch 13: loss: 0.4644
Epoch 14: loss: 0.6908
Epoch 15: loss: 0.5845
Epoch 16: loss: 0.5759
Epoch 17: loss: 0.5349
Epoch 18: loss: 0.5006
Epoch 19: loss: 0.4778
Epoch 20: loss: 0.4635
Epoch 21: loss: 0.4630
Epoch 22: loss: 0.4523
Epoch 23: loss: 0.4480
Epoch 24: loss: 0.4447
Epoch 25: loss: 0.4465
Epoch 26: loss: 0.4388
Epoch 27: loss: 0.4393
Epoch 28: loss: 0.4444
Epoch 29: loss: 0.4408
Epoch 30: loss: 0.4358
Epoch 31: loss: 0.4365
Epoch 32: loss: 0.4385
Epoch 33: loss: 0.4398
Epoch 34: loss: 0.4400
Epoch 35: loss: 0.4428
Epoch 36: loss: 0.4465
Epoch 37: loss: 0.4376
Epoch 38: loss: 0.4414
Epoch 39: loss: 0.4354
Epoch 40: loss: 0.4394
Epoch 41: loss: 0.4408
Epoch 42: loss: 0.4366
Epoch 43: loss: 0.439

In [9]:
with torch.no_grad():
    modelN.load_state_dict(best_modelN)
    modelB.load_state_dict(best_modelB)

    sampled_degs = sample(modelN, modelB, num_samples=400)
    eval_mmd = DegreeMMD.evaluate(sampled_degs, [nx.degree_histogram(g) for g in test_nx_graphs])
    print(eval_mmd)


0.01956113805428683


In [11]:
torch.save({'modelBit': best_modelB, 'modelNode':best_modelN, 'NUM_BITS':NUM_BITS, 'SEQ_LENS':SEQ_LENS, 'MAX_NUM_NODES': MAX_NUM_NODES}, f'graphs/{dataset_name}_degree_sampler.pt')