Skip to content

Commit

Permalink
feat: Add fatras python example + test (#1030)
Browse files Browse the repository at this point in the history
This PR adds an example python script to run Fatras and write outputs to various formats. It also adds a test to check this behavior.
  • Loading branch information
paulgessinger committed Oct 11, 2021
1 parent 54532ab commit 49777c8
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 0 deletions.
34 changes: 34 additions & 0 deletions Examples/Python/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,40 @@ def assert_entries(root_file, tree_name, exp):
assert rf.Get(tree_name).GetEntries() == exp, f"{root_file}:{tree_name}"


def test_fatras(trk_geo, tmp_path, field):
from fatras import runFatras

csv = tmp_path / "csv"
csv.mkdir()

nevents = 10

root_files = [
("fatras_particles_final.root", "particles", nevents),
("fatras_particles_initial.root", "particles", nevents),
("hits.root", "hits", 115),
]

assert len(list(csv.iterdir())) == 0
for rf, _, _ in root_files:
assert not (tmp_path / rf).exists()

seq = Sequencer(events=nevents)
runFatras(trk_geo, field, str(tmp_path), s=seq).run()

del seq

assert_csv_output(csv, "particles_final")
assert_csv_output(csv, "particles_initial")
assert_csv_output(csv, "hits")
for f, tn, exp_entries in root_files:
rfp = tmp_path / f
assert rfp.exists()
assert rfp.stat().st_size > 2 ** 10 * 10

assert_entries(rfp, tn, exp_entries)


def test_propagation(tmp_path, trk_geo, field, seq):
from propagation import runPropagation

Expand Down
127 changes: 127 additions & 0 deletions Examples/Scripts/Python/fatras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#!/usr/bin/env python3
import os

import acts
import acts.examples

u = acts.UnitConstants


def runFatras(trackingGeometry, field, outputDir, s: acts.examples.Sequencer = None):

# Preliminaries
rnd = acts.examples.RandomNumbers()

# Input
vtxGen = acts.examples.GaussianVertexGenerator()
vtxGen.stddev = acts.Vector4(0, 0, 0, 0)

ptclGen = acts.examples.ParametricParticleGenerator(
p=(1 * u.GeV, 10 * u.GeV), eta=(-2, 2)
)

g = acts.examples.EventGenerator.Generator()
g.multiplicity = acts.examples.FixedMultiplicityGenerator()
g.vertex = vtxGen
g.particles = ptclGen

evGen = acts.examples.EventGenerator(
level=acts.logging.INFO,
generators=[g],
outputParticles="particles_input",
randomNumbers=rnd,
)

# Selector
selector = acts.examples.ParticleSelector(
level=acts.logging.INFO,
inputParticles=evGen.config.outputParticles,
outputParticles="particles_selected",
)

# Simulation
alg = acts.examples.FatrasSimulation(
level=acts.logging.INFO,
inputParticles=selector.config.outputParticles,
outputParticlesInitial="particles_initial",
outputParticlesFinal="particles_final",
outputSimHits="simhits",
randomNumbers=rnd,
trackingGeometry=trackingGeometry,
magneticField=field,
generateHitsOnSensitive=True,
)

# Sequencer
s = s or acts.examples.Sequencer(
events=100, numThreads=-1, logLevel=acts.logging.INFO
)

s.addReader(evGen)
s.addAlgorithm(selector)
s.addAlgorithm(alg)

# Output
s.addWriter(
acts.examples.CsvParticleWriter(
level=acts.logging.INFO,
outputDir=outputDir + "/csv",
inputParticles="particles_final",
outputStem="particles_final",
)
)

s.addWriter(
acts.examples.RootParticleWriter(
level=acts.logging.INFO,
inputParticles="particles_final",
filePath=outputDir + "/fatras_particles_final.root",
)
)

s.addWriter(
acts.examples.CsvParticleWriter(
level=acts.logging.INFO,
outputDir=outputDir + "/csv",
inputParticles="particles_initial",
outputStem="particles_initial",
)
)

s.addWriter(
acts.examples.RootParticleWriter(
level=acts.logging.INFO,
inputParticles="particles_initial",
filePath=outputDir + "/fatras_particles_initial.root",
)
)

s.addWriter(
acts.examples.CsvSimHitWriter(
level=acts.logging.INFO,
inputSimHits=alg.config.outputSimHits,
outputDir=outputDir + "/csv",
outputStem="hits",
)
)

s.addWriter(
acts.examples.RootSimHitWriter(
level=acts.logging.INFO,
inputSimHits=alg.config.outputSimHits,
filePath=outputDir + "/hits.root",
)
)

return s


if "__main__" == __name__:

gdc = acts.examples.GenericDetector.Config()
detector = acts.examples.GenericDetector()
trackingGeometry, contextDecorators = detector.finalize(gdc, None)

field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))

runFatras(trackingGeometry, field, os.getcwd()).run()

0 comments on commit 49777c8

Please sign in to comment.