Skip to content

Commit

Permalink
ensured proper data types when transfering between python and c++
Browse files Browse the repository at this point in the history
  • Loading branch information
GitFuchs authored and tbaudier committed May 29, 2024
1 parent 9a478a4 commit b792627
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 45 deletions.
26 changes: 13 additions & 13 deletions core/opengate_core/opengate_lib/GatePhaseSpaceSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class GatePhaseSpaceSource : public GateVSource {
G4ParticleDefinition *fParticleDefinition;
G4ParticleTable *fParticleTable;

float fCharge;
float fMass;
std::float_t fCharge;
std::float_t fMass;
bool fGlobalFag;
bool fUseParticleTypeFromFile;
bool fVerbose;
Expand Down Expand Up @@ -86,26 +86,26 @@ class GatePhaseSpaceSource : public GateVSource {
struct threadLocalTPhsp {

bool fgenerate_until_next_primary;
int fprimary_PDGCode;
float fprimary_lower_energy_threshold;
std::int32_t fprimary_PDGCode;
std::float_t fprimary_lower_energy_threshold;

ParticleGeneratorType fGenerator;
unsigned long fNumberOfGeneratedEvents;
size_t fCurrentIndex;
size_t fCurrentBatchSize;

int *fPDGCode;
std::int32_t *fPDGCode;

float *fPositionX;
float *fPositionY;
float *fPositionZ;
std::float_t *fPositionX;
std::float_t *fPositionY;
std::float_t *fPositionZ;

float *fDirectionX;
float *fDirectionY;
float *fDirectionZ;
std::float_t *fDirectionX;
std::float_t *fDirectionY;
std::float_t *fDirectionZ;

float *fEnergy;
float *fWeight;
std::float_t *fEnergy;
std::float_t *fWeight;
// double * fTime;
};
G4Cache<threadLocalTPhsp> fThreadLocalDataPhsp;
Expand Down
105 changes: 73 additions & 32 deletions opengate/sources/phspsources.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def initialize(self, user_info):

def read_phsp_and_keys(self):
# convert str like 1e5 to int
self.user_info.batch_size = int(float(self.user_info.batch_size))
self.user_info.batch_size = int(self.user_info.batch_size)
if self.user_info.batch_size < 1:
gate.fatal("PhaseSpaceSourceGenerator: Batch size should be > 0")

Expand Down Expand Up @@ -146,6 +146,14 @@ def generate(self, source):
# print("batch_size: ", current_batch_size)

batch = self.batch
# ensure encoding is float32
for key in batch:
# Convert to float32 if the array contains floating-point values
if np.issubdtype(batch[key].dtype, np.floating):
batch[key] = batch[key].astype(np.float32)
else:
if np.issubdtype(batch[key].dtype, np.integer):
batch[key] = batch[key].astype(np.int32)

# update index if end of file
self.current_index += current_batch_size
Expand All @@ -157,12 +165,12 @@ def generate(self, source):
)
self.current_index = 0

# send to cpp
# prepare data

# set particle type
if ui.particle == "" or ui.particle is None:
# check if the keys for PDGCode are in the root file
if ui.PDGCode_key in batch:
source.SetPDGCodeBatch(batch[ui.PDGCode_key])
else:
if ui.PDGCode_key not in batch:
fatal(
f"PhaseSpaceSource: no PDGCode key ({ui.PDGCode_key}) "
f"in the phsp file and no source.particle"
Expand All @@ -171,16 +179,17 @@ def generate(self, source):
# if translate_position is set to True, the position
# supplied will be added to the phsp file position
if ui.translate_position:
batch[ui.position_key_x] += ui.position.translation[0]
batch[ui.position_key_y] += ui.position.translation[1]
batch[ui.position_key_z] += ui.position.translation[2]
source.SetPositionXBatch(batch[ui.position_key_x])
source.SetPositionYBatch(batch[ui.position_key_y])
source.SetPositionZBatch(batch[ui.position_key_z])
else:
source.SetPositionXBatch(batch[ui.position_key_x])
source.SetPositionYBatch(batch[ui.position_key_y])
source.SetPositionZBatch(batch[ui.position_key_z])
batch[ui.position_key_x] += float(ui.position.translation[0])
batch[ui.position_key_y] += float(ui.position.translation[1])
batch[ui.position_key_z] += float(ui.position.translation[2])

# source.SetPositionXBatch(batch[ui.position_key_x])
# source.SetPositionYBatch(batch[ui.position_key_y])
# source.SetPositionZBatch(batch[ui.position_key_z])
# else:
# source.SetPositionXBatch(batch[ui.position_key_x])
# source.SetPositionYBatch(batch[ui.position_key_y])
# source.SetPositionZBatch(batch[ui.position_key_z])

# direction is a rotation of the stored direction
# if rotate_direction is set to True, the direction
Expand All @@ -196,31 +205,61 @@ def generate(self, source):
)
# create rotation matrix
r = Rotation.from_matrix(ui.position.rotation)
if ui.verbose:
print("Rotation matrix: ", r.as_matrix())
# rotate vector with rotation matrix
points = r.apply(self.points)
# source.fDirectionX, source.fDirectionY, source.fDirectionZ = points.T
source.SetDirectionXBatch(points[:, 0])
source.SetDirectionYBatch(points[:, 1])
source.SetDirectionZBatch(points[:, 2])
else:
source.SetDirectionXBatch(batch[ui.direction_key_x])
source.SetDirectionYBatch(batch[ui.direction_key_y])
source.SetDirectionZBatch(batch[ui.direction_key_z])
batch[ui.direction_key_x] = points[:, 0].astype(np.float32)
batch[ui.direction_key_y] = points[:, 1].astype(np.float32)
batch[ui.direction_key_z] = points[:, 2].astype(np.float32)

# set energy
source.SetEnergyBatch(batch[ui.energy_key])
# source.SetDirectionXBatch(batch[ui.direction_key_x])
# source.SetDirectionYBatch(batch[ui.direction_key_y])
# source.SetDirectionZBatch(batch[ui.direction_key_z])

# set weight
if ui.weight_key != "" and ui.weight_key is not None:
if ui.weight_key in batch:
source.SetWeightBatch(batch[ui.weight_key])
else:
if ui.weight_key != "" or ui.weight_key is not None:
if ui.weight_key not in batch:
fatal(
f"PhaseSpaceSource: no Weight key ({ui.weight_key}) in the phsp file."
)
else:
self.w = np.ones(current_batch_size)
source.SetWeightBatch(self.w)
self.w = np.ones(current_batch_size, dtype=np.float32)
batch[ui.weight_key] = self.w.astype(np.float32)

# send to cpp
# set position
source.SetPositionXBatch(batch[ui.position_key_x])
source.SetPositionYBatch(batch[ui.position_key_y])
source.SetPositionZBatch(batch[ui.position_key_z])
# set direction
source.SetDirectionXBatch(batch[ui.direction_key_x])
source.SetDirectionYBatch(batch[ui.direction_key_y])
source.SetDirectionZBatch(batch[ui.direction_key_z])
# set energy
source.SetEnergyBatch(batch[ui.energy_key])
# set PDGCode
source.SetPDGCodeBatch(batch[ui.PDGCode_key])
# set weight
source.SetWeightBatch(batch[ui.weight_key])

if ui.verbose:
print("PhaseSpaceSourceGenerator: batch generated: ")
print("particle name: ", ui.particle)
print("source.fPDGCode: ", batch[ui.PDGCode_key])
print("source.fEnergy: ", batch[ui.energy_key])
print("source.fWeight: ", batch[ui.weight_key])
print("source.fPositionX: ", batch[ui.position_key_x])
print("source.fPositionY: ", batch[ui.position_key_y])
print("source.fPositionZ: ", batch[ui.position_key_z])
print("source.fDirectionX: ", batch[ui.direction_key_x])
print("source.fDirectionY: ", batch[ui.direction_key_y])
print("source.fDirectionZ: ", batch[ui.direction_key_z])
print("source.fEnergy dtype: ", batch[ui.energy_key].dtype)

# Release the lock when the function execution is complete
# self.lock.release()

return current_batch_size

Expand Down Expand Up @@ -252,7 +291,7 @@ def set_default_user_info(user_info):
user_info.activity = 0
user_info.half_life = -1 # negative value is not half_life
user_info.particle = "" # FIXME later as key
user_info.entry_start = 0
user_info.entry_start = None
# if a particle name is supplied, the particle type is set to it
# otherwise, information from the phase space is used

Expand Down Expand Up @@ -289,6 +328,7 @@ def set_default_user_info(user_info):
# user_info.time_key = None # FIXME TODO later
# for debug
user_info.verbose_batch = False
user_info.verbose = False

def create_g4_source(self):
return opengate_core.GatePhaseSpaceSource()
Expand Down Expand Up @@ -329,7 +369,7 @@ def initialize(self, run_timing_intervals):
f"PhaseSpaceSource: generate_until_next_primary is True but no primary_lower_energy_threshold is defined"
)
# print("threads: ", self.simulation.user_info.number_of_threads)

# print("number of particles:", ui.n)
# if not set, initialize the entry_start to 0 or to a list for multithreading
if ui.entry_start is None:
if not opengate_core.IsMultithreadedApplication():
Expand All @@ -338,6 +378,7 @@ def initialize(self, run_timing_intervals):
# create a entry_start array with the correct number of start entries
# all entries are spaced by the number of particles/thread
n_threads = self.simulation.user_info.number_of_threads
# ui.entry_start = [0] * n_threads
step = np.ceil(ui.n / n_threads) + 1 # Specify the increment value
ui.entry_start = [i * step for i in range(n_threads)]
print("INFO: entry_start not set. Using default values: ", ui.entry_start)
Expand Down

0 comments on commit b792627

Please sign in to comment.