In [1]:
import os
import dgl
import random
import time
import torch
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
import pandas as pd
from networkx.algorithms import bipartite
from pathlib import Path

Using backend: pytorch


In [2]:
data_path = Path(os.path.join(os.getcwd(), '../clean_data/human'))

In [3]:
class FileNxData(object):
    """In tis BiGraph, we think (src_nodes == genes) and (tgt_nodes == cells)

    """
    def __init__(self, species, tissue, number, data_path):
        # put species, tissue, number into one tuple for easy param passing
        self.target_file_tuple = (species, tissue, str(number))
        self.G = nx.DiGraph()
        
        # collect location of *_data.csv and *_celltype.csv
        cell_type_file, data_file = self.get_file_path(data_path)
        
        # add gene nodes
        src_list = self.get_src(data_file)
        self.add_src_nodes(src_list, cell_type_file, data_file)
        
        # add cell nodes (cell has type attribute called 'type_name')
        tgt_list = self.get_tgt(cell_type_file)
        self.add_tgt_nodes(tgt_list, cell_type_file, data_file)
        
        # add edges between gene and cell with weight in form of numpy array (weight > 0)
        weight_mat = self.get_weight(data_file)
        self.add_edges_with_weight(src_list, tgt_list, weight_mat)
        
        self.G = self.delete_isolated_nodes()

    def get_file_path(self, data_path):
        species, tissue, number = self.target_file_tuple
        celltype_file_name = f'{species}_{tissue}{number}_celltype.csv'
        data_file_name = f'{species}_{tissue}{number}_data.csv'
        celltype_file = data_path / celltype_file_name
        data_file = data_path / data_file_name
        return celltype_file, data_file

    def get_tgt(self, cell_type_file):
        species, tissue, number = self.target_file_tuple
        cell_name_df = pd.read_csv(cell_type_file, dtype=str, usecols=[0])
        cell_name_list = cell_name_df.values[:,0].tolist()
        cell_name_list = [species + '_' + tissue + '_' + number + '_' + cell_name for cell_name in cell_name_list]
        return cell_name_list

    def get_src(self, data_file):
        gene_name_df = pd.read_csv(data_file, dtype=str, usecols=[0])
        gene_name_list = gene_name_df.values[:,0].tolist()
        return gene_name_list

    def get_weight(self, data_file):
        gene_cell_mat_df = pd.read_csv(data_file)
        gene_cell_mat = gene_cell_mat_df.values[:,1:]
        return gene_cell_mat

    def add_src_nodes(self, src_list, cell_type_file, data_file):
        self.G.add_nodes_from(src_list, bipartite=0)
    
    def add_tgt_nodes(self, tgt_list, cell_type_file, data_file):
        self.G.add_nodes_from(tgt_list, bipartite=1)
        # add cell_type attribute
        species, tissue, number = self.target_file_tuple
        cell_df = pd.read_csv(cell_type_file, usecols=[0,1])
        # change Cell | Cell_type two columsn of csv into {Cell: Cell_type} dict
        cell_dict = cell_df.set_index('Cell').T.to_dict('records')[0]
        # add 'type_name' attributes based on 'Cell_type' in the dict
        for key, value in cell_dict.items():
            self.G.nodes[species + '_' + tissue + '_' + number + '_' + key]['type_name'] = value
    
    def add_edges_with_weight(self,src_list, tgt_list, weight_mat):
        for i in range(len(src_list)):
            for j in range(len(tgt_list)):
                # if weight == 0, we ignore this edge
                if (weight_mat[i][j] > 0):
                    self.G.add_edge(src_list[i], tgt_list[j], weight=weight_mat[i][j])
                    # if bidirected, we need to add this line, else not
                    #self.G.add_edge(tgt_list[j], src_list[i], weight=weight_mat[i][j])

    def delete_isolated_nodes(self):
        self.G.remove_nodes_from(list(nx.isolates(self.G)))
        return self.G
    
    def print_nx_graph(self):
        # Separate by group
        l, r = nx.bipartite.sets(self.G)
        pos = {}
        # Update position for node from each group
        pos.update((node, (1, index)) for index, node in enumerate(l))
        pos.update((node, (2, index)) for index, node in enumerate(r))
        nx.draw(self.G, with_labels=True, pos=pos)
        plt.show()

In [4]:
class SpeciesNxData(object):
    def __init__(self, filebigraph_list):
        graph_list = [graph.G for graph in filebigraph_list]
        self.G = nx.compose_all(graph_list)
        self.node_index_dict = self.save_node_index_dict()
        self.type_name_dict = self.save_type_name_dict()
        self.gene_name_dict, self.cell_name_dict = self.save_gene_cell_name_dict()
        self.gene_num, self.cell_num = self.check_gene_cell_num()
        self.get_gene_embed()
        self.get_cell_embed()
        self.dgl_G = self.nx_to_dgl()
    
    def save_node_index_dict(self):
        # node_index_dict : {"C_1": 1, "10001": 2}
        gene_set = {n for n, d in self.G.nodes(data=True) if d["bipartite"] == 0}
        cell_set = set(self.G) - gene_set
        nx_node_name_list = list(gene_set) + list(cell_set)
        
        node_index_dict = {}
        for i in range(len(nx_node_name_list)):
            node_index_dict[nx_node_name_list[i]] = i
        return node_index_dict
    
    def save_type_name_dict(self):
        # type_name_dict : {"C_1" : "T Cell"}
        type_name_dict = nx.get_node_attributes(self.G, "type_name")
        return type_name_dict
         
    def save_gene_cell_name_dict(self):
        # seperately store gene_name dict and cell_name dict
        # two dict : {"C_1": 0}, {"10001": 1}
        gene_set = {n for n, d in self.G.nodes(data=True) if d["bipartite"] == 0}
        cell_set = set(self.G) - gene_set
        gene_name_dict = {gene_name: idx for idx, gene_name in enumerate(list(gene_set))}
        cell_name_dict = {cell_name: idx for idx, cell_name in enumerate(list(cell_set))}
        return gene_name_dict, cell_name_dict
    
    def check_gene_cell_num(self):
        gene_num = len(self.gene_name_dict)
        cell_num = len(self.cell_name_dict)
        return gene_num, cell_num
    
    def get_gene_embed(self):
        for n in self.gene_name_dict:
            gene_embed = torch.zeros(self.gene_num)
            gene_embed[self.gene_name_dict[n]] = 1
            self.G.nodes[n]['embed'] = gene_embed

    def get_cell_embed(self):
        cell_embed = {}
        for n in self.cell_name_dict:
            cell_embed = torch.zeros(self.gene_num)
            for m in self.gene_name_dict:
                #if self.G.get_edge_data(m, n, default=0):
                #    cell_embed[self.gene_name_dict[m]] = 1
                if self.G.get_edge_data(m, n, default=0) and self.G[m][n]['weight'] > 0:
                    cell_embed[self.gene_name_dict[m]] = 1
            self.G.nodes[n]['embed'] = cell_embed
    
    def nx_to_dgl(self):
        dgl_G = dgl.from_networkx(self.G, edge_attrs=['weight'], node_attrs=['embed'])
        dgl_G = self.make_bidirected(dgl_G)
        return dgl_G
    
    def make_bidirected(self, dgl_G):
        src_node_tensor = dgl_G.edges()[0]
        tgt_node_tensor = dgl_G.edges()[1]
        weight_tensor = dgl_G.edata['weight']
        dgl_G.add_edges(tgt_node_tensor, src_node_tensor)
        return dgl_G
    
    def print_embed(self):
        print(self.gene_name_dict)
        for n, _ in self.G.nodes(data=True):
            print(n, self.G.nodes[n]['embed'])
    
    def print_nx_graph(self):
        # Separate by group
        l, r = nx.bipartite.sets(self.G)
        pos = {}
        # Update position for node from each group
        pos.update((node, (1, index)) for index, node in enumerate(l))
        pos.update((node, (2, index)) for index, node in enumerate(r))
        nx.draw(self.G, with_labels=True, pos=pos)
        plt.show()
    

In [5]:
start1 = time.time()
graph1 = FileNxData('human', 'lung', '8426', data_path)
end1 = time.time()
print('time1 : {}'.format(end1-start1))

#graph1.print_nx_graph()
start2 = time.time()
graph2 = FileNxData('human', 'lung', '6022', data_path)
end2 = time.time()
print('time2 : {}'.format(end2-start2))
#graph2.print_nx_graph()
#start3 = time.time()
#graph3 = SpeciesNxData([graph1, graph2])
#end3 = time.time()
#print('time3 : {}'.format(end3-start3))
#graph3.print_nx_graph()
#dgl_G = graph3.nx_to_dgl()
'''
print('--------------------dgl_G---------------------')
print(dgl_G)
print(dgl_G.nodes())
print(dgl_G.edges())
print(dgl_G.edata['weight'])
print(dgl_G.ndata['embed'])
print('--------------------dgl_G---------------------')
'''

time1 : 267.43579149246216
time2 : 159.8538544178009


"\nprint('--------------------dgl_G---------------------')\nprint(dgl_G)\nprint(dgl_G.nodes())\nprint(dgl_G.edges())\nprint(dgl_G.edata['weight'])\nprint(dgl_G.ndata['embed'])\nprint('--------------------dgl_G---------------------')\n"

In [6]:
class SpeciesDGLData(object):
    def __init__(self, nx_data):
        self.G = nx_data.dgl_G
        self.node_index_dict = nx_data.node_index_dict
        self.type_name_dict = nx_data.type_name_dict
        self.type_index_dict = self.name_to_index(self.type_name_dict)
        self.gene_name_dict, self.cell_name_dict = nx_data.gene_name_dict, nx_data.cell_name_dict
        self.gene_num, self.cell_num = nx_data.gene_num, nx_data.cell_num
        
        self.feature = self.extract_feature()
        self.label = self.extract_label()
        self.train_index, self.valid_index, self.test_index = self.split_cell_nodes()
        
    def name_to_index(self, type_name_dict):
        type_index_dict = {}
        type_set = set()
        for key, value in type_name_dict.items():
            type_set |= {value}
        type_list = list(type_set)
        for key, value in type_name_dict.items():
            type_index_dict[key] = type_list.index(value)
        return type_index_dict
    
    def extract_feature(self):
        feature_tensor = self.G.ndata['embed']
        return feature_tensor
        
    def extract_label(self):
        label_list = []
        for key, _ in self.node_index_dict.items():
            if key in self.type_index_dict.keys():
                label_list += [self.type_index_dict[key]]
            else:
                label_list += [-1]
        label_tensor = torch.tensor(label_list)
        return label_tensor
        
    def split_cell_nodes(self):
        # only calculate loss on cells
        # TODO: here we suppose gene name is smaller than cell name
        cell_dataset = self.G.nodes()[self.gene_num:]
        train_subset, valid_subset, test_subset = dgl.data.utils.split_dataset(cell_dataset, 
                                                                               shuffle=True, 
                                                                               frac_list=[0.7, 0.2, 0.1])
        train_index = train_subset[:]
        valid_index = valid_subset[:]
        test_index = test_subset[:]
        return train_index, valid_index, test_index
    
    def load_data(self):
        return self.G, self.feature, self.label, self.train_index, self.valid_index, self.test_index
    
    def print_dgl_graph(self):
        fig, ax = plt.subplots()
        nx.draw(self.G.to_networkx(), ax=ax)
        plt.show()
        
    def save_graph(self):
        dgl.data.utils.save_graphs("./dgl_data.bin", self.G)
        
    def load_graph(self):
        glist, _ = dgl.data.utils.load_graphs("./data.bin")

In [33]:
start = time.time()
nx_data = SpeciesNxData([graph1, graph2])
end = time.time()
print(end - start)
#print(nx_data.type_name_dict)
dgl_data = SpeciesDGLData(nx_data)
#print(dgl_data.G.edges())
#dgl_data.G.edata['weight']
#print(dgl_data.test_mask)
#print(dgl_data.train_mask)
#print(dgl_data.G.nodes())
#print(dgl_data.node_index_dict)
#print(dgl_data.G.srcdata)
#print(dgl_data.G.dstdata)
#dgl_data.extract_label()
#dgl_data.print_dgl_graph()
print(dgl_data.train_index)
print(dgl_data.test_index)
print(dgl_data.valid_index)
#dgl_data.save_graph()
#sg = dgl.sampling.select_topk(g=dgl_data.G, k=2, weight='weight', nodes=dgl_data.G.nodes())
#fig, ax = plt.subplots()
#nx.draw(sg.to_networkx(), ax=ax)
#plt.show()
#dgl_data.G.nodes()[:dgl_data.gene_num]

2.8133392333984375e-05
tensor([18491, 22958, 25156,  ..., 17796, 19552, 20332])
tensor([25303, 22335, 16544,  ..., 26906, 16759, 27449])
tensor([27698, 26187, 24146,  ..., 22476, 21708, 19343])


In [None]:
import pickle

file = open('nx_data.pkl', 'wb')
pickle.dump(nx_data, file)

In [36]:
file = open('dgl_data.pkl', 'wb')
pickle.dump(dgl_data, file)

In [41]:
start = time.time()
file1 = open('dgl_data.pkl', 'rb')
file2 = open('nx_data.pkl', 'rb')
dgl_data = pickle.load(file1)
#nx_data = pickle.load(file2)
end = time.time()
print(end - start)

7.065536022186279


30587

In [8]:
import dgl
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.nn.pytorch.conv import SAGEConv
from torch.utils.data import DataLoader
import tqdm

In [29]:
class GraphSAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        # block is the bipartite graph we sample. Here it is used for message passing.
        # x is node feature
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h_dst = h[:block.number_of_dst_nodes()]
            h = layer(block, (h, h_dst))
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h

    def inference(self, g, x, batch_size, device):
        # inference 用于评估测试，针对的是完全图
        # 目前会出现重复计算的问题，优化方案还在 to do list 上
        nodes = torch.arange(g.number_of_nodes())
        for l, layer in enumerate(self.layers):
            y = torch.zeros(g.number_of_nodes(), 
                         self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
            for start in tqdm.trange(0, len(nodes), batch_size):
                end = start + batch_size
                batch_nodes = nodes[start:end]
                block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
                input_nodes = block.srcdata[dgl.NID]
                h = x[input_nodes].to(device)
                h_dst = h[:block.number_of_dst_nodes()]
                h = layer(block, (h, h_dst))
                if l != len(self.layers) - 1:
                    h = self.activation(h)
                    h = self.dropout(h)
                y[start:end] = h.cpu()
            x = y
        return y

In [30]:
class NeighborSampler(object):
    def __init__(self, g, fanouts):
        """
        g 为 DGLGraph；
        fanouts 为采样节点的数量，实验使用 10,25，指一阶邻居采样 10 个，二阶邻居采样 25 个。
        """
        self.g = g
        self.fanouts = fanouts

    def sample_blocks(self, seeds):
        seeds = torch.LongTensor(np.asarray(seeds))
        blocks = []
        for fanout in self.fanouts: 
            # sample_neighbors 可以对每一个种子的节点进行邻居采样并返回相应的子图
            # replace=True 表示用采样后的邻居节点代替所有邻居节点
            frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout, replace=True)
            # 将图转变为可以用于消息传递的二部图（源节点和目的节点）
            # 其中源节点的 id 也可能包含目的节点的 id（原因上面说了）
            # 转变为二部图主要是为了方便进行消息传递
            block = dgl.to_block(frontier, seeds)
            # 获取新图的源节点作为种子节点，为下一层作准备
            # 之所以是从 src 中获取种子节点，是因为采样操作相对于聚合操作来说是一个逆向操作
            seeds = block.srcdata[dgl.NID]
            # 把这一层放在最前面。
            # PS：如果数据量大的话，插入操作是不是不太友好。
            blocks.insert(0, block)
        return blocks

In [73]:
  
class Runner(object):
    def __init__(self, dgl_data):
        # init and load data from dgl_data object
        self.g, self.features, self.labels, self.train_index, self.valid_index, self.test_index = dgl_data.load_data()        
        self.batch_size = 512
        self.epochs = 300
        self.device = 'cpu'
        
        # init the model
        feat_size = self.features.shape[1]
        n_hidden = 512
        n_classes = torch.max(self.labels).item()+1
        n_layers = 1
        activation = F.relu
        dropout = 0.5
        self.model = GraphSAGE(feat_size, n_hidden, n_classes, n_layers, activation, dropout).to(self.device)
    
        # init the optimizer
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)

        # init sampler and dataloader to produce batch shape data
        fan_out = '5'
        num_workers = 4
        # Create sampler
        self.sampler = NeighborSampler(self.g, [int(fanout) for fanout in fan_out.split(',')])

        # Create PyTorch DataLoader for constructing blocks
        self.dataloader = DataLoader(
            dataset=self.train_index.numpy(),
            batch_size=self.batch_size,
            collate_fn=self.sampler.sample_blocks,
            shuffle=True,
            drop_last=False,
            num_workers=num_workers)
    
    def compute_loss(self, logits, labels):
        logp = F.log_softmax(logits, 1)
        loss = F.nll_loss(logp, labels)
        return loss
    
    def compute_acc(self, pred, labels):
        return (torch.argmax(pred, dim=1) == labels).float().sum() / len(pred)
 
    '''
    def evaluate(self, model, g, inputs, labels, val_index, batch_size, device):
        model.eval()
        with torch.no_grad():
            pred = model.inference(g, inputs, batch_size, device)
        model.train()
        return self.compute_acc(pred[val_index], labels[val_index])
    '''
    
    def evaluate(self, model, g, features, labels, index):
        model.eval()
        with torch.no_grad():
            dataloader = DataLoader(
                dataset=self.valid_index.numpy(),
                batch_size=self.batch_size,
                collate_fn=self.sampler.sample_blocks,
                shuffle=True,
                drop_last=False,
                num_workers=4)
            
            for step, blocks in enumerate(dataloader):
                tic_step = time.time()
                
                input_nodes = blocks[0].srcdata[dgl.NID]
                seeds = blocks[-1].dstdata[dgl.NID]

                # Load the input features as well as output labels
                batch_inputs, batch_labels = self.load_subtensor(g, labels, seeds, input_nodes, self.device)

                # Compute loss and prediction
                batch_preds = self.model(blocks, batch_inputs)
                
                acc = self.compute_acc(batch_pred, batch_labels)

            return acc
    
    
    def train(self):
        # Training loop
        avg = 0
        iter_output = []
        for epoch in range(self.epochs):
            tic = time.time()

            for step, blocks in enumerate(self.dataloader):
                tic_step = time.time()
                
                input_nodes = blocks[0].srcdata[dgl.NID]
                seeds = blocks[-1].dstdata[dgl.NID]

                # Load the input features as well as output labels
                batch_inputs, batch_labels = self.load_subtensor(self.g, self.labels, seeds, input_nodes, self.device)

                # Compute loss and prediction
                batch_pred = self.model(blocks, batch_inputs)
                loss = self.compute_loss(batch_pred, batch_labels)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                iter_output.append(len(seeds) / (time.time() - tic_step))
                if step % 2 == 0:
                    acc = self.compute_acc(batch_pred, batch_labels)
                    gpu_mem_alloc = torch.cuda.max_memory_allocated() / 1000000 if torch.cuda.is_available() else 0
                    print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB'.format(
                        epoch, step, loss.item(), acc.item(), np.mean(iter_output[3:]), gpu_mem_alloc))

            toc = time.time()
            #print('Epoch Time(s): {:.4f}'.format(toc - tic))
            if epoch >= 5:
                avg += toc - tic
            if  epoch != 0:
                print(self.valid_index)
                eval_acc = self.evaluate(self.model, self.g, self.features, self.labels, self.valid_index)
                print('Eval Acc {:.4f}'.format(eval_acc))

        print('Avg epoch time: {}'.format(avg / (epoch - 4)))


        losses = []
        self.model.train()
        # model forward
        logits = self.model(self.g, self.features)
        # loss calculate
        labels = self.labels
        loss = self.calculate_loss(logits, labels)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        losses.append(loss)
        return np.mean(losses)
    
    def load_subtensor(self, g, labels, seeds, input_nodes, device):
        """
        将一组节点的特征和标签复制到 GPU 上。
        """
        batch_inputs = self.features[input_nodes].to(device)
        batch_labels = labels[seeds].to(device)
        return batch_inputs, batch_labels

In [74]:
runner = Runner(dgl_data)
runner.train()

Epoch 00000 | Step 00000 | Loss 6.3098 | Train Acc 0.0020 | Speed (samples/sec) nan | GPU 0.0 MiB
Epoch 00000 | Step 00002 | Loss 6.2631 | Train Acc 0.0000 | Speed (samples/sec) nan | GPU 0.0 MiB
Epoch 00000 | Step 00004 | Loss 6.1315 | Train Acc 0.0137 | Speed (samples/sec) 270.5808 | GPU 0.0 MiB
Epoch 00000 | Step 00006 | Loss 6.1212 | Train Acc 0.0215 | Speed (samples/sec) 263.9268 | GPU 0.0 MiB
Epoch 00000 | Step 00008 | Loss 5.9569 | Train Acc 0.0586 | Speed (samples/sec) 269.4014 | GPU 0.0 MiB
Epoch 00000 | Step 00010 | Loss 5.9528 | Train Acc 0.0742 | Speed (samples/sec) 265.3685 | GPU 0.0 MiB
Epoch 00000 | Step 00012 | Loss 5.8687 | Train Acc 0.0840 | Speed (samples/sec) 285.2210 | GPU 0.0 MiB
Epoch 00000 | Step 00014 | Loss 5.8247 | Train Acc 0.1270 | Speed (samples/sec) 300.4771 | GPU 0.0 MiB
Epoch 00000 | Step 00016 | Loss 5.7148 | Train Acc 0.1641 | Speed (samples/sec) 303.6623 | GPU 0.0 MiB
Epoch 00000 | Step 00018 | Loss 5.6911 | Train Acc 0.1719 | Speed (samples/sec) 303

RuntimeError: The size of tensor a (329) must match the size of tensor b (2889) at non-singleton dimension 0

In [None]:
import dgl
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.nn.pytorch.conv import SAGEConv

class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
    
    def forward(self, g, feat):
        # Creating a local scope so that all the stored ndata and edata
        # (such as the `'h'` ndata below) are automatically popped out
        # when the scope exits.
        print(feat)
        with g.local_scope():
            g.srcdata['h1'] = feat
            g.update_all(self.edge_message, self.node_reduce)
            h = g.ndata['h2']
            return self.linear(h)
    
    @staticmethod
    def node_reduce(nodes):
        # nodes.data['h'] is a tensor of shape (N, 1),
        # nodes.mailbox['m'] is a tensor of shape (N, D, 1),
        # where N is the number of nodes in the batch,
        # D is the number of messages received per node for this node batch
        return {'h2': nodes.dstdata['h1'] + nodes.mailbox['m'].mean(1)}
    
    @staticmethod
    def edge_message(edges):
        # edges.src with the shape of (#edges, embed_dim)
        # edges.data['weight'] with the shape of (#edges, 1)
        # broadcast in order to update with weight
        w = edges.data['weight'].float().reshape(-1, 1)
        return {'m': edges.src['h1'] * w}

class GCN(nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        # TODO: modify the layers in the GCN layer
        # input and output for the GCN should be modified to real numbers
        self.layer1 = GCNLayer(8, 512)
        self.layer2 = GCNLayer(512, 6)

    def forward(self, blocks, input_nodes, output_nodes):
        print(blocks[0].ndata['embed'])
        print(input_nodes)
        print(output_nodes)
        for block in blocks:
            x = F.relu(self.layer1(block, block.ndata['embed']))
            print(x)
            x = self.layer2(block, x)
        return x
    
class GraphSAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout,
                 aggregator_type):
        super(GraphSAGE, self).__init__()
        self.layers = nn.ModuleList()
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

        # input layer
        self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type))
        # hidden layers
        for i in range(n_layers - 1):
            self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type))
        # output layer
        self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type)) # activation None

    def forward(self, graph, inputs):
        h = self.dropout(inputs)
        for l, layer in enumerate(self.layers):
            h = layer(graph, h)
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h
    
class Runner(object):
    def __init__(self, dgl_data):
        self.optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        self.g, self.features, self.labels, self.train_index, self.valid_index, self.test_index = dgl_data.load_data()
        self.epochs = 300
        
        feat_size = self.features.shape[1]
        n_hidden = 32
        n_classes = torch.max(labels).item()+1
        n_layers = 3
        activation = F.relu
        dropout = 0.5
        aggregator_type = 'gcn'
        
        self.model = GraphSAGE(feat_size, n_hidden, n_classes, n_layers, activation, dropout, aggregator_type)
    
    def compute_loss(self, logits, labels):
        print(logits.shape)
        print(labels.shape)
        logp = F.log_softmax(logits, 1)
        loss = F.nll_loss(logp, labels)
        return loss
    
    def compute_acc(pred, labels):
        return (torch.argmax(pred, dim=1) == labels).float().sum() / len(pred)
    
    def evaluate(self, model, g, features, labels, index):
        model.eval()
        with torch.no_grad():
            logits = model(g, features)
            logits = logits[index]
            labels = labels[index]
            _, indices = torch.max(logits, dim=1)
            correct = torch.sum(indices == labels)
            return correct.item() * 1.0 / len(labels)
    
    def train(self):
        # Training loop
        avg = 0
        iter_output = []
        for epoch in range(self.epochs):
            tic = time.time()

            for step, blocks in enumerate(self.dataloader):
                tic_step = time.time()

                input_nodes = blocks[0].srcdata[dgl.NID]
                seeds = blocks[-1].dstdata[dgl.NID]

                # Load the input features as well as output labels
                batch_inputs, batch_labels = self.load_subtensor(g, labels, seeds, input_nodes, device)

                # Compute loss and prediction
                batch_pred = self.model(blocks, batch_inputs)
                loss = self.compute_loss(batch_pred, batch_labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                iter_tput.append(len(seeds) / (time.time() - tic_step))
                if step % log_every == 0:
                    acc = self.compute_acc(batch_pred, batch_labels)
                    gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
                    print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB'.format(
                        epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc))

            toc = time.time()
            print('Epoch Time(s): {:.4f}'.format(toc - tic))
            if epoch >= 5:
                avg += toc - tic
            if epoch % eval_every == 0 and epoch != 0:
                eval_acc = evaluate(self.model, self.g, self.g.ndata['features'], self.labels, val_mask, batch_size, device)
                print('Eval Acc {:.4f}'.format(eval_acc))

        print('Avg epoch time: {}'.format(avg / (epoch - 4)))


        losses = []
        self.model.train()
        # model forward
        logits = self.model(self.g, self.features)
        # loss calculate
        labels = self.labels
        loss = self.calculate_loss(logits, labels)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        losses.append(loss)
        return np.mean(losses)
    
    def load_subtensor(self, g, labels, seeds, input_nodes, device):
        """
        将一组节点的特征和标签复制到 GPU 上。
        """
        batch_inputs = g.ndata['features'][input_nodes].to(device)
        batch_labels = labels[seeds].to(device)
        return batch_inputs, batch_labels

In [None]:
runner = Runner(dgl_data)

In [None]:
def evaluate(model, g, features, labels, index):
    model.eval()
    with torch.no_grad():
        logits = model(g, features)
        logits = logits[index]
        labels = labels[index]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

In [None]:
import numpy as np
import torch







dur = []
for epoch in range(300):
    for input_nodes, output_nodes, blocks in dataloader:
        t0 = time.time()
        model.train()
        logits = model(blocks)
        logp = F.log_softmax(logits, 1)
        loss = F.nll_loss(logp[input_nodes], labels[output_nodes])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        dur.append(time.time() - t0)

        acc = evaluate(gcn, g, features, labels, train_index)
        print("Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
                epoch, loss.item(), acc, np.mean(dur)))

In [None]:
sampler = dgl.dataloading.MultiLayerNeighborSampler([1,1,1])
dataloader = dgl.dataloading.NodeDataLoader(
    g, g.nodes()[train_index], sampler,
    batch_size=1, shuffle=True, drop_last=False, num_workers=4)
for input_nodes, output_nodes, blocks in dataloader:
    print(input_nodes)
    print(output_nodes)
    print(blocks)
    print(blocks[0].dstdata['_ID'], blocks[0].srcdata['_ID'])
    print(blocks[1].dstdata['_ID'], blocks[1].srcdata['_ID'])
    print(blocks[2].dstdata['_ID'], blocks[2].srcdata['_ID'])
    print(blocks[0].ndata['embed'])

In [None]:
# GraphSAGE 的代码实现
class GraphSAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        # block 是我们采样获得的二部图，这里用于消息传播
        # x 为节点特征
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h_dst = h[:block.number_of_dst_nodes()]
            h = layer(block, (h, h_dst))
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h

    def inference(self, g, x, batch_size, device):
        # inference 用于评估测试，针对的是完全图
        # 目前会出现重复计算的问题，优化方案还在 to do list 上
        nodes = th.arange(g.number_of_nodes())
        for l, layer in enumerate(self.layers):
            y = th.zeros(g.number_of_nodes(), 
                         self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
            for start in tqdm.trange(0, len(nodes), batch_size):
                end = start + batch_size
                batch_nodes = nodes[start:end]
                block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
                input_nodes = block.srcdata[dgl.NID]
                h = x[input_nodes].to(device)
                h_dst = h[:block.number_of_dst_nodes()]
                h = layer(block, (h, h_dst))
                if l != len(self.layers) - 1:
                    h = self.activation(h)
                    h = self.dropout(h)
                y[start:end] = h.cpu()
            x = y
        return y

def compute_acc(pred, labels):
    """
    计算准确率
    """
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

def evaluate(model, g, inputs, labels, val_mask, batch_size, device):
    """
    评估模型，调用 model 的 inference 函数
    """
    model.eval()
    with th.no_grad():
        pred = model.inference(g, inputs, batch_size, device)
    model.train()
    return compute_acc(pred[val_mask], labels[val_mask])

def load_subtensor(g, labels, seeds, input_nodes, device):
    """
    将一组节点的特征和标签复制到 GPU 上。
    """
    batch_inputs = g.ndata['features'][input_nodes].to(device)
    batch_labels = labels[seeds].to(device)
    return batch_inputs, batch_labels

# 参数设置
gpu = -1
num_epochs = 20
num_hidden = 16
num_layers = 2
fan_out = '10,25'
batch_size = 1000
log_every = 20  # 记录日志的频率
eval_every = 5
lr = 0.003
dropout = 0.5
num_workers = 0  # 用于采样进程的数量

if gpu >= 0:
    device = th.device('cuda:%d' % gpu)
else:
    device = th.device('cpu')

# load reddit data
# NumNodes: 232965
# NumEdges: 114848857
# NumFeats: 602
# NumClasses: 41
# NumTrainingSamples: 153431
# NumValidationSamples: 23831
# NumTestSamples: 55703
data = RedditDataset(self_loop=True)
train_mask = data.train_mask
val_mask = data.val_mask
features = th.Tensor(data.features)
in_feats = features.shape[1]
labels = th.LongTensor(data.labels)
n_classes = data.num_labels
# Construct graph
g = dgl.graph(data.graph.all_edges())
g.ndata['features'] = features

In [None]:
train_nid = th.LongTensor(np.nonzero(train_mask)[0])
val_nid = th.LongTensor(np.nonzero(val_mask)[0])
train_mask = th.BoolTensor(train_mask)
val_mask = th.BoolTensor(val_mask)

# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in fan_out.split(',')])

# Create PyTorch DataLoader for constructing blocks
# collate_fn 参数指定了 sampler，可以对 batch 中的节点进行采样
dataloader = DataLoader(
    dataset=train_nid.numpy(),
    batch_size=batch_size,
    collate_fn=sampler.sample_blocks,
    shuffle=True,
    drop_last=False,
    num_workers=num_workers)

# Define model and optimizer
model = GraphSAGE(in_feats, num_hidden, n_classes, num_layers, F.relu, dropout)
model = model.to(device)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop
avg = 0
iter_tput = []
for epoch in range(num_epochs):
    tic = time.time()

    for step, blocks in enumerate(dataloader):
        tic_step = time.time()

        input_nodes = blocks[0].srcdata[dgl.NID]
        seeds = blocks[-1].dstdata[dgl.NID]

        # Load the input features as well as output labels
        batch_inputs, batch_labels = load_subtensor(g, labels, seeds, input_nodes, device)

        # Compute loss and prediction
        batch_pred = model(blocks, batch_inputs)
        loss = loss_fcn(batch_pred, batch_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iter_tput.append(len(seeds) / (time.time() - tic_step))
        if step % log_every == 0:
            acc = compute_acc(batch_pred, batch_labels)
            gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
            print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB'.format(
                epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc))

    toc = time.time()
    print('Epoch Time(s): {:.4f}'.format(toc - tic))
    if epoch >= 5:
        avg += toc - tic
    if epoch % eval_every == 0 and epoch != 0:
        eval_acc = evaluate(model, g, g.ndata['features'], labels, val_mask, batch_size, device)
        print('Eval Acc {:.4f}'.format(eval_acc))

print('Avg epoch time: {}'.format(avg / (epoch - 4)))