In [1]:
from pprint import pprint
from itertools import chain
from functools import reduce
from glob import iglob
from yaml import safe_load
from pyspark.sql import SparkSession, DataFrame, functions as f
from dltools import load_combiner
from dltools.sacla import restructure, load_analyzer

In [2]:
# %% Load config file
with open("sacla_analysis-config.yaml", "r") as file:
    print("Loading config file...")
    config = safe_load(file)
pprint(config)

Loading config file...
{'momentum_analyzer': {'dx': 1,
                       'dy': 1,
                       'models': {'particle_a': {'fr': 4561.4,
                                                 'mass': 8501.3198901104,
                                                 'pr_coeffs': [22.5668,
                                                               -0.00087779,
                                                               4.24205e-08,
                                                               -2.76085e-12,
                                                               4.22634e-11,
                                                               -1.98336e-18],
                                                 'pz_coeffs': [26824.7,
                                                               -14.2786,
                                                               0.000193147,
                                                               6.57577e-07,
                        

In [3]:
# %% Load momentum model
print("Loading momentum model...")
analyzer = load_analyzer(config["momentum_analyzer"])
print(analyzer)

Loading momentum model...
<function UserDefinedFunction._wrapped.<locals>.wrapper at 0x119125a60>


In [4]:
# %% Load PySpark
print("Loading PySpark...")
builder = (SparkSession
           .builder
           .config("spark.jars.packages",
                   "org.diana-hep:spark-root_2.11:0.1.15,"
                   "org.mongodb.spark:mongo-spark-connector_2.11:2.3.1")
           )
spark = builder.getOrCreate()
print(spark)

Loading PySpark...
<pyspark.sql.session.SparkSession object at 0x11934ec50>


In [5]:
# %% Load data
print("Loading data...")
globbed = chain.from_iterable(iglob(patt) for patt in config["target_files"])
loadme = (spark.read.format("org.dianahep.sparkroot").load(f) for f in sorted(set(globbed)))
df = restructure(reduce(DataFrame.union, loadme))
df.printSchema()
df.show()
(
    df
    .select(f.explode("hits").alias("h"))
    .select(f.col("h.t").alias("t"),
            f.col("h.x").alias("x"),
            f.col("h.y").alias("y"),
            f.col("h.flag").alias("flag"))
    .limit(20)
    .toPandas()
)

Loading data...
root
 |-- tag: long (nullable = true)
 |-- hits: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- t: double (nullable = false)
 |    |    |-- x: double (nullable = false)
 |    |    |-- y: double (nullable = false)
 |    |    |-- as_: map (nullable = false)
 |    |    |    |-- key: string
 |    |    |    |-- value: struct (valueContainsNull = true)
 |    |    |    |    |-- pz: double (nullable = false)
 |    |    |    |    |-- px: double (nullable = false)
 |    |    |    |    |-- py: double (nullable = false)
 |    |    |    |    |-- ke: double (nullable = false)
 |    |    |-- flag: integer (nullable = true)

+---------+--------------------+
|      tag|                hits|
+---------+--------------------+
|158648231|[[803.53128890943...|
|158648232|[[804.66343712302...|
|158648233|[[794.79463683844...|
|158648234|[[786.43060318885...|
|158648235|[[709.28913731595...|
|158648236|[[627.36191628571...|
|158648237|[[899.62290167172..

Unnamed: 0,t,x,y,flag
0,803.531289,22.084287,-9.010443,13
1,805.186457,-3.410069,-2.075715,0
2,926.20598,4.763139,28.911008,0
3,944.8419,-5.041342,7.548212,17
4,1013.461017,-22.692024,-13.571205,0
5,1455.289515,1.290518,0.833531,0
6,2311.770348,2.549458,17.51087,0
7,2382.982309,2.490935,0.800476,6
8,2440.103272,-5.223339,-54.310071,18
9,2448.647841,-25.933024,-31.314864,15


In [6]:
# %% Analyze momentum
print("Analyzing momentum...")
analyzed = df.select(analyzer(f.col("hits")).alias("analyzed"))
analyzed.printSchema()
analyzed.show()
(
    analyzed
    .select(f.explode("analyzed").alias("h"))
    .select(f.explode("h.as_").alias("as_", "m"))
    .select("as_", "m.*")
    .limit(20)
    .toPandas()
)

Analyzing the data...
root
 |-- analyzed: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- t: double (nullable = false)
 |    |    |-- x: double (nullable = false)
 |    |    |-- y: double (nullable = false)
 |    |    |-- as_: map (nullable = false)
 |    |    |    |-- key: string
 |    |    |    |-- value: struct (valueContainsNull = true)
 |    |    |    |    |-- pz: double (nullable = false)
 |    |    |    |    |-- px: double (nullable = false)
 |    |    |    |    |-- py: double (nullable = false)
 |    |    |    |    |-- ke: double (nullable = false)
 |    |    |-- flag: integer (nullable = true)

+--------------------+
|            analyzed|
+--------------------+
|[[176.93128890943...|
|[[178.06343712302...|
|[[168.19463683844...|
|[[159.83060318885...|
|[[82.689137315952...|
|[[0.7619162857187...|
|[[273.02290167172...|
|[[112.74554237618...|
|[[90.438398843434...|
|[[136.92143977673...|
|[[162.13078322175...|
|[[155.82968366127...|
|[[-0

Unnamed: 0,as_,pz,px,py,ke
0,particle_a,196.714475,-35.611128,60.151577,2.563305
1,particle_a,-52.386493,-152.52438,22.46942,1.559341
2,particle_a,110.332093,-86.407544,-43.48397,1.266291
3,particle_a,304.214238,132.412487,151.972608,7.83261
4,particle_a,125.963435,46.140288,3.578911,1.05916
5,particle_a,-132.23563,26.738194,-232.583904,4.252073
6,particle_a,158.491675,-47.811755,69.175663,1.893285
7,particle_a,155.559296,-73.617873,-92.052568,2.240356
8,particle_a,217.332357,49.413152,-91.465145,3.41364
9,particle_a,138.861378,17.767142,27.275591,1.196409


In [7]:
# %% Combine hits
print("Combining hits...")
combined = analyzed.select(load_combiner(r=2)(f.col("analyzed")).alias("combined"))
combined.printSchema()
combined.show()
(
    combined
    .select(f.explode("combined").alias("h"))
    .select(f.explode("h.as_").alias("as_", "m"))
    .select("as_", "m.*")
    .limit(20)
    .toPandas()
)

Combining hits...
root
 |-- combined: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- comb: array (nullable = false)
 |    |    |    |-- element: struct (containsNull = true)
 |    |    |    |    |-- t: double (nullable = false)
 |    |    |    |    |-- x: double (nullable = false)
 |    |    |    |    |-- y: double (nullable = false)
 |    |    |    |    |-- as_: map (nullable = false)
 |    |    |    |    |    |-- key: string
 |    |    |    |    |    |-- value: struct (valueContainsNull = true)
 |    |    |    |    |    |    |-- pz: double (nullable = false)
 |    |    |    |    |    |    |-- px: double (nullable = false)
 |    |    |    |    |    |    |-- py: double (nullable = false)
 |    |    |    |    |    |    |-- ke: double (nullable = false)
 |    |    |    |    |-- flag: integer (nullable = true)
 |    |    |-- as_: map (nullable = false)
 |    |    |    |-- key: string
 |    |    |    |-- value: struct (valueContainsNull = true)
 |   

Unnamed: 0,as_,pz,px,py,ke
0,"particle_a,particle_a",414.546331,46.004943,108.488638,9.098901
1,"particle_a,particle_a",372.891653,-24.204721,-183.517714,5.653995
2,"particle_a,particle_a",60.870429,-79.243493,-35.145377,12.967198
3,"particle_a,particle_a",208.833496,25.802142,109.254807,6.434601
4,"particle_a,particle_a",252.578256,-111.650218,114.970455,6.026293
5,"particle_a,particle_a",226.253723,-186.932301,-208.556622,4.38508
6,"particle_a,particle_a",279.966706,134.142331,-27.347307,7.474504
7,"particle_a,particle_a",85.795803,396.726636,-77.653323,5.898578
8,"particle_a,particle_a",166.633383,9.014688,68.800028,3.948449
9,"particle_a,particle_a",162.624627,31.857678,-112.145382,3.645012
