有几个模块

1. DataModule：实现数据模块，能够分批加载训练、验证数据
2. SAGELightning：继承LightningModule的模块，用于sage的信息传递
3. CrossEntropyLoss：定义损失函数，graphsage里的无监督训练
4. UnsuptervisedClassfication：无监督训练，产出emb后进行下游的分类任务


## 1. data loader

In [22]:
import dgl
import torch as th
from dgl.data import AsNodePredDataset
import numpy as np 

In [2]:
def load_cora():
    data0 = dgl.data.CSVDataset('../graph_dgl/cora_csv/')
    data = AsNodePredDataset(data0, split_ratio=(0.5,0.2,0.3))
    g = data[0]
    g.ndata["features"] = g.ndata.pop("feat")
    g.ndata["labels"] = g.ndata.pop("label")
    return g, data.num_classes

In [3]:
def load_reddit(self_loop=True, raw_dir='~/.dgl/'):
    from dgl.data import RedditDataset

    # load reddit data
    data = RedditDataset(self_loop=self_loop, raw_dir=raw_dir)
    g = data[0]
    g.ndata["features"] = g.ndata.pop("feat")
    g.ndata["labels"] = g.ndata.pop("label")
    return g, data.num_classes

In [4]:
g, n_classes = load_cora()

Done loading data from cached files.


In [5]:
n_edges = g.num_edges()
reverse_eids = torch.cat([
    torch.arange(n_edges // 2, n_edges),
    torch.arange(0, n_edges // 2)])

In [6]:
n_edges

5429

In [7]:
reverse_eids

tensor([2714, 2715, 2716,  ..., 2711, 2712, 2713])

In [10]:
train_nid = th.nonzero(g.ndata['train_mask'], as_tuple=True)[0]
val_nid = th.nonzero(g.ndata['val_mask'], as_tuple=True)[0]
test_nid = th.nonzero(~(g.ndata['train_mask'] | g.ndata['val_mask']), as_tuple=True)[0]

In [11]:
def inductive_split(g):
    """Split the graph into training graph, validation graph, and test graph by training
    and validation masks.  Suitable for inductive models."""
    train_g = g.subgraph(g.ndata["train_mask"])
    val_g = g.subgraph(g.ndata["train_mask"] | g.ndata["val_mask"])
    test_g = g
    return train_g, val_g, test_g

In [15]:
fan_out=[10, 25]
base_sampler = dgl.dataloading.MultiLayerNeighborSampler([int(_) for _ in fan_out])

## 2. nagetive sampler

In [18]:
class NegativeSampler(object):
    def __init__(self, g, k, neg_share=False, device=None):
        if device is None:
            device = g.device
        self.weights = g.in_degrees().float().to(device) ** 0.75
        self.k = k # k的作用？？？
        self.neg_share = neg_share

    def __call__(self, g, eids):
        src, _ = g.find_edges(eids)
        n = len(src)
        if self.neg_share and n % self.k == 0:
            dst = self.weights.multinomial(n, replacement=True)
            dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten()
        else:
            dst = self.weights.multinomial(n * self.k, replacement=True)
        src = src.repeat_interleave(self.k)
        return src, dst


In [20]:
sampler = dgl.dataloading.as_edge_prediction_sampler(
    # 如果是无监督训练，仅有的label，也就是节点之间的连接。所以利用节点的
            base_sampler, exclude='reverse_id',
            reverse_eids=reverse_eids,
            negative_sampler=NegativeSampler(g, 1, False))

In [28]:
train_dataloder = dgl.dataloading.DataLoader(
            g,
            np.arange(g.num_edges()),
            sampler,
            device=th.device('cpu'),
            batch_size=128,
            shuffle=True,
            drop_last=False,
            num_workers=0)

In [29]:
one_batch = next(iter(train_dataloder))

In [31]:
len(one_batch) # input_nodes, pos_graph, neg_graph, mfgs = batch
"""
一个batch有4个返回。第一个是input的节点，第二是 正样本的图，第三是 负样本的图，第四是 各层的 block（但如何区分是正还是负？）
"""

4

In [37]:
one_batch[1], '', one_batch[2], one_batch[3]

(Graph(num_nodes=342, num_edges=128,
       ndata_schemes={'train_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'features': Scheme(shape=(1433,), dtype=torch.int64), 'labels': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}
       edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}),
 '',
 Graph(num_nodes=342, num_edges=128,
       ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
       edata_schemes={}),
 [Block(num_src_nodes=845, num_dst_nodes=721, num_edges=1316),
  Block(num_src_nodes=721, num_dst_nodes=342, num_edges=754)])

In [38]:
b0, b1, b2, b3 = one_batch

In [40]:
b0.shape

torch.Size([845])

In [74]:
b1.number_of_nodes()

342

In [67]:
b1.ndata.keys()

dict_keys(['train_mask', 'val_mask', 'test_mask', 'features', 'labels', '_ID'])

In [59]:
b1.edges()[0].shape

torch.Size([128])

In [64]:
b2.dstnodes()

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
        154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
        168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 1

In [49]:
b2.edges()

(tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  10,  11,   9,
          12,  13,  14,  15,  16,  17,  10,  18,  19,  20,  21,  22,  23,  24,
          25,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,   9,
          38,  39,  40,   4,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,
          51,  52,  53,  54,  55,  56,  10,  57,  58,  59,  60,  61,  62,  63,
          64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
          78,  79,  80,   7,  81,  82,  71,   4,  83,  84,  85,  86,  87,  88,
          89,  90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101,   2,
         102, 103,   4, 104, 105,  10, 106, 101, 107, 108, 109,   4, 110, 111,
         112, 113]),
 tensor([  5, 179, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241,  56,
         242, 101, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254,
         255, 256, 257, 258, 259, 252, 260, 261, 262, 263, 264,  70,  61, 265,
         266, 267, 268, 269, 27