Skip to content

Commit

Permalink
Expose global thread index via DeviceAPI
Browse files Browse the repository at this point in the history
  • Loading branch information
Robadob committed Apr 14, 2021
1 parent d6dc891 commit 2bea090
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
4 changes: 2 additions & 2 deletions include/flamegpu/runtime/AgentFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ __global__ void agent_function_wrapper(
}
#endif
// Must be terminated here, else AgentRandom has bounds issues inside DeviceAPI constructor
if (DeviceAPI<MsgIn, MsgOut>::TID() >= popNo)
if (DeviceAPI<MsgIn, MsgOut>::getThreadIndex() >= popNo)
return;
// create a new device FLAME_GPU instance
DeviceAPI<MsgIn, MsgOut> api = DeviceAPI<MsgIn, MsgOut>(
Expand All @@ -92,7 +92,7 @@ __global__ void agent_function_wrapper(
FLAME_GPU_AGENT_STATUS flag = AgentFunction()(&api);
if (scanFlag_agentDeath) {
// (scan flags will not be processed unless agent death has been requested in model definition)
scanFlag_agentDeath[DeviceAPI<MsgIn, MsgOut>::TID()] = flag;
scanFlag_agentDeath[DeviceAPI<MsgIn, MsgOut>::getThreadIndex()] = flag;
#if !defined(SEATBELTS) || SEATBELTS
} else if (flag == DEAD) {
DTHROW("Agent death must be enabled per agent function when defining the model.\n");
Expand Down
4 changes: 2 additions & 2 deletions include/flamegpu/runtime/AgentFunctionCondition.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ __global__ void agent_function_condition_wrapper(
}
#endif
// Must be terminated here, else AgentRandom has bounds issues inside DeviceAPI constructor
if (ReadOnlyDeviceAPI::TID() >= popNo)
if (ReadOnlyDeviceAPI::getThreadIndex() >= popNo)
return;
// create a new device FLAME_GPU instance
ReadOnlyDeviceAPI api = ReadOnlyDeviceAPI(
Expand All @@ -61,7 +61,7 @@ __global__ void agent_function_condition_wrapper(
// Negate the return value, we want false at the start of the scattered array
bool conditionResult = !(AgentFunctionCondition()(&api));
// (scan flags will be processed to filter agents
scanFlag_conditionResult[ReadOnlyDeviceAPI::TID()] = conditionResult;
scanFlag_conditionResult[ReadOnlyDeviceAPI::getThreadIndex()] = conditionResult;
}
}

Expand Down
16 changes: 10 additions & 6 deletions include/flamegpu/runtime/DeviceAPI.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ReadOnlyDeviceAPI {
const Curve::NamespaceHash &instance_id_hash,
const Curve::NamespaceHash &agentfuncname_hash,
curandState *&d_rng)
: random(AgentRandom(&d_rng[TID()]))
: random(AgentRandom(&d_rng[getThreadIndex()]))
, environment(DeviceEnvironment(instance_id_hash))
, agent_func_name_hash(agentfuncname_hash) { }

Expand Down Expand Up @@ -80,13 +80,12 @@ class ReadOnlyDeviceAPI {
return environment.getProperty<unsigned int>("_stepCount");
}

protected:
Curve::NamespaceHash agent_func_name_hash;

/**
* Thread index
* Returns the current CUDA thread of the agent
* All agents execute in a unique thread, but their associated thread may change between agent functions
* Thread indices begin at 0 and continue to 1 below the number of agents executing
*/
__forceinline__ __device__ static unsigned int TID() {
__forceinline__ __device__ static unsigned int getThreadIndex() {
/*
// 3D version
auto blockId = blockIdx.x + blockIdx.y * gridDim.x
Expand All @@ -96,12 +95,17 @@ class ReadOnlyDeviceAPI {
+ (threadIdx.y * blockDim.x)
+ threadIdx.x;
return threadId;*/
#ifdef SEATBELTS
assert(blockDim.y == 1);
assert(blockDim.z == 1);
assert(gridDim.y == 1);
assert(gridDim.z == 1);
#endif
return blockIdx.x * blockDim.x + threadIdx.x;
}

protected:
Curve::NamespaceHash agent_func_name_hash;
};

/** @brief A flame gpu api class for the device runtime only
Expand Down

0 comments on commit 2bea090

Please sign in to comment.