In [1]:
import flow2graph
from pathlib import Path
import networkx as nx

import torch
import dgl
import numpy as np

Using backend: pytorch


In [2]:
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
import dgl.nn as dglnn

class EdgeToNode(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, non_linear=None):
        super().__init__()
        self.linear_in = nn.Linear(in_feats, hid_feats)
        self.linear_out = nn.Linear(in_feats, hid_feats)
        self.linear_final = nn.Linear(in_features=hid_feats, out_features=out_feats)
        self.non_linear = non_linear or nn.Identity()
        
    def forward(self, graph: dgl.DGLGraph, h):
        h_in, h_out = self.linear_in(h), self.linear_out(h)
        h_in, h_out = self.non_linear(h_in), self.non_linear(h_out)
        
        with graph.local_scope(): 
            graph.edata['e_in'] = h_in
            graph.edata['e_out'] = h_out
            
            # copying `e_in` edge feature to dst node + aggregating with sum into `n_in`
            graph.update_all(fn.copy_e('e_in', 'n_in'), fn.sum('n_in', 'n_in')) 
            
            # reversing the graph so that src nodes become dst nodes
            r_graph = graph.reverse(copy_ndata=True, copy_edata=True) 
            # copying `e_out` edge feature to src node + aggregating with sum into `n_in`
            r_graph.update_all(fn.copy_e('e_out', 'n_out'), fn.sum('n_out', 'n_out')) 
            
            return torch.tanh(self.linear_final(graph.ndata['n_in'] + r_graph.ndata['n_out']))
    

class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        self.conv1 = dglnn.SAGEConv(in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
        self.conv2 = dglnn.SAGEConv(in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')

    def forward(self, graph, h):    
        h = F.relu(self.conv1(graph, h))
        h = self.conv2(graph, h)
        return h
    
class NodePredictor(nn.Module):
    def __init__(self, n_classes, in_features_e, out_features_e=16, hid_features_n=32, out_features_n=16):
        super().__init__()
        
        self.e2n = EdgeToNode(in_feats=in_features_e, hid_feats=hid_features_n, out_feats=out_features_e, non_linear=nn.ReLU())
        self.lin1 = nn.Linear(in_features=out_features_e, out_features=out_features_e)
        self.sage = SAGE(in_feats=out_features_e, hid_feats=hid_features_n, out_feats=out_features_n)
        self.lin2 = nn.Linear(in_features=out_features_n, out_features=out_features_n)
        self.n2y = nn.Linear(in_features=out_features_n, out_features=n_classes)
        
    def forward(self, graph, h):    
        h = self.e2n(graph, h)
        h = self.lin1(h)
        h = F.relu(h)
        h = self.sage(graph, h)
        h = F.relu(h)
        h = self.lin2(h)
        
        return self.n2y(h)
 
class NodeToEdge(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, non_linear=None):
        super().__init__()
        self.lin_in = nn.Linear(in_features=in_feats, out_features=hid_feats)
        self.lin_out = nn.Linear(in_features=in_feats, out_features=hid_feats)
        self.lin_final = nn.Linear(in_features=hid_feats, out_features=out_feats)
        self.non_linear = non_linear or nn.Identity()
        
    def forward(self, graph, h):
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.edata['h'] = torch.zeros((graph.number_of_edges(), h.shape[-1]))
            
            graph.apply_edges(dgl.function.e_add_u('h', 'h', 'h_out'))
            graph.apply_edges(dgl.function.e_add_v('h', 'h', 'h_in'))
            
            h_out = self.lin_out(graph.edata['h_out'])
            h_in = self.lin_in(graph.edata['h_in'])
            
            h_out = self.non_linear(h_out)
            h_in = self.non_linear(h_in)
            
            return torch.tanh(self.lin_final(h_out + h_in))

In [3]:
from sklearn.datasets import make_gaussian_quantiles
def generate_graph(n2e, n, n_ft_n=10, n_ft_e=None, n_classes=8, cov=3):
    n_ft_e = n_ft_e or 2*n_ft_n
    
    g = nx.barabasi_albert_graph(n, 1)
    g = dgl.from_networkx(g)
    
    x, y = make_gaussian_quantiles(cov=cov, n_classes=n_classes, n_features=n_ft_n, n_samples=n)
    x, y = torch.from_numpy(x).float(), torch.from_numpy(y).long()
    g.ndata['ft'] = x
    g.ndata['y'] = y
    
    with torch.no_grad():
        g.edata['ft'] = n2e(g, g.ndata['ft'])
        
    return g

n_ft_n = 8
n_ft_e = 32
# e2n = EdgeToNode(in_feats=n_features_e, hid_feats=3*n_features_e, out_feats=n_features)
n2e = NodeToEdge(in_feats=n_ft_n, hid_feats=3*n_ft_n, out_feats=n_ft_e, non_linear=nn.ReLU())

In [4]:
from tqdm.notebook import tqdm
dt = [generate_graph(n2e, np.random.randint(100, 4000), n_ft_n=n_ft_n, n_ft_e=n_ft_e, cov=0.05) for _ in tqdm(range(200))]
dt_train, dt_val, dt_test = torch.utils.data.random_split(dt, (len(dt) * np.array([0.7, 0.2, 0.1])).astype(int))

HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [5]:
from tqdm.notebook import tqdm

model = NodePredictor(8, n_ft_e)
criterion = nn.CrossEntropyLoss()

In [6]:
def evaluate(model, dt):
    losses = []
    corr, tot = 0, 0
    with torch.no_grad():
        model.eval()
        for g in tqdm(dt, desc='eval', leave=False):
            Y_true = g.ndata['y']
            Y_pred = model(g, g.edata['ft'])
            
            loss = criterion(Y_pred, Y_true)
            losses.append(loss.item())
            Y_pred = Y_pred.max(dim=-1)[1]
            
            tot += len(Y_pred)
            corr += (Y_true == Y_pred).sum().item()
            
    return np.mean(losses), corr/tot
    
def train(model, dt_train, dt_val, dt_test, nb_epochs=100, freq_show_loss=10):
    optimizer = torch.optim.Adam(model.parameters())
    
    loss_val, acc_val = evaluate(model, dt_val)
    loss_test, acc_test = evaluate(model, dt_test)
    print('Before:')
    print(f'\tloss_val: {loss_val:.3f} / acc_val: {acc_val:.3f}')
    print(f'\tloss_test: {loss_test:.3f} / acc_tes: {acc_test:.3f}')
    
    for epoch in tqdm(range(nb_epochs), desc='epoch'):
        losses = []
        corr, tot = 0, 0
        
        model.train()
        pb_training = tqdm(dt_train, desc='train', leave=False)
        for idx, g in enumerate(pb_training):
            Y_true = g.ndata['y']
            Y_pred = model(g, g.edata['ft'])
            
            loss = criterion(Y_pred, Y_true)
            losses.append(loss.item())
            
            Y_pred = Y_pred.max(dim=-1)[1]
            tot += len(Y_pred)
            corr += (Y_pred == Y_true).sum().item()
            
            if idx % freq_show_loss == 0:
                pb_training.set_description(f"l:{np.mean(losses):.3f}, acc:{corr/tot:.3f}")
            
            # weights optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        loss_val, acc_val = evaluate(model, dt_val)
        print(f"epoch {epoch:03d}: loss_val:{loss_val:.3f}, acc: {acc_val:.3f}")
              
    loss_val, acc_val = evaluate(model, dt_val)
    loss_test, acc_test = evaluate(model, dt_test)
    print('Before:')
    print(f'\tloss_val: {loss_val:.3f} / acc_val: {acc_val:.3f}')
    print(f'\tloss_test: {loss_test:.3f} / acc_tes: {acc_test:.3f}')
    
train(model, dt_train, dt_val, dt_test, nb_epochs=20)                    

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='eval', max=20.0, style=ProgressStyle(description_width='i…

Before:
	loss_val: 2.090 / acc_val: 0.125
	loss_test: 2.089 / acc_tes: 0.125


HBox(children=(FloatProgress(value=0.0, description='epoch', max=20.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 000: loss_val:2.028, acc: 0.199


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 001: loss_val:1.598, acc: 0.341


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 002: loss_val:1.521, acc: 0.369


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 003: loss_val:1.483, acc: 0.384


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 004: loss_val:1.364, acc: 0.426


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 005: loss_val:1.333, acc: 0.440


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 006: loss_val:1.299, acc: 0.450


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 007: loss_val:1.257, acc: 0.464


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 008: loss_val:1.238, acc: 0.469


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 009: loss_val:1.237, acc: 0.469


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 010: loss_val:1.205, acc: 0.480


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 011: loss_val:1.165, acc: 0.497


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 012: loss_val:1.131, acc: 0.511


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 013: loss_val:1.114, acc: 0.518


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 014: loss_val:1.089, acc: 0.528


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 015: loss_val:1.095, acc: 0.526


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 016: loss_val:1.070, acc: 0.535


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 017: loss_val:1.069, acc: 0.536


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 018: loss_val:1.072, acc: 0.534


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 019: loss_val:1.072, acc: 0.533



HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='eval', max=20.0, style=ProgressStyle(description_width='i…

Before:
	loss_val: 1.072 / acc_val: 0.533
	loss_test: 1.061 / acc_tes: 0.537


In [7]:
model = nn.Sequential(
    nn.Linear(n_ft_n, n_ft_n * 2),
    nn.ReLU(),
    nn.Linear(n_ft_n * 2, n_ft_n * 2),
    nn.ReLU(),
    nn.Linear(n_ft_n * 2, 8),
)

criterion = nn.CrossEntropyLoss()

In [8]:
def evaluate(model, dt):
    losses = []
    corr, tot = 0, 0
    with torch.no_grad():
        model.eval()
        for g in tqdm(dt, desc='eval', leave=False):
            Y_true = g.ndata['y']
            Y_pred = model(g.ndata['ft'])
            
            loss = criterion(Y_pred, Y_true)
            losses.append(loss.item())
            Y_pred = Y_pred.max(dim=-1)[1]
            
            tot += len(Y_pred)
            corr += (Y_true == Y_pred).sum().item()
            
    return np.mean(losses), corr/tot
    
def train(model, dt_train, dt_val, dt_test, nb_epochs=100, freq_show_loss=10):
    optimizer = torch.optim.Adam(model.parameters())
    
    loss_val, acc_val = evaluate(model, dt_val)
    loss_test, acc_test = evaluate(model, dt_test)
    print('Before:')
    print(f'\tloss_val: {loss_val:.3f} / acc_val: {acc_val:.3f}')
    print(f'\tloss_test: {loss_test:.3f} / acc_tes: {acc_test:.3f}')
    
    for epoch in tqdm(range(nb_epochs), desc='epoch'):
        losses = []
        corr, tot = 0, 0
        
        model.train()
        pb_training = tqdm(dt_train, desc='train', leave=False)
        for idx, g in enumerate(pb_training):
            Y_true = g.ndata['y']
            Y_pred = model(g.ndata['ft'])
            
            loss = criterion(Y_pred, Y_true)
            losses.append(loss.item())
            
            Y_pred = Y_pred.max(dim=-1)[1]
            tot += len(Y_pred)
            corr += (Y_pred == Y_true).sum().item()
            
            if idx % freq_show_loss == 0:
                pb_training.set_description(f"l:{np.mean(losses):.3f}, acc:{corr/tot:.3f}")
            
            # weights optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        loss_val, acc_val = evaluate(model, dt_val)
        print(f"epoch {epoch:03d}: loss_val:{loss_val:.3f}, acc: {acc_val:.3f}")
              
    loss_val, acc_val = evaluate(model, dt_val)
    loss_test, acc_test = evaluate(model, dt_test)
    print('Before:')
    print(f'\tloss_val: {loss_val:.3f} / acc_val: {acc_val:.3f}')
    print(f'\tloss_test: {loss_test:.3f} / acc_tes: {acc_test:.3f}')

train(model, dt_train, dt_val, dt_test, nb_epochs=20)                    

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='eval', max=20.0, style=ProgressStyle(description_width='i…

Before:
	loss_val: 2.094 / acc_val: 0.125
	loss_test: 2.094 / acc_tes: 0.125


HBox(children=(FloatProgress(value=0.0, description='epoch', max=20.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 000: loss_val:2.000, acc: 0.209


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 001: loss_val:1.688, acc: 0.334


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 002: loss_val:1.418, acc: 0.493


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 003: loss_val:1.250, acc: 0.567


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 004: loss_val:1.143, acc: 0.589


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 005: loss_val:1.072, acc: 0.596


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 006: loss_val:1.024, acc: 0.599


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 007: loss_val:0.990, acc: 0.601


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 008: loss_val:0.965, acc: 0.603


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 009: loss_val:0.947, acc: 0.606


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 010: loss_val:0.932, acc: 0.606


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 011: loss_val:0.921, acc: 0.608


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 012: loss_val:0.912, acc: 0.609


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 013: loss_val:0.905, acc: 0.609


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 014: loss_val:0.899, acc: 0.609


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 015: loss_val:0.895, acc: 0.610


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 016: loss_val:0.891, acc: 0.610


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 017: loss_val:0.888, acc: 0.611


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 018: loss_val:0.885, acc: 0.611


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 019: loss_val:0.883, acc: 0.611



HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='eval', max=20.0, style=ProgressStyle(description_width='i…

Before:
	loss_val: 0.883 / acc_val: 0.611
	loss_test: 0.873 / acc_tes: 0.610
