-
Notifications
You must be signed in to change notification settings - Fork 267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save embeddings issue when distributed training #53
Comments
The same problem happened to me a few days ago. I guess the problem may be caused by the
But MTS will freeze the TF Graph(not GNN) after the initialization step is completed. So users are not allowed to modify the graph(can't define some new variable/op). For this issue, after we use trainer.py to init the session, the TF Graph will be frozen, then the model begins to train for n epochs. (No modify here). But when the code goes to save embedding, it will register a new variable for getting source nodes' embeddings(try to modify the TF graph, then raise an error). For me, I made some modifications on graph_sage.py. def build(self):
...
pos_src_emb = self.encoders['src'].encode(self.pos_src_ego_tensor)
pos_dst_emb = self.encoders['dst'].encode(self.pos_dst_ego_tensor)
neg_dst_emb = self.encoders['dst'].encode(self.neg_dst_ego_tensor)
# some modifications
self.pos_src_emb = pos_src_emb
self.pos_dst_emb = pos_dst_emb
self.neg_dst_emb = neg_dst_emb
self.loss = self._unsupervised_loss(pos_src_emb, pos_dst_emb, neg_dst_emb)
...
...
def node_embedding(self, type):
iterator = self.ego_flow.iterator
# remove
# ego_tensor = self.ego_flow.pos_src_ego_tensor
# remove
# src_emb = self.encoders['src'].encode(ego_tensor)
# add
src_emb = self.pos_src_emb
src_ids = self.pos_src_ego_tensor.src.ids
return src_ids, src_emb, iterator
... Similar issue on StackOverflow. |
The |
Hi,
I run the dist_trian.py (examples/tf/graphsage/dist_train.py), it works well. However, when I try to save embedding after training, it raises RuntimeError("Graph is finalized and cannot be modified."). I meet the same issue when I try to run Bipartite GraphSAGE in the distribute mode.
The text was updated successfully, but these errors were encountered: