In [3]:
import torch, dgl
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from dgl.data import DGLDataset, CoraGraphDataset
import dgl.function as fn
import re
import numpy as np
import pandas as pd
import os.path as osp
from colorama import Fore
from glob import glob


In [4]:
class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    
    def __init__(self, in_feat:int, out_feat:int):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)

    def forward(self, g, h):
        """Forward computation

        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        """
        with g.local_scope():
            g.ndata['h'] = h
            # update_all is a message passing API.
            g.update_all(message_func = fn.copy_u('h', 'm'), 
                         reduce_func = fn.mean('m', 'h_N'))
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)


class GraphSAGE(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

In [5]:
class DkkdGraphDataset(DGLDataset):
    def __init__(self, root:str='dataset/graph_DKKD'):
        super().__init__(name='dataset/Karate_Club')
        self.root = root
        self.edges = glob(osp.join(root, '*.edges.csv'))
        self.nodes_feat = glob(osp.join(root, '*.nfeat.npy'))
        self.nodes_label = glob(osp.join(root, '*.idx.csv'))

    def __len__(self): return len(self.edges)
    def __getitem__(self, i):
        edgep = self.edges[i]
        nodes_feat = np.load(re.sub('.edges.csv$', '.nfeat.npy', edgep))
        nodes_label = pd.read_csv(
            re.sub('.edges.csv$', '.nfeat.npy', edgep), encoding='utf-8')
        
        def _get_n_nodes():
            r"""
            tính và kiểm tra số thứ tự của node
            """
            n_nodes = nodes_label['Id'].to_list()
            for i, idx in enumerate(n_nodes):
                assert i == idx, 'i != idx'
            return len(n_nodes)
        n_nodes = _get_n_nodes()
        
        nodes_label = nodes_label['label'].astype('category').cat.codes.to_list()
        edge = pd.read_csv(edgep, encoding='utf-8')
        
        g = dgl.graph((edge['src'], edge['dst']), num_nodes=n_nodes)
        g = dgl.to_bidirected(g)
        g.ndata['feat' ] = torch.from_numpy(nodes_feat )
        g.ndata['label'] = torch.tensor    (nodes_label)
        g.ndata['train_mask'] = torch.ones (n_nodes, dtype=torch.bool)
        g.ndata['val_mask'  ] = torch.zeros(n_nodes, dtype=torch.bool)
        g.ndata['test_mask' ] = torch.zeros(n_nodes, dtype=torch.bool)
        
        return g
    
    def process(self): ...

In [6]:
dataset = DkkdGraphDataset()