In [1]:
import flow2graph
from pathlib import Path
import networkx as nx

import torch
import dgl
import numpy as np

Using backend: pytorch


In [None]:
path_labels = Path("./Datasets/CTU-13-Extended/labels.json")
datasets = flow2graph.NetflowDataset.load_from_labels(path_labels, 
                                                      window_time_sec=120,
                                                      chunksize=int(1e6))

In [None]:
def remove_noise(g, k=2, min_degree=1, inplace=True):
    g = g if inplace else g.copy()
    if k > 0:
        bt = nx.degree_centrality(g)
        bt = sorted(bt.items() , key=lambda x: x[-1])[::-1]
        g.remove_nodes_from(map(lambda x: x[0], bt[:k]))
    if min_degree > 0:
        to_remove = list(map(lambda d: d[0], filter(lambda d: d[1] < min_degree, g.degree)))
        g.remove_nodes_from(to_remove)
    return g

In [None]:
from tqdm.notebook import tqdm

!mkdir graphs-wo-noise

root = Path('./graphs-wo-noise')
counter = 0
n_labels = None
for dt in tqdm(datasets.values()):
    for (i, df) in tqdm(enumerate(dt), leave=False): 
        if (df.label.values == flow2graph.Label.malicious.value).sum() < 10:
            continue
        df = dt.compute_features(df, normalize=True)
        g = dt.to_graph(df)
#         g = remove_noise(g, k=2)
        nx.write_gpickle(g, path=root/f"{counter:05d}.pkl")
        counter += 1

In [2]:
from dgl.data import DGLDataset
from dgl.data.utils import save_graphs, load_graphs, download

class CTU13Dataset(DGLDataset):
    from functools import cached_property
    
    url = "https://filebin.net/cy19fulrva0t71n3/ctu13.tar.gz?t=jj1h0gl9"
        
    label_normal = flow2graph.Label.normal.value
    label_background = flow2graph.Label.background.value
    label_malicious = flow2graph.Label.malicious.value
    
    def __init__(self, url=None, raw_dir=None, save_dir=None, force_reload=False, verbose=False):
        self._save_dir = Path(save_dir) # (1) temporary fix waiting for: https://github.com/dmlc/dgl/pull/2262
        self._raw_dir = Path(raw_dir)
        
        self._p_raws = sorted(list(self._raw_dir.glob('[0-9]*.pkl')))
        self._p_unraws = list(map(lambda p: self._save_dir/f"{p.name.rstrip('pkl')}nx", self._p_raws))
        
        super().__init__(name='CTU13',
                         url=url, 
                         raw_dir=self._raw_dir, 
                         save_dir=123, # (1)
                         force_reload=force_reload, 
                         verbose=verbose)
    def process(self):
        self._save_dir.mkdir(parents=True, exist_ok=True)
        for i in tqdm(range(len(self._p_raws)), desc='processing', disable=self.verbose):
            rp, wp = self._p_raws[i], self._p_unraws[i]
            
            g = nx.read_gpickle(rp)
            g = dgl.from_networkx(g, node_attrs=['label'], edge_attrs=['features', 'label'])
            
            label_edge, label_node = g.edata.pop('label'), g.ndata.pop('label')
            save_graphs(str(wp), g, labels={
                'edge': label_edge,
                'node': label_node,
            })
            
    def __getitem__(self, idx):
        if idx < 0 or idx >= len(self):
            raise IndexError
        g, labels = load_graphs(str(self._p_unraws[idx]))
        return g[0], labels
    
    def __len__(self):
        return len(self._p_unraws)
        
    def has_cache(self):
        return len(self) == len(list(self._save_dir.glob('[0-9]*.nx')))

In [3]:
from dgl.data.utils import split_dataset
from tqdm.notebook import tqdm

dt = CTU13Dataset(raw_dir='./graphs-noise/', save_dir='/tmp/dgl-test')
dt_train, dt_val, dt_test = split_dataset(dt, shuffle=True)

HBox(children=(FloatProgress(value=0.0, description='processing', max=1729.0, style=ProgressStyle(description_…




In [4]:
len(dt_train), len(dt_val), len(dt_test)

(1383, 172, 174)

In [5]:
import dgl.nn as dglnn
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F

class EdgeToNode(nn.Module):
    def __init__(self, in_features, out_features, non_linear=None):
        super().__init__()
        self.linear_in = nn.Linear(in_features, out_features)
        self.linear_out = nn.Linear(in_features, out_features)
        self.non_linear = non_linear or nn.Identity()
        
    def forward(self, graph: dgl.DGLGraph, h):
        h_in, h_out = self.linear_in(h), self.linear_out(h)
        h_in, h_out = self.non_linear(h_in), self.non_linear(h_out)
        
        with graph.local_scope(): 
            graph.edata['e_in'] = h_in
            graph.edata['e_out'] = h_out
            
            # copying `e_in` edge feature to dst node + aggregating with sum into `n_in`
            graph.update_all(fn.copy_e('e_in', 'n_in'), fn.sum('n_in', 'n_in')) 
            
            # reversing the graph so that src nodes become dst nodes
            r_graph = graph.reverse(copy_ndata=True, copy_edata=True) 
            # copying `e_out` edge feature to src node + aggregating with sum into `n_in`
            r_graph.update_all(fn.copy_e('e_out', 'n_out'), fn.sum('n_out', 'n_out')) 
            
            return torch.tanh(graph.ndata['n_in'] + r_graph.ndata['n_out'])

class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        self.conv1 = dglnn.SAGEConv(in_feats=in_feats, out_feats=hid_feats, aggregator_type='pool')
        self.conv2 = dglnn.SAGEConv(in_feats=hid_feats, out_feats=out_feats, aggregator_type='pool')

    def forward(self, graph, h):    
        h = F.relu(self.conv1(graph, h))
        h = self.conv2(graph, h)
        return h
    
class NodeToEdge(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, non_linear=None):
        super().__init__()
        self.lin_in = nn.Linear(in_features=in_feats, out_features=hid_feats)
        self.lin_out = nn.Linear(in_features=in_feats, out_features=hid_feats)
        self.lin_final = nn.Linear(in_features=hid_feats, out_features=out_feats)
        self.non_linear = non_linear or nn.Identity()
        
    def forward(self, graph, h):
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.edata['h'] = torch.zeros((graph.number_of_edges(), h.shape[-1]))
            
            graph.apply_edges(dgl.function.e_add_u('h', 'h', 'h_out'))
            graph.apply_edges(dgl.function.e_add_v('h', 'h', 'h_in'))
            
            h_out = self.lin_out(graph.edata['h_out'])
            h_in = self.lin_in(graph.edata['h_in'])
            
            h_out = self.non_linear(h_out)
            h_in = self.non_linear(h_in)
            
            return self.lin_final(h_out + h_in)

class EdgePredictor(nn.Module):
    def __init__(self, in_features_e, out_features_e=16, hid_features_n=32, out_features_n=16):
        super().__init__()
        
        self.e2n = EdgeToNode(in_features=in_features_e, out_features=out_features_e, non_linear=nn.ReLU())
        self.lin1 = nn.Linear(in_features=out_features_e, out_features=out_features_e)
        self.sage = SAGE(in_feats=out_features_e, hid_feats=hid_features_n, out_feats=out_features_n)
        self.lin2 = nn.Linear(in_features=out_features_n, out_features=out_features_n)
        self.n2y = nn.Linear(in_features=out_features_n, out_features=2)
        self.n2e = NodeToEdge(in_feats=out_features_n, hid_feats=hid_features_n, out_feats=2)
        
    def forward(self, graph, h):    
        h = self.e2n(graph, h)
        h = self.lin1(h)
        h = F.relu(h)
        h = self.sage(graph, h)
        h = F.relu(h)
        h = self.lin2(h)
        
        pred_node = self.n2y(h)
        pred_edge = self.n2e(graph, h)
        
        return pred_node, pred_edge

In [6]:
nb_malicious, nb_tot = 0, 0
nb_malicious_n, nb_tot_n = 0, 0
for (_, labels) in tqdm(dt_train):
    y = labels['edge'] == CTU13Dataset.label_malicious
    nb_malicious += y.sum().item()
    nb_tot += len(y)
    
ratio = nb_malicious/nb_tot
weights = 1.0-torch.tensor([1.0-ratio, ratio])
weights

HBox(children=(FloatProgress(value=0.0, max=1383.0), HTML(value='')))




tensor([0.0075, 0.9925])

In [7]:
# model = EdgePredictor(5, out_features_e=16, hid_features_n=32, out_features_n=16)
loss_test, acc_test_n, acc_test_e = evaluate(model, dt_test)
print(f"Test      : loss:{loss_test:.3f} | acc_n:{acc_test_n:.3f} | acc_e: {acc_test_e:.3f}")       

NameError: name 'evaluate' is not defined

In [8]:
from tqdm.notebook import tqdm

model = EdgePredictor(5, out_features_e=16, hid_features_n=32, out_features_n=16)
criterion = nn.CrossEntropyLoss(weight=weights)

cst_wl_edge = 0.7
assert cst_wl_edge >= 0.0 and cst_wl_edge <= 1.0, "edges weighting loss must belong to [0, 1]"

def compute_loss(y_pred_n, y_pred_e, y_true_n, y_true_e):
    loss_n = criterion(y_pred_n, y_true_n)
    loss_e = criterion(y_pred_e, y_true_e)
    return (1.0 - cst_wl_edge) * loss_n + cst_wl_edge * loss_e

def evaluate(model, dt):
    losses = []
    acc_e, acc_n = 0, 0
    nb_e, nb_n = 0, 0
    # validation loop
    with torch.no_grad():
        model.eval()
        for (g, labels) in tqdm(dt, desc='eval', leave=False):
#             mask_node = labels['node'] != CTU13Dataset.label_background
#             mask_edge = labels['edge'] != CTU13Dataset.label_background
            label_node = (labels['node'] == CTU13Dataset.label_malicious).long()
            label_edge = (labels['edge'] == CTU13Dataset.label_malicious).long()

            ft_edges = g.edata['features']

            pred_node, pred_edge = tuple(map(lambda y: y.squeeze(), model(g, ft_edges)))
            loss = compute_loss(pred_node, pred_edge, label_node, label_edge)
            
            # that's not really the accuracy, but True Positives (malicious)
            acc_e += (pred_edge.max(dim=-1)[1])[label_edge.bool()].sum().item()
            acc_n += (pred_node.max(dim=-1)[1])[label_node.bool()].sum().item()
            nb_e += label_edge.sum().item()
            nb_n += label_node.sum().item()
            
            losses.append(loss.item()) 
    return np.mean(losses), acc_n/nb_e, acc_e/nb_e
    
def train(model, dt_train, dt_val, dt_test, nb_epochs=10, freq_show_loss=10):
    optimizer = torch.optim.Adam(model.parameters())
    
    loss_val, acc_val_n, acc_val_e = evaluate(model, dt_val)
    loss_test, acc_test_n, acc_test_e = evaluate(model, dt_test)
    print("Before train:")
    print(f"\t Validation: loss:{loss_val:.3f} | acc_n:{acc_val_n:.3f} | acc_e: {acc_val_e:.3f}")
    print(f"\t Test      : loss:{loss_test:.3f} | acc_n:{acc_test_n:.3f} | acc_e: {acc_test_e:.3f}")       
    
    for epoch in tqdm(range(nb_epochs), desc='epoch'):
        losses = []
        
        acc_train_e, acc_train_n = 0, 0
        nb_train_e, nb_train_n = 0, 0
        
        model.train()
        
        # training loop
        pb_training = tqdm(dt_train, desc='train', leave=False)
        for idx, (g, labels) in enumerate(pb_training):
#             mask_n = labels['node'] != CTU13Dataset.label_background
#             mask_e = labels['edge'] != CTU13Dataset.label_background
            label_n = (labels['node'] == CTU13Dataset.label_malicious).long()
            label_e = (labels['edge'] == CTU13Dataset.label_malicious).long()
            
            ft_e = g.edata['features']

            # criterion
            pred_n, pred_e = model(g, ft_e)
            pred_n, pred_e = pred_n.squeeze(), pred_e.squeeze()
            loss = compute_loss(y_pred_n=pred_n, y_pred_e=pred_e,
                               y_true_n=label_n, y_true_e=label_e)
            
            # loss monitoring
            acc_train_e += (pred_e.max(dim=-1)[-1])[label_e.bool()].sum().item()
            acc_train_n += (pred_n.max(dim=-1)[-1])[label_n.bool()].sum().item()
            nb_train_e += label_e.sum().item()
            nb_train_n += label_n.sum().item()
            losses.append(loss.item())
            if idx % freq_show_loss == 0:
                pb_training.set_description(f"loss: {np.mean(losses):.2f}, \
                acc_e:{acc_train_e/nb_train_e:.2f}, \
                acc_n:{acc_train_n/nb_train_n:.2f}")
            
            # weights optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        loss_train, acc_train_n, acc_train_e = np.mean(losses), acc_train_n/nb_train_n, acc_train_e/nb_train_e
        loss_val, acc_val_n, acc_val_e = evaluate(model, dt_val)
        print(f"epoch {epoch}: loss_train={loss_train:.3f} loss_val={loss_val:.3f} | "
              f"acc_train_n={acc_train_n:.3f} acc_val_n={acc_val_n:.3f} | "
              f"acc_train_e={acc_train_e:.3f} acc_val_e={acc_val_e:.3f}") 

    loss_val, acc_val_n, acc_val_e = evaluate(model, dt_val)
    loss_test, acc_test_n, acc_test_e = evaluate(model, dt_test)
    print("After train:")
    print(f"\t Validation: loss:{loss_val:.3f} | acc_n:{acc_val_n:.3f} | acc_e: {acc_val_e:.3f}")
    print(f"\t Test      : loss:{loss_test:.3f} | acc_n:{acc_test_n:.3f} | acc_e: {acc_test_e:.3f}")       
    
train(model, dt_train, dt_val, dt_test, nb_epochs=10)

HBox(children=(FloatProgress(value=0.0, description='eval', max=172.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='eval', max=174.0, style=ProgressStyle(description_width='…

Before train:
	 Validation: loss:0.635 | acc_n:0.000 | acc_e: 0.000
	 Test      : loss:0.638 | acc_n:0.000 | acc_e: 0.000


HBox(children=(FloatProgress(value=0.0, description='epoch', max=10.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='train', max=1383.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='eval', max=172.0, style=ProgressStyle(description_width='…

epoch 0: loss_train=0.101 loss_val=0.058 | acc_train_n=0.611 acc_val_n=0.019 | acc_train_e=0.883 acc_val_e=0.926


HBox(children=(FloatProgress(value=0.0, description='train', max=1383.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='eval', max=172.0, style=ProgressStyle(description_width='…

epoch 1: loss_train=0.044 loss_val=0.048 | acc_train_n=0.802 acc_val_n=0.019 | acc_train_e=0.986 acc_val_e=0.930


HBox(children=(FloatProgress(value=0.0, description='train', max=1383.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='eval', max=172.0, style=ProgressStyle(description_width='…

epoch 2: loss_train=0.035 loss_val=0.052 | acc_train_n=0.823 acc_val_n=0.019 | acc_train_e=0.990 acc_val_e=0.912


HBox(children=(FloatProgress(value=0.0, description='train', max=1383.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='eval', max=172.0, style=ProgressStyle(description_width='…

epoch 3: loss_train=0.033 loss_val=0.051 | acc_train_n=0.840 acc_val_n=0.019 | acc_train_e=0.989 acc_val_e=0.941


HBox(children=(FloatProgress(value=0.0, description='train', max=1383.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='eval', max=172.0, style=ProgressStyle(description_width='…

epoch 4: loss_train=0.029 loss_val=0.032 | acc_train_n=0.857 acc_val_n=0.020 | acc_train_e=0.992 acc_val_e=0.954


HBox(children=(FloatProgress(value=0.0, description='train', max=1383.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='eval', max=172.0, style=ProgressStyle(description_width='…

epoch 5: loss_train=0.024 loss_val=0.035 | acc_train_n=0.870 acc_val_n=0.020 | acc_train_e=0.993 acc_val_e=0.950


HBox(children=(FloatProgress(value=0.0, description='train', max=1383.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='eval', max=172.0, style=ProgressStyle(description_width='…

epoch 6: loss_train=0.023 loss_val=0.036 | acc_train_n=0.880 acc_val_n=0.020 | acc_train_e=0.994 acc_val_e=0.948


HBox(children=(FloatProgress(value=0.0, description='train', max=1383.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='eval', max=172.0, style=ProgressStyle(description_width='…

epoch 7: loss_train=0.021 loss_val=0.030 | acc_train_n=0.887 acc_val_n=0.020 | acc_train_e=0.995 acc_val_e=0.977


HBox(children=(FloatProgress(value=0.0, description='train', max=1383.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='eval', max=172.0, style=ProgressStyle(description_width='…

epoch 8: loss_train=0.019 loss_val=0.026 | acc_train_n=0.899 acc_val_n=0.021 | acc_train_e=0.994 acc_val_e=0.947


HBox(children=(FloatProgress(value=0.0, description='train', max=1383.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='eval', max=172.0, style=ProgressStyle(description_width='…

epoch 9: loss_train=0.016 loss_val=0.034 | acc_train_n=0.915 acc_val_n=0.021 | acc_train_e=0.995 acc_val_e=0.936



HBox(children=(FloatProgress(value=0.0, description='eval', max=172.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='eval', max=174.0, style=ProgressStyle(description_width='…

After train:
	 Validation: loss:0.034 | acc_n:0.021 | acc_e: 0.936
	 Test      : loss:0.020 | acc_n:0.019 | acc_e: 0.987


In [10]:
from sklearn.metrics import classification_report

g, labels = dt_test[5]
ft_edges = g.edata['features']
mask = labels['edge'] != CTU13Dataset.label_background
y_true_n = labels['node'] == CTU13Dataset.label_malicious
y_true_e = labels['edge'] == CTU13Dataset.label_malicious
with torch.no_grad():
    y_pred_n, y_pred_e = model(g, ft_edges)
    _, y_pred_n = y_pred_n.max(dim=-1)
    _, y_pred_e = y_pred_e.max(dim=-1)
print('Eval Nodes Classification:')
print(classification_report(y_true_n, y_pred_n, target_names=['normal', 'malicious']))
print('Eval Edges Classification:')
print(classification_report(y_true_e, y_pred_e, target_names=['normal', 'malicious']))

Eval Nodes Classification:
              precision    recall  f1-score   support

      normal       1.00      0.99      1.00       659
   malicious       0.17      1.00      0.29         1

    accuracy                           0.99       660
   macro avg       0.58      1.00      0.64       660
weighted avg       1.00      0.99      1.00       660

Eval Edges Classification:
              precision    recall  f1-score   support

      normal       1.00      0.99      1.00      1044
   malicious       0.88      1.00      0.94        67

    accuracy                           0.99      1111
   macro avg       0.94      1.00      0.97      1111
weighted avg       0.99      0.99      0.99      1111



In [10]:
b = []
for i, (_, l) in enumerate(dt_test):
    ml = (l['edge']==CTU13Dataset.label_malicious).sum().item()
    b.append((ml, i))
sorted(b)[::-1]

[(866, 131),
 (861, 5),
 (835, 168),
 (824, 159),
 (823, 31),
 (820, 28),
 (795, 114),
 (767, 4),
 (764, 52),
 (583, 151),
 (349, 120),
 (305, 133),
 (263, 104),
 (188, 41),
 (187, 32),
 (158, 37),
 (108, 165),
 (91, 164),
 (90, 100),
 (89, 67),
 (89, 42),
 (88, 71),
 (84, 44),
 (82, 103),
 (76, 149),
 (74, 153),
 (71, 25),
 (70, 160),
 (70, 129),
 (69, 132),
 (69, 102),
 (69, 87),
 (69, 22),
 (67, 13),
 (66, 171),
 (66, 134),
 (66, 116),
 (66, 101),
 (65, 148),
 (65, 112),
 (65, 69),
 (65, 65),
 (64, 162),
 (64, 156),
 (64, 115),
 (64, 113),
 (64, 94),
 (64, 76),
 (64, 73),
 (64, 38),
 (64, 14),
 (63, 163),
 (63, 143),
 (63, 92),
 (63, 90),
 (63, 79),
 (63, 36),
 (63, 7),
 (63, 2),
 (63, 1),
 (62, 147),
 (62, 99),
 (62, 89),
 (62, 62),
 (62, 57),
 (62, 27),
 (61, 166),
 (61, 85),
 (61, 48),
 (61, 19),
 (61, 16),
 (60, 173),
 (60, 140),
 (60, 139),
 (60, 110),
 (60, 95),
 (60, 91),
 (60, 15),
 (60, 3),
 (59, 98),
 (59, 53),
 (58, 157),
 (57, 105),
 (57, 81),
 (56, 60),
 (55, 55),
 (54,

In [None]:
def evaluate(model, graph, features, labels, mask):
    model.eval()
    with th.no_grad():
        logits = model(graph, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = th.max(logits, dim=1)
        correct = th.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

In [None]:
model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)
opt = th.optim.Adam(model.parameters())

In [None]:
from tqdm.notebook import tqdm
model.train()
t = tqdm(range(10))
for epoch in t:
    logits = model(graph, node_features)
    loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
    acc = evaluate(model, graph, node_features, node_labels, valid_mask)
    opt.zero_grad()
    loss.backward()
    opt.step()
    t.set_description(f"val_acc: {acc:0.3f}, loss: {loss.item():0.3f}")

In [None]:
evaluate(model, graph, node_features, node_labels, valid_mask)

In [None]:
class MLPPredictor(nn.Module):
    def __init__(self, in_features, out_classes):
        super().__init__()
        self.W = nn.linear(in_features * 2, out_classes)
    def apply_edges(self, edges):    
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(th.cat([h_u, h_v], 1))
        return {'score': score}
    def forward(self, graph, h):
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']

In [None]:
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.sage = SAGE(in_features, hidden_features, out_features)
        self.pred = DotProductPredictor()
    def forward(self, g, x):
        h = self.sage(g, x)
        h = self.pred(g, h)
        h = th.sigmoid(h)
        return h

In [None]:
src = np.random.randint(0, 100, 500)
dst = np.random.randint(0, 100, 500)
# make it symmetric
edge_pred_graph = dgl.graph((np.concatenate([src, dst]), np.concatenate([dst, src])))
# synthetic node and edge features, as well as edge labels
edge_pred_graph.ndata['feature'] = th.randn(100, 10)
edge_pred_graph.edata['feature'] = th.randn(1000, 10)
edge_pred_graph.edata['label'] = th.randn(1000)
# synthetic train-validation-test splits
edge_pred_graph.edata['train_mask'] = th.zeros(1000, dtype=th.bool).bernoulli(0.6)

In [None]:
node_features = edge_pred_graph.ndata['feature']
edge_label = edge_pred_graph.edata['label']
train_mask = edge_pred_graph.edata['train_mask']
model = Model(10, 32, 16)
opt = th.optim.Adam(model.parameters())
t = tqdm(range(100))
for epoch in t:
    pred = model(edge_pred_graph, node_features)
    loss = ((pred[train_mask] - edge_label[train_mask]) ** 2).mean()
    opt.zero_grad()
    loss.backward()
    opt.step()
    t.set_description(f"loss: {loss.item():.03f}")