Skip to content

Are entity-agent and cluster-agent walking independently on the graph? #5

@ZhixiangSu

Description

@ZhixiangSu

Hi,

Regarding the code between line 265-303 in 'trainer.py', are entity-agent and cluster-agent walking independently on the graph?
When I look through this part, I didn't see any synchronization between two agents.
If I'm not correct, could you please tell me which variable is for synchronization?
Otherwise, may I know the rationale behind it? And how do we ensure entities in each step are in the correct corresponding cluster?

Thank you so much in advance!

Below is the code between line 265-303 in 'trainer.py':

		for i in range(self.path_length):
			loss, cluster_state_emb, logits, idx, chosen_relation, scores = self.c_agent.cluster_step(
				prev_possible_clusters, next_possible_clusters,
				cluster_state_emb, prev_cluster, end_cluster,
				current_clusters_t, range_arr,
				first_step_of_test, entity_state_emb
			)

			c_all_losses.append(loss)
			c_all_logits.append(logits)
			c_all_action_id.append(idx)
			cluster_scores.append(scores)

			cluster_state = cluster_episode.next_action(idx)  ## important !! switch to next state with new cluster
			prev_possible_clusters = next_possible_clusters.clone()
			next_possible_clusters = torch.tensor(cluster_state['next_clusters']).long().to(self.device)
			current_clusters_t = torch.tensor(cluster_state['current_clusters']).long().to(self.device)
			prev_cluster = chosen_relation.to(self.device)

			loss, entity_state_emb, logits, idx, chosen_relation = self.e_agent.step(
				next_possible_relations,
				next_possible_entities, entity_state_emb,
				prev_relation, query_relation,
				current_entities_t, range_arr,
				first_step_of_test, cluster_state_emb
			)

			entity_state, whether_e_agent_follows_c_agent = entity_episode(idx, prev_cluster.cpu(), i)  ## important !! switch to next state with new entity and new relation
			next_possible_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)
			next_possible_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)
			current_entities_t = torch.tensor(entity_state['current_entities']).long().to(self.device)
			prev_relation = chosen_relation.to(self.device)

			entity_episode.get_stepwise_approximated_reward(current_entities_t,	current_clusters_t, prev_entities)  ## estimate the reward by taking each step
			prev_entities = current_entities_t.clone()

			e_all_losses.append(loss)
			e_all_logits.append(logits)
			e_all_action_id.append(idx)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions