Skip to content

Commit

Permalink
[Bug] Fix Reddit
Browse files Browse the repository at this point in the history
  • Loading branch information
dddg617 committed Jan 12, 2024
1 parent d54dca8 commit c3c06d8
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 17 deletions.
15 changes: 9 additions & 6 deletions examples/graphsage/reddit_sage.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ def main(args):
test_idx = mask_to_index(graph.test_mask)
val_idx = mask_to_index(graph.val_mask)

train_loader = NeighborSampler(edge_index=graph.edge_index.numpy(),
node_idx=tlx.convert_to_numpy(train_idx),
train_loader = NeighborSampler(edge_index=graph.edge_index,
node_idx=train_idx,
sample_lists=[25, 10], batch_size=2048, shuffle=True, num_workers=0)

val_loader = NeighborSampler(edge_index=graph.edge_index.numpy(),
node_idx=tlx.convert_to_numpy(val_idx),
val_loader = NeighborSampler(edge_index=graph.edge_index,
node_idx=val_idx,
sample_lists=[-1], batch_size=2048 * 2, shuffle=False, num_workers=0)
test_loader = NeighborSampler(edge_index=graph.edge_index.numpy(),
node_idx=tlx.convert_to_numpy(test_idx),
test_loader = NeighborSampler(edge_index=graph.edge_index,
node_idx=test_idx,
sample_lists=[-1], batch_size=2048 * 2, shuffle=False, num_workers=0)

x = tlx.convert_to_tensor(graph.x)
Expand All @@ -78,6 +78,9 @@ def main(args):
pbar = tqdm(total=int(len(train_loader.dataset)))
pbar.set_description(f'Epoch {epoch:02d}')
for dst_node, n_id, adjs in train_loader:
print("---------------")
print(adjs)
print(type(adjs))
net.set_train()
# input : sampled subgraphs, sampled node's feat
data = {"x": tlx.gather(x, n_id),
Expand Down
6 changes: 3 additions & 3 deletions gammagl/datasets/reddit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def download(self):

def process(self):
data = np.load(osp.join(self.raw_dir, 'reddit_data.npz'))
x = np.array(data['feature'], dtype=np.float32)
y = np.array(data['label'], np.int32)
x = tlx.convert_to_tensor(data['feature'], dtype=tlx.float32)
y = tlx.convert_to_tensor(data['label'], dtype=tlx.int64)
split = np.array(data['node_types'])

adj = sp.load_npz(osp.join(self.raw_dir, 'reddit_graph.npz'))

edge = np.array([adj.row, adj.col], dtype=np.int64)
edge = tlx.convert_to_tensor([adj.row, adj.col], dtype=tlx.int64)

edge, _ = coalesce(edge, None, x.shape[0], x.shape[0])

Expand Down
8 changes: 0 additions & 8 deletions gammagl/models/graphsage.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,20 @@ def __init__(self, in_feat, hid_feat, out_feat, drop_rate, num_layers, name=None

def forward(self, x, edgeIndices):
for l, (layer, edgeIndex) in enumerate(zip(self.convs, edgeIndices)):
if tlx.BACKEND == 'torch':
edgeIndex.to(x.device)
target_x = tlx.gather(x, tlx.arange(0, edgeIndex.size[1])) # Target nodes are always placed first.
x = layer((x, target_x), edgeIndex.edge_index)
if l != len(self.convs) - 1:
x = self.dropout(x)
return x

def inference(self, feat, dataloader, cur_x):
if tlx.BACKEND == 'torch':
feat = feat.to(cur_x.device)
for l, layer in enumerate(self.convs):
y = tlx.zeros((feat.shape[0], self.num_class if l == len(self.convs) - 1 else self.hid_feat))
if tlx.BACKEND == 'torch':
y = y.to(feat.device)
for dst_node, n_id, adjs in dataloader:
if isinstance(adjs, (List, Tuple)):
sg = adjs[0]
else:
sg = adjs
if tlx.BACKEND == 'torch':
sg.to(y.device)
h = tlx.gather(feat, n_id)
target_feat = tlx.gather(h, tlx.arange(0, sg.size[1]))
h = layer((h, target_feat), sg.edge_index)
Expand Down

0 comments on commit c3c06d8

Please sign in to comment.