In [1]:
import jax
import jax.numpy as jnp
import uproot

In [2]:
raw_events = uproot.open('/share/lazy/will/data/trks_June30_2020_80k_14.root')['trks;65']

In [3]:
# this is pretty slow
# actually, it is painfully slow - about 1/5 the total evaluation time
def format_event(event):
    d = event[0]
    
    track_pos = (
        jnp.array(d['recon_x'])[:, None], 
        jnp.array(d['recon_y'])[:, None], 
        jnp.array(d['recon_z'])[:, None]
    )
    
    track_pos = jnp.concatenate(track_pos, axis=1)

    track_vec = (
        jnp.array(d['recon_tx'])[:, None], 
        jnp.array(d['recon_ty'])[:, None], 
        jnp.ones_like(jnp.array(d['recon_tx']))[:, None]
    )
    
    track_vec = jnp.concatenate(track_vec, axis=1)
    
    beam_pos = jnp.zeros_like(track_vec)
    # for each hit, beamline vector is (0, 0, 1)
    beam_vec = jax.ops.index_add(beam_pos, jax.ops.index[:, 2], 1)
    
    return track_pos, track_vec, beam_pos, beam_vec

In [4]:
@jax.jit        
def distance(beam_pos, beam_vec, track_pos, track_vec):
    tcp = jnp.cross(track_vec, jnp.cross(beam_vec, track_vec))
    return jnp.einsum('ij,ij->i', track_pos - beam_pos, tcp) / jnp.einsum('ij,ij->i', beam_vec, tcp)

@jax.jit    
def generate_KDE(event):  
    track_pos, track_vec, beam_pos, beam_vec = event
    
    # compute distances - what type of distance? what is going on here? 
    L1_poca = beam_pos + distance(beam_pos, beam_vec, track_pos, track_vec)[:, None] * beam_vec # empty dimension needed for broadcasting
    
    # track_pos is the center. what is v3?
    v3 = track_pos - L1_poca
    
    # compute unit vectors to the beamline
    z_hat = v3 / jnp.linalg.norm(v3, axis=-1)[:, None] # empty dimension needed for broadcasting
    x_hat = v3 / jnp.linalg.norm(track_vec, axis=-1)[:, None] # empty dimension needed for broadcasting
    y_hat = jnp.cross(z_hat, x_hat) # what? 
    
    # ask what this is 
    arg = jnp.einsum('ij,ij->i', beam_vec, track_vec) / jnp.sqrt(jnp.linalg.norm(beam_vec, axis=-1)**2 * jnp.linalg.norm(track_vec, axis=-1)**2)
    arg = arg[:, None]
    no_idea = jnp.sqrt(1-arg**2)  # what is this?
    
    # what is this? 
    road_error = 0.1
    
    minor_ax1 = road_error*z_hat # i do not know what this is 
    minor_ax2 = road_error*y_hat # this is all zeros - calculation wrong?
    major_ax = road_error*arg / (no_idea*x_hat)
    
    # track_pos is the center. what does "center" mean?
    z_min = track_pos[:, 2] - 3*major_ax[:, 2]
    z_max = track_pos[:, 2] + 3*major_ax[:, 2]
    determinant = arg*road_error**3 / no_idea
    
    # compute the PDF
    def PDF(x, y, z):
        # input to this function will just be a set of 3 points - a pv candidate location. this seems suspicious. something might be wrong here
        x_vec = track_pos - jnp.array([x, y, z])
        
        # pre-compute the factors needed to normalize minor/maxjor ax into unit vectors - saves us from doing this twice
        min_ax_1_norm = jnp.linalg.norm(minor_ax1, axis=-1)[:, None]
        min_ax_2_norm = jnp.linalg.norm(minor_ax2, axis=-1)[:, None]
        mj_ax_norm = jnp.linalg.norm(major_ax, axis=-1)[:, None]
        
        chi_sq = (
            jnp.square(jnp.einsum('ij,ij->i', x_vec, minor_ax1 / min_ax_1_norm)[:, None]) / min_ax_1_norm**2 +
            jnp.square(jnp.einsum('ij,ij->i', x_vec, minor_ax2 / min_ax_2_norm)[:, None]) / min_ax_2_norm**2 +
            jnp.square(jnp.einsum('ij,ij->i', x_vec, mj_ax_norm / mj_ax_norm)[:, None]) / mj_ax_norm**2 
        )
        
        return jnp.exp(-0.5*chi_sq)/jnp.sqrt(determinant) # ret shape should be (n_tracks, 1)
    
    zmin = -200 # what is this? ask marian
    zmax = 200
    
    nbxy = 20 
    ninterxy = 3 # for fine grid search - do +/- 3 steps in every direction, increment of interxy
    interxy = 0.01 # increment for the finer grid search
    xymin = -0.4
    xymax = 0.4
    course_search_points = jnp.arange(0, 20) # x, y  coordinates for grid search
    kde_points = jnp.arange(0, 4000) # i think we all know what this is 
    
    # this function is underneath the generate KDE function so it all gets compiled into one function
    def get_pdf(bx, by, bz):
        centered_bins_x = (bx + 0.5) / nbxy * (xymax - xymin) + xymin
        centered_bins_y = (by + 0.5) / nbxy * (xymax - xymin) + xymin
        return PDF(centered_bins_x, centered_bins_y, bz) # (20, 3)

    
    # 3d grid search
    f0 = jax.vmap(get_pdf, (None, None, 0))
    f1 = jax.vmap(f0, (None, 0, None)) 
    f2 = jax.vmap(f1, (0, None, None)) 
    result = f2(course_search_points, course_search_points, kde_points)

    # need for grid search result - its 2d argmax
#     jax.numpy.unravel_index(result[:, :, -1].argmax(), result[:, :, -1].shape)
    
    return result

In [7]:
from time import time
def put_device(array):
    jax.device_put(array, jax.devices()[0]).device_buffer.device() # hardcode 3090

# syntax for looping over 80k events - useful after it is determiend how much memory/compute each KDE generations are required
# the first iteration will be significantly slower- run multiple times to benchmark
for event in raw_events.iterate():
    event = format_event(event)
    
    start = time()
    # ensure arrays are backed by device memory
    [put_device(array) for array in event]
    
    result = generate_KDE(event).block_until_ready()
    print(time() - start)
    print(result.shape)
    
    break


0.018097877502441406
(20, 20, 4000, 202, 1)
