In [1]:
import dgl

g_path = '/root/autodl-tmp/source/pprogo-flg/data/bp/graph.dgl'
g, _ = dgl.load_graphs(g_path)
g = g[0]
ppi = dgl.node_subgraph(g, {'protein': range(g.num_nodes('protein'))})
ppi = dgl.to_homogeneous(ppi, ndata = 'h')

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
import dgl
import torch
import torch.nn.functional as F


class GCN(torch.nn.Module):
    def __init__(self, input_features, hidden_size, num_classes, dropout=0.5, num_gcn=0):
        super().__init__()
        self.dropout = torch.nn.Dropout(dropout)
        self.num_gcn = num_gcn
        self.input = torch.nn.Linear(input_features, hidden_size)
        self.conv1 = dgl.nn.GraphConv(hidden_size, hidden_size)
        self.conv2 = dgl.nn.GraphConv(hidden_size, hidden_size)
        self.output = torch.nn.Linear(hidden_size, num_classes)
        self.input_bias = torch.nn.Parameter(torch.zeros(hidden_size))
        
        
    def forward(self, blocks, x):
        outputs = self.dropout(F.relu(self.input(x)) + self.input_bias)
        outputs = self.conv1(blocks[0], outputs)
        outputs = self.conv2(blocks[1], outputs)
        outputs = self.output(outputs)
        return outputs

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
feature_dim = len(ppi.ndata['h'][0])
hidden_size = 256
num_classes = len(g.ndata['h']['go_annotation'][0])
model = GCN(feature_dim, hidden_size, num_classes)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.BCEWithLogitsLoss()
from tqdm import tqdm

sampler = dgl.dataloading.NeighborSampler([3,3])
dataloader = dgl.dataloading.DataLoader(
    ppi, torch.arange(ppi.num_nodes()), sampler,
    batch_size=8,
    device=device,
    shuffle=True,
    drop_last=False,
    num_workers=4
)
for epoch in range(1):
    loss_all = 0
    dataloader_tqdm = tqdm(dataloader)
    for i, (input_nodes, output_nodes, blocks) in enumerate(dataloader_tqdm):
        input_features = blocks[0].srcdata['h']
        pred = model(blocks, input_features)
        labels = torch.zeros([pred.shape[0], num_classes]).to(device)
        loss = loss_func(pred, labels)
        loss_all += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_all = loss_all / (i+1)
    


100%|██████████| 15510/15510 [05:21<00:00, 48.24it/s]


In [22]:
print(F.sigmoid(pred).max())
print(loss_all)

tensor(2.7639e-24, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(0.0002, device='cuda:0', grad_fn=<DivBackward0>)


In [3]:
import dgl
from dgl.dataloading import BlockSampler

class nodeflowSampler(BlockSampler):
    def __init__(self, fanout, num_layers):
        super().__init__()
        self.fanout = fanout
        self.num_layers = num_layers
        
    def sample(self, g, seed_nodes):
        frontier = dgl.sampling.sample_neighbors(g, seed_nodes, self.fanout)
        
        return frontier
        

In [6]:
import torch

fanout = 5
num_layers = 2
# sampler = nodeflowSampler(fanout, num_layers)
sampler = dgl.dataloading.NeighborSampler([1])

dataloader = dgl.dataloading.DataLoader(
    g, {'protein': torch.arange(g.num_nodes('protein'))}, sampler,
    batch_size=8,
    shuffle=True,
    drop_last=False,
    num_workers=4
)
# sampler = dgl.dataloading.NeighborSampler([3,3])
# dataloader = dgl.dataloading.DataLoader(
#     g, {'protein':torch.arange(g.num_nodes('protein'))}, sampler,
#     batch_size=8, shuffle=True, drop_last=False, num_workers=4)
for i, j, blocks in dataloader:
    print(i)
    print(j)
    print(blocks)
    break



{'go_annotation': tensor([ 5580, 11727,    48, 17745, 16902, 21624,   873, 14156,  4724, 20571,
        16199, 14890, 17325, 15111,  4446,   242,  8556,   461,  1825,  2831,
         2671,  3105, 15689,  9689,  2973, 16038, 13163, 11441, 18951,  8652,
          767, 19398, 19369,  8172, 16268,   543,  2272,  4140, 11677, 10457]), 'protein': tensor([ 32885, 100773,  84355, 100808,  24470,  50295,  48835, 104497,  12404,
         63353,  36421,  32785,   2893,  17806,  60819, 110602,  99974,  53059,
         12280,  94177,  52625,  81683,  86981, 100855,  66669, 100708,  14424,
         16551,  27036,  84221,  62643,  38292,  42225,  20000,  96740,  31275,
         30599,  27334, 107619,  74222, 102975, 111183,  16244, 112304,  28469,
         45082, 108163,  27541,  94368,  93131,  84224, 114423,  87998,  31263,
         82588,  22101,  94728, 113318,  92851, 113710,  22034,  30128, 113296,
        122393,  43214,  22406,  24782,  93222,  77700,  52151,  33575,  10748,
         26937, 1