-
Notifications
You must be signed in to change notification settings - Fork 420
/
tutorial_node_classification_pyg_k8s.md
425 lines (392 loc) · 12.6 KB
/
tutorial_node_classification_pyg_k8s.md
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
# Tutorial: Training a Node Classification Model (PyG) on a K8S Cluster
This tutorial presents a server-client example that illustrates how GraphScope trains the GraphSAGE model (implemented in PyG) for a node classification task on a Kubernetes cluster.
## Set parameters & load graph
```python
import graphscope as gs
from graphscope.dataset import load_ogbn_arxiv
gs.set_option(log_level="DEBUG")
gs.set_option(show_log=True)
params = {
"NUM_SERVER_NODES": 2,
"NUM_CLIENT_NODES": 2,
}
# load the ogbn_arxiv graph as an example.
sess = gs.session(
with_dataset=True,
k8s_service_type="NodePort",
k8s_vineyard_mem="8Gi",
k8s_engine_mem="8Gi",
vineyard_shared_mem="8Gi",
k8s_image_pull_policy="IfNotPresent",
k8s_image_tag="0.26.0a20240115-x86_64",
num_workers=params["NUM_SERVER_NODES"],
)
g = load_ogbn_arxiv(sess=sess, prefix="/dataset/ogbn_arxiv")
```
## Launch the Server Engine
```python
glt_graph = gs.graphlearn_torch(
g,
edges=[
("paper", "citation", "paper"),
],
node_features={
"paper": [f"feat_{i}" for i in range(128)],
},
node_labels={
"paper": "label",
},
edge_dir="out",
random_node_split={
"num_val": 0.1,
"num_test": 0.1,
},
num_clients=params["NUM_CLIENT_NODES"],
# Specify the client yaml with the client pods' configuration.
manifest_path="./client.yaml",
# Specify the client folder path that contains the client scripts.
client_folder_path="./",
)
print("Exiting...")
```
## Configure the parameters for client pods
```yaml
apiVersion: "kubeflow.org/v1"
kind: PyTorchJob
metadata:
name: graphlearn-torch-client
namespace: default
spec:
pytorchReplicaSpecs:
Master:
replicas: 1
restartPolicy: OnFailure
template:
spec:
containers:
- name: pytorch
image: registry.cn-hongkong.aliyuncs.com/graphscope/graphlearn-torch-client:0.26.0a20240115-x86_64
imagePullPolicy: IfNotPresent
command:
- bash
- -c
- |-
python3 /workspace/client.py --node_rank 0 --master_addr ${MASTER_ADDR} --num_server_nodes ${NUM_SERVER_NODES} --num_client_nodes ${NUM_CLIENT_NODES}
volumeMounts:
- mountPath: /dev/shm
name: cache-volume
- mountPath: /workspace
name: client-volume
volumes:
- name: cache-volume
emptyDir:
medium: Memory
sizeLimit: "8G"
- name: client-volume
configMap:
name: graphlearn-torch-client-config
Worker:
replicas: ${NUM_WORKER_REPLICAS}
restartPolicy: OnFailure
template:
spec:
containers:
- name: pytorch
image: registry.cn-hongkong.aliyuncs.com/graphscope/graphlearn-torch-client:0.26.0a20240115-x86_64
imagePullPolicy: IfNotPresent
command:
- bash
- -c
- |-
python3 /workspace/client.py --node_rank $((${MY_POD_NAME: -1}+1)) --master_addr ${MASTER_ADDR} --group_master ${GROUP_MASTER} --num_server_nodes ${NUM_SERVER_NODES} --num_client_nodes ${NUM_CLIENT_NODES}
env:
- name: GROUP_MASTER
value: graphlearn-torch-client-master-0
- name: MY_POD_NAME
valueFrom:
fieldRef:
fieldPath: metadata.name
volumeMounts:
- mountPath: /dev/shm
name: cache-volume
- mountPath: /workspace
name: client-volume
volumes:
- name: cache-volume
emptyDir:
medium: Memory
sizeLimit: "8G"
- name: client-volume
configMap:
name: graphlearn-torch-client-config
```
## Write training and testing script
### Import packages
```python
import argparse
import time
from typing import List
import torch
import torch.nn.functional as F
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel
from torch_geometric.nn import GraphSAGE
import graphscope as gs
import graphscope.learning.graphlearn_torch as glt
from graphscope.learning.gl_torch_graph import GLTorchGraph
from graphscope.learning.graphlearn_torch.typing import Split
gs.set_option(log_level="DEBUG")
gs.set_option(show_log=True)
```
### Define test function
```python
@torch.no_grad()
def test(model, test_loader, dataset_name):
model.eval()
xs = []
y_true = []
for i, batch in enumerate(test_loader):
if i == 0:
device = batch.x.device
batch.x = batch.x.to(torch.float32) # TODO
x = model.module(batch.x, batch.edge_index)[: batch.batch_size]
xs.append(x.cpu())
y_true.append(batch.y[: batch.batch_size].cpu())
del batch
xs = [t.to(device) for t in xs]
y_true = [t.to(device) for t in y_true]
y_pred = torch.cat(xs, dim=0).argmax(dim=-1, keepdim=True)
y_true = torch.cat(y_true, dim=0)
test_acc = sum((y_pred.T == y_true.T)[0]) / len(y_true.T)
return test_acc.item()
```
### Define the loader and training process
```python
def run_client_proc(
glt_graph,
group_master: str,
num_servers: int,
num_clients: int,
client_rank: int,
server_rank_list: List[int],
dataset_name: str,
epochs: int,
batch_size: int,
training_pg_master_port: int,
):
print("-- Initializing client ...")
glt.distributed.init_client(
num_servers=num_servers,
num_clients=num_clients,
client_rank=client_rank,
master_addr=glt_graph.master_addr,
master_port=glt_graph.server_client_master_port,
num_rpc_threads=4,
client_group_name="k8s_glt_client",
is_dynamic=True,
)
# Initialize training process group of PyTorch.
current_ctx = glt.distributed.get_context()
torch.distributed.init_process_group(
backend="gloo",
rank=current_ctx.rank,
world_size=current_ctx.world_size,
init_method="tcp://{}:{}".format(group_master, training_pg_master_port),
)
device = torch.device("cpu")
# Create distributed neighbor loader on remote server for training.
print("-- Creating training dataloader ...")
train_loader = glt.distributed.DistNeighborLoader(
data=None,
num_neighbors=[5, 3, 2],
input_nodes=Split.train,
batch_size=batch_size,
shuffle=True,
collect_features=True,
to_device=device,
worker_options=glt.distributed.RemoteDistSamplingWorkerOptions(
server_rank=server_rank_list,
num_workers=1,
worker_devices=[torch.device("cpu")],
worker_concurrency=1,
buffer_size="256MB",
prefetch_size=1,
glt_graph=glt_graph,
workload_type="train",
),
)
# Create distributed neighbor loader on remote server for testing.
print("-- Creating testing dataloader ...")
test_loader = glt.distributed.DistNeighborLoader(
data=None,
num_neighbors=[5, 3, 2],
input_nodes=Split.test,
batch_size=batch_size,
shuffle=False,
collect_features=True,
to_device=device,
worker_options=glt.distributed.RemoteDistSamplingWorkerOptions(
server_rank=server_rank_list,
num_workers=1,
worker_devices=[torch.device("cpu")],
worker_concurrency=1,
buffer_size="256MB",
prefetch_size=1,
glt_graph=glt_graph,
workload_type="test",
),
)
# Define model and optimizer.
print("-- Initializing model and optimizer ...")
model = GraphSAGE(
in_channels=128,
hidden_channels=128,
num_layers=3,
out_channels=47,
).to(device)
model = DistributedDataParallel(model, device_ids=None)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Train and test.
print("-- Start training and testing ...")
epochs = 10
dataset_name = "ogbn-arxiv"
for epoch in range(0, epochs):
model.train()
start = time.time()
with Join([model]):
for batch in train_loader:
optimizer.zero_grad()
batch.x = batch.x.to(torch.float32) # TODO
out = model(batch.x, batch.edge_index)[: batch.batch_size].log_softmax(
dim=-1
)
loss = F.nll_loss(out, torch.flatten(batch.y[: batch.batch_size]))
loss.backward()
optimizer.step()
end = time.time()
print(f"-- Epoch: {epoch:03d}, Loss: {loss:04f} Epoch Time: {end - start}")
torch.distributed.barrier()
# Test accuracy.
if epoch == 0 or epoch > (epochs // 2):
test_acc = test(model, test_loader, dataset_name)
print(f"-- Test Accuracy: {test_acc:.4f}")
torch.distributed.barrier()
print("-- Shutdowning ...")
glt.distributed.shutdown_client()
print("-- Exited ...")
```
### main function
```python
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Arguments for distributed training of supervised SAGE with servers."
)
parser.add_argument(
"--dataset",
type=str,
default="ogbn-arxiv",
help="The name of ogbn arxiv.",
)
parser.add_argument(
"--num_server_nodes",
type=int,
default=2,
help="Number of server nodes for remote sampling.",
)
parser.add_argument(
"--num_client_nodes",
type=int,
default=1,
help="Number of client nodes for training.",
)
parser.add_argument(
"--node_rank",
type=int,
default=0,
help="The node rank of the current role.",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
help="The number of training epochs. (client option)",
)
parser.add_argument(
"--batch_size",
type=int,
default=256,
help="Batch size for the training and testing dataloader.",
)
parser.add_argument(
"--training_pg_master_port",
type=int,
default=9997,
help="The port used for PyTorch's process group initialization across all training processes.",
)
parser.add_argument(
"--train_loader_master_port",
type=int,
default=9998,
help="The port used for RPC initialization across all sampling workers of training loader.",
)
parser.add_argument(
"--test_loader_master_port",
type=int,
default=9999,
help="The port used for RPC initialization across all sampling workers of testing loader.",
)
parser.add_argument(
"--master_addr",
type=str,
default="localhost",
help="The master address of the graphlearn server.",
)
parser.add_argument(
"--group_master",
type=str,
default="localhost",
help="The master address of the training process group.",
)
args = parser.parse_args()
print(
f"--- Distributed training example of supervised SAGE with server-client mode. Client {args.node_rank} ---"
)
print(f"* dataset: {args.dataset}")
print(f"* total server nodes: {args.num_server_nodes}")
print(f"* total client nodes: {args.num_client_nodes}")
print(f"* node rank: {args.node_rank}")
num_servers = args.num_server_nodes
num_clients = args.num_client_nodes
print(f"* epochs: {args.epochs}")
print(f"* batch size: {args.batch_size}")
print(f"* training process group master port: {args.training_pg_master_port}")
print(f"* training loader master port: {args.train_loader_master_port}")
print(f"* testing loader master port: {args.test_loader_master_port}")
client_rank = args.node_rank
print("--- Loading graph info ...")
glt_graph = GLTorchGraph(
[
args.master_addr + ":9001",
args.master_addr + ":9002",
args.master_addr + ":9003",
args.master_addr + ":9004",
]
)
print("--- Launching client processes ...")
run_client_proc(
glt_graph,
args.group_master,
num_servers,
num_clients,
client_rank,
[server_rank for server_rank in range(num_servers)],
args.dataset,
args.epochs,
args.batch_size,
args.training_pg_master_port,
)
```
## Run the script
```shell
python3 k8s_launch.py
```