-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
Hi,
Regarding the code between line 265-303 in 'trainer.py', are entity-agent and cluster-agent walking independently on the graph?
When I look through this part, I didn't see any synchronization between two agents.
If I'm not correct, could you please tell me which variable is for synchronization?
Otherwise, may I know the rationale behind it? And how do we ensure entities in each step are in the correct corresponding cluster?
Thank you so much in advance!
Below is the code between line 265-303 in 'trainer.py':
for i in range(self.path_length):
loss, cluster_state_emb, logits, idx, chosen_relation, scores = self.c_agent.cluster_step(
prev_possible_clusters, next_possible_clusters,
cluster_state_emb, prev_cluster, end_cluster,
current_clusters_t, range_arr,
first_step_of_test, entity_state_emb
)
c_all_losses.append(loss)
c_all_logits.append(logits)
c_all_action_id.append(idx)
cluster_scores.append(scores)
cluster_state = cluster_episode.next_action(idx) ## important !! switch to next state with new cluster
prev_possible_clusters = next_possible_clusters.clone()
next_possible_clusters = torch.tensor(cluster_state['next_clusters']).long().to(self.device)
current_clusters_t = torch.tensor(cluster_state['current_clusters']).long().to(self.device)
prev_cluster = chosen_relation.to(self.device)
loss, entity_state_emb, logits, idx, chosen_relation = self.e_agent.step(
next_possible_relations,
next_possible_entities, entity_state_emb,
prev_relation, query_relation,
current_entities_t, range_arr,
first_step_of_test, cluster_state_emb
)
entity_state, whether_e_agent_follows_c_agent = entity_episode(idx, prev_cluster.cpu(), i) ## important !! switch to next state with new entity and new relation
next_possible_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)
next_possible_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)
current_entities_t = torch.tensor(entity_state['current_entities']).long().to(self.device)
prev_relation = chosen_relation.to(self.device)
entity_episode.get_stepwise_approximated_reward(current_entities_t, current_clusters_t, prev_entities) ## estimate the reward by taking each step
prev_entities = current_entities_t.clone()
e_all_losses.append(loss)
e_all_logits.append(logits)
e_all_action_id.append(idx)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels