Skip to content

Commit

Permalink
Update the learning model to align with the latest graphlearn (#2235)
Browse files Browse the repository at this point in the history
* Update the learning model to align with the latest graphlearn
* Add the fragment ids into vy_info
* Move to graphlearn latest
* Fixes the source and build path
* Setting the fragment ids correctly
* Drop learning instances and interactive queries correctly

Signed-off-by: Tao He <sighingnow@gmail.com>
  • Loading branch information
sighingnow committed Nov 26, 2022
1 parent 1c4c1cd commit f77627d
Show file tree
Hide file tree
Showing 24 changed files with 614 additions and 506 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ package.lock

# docker: ignore learning engine's build artifacts
learning_engine/graph-learn/cmake-build
learning_engine/graph-learn/graphlearn/cmake-build

# docker: ignore .install_prefix
.install_prefix
Expand Down
7 changes: 4 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ GAE_DIR := $(WORKING_DIR)/analytical_engine
GIE_DIR := $(WORKING_DIR)/interactive_engine
GLE_DIR := $(WORKING_DIR)/learning_engine/graph-learn
GAE_BUILD_DIR := $(GAE_DIR)/build
GLE_BUILD_DIR := $(GLE_DIR)/cmake-build
GLE_BUILD_DIR := $(GLE_DIR)/graphlearn/cmake-build
CLIENT_DIR := $(WORKING_DIR)/python
COORDINATOR_DIR := $(WORKING_DIR)/coordinator
K8S_DIR := $(WORKING_DIR)/k8s
Expand Down Expand Up @@ -121,14 +121,15 @@ $(GIE_DIR)/assembly/target/graphscope.tar.gz:
gle-install: gle
mkdir -p $(INSTALL_PREFIX)
$(MAKE) -C $(GLE_BUILD_DIR) install
gle: $(GLE_DIR)/built/lib/libgraphlearn_shared.$(SUFFIX)
gle: $(GLE_DIR)/graphlearn/built/lib/libgraphlearn_shared.$(SUFFIX)

$(GLE_DIR)/built/lib/libgraphlearn_shared.$(SUFFIX):
$(GLE_DIR)/graphlearn/built/lib/libgraphlearn_shared.$(SUFFIX):
git submodule update --init
cd $(GLE_DIR) && git submodule update --init third_party/pybind11
mkdir -p $(GLE_BUILD_DIR)
cd $(GLE_BUILD_DIR) && \
cmake -DCMAKE_INSTALL_PREFIX=$(INSTALL_PREFIX) \
-DKNN=OFF \
-DWITH_VINEYARD=ON \
-DTESTING=${BUILD_TEST} .. && \
$(MAKE) -j$(NUMPROC)
Expand Down
104 changes: 52 additions & 52 deletions README-zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,64 +145,64 @@ lg = graphscope.graphlearn(sub_graph, nodes=[("paper", paper_features)],
```python
# Note: Here we use tensorflow as NN backend to train GNN model. so please
# install tensorflow.
try:
# https://www.tensorflow.org/guide/migrate
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
except ImportError:
import tensorflow as tf

import graphscope.learning
from graphscope.learning.examples import GCN
from graphscope.learning.graphlearn.python.model.tf.trainer import LocalTFTrainer
from graphscope.learning.graphlearn.python.model.tf.optimizer import get_tf_optimizer
from graphscope.learning.examples import EgoGraphSAGE
from graphscope.learning.examples import EgoSAGESupervisedDataLoader
from graphscope.learning.examples.tf.trainer import LocalTrainer

# supervised GCN.
def train_gcn(graph, node_type, edge_type, class_num, features_num,
hops_num=2, nbrs_num=[25, 10], epochs=2,
hidden_dim=256, in_drop_rate=0.5, learning_rate=0.01,
):
graphscope.learning.reset_default_tf_graph()

dimensions = [features_num] + [hidden_dim] * (hops_num - 1) + [class_num]
model = EgoGraphSAGE(dimensions, act_func=tf.nn.relu, dropout=in_drop_rate)

def train(config, graph):
def model_fn():
return GCN(
graph,
config["class_num"],
config["features_num"],
config["batch_size"],
val_batch_size=config["val_batch_size"],
test_batch_size=config["test_batch_size"],
categorical_attrs_desc=config["categorical_attrs_desc"],
hidden_dim=config["hidden_dim"],
in_drop_rate=config["in_drop_rate"],
neighs_num=config["neighs_num"],
hops_num=config["hops_num"],
node_type=config["node_type"],
edge_type=config["edge_type"],
full_graph_mode=config["full_graph_mode"],
# prepare train dataset
train_data = EgoSAGESupervisedDataLoader(
graph, graphscope.learning.Mask.TRAIN,
node_type=node_type, edge_type=edge_type, nbrs_num=nbrs_num, hops_num=hops_num,
)
train_embedding = model.forward(train_data.src_ego)
train_labels = train_data.src_ego.src.labels
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=train_labels, logits=train_embedding,
)
graphscope.learning.reset_default_tf_graph()
trainer = LocalTFTrainer(
model_fn,
epoch=config["epoch"],
optimizer=get_tf_optimizer(
config["learning_algo"], config["learning_rate"], config["weight_decay"]
),
)
trainer.train_and_evaluate()


config = {
"class_num": 349, # output dimension
"features_num": 130, # 128 dimension + kcore + triangle count
"batch_size": 500,
"val_batch_size": 100,
"test_batch_size": 100,
"categorical_attrs_desc": "",
"hidden_dim": 256,
"in_drop_rate": 0.5,
"hops_num": 2,
"neighs_num": [5, 10],
"full_graph_mode": False,
"agg_type": "gcn", # mean, sum
"learning_algo": "adam",
"learning_rate": 0.0005,
"weight_decay": 0.000005,
"epoch": 20,
"node_type": "paper",
"edge_type": "cites",
}

train(config, lg)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

# prepare test dataset
test_data = EgoSAGESupervisedDataLoader(
graph, graphscope.learning.Mask.TEST,
node_type=node_type, edge_type=edge_type, nbrs_num=nbrs_num, hops_num=hops_num,
)
test_embedding = model.forward(test_data.src_ego)
test_labels = test_data.src_ego.src.labels
test_indices = tf.math.argmax(test_embedding, 1, output_type=tf.int32)
test_acc = tf.div(
tf.reduce_sum(tf.cast(tf.math.equal(test_indices, test_labels), tf.float32)),
tf.cast(tf.shape(test_labels)[0], tf.float32),
)

# train and test
trainer = LocalTrainer()
trainer.train(train_data.iterator, loss, optimizer, epochs=epochs)
trainer.test(test_data.iterator, test_acc)

train_gcn(lg, node_type="paper", edge_type="cites",
class_num=349, # output dimension
features_num=130, # input dimension, 128 + kcore + triangle count
)
```

可以点击 [node_classification_on_citation.ipynb](tutorials/zh/10_node_classification_on_citation.ipynb) 查看完整的代码以及执行结果。
Expand Down
104 changes: 52 additions & 52 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,64 +170,64 @@ Then we define the training process, and run it.
```python
# Note: Here we use tensorflow as NN backend to train GNN model. so please
# install tensorflow.
try:
# https://www.tensorflow.org/guide/migrate
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
except ImportError:
import tensorflow as tf

import graphscope.learning
from graphscope.learning.examples import GCN
from graphscope.learning.graphlearn.python.model.tf.trainer import LocalTFTrainer
from graphscope.learning.graphlearn.python.model.tf.optimizer import get_tf_optimizer
from graphscope.learning.examples import EgoGraphSAGE
from graphscope.learning.examples import EgoSAGESupervisedDataLoader
from graphscope.learning.examples.tf.trainer import LocalTrainer

# supervised GCN.
def train_gcn(graph, node_type, edge_type, class_num, features_num,
hops_num=2, nbrs_num=[25, 10], epochs=2,
hidden_dim=256, in_drop_rate=0.5, learning_rate=0.01,
):
graphscope.learning.reset_default_tf_graph()

dimensions = [features_num] + [hidden_dim] * (hops_num - 1) + [class_num]
model = EgoGraphSAGE(dimensions, act_func=tf.nn.relu, dropout=in_drop_rate)

def train(config, graph):
def model_fn():
return GCN(
graph,
config["class_num"],
config["features_num"],
config["batch_size"],
val_batch_size=config["val_batch_size"],
test_batch_size=config["test_batch_size"],
categorical_attrs_desc=config["categorical_attrs_desc"],
hidden_dim=config["hidden_dim"],
in_drop_rate=config["in_drop_rate"],
neighs_num=config["neighs_num"],
hops_num=config["hops_num"],
node_type=config["node_type"],
edge_type=config["edge_type"],
full_graph_mode=config["full_graph_mode"],
# prepare train dataset
train_data = EgoSAGESupervisedDataLoader(
graph, graphscope.learning.Mask.TRAIN,
node_type=node_type, edge_type=edge_type, nbrs_num=nbrs_num, hops_num=hops_num,
)
train_embedding = model.forward(train_data.src_ego)
train_labels = train_data.src_ego.src.labels
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=train_labels, logits=train_embedding,
)
graphscope.learning.reset_default_tf_graph()
trainer = LocalTFTrainer(
model_fn,
epoch=config["epoch"],
optimizer=get_tf_optimizer(
config["learning_algo"], config["learning_rate"], config["weight_decay"]
),
)
trainer.train_and_evaluate()


config = {
"class_num": 349, # output dimension
"features_num": 130, # 128 dimension + kcore + triangle count
"batch_size": 500,
"val_batch_size": 100,
"test_batch_size": 100,
"categorical_attrs_desc": "",
"hidden_dim": 256,
"in_drop_rate": 0.5,
"hops_num": 2,
"neighs_num": [5, 10],
"full_graph_mode": False,
"agg_type": "gcn", # mean, sum
"learning_algo": "adam",
"learning_rate": 0.0005,
"weight_decay": 0.000005,
"epoch": 20,
"node_type": "paper",
"edge_type": "cites",
}

train(config, lg)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

# prepare test dataset
test_data = EgoSAGESupervisedDataLoader(
graph, graphscope.learning.Mask.TEST,
node_type=node_type, edge_type=edge_type, nbrs_num=nbrs_num, hops_num=hops_num,
)
test_embedding = model.forward(test_data.src_ego)
test_labels = test_data.src_ego.src.labels
test_indices = tf.math.argmax(test_embedding, 1, output_type=tf.int32)
test_acc = tf.div(
tf.reduce_sum(tf.cast(tf.math.equal(test_indices, test_labels), tf.float32)),
tf.cast(tf.shape(test_labels)[0], tf.float32),
)

# train and test
trainer = LocalTrainer()
trainer.train(train_data.iterator, loss, optimizer, epochs=epochs)
trainer.test(test_data.iterator, test_acc)

train_gcn(lg, node_type="paper", edge_type="cites",
class_num=349, # output dimension
features_num=130, # input dimension, 128 + kcore + triangle count
)
```

A python script with the entire process is availabe [here](https://colab.research.google.com/github/alibaba/GraphScope/blob/main/tutorials/1_node_classification_on_citation.ipynb), you may try it out by yourself.
Expand Down
24 changes: 24 additions & 0 deletions analytical_engine/core/object/fragment_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ class FragmentWrapper<vineyard::ArrowFragment<OID_T, VID_T, VERTEX_MAP_T>>
auto* client = dynamic_cast<vineyard::Client*>(meta.GetClient());
BOOST_LEAF_AUTO(frag_group_id, vineyard::ConstructFragmentGroup(
*client, fragment_->id(), comm_spec));
auto fg = std::dynamic_pointer_cast<vineyard::ArrowFragmentGroup>(
client->GetObject(frag_group_id));
auto dst_graph_def = graph_def_;

dst_graph_def.set_key(dst_graph_name);
Expand All @@ -288,6 +290,10 @@ class FragmentWrapper<vineyard::ArrowFragment<OID_T, VID_T, VERTEX_MAP_T>>
dst_graph_def.extension().UnpackTo(&vy_info);
}
vy_info.set_vineyard_id(frag_group_id);
vy_info.clear_fragments();
for (auto const& item : fg->Fragments()) {
vy_info.add_fragments(item.second);
}
dst_graph_def.mutable_extension()->PackFrom(vy_info);

auto wrapper = std::make_shared<FragmentWrapper<fragment_t>>(
Expand Down Expand Up @@ -319,6 +325,8 @@ class FragmentWrapper<vineyard::ArrowFragment<OID_T, VID_T, VERTEX_MAP_T>>
VINEYARD_CHECK_OK(client->Persist(new_frag_id));
BOOST_LEAF_AUTO(frag_group_id, vineyard::ConstructFragmentGroup(
*client, new_frag_id, comm_spec));
auto fg = std::dynamic_pointer_cast<vineyard::ArrowFragmentGroup>(
client->GetObject(frag_group_id));
auto new_frag = client->GetObject<fragment_t>(new_frag_id);

rpc::graph::GraphDefPb new_graph_def;
Expand All @@ -330,6 +338,10 @@ class FragmentWrapper<vineyard::ArrowFragment<OID_T, VID_T, VERTEX_MAP_T>>
graph_def_.extension().UnpackTo(&vy_info);
}
vy_info.set_vineyard_id(frag_group_id);
vy_info.clear_fragments();
for (auto const& item : fg->Fragments()) {
vy_info.add_fragments(item.second);
}
new_graph_def.mutable_extension()->PackFrom(vy_info);

set_graph_def(new_frag, new_graph_def);
Expand Down Expand Up @@ -518,6 +530,8 @@ class FragmentWrapper<vineyard::ArrowFragment<OID_T, VID_T, VERTEX_MAP_T>>
VINEYARD_CHECK_OK(client->Persist(new_frag_id));
BOOST_LEAF_AUTO(frag_group_id, vineyard::ConstructFragmentGroup(
*client, new_frag_id, comm_spec));
auto fg = std::dynamic_pointer_cast<vineyard::ArrowFragmentGroup>(
client->GetObject(frag_group_id));
auto new_frag = client->GetObject<fragment_t>(new_frag_id);

rpc::graph::GraphDefPb new_graph_def;
Expand All @@ -527,6 +541,10 @@ class FragmentWrapper<vineyard::ArrowFragment<OID_T, VID_T, VERTEX_MAP_T>>
graph_def_.extension().UnpackTo(&vy_info);
}
vy_info.set_vineyard_id(frag_group_id);
vy_info.clear_fragments();
for (auto const& item : fg->Fragments()) {
vy_info.add_fragments(item.second);
}
new_graph_def.mutable_extension()->PackFrom(vy_info);

set_graph_def(new_frag, new_graph_def);
Expand Down Expand Up @@ -675,6 +693,8 @@ class FragmentWrapper<vineyard::ArrowFragment<OID_T, VID_T, VERTEX_MAP_T>>
VINEYARD_CHECK_OK(client->Persist(new_frag_id));
BOOST_LEAF_AUTO(frag_group_id, vineyard::ConstructFragmentGroup(
*client, new_frag_id, comm_spec));
auto fg = std::dynamic_pointer_cast<vineyard::ArrowFragmentGroup>(
client->GetObject(frag_group_id));
auto new_frag = client->GetObject<fragment_t>(new_frag_id);

rpc::graph::GraphDefPb new_graph_def;
Expand All @@ -686,6 +706,10 @@ class FragmentWrapper<vineyard::ArrowFragment<OID_T, VID_T, VERTEX_MAP_T>>
graph_def_.extension().UnpackTo(&vy_info);
}
vy_info.set_vineyard_id(frag_group_id);
vy_info.clear_fragments();
for (auto const& item : fg->Fragments()) {
vy_info.add_fragments(item.second);
}
new_graph_def.mutable_extension()->PackFrom(vy_info);

set_graph_def(new_frag, new_graph_def);
Expand Down
Loading

0 comments on commit f77627d

Please sign in to comment.