In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from tomopt.muon import *
from tomopt.inference import *
from tomopt.volume import *
from tomopt.core import *
from tomopt.optimisation import *

import matplotlib.pyplot as plt
import seaborn as sns
from typing import *
import numpy as np

import torch
from torch import Tensor, nn
import torch.nn.functional as F

In [3]:
import torch
from torch import nn, Tensor
from torch.autograd import grad

In [4]:
x = nn.Parameter(Tensor([1]))

In [5]:
y = 3*(x**2)

In [6]:
grad(y, x, retain_graph=True)

(tensor([6.]),)

In [7]:
z = 2*(y**3)

In [8]:
grad(z, y, retain_graph=True)

(tensor([54.]),)

In [9]:
6*(y**2)

tensor([54.], grad_fn=<MulBackward0>)

In [10]:
def arb_rad_length(*,z:float, lw:Tensor, size:float) -> float:
    rad_length = torch.ones(list((lw/size).long()))*X0['beryllium']
    if z >= 0.4 and z <= 0.5: rad_length[5:,5:] = X0['lead']
    return rad_length

In [11]:
def eff_cost(x:Tensor) -> Tensor:
    return torch.expm1(3*F.relu(x))

In [12]:
def res_cost(x:Tensor) -> Tensor:
    return F.relu(x/100)**2

In [13]:
def get_layers():
    layers = []
    lwh = Tensor([1,1,1])
    size = 0.1
    init_eff = 0.5
    init_res = 1000
    pos = 'above'
    for z,d in zip(np.arange(lwh[2],0,-size), [1,1,0,0,0,0,0,0,1,1]):
        if d:
            layers.append(DetectorLayer(pos=pos, init_eff=init_eff, init_res=init_res,
                                        lw=lwh[:2], z=z, size=size, eff_cost_func=eff_cost, res_cost_func=res_cost))
        else:
            pos = 'below'
            layers.append(PassiveLayer(rad_length_func=arb_rad_length, lw=lwh[:2], z=z, size=size))

    return nn.ModuleList(layers) 

In [14]:
volume = Volume(get_layers())

In [15]:
muons = MuonBatch(generate_batch(100), 1)

In [16]:
volume(muons)

In [17]:
hits = muons.get_hits(volume.lw)

In [53]:
xa0 = torch.cat([hits["above"]["xy"][:, 0], hits["above"]["z"][:, 0]], dim=-1)  # reco x, reco y, gen z
xa1 = torch.cat([hits["above"]["xy"][:, 1], hits["above"]["z"][:, 1]], dim=-1)
xb0 = torch.cat([hits["below"]["xy"][:, 1], hits["below"]["z"][:, 1]], dim=-1)
xb1 = torch.cat([hits["below"]["xy"][:, 0], hits["below"]["z"][:, 0]], dim=-1)

dets = volume.get_detectors()
res = []
for p, l, i in zip(("above", "above", "below", "below"), dets, (0, 1, 1, 0)):
    x = l.abs2idx(hits[p]["xy"][:, i])
    r = 1/l.resolution[x[:, 0], x[:, 1]]
    res.append(torch.stack([r,r,torch.zeros_like(r)], dim=-1))
hit_unc = torch.stack(res, dim=1)

# Extrapolate muon-path vectors from hits
v1 = xa1 - xa0
v2 = xb1 - xb0

# scatter locations
v3 = torch.cross(v1, v2, dim=1)  # connecting vector perpendicular to both lines
rhs = xb0 - xa0
lhs = torch.stack([v1, -v2, v3], dim=1).transpose(2, 1)
coefs = torch.linalg.solve(lhs, rhs)  # solve p1+t1*v1 + t3*v3 = p2+t2*v2 => p2-p1 = t1*v1 - t2*v2 + t3*v3

q1 = xa0 + (coefs[:, 0:1] * v1)  # closest point on v1
loc = q1 + (coefs[:, 2:3] * v3 / 2)  # Move halfway along v3 from q1

In [54]:
hit_unc.shape

torch.Size([83, 4, 3])

In [55]:
grad(xa0[0,0], volume.get_detectors()[0].resolution, retain_graph=True)[0].sum()

tensor(-2.9437e-07)

In [56]:
grad(loc[0,0], volume.get_detectors()[0].resolution, retain_graph=True)[0].sum()

tensor(-1.7415e-06)

In [57]:
grad(loc[0,0], xa0, retain_graph=True, allow_unused=True)[0].sum(0)

tensor([ 7.0401, -0.6467, -0.1099])

In [58]:
from tomopt.utils import jacobian

In [59]:
jacobian(loc[0], xa0).sum(1)

tensor([[ 7.0401e+00, -6.4668e-01, -1.0990e-01],
        [-9.0150e+00,  3.1523e+00,  2.0071e-01],
        [ 2.5401e+02, -3.2541e+01, -4.2027e+00]])

In [61]:
dloc_dx.shape, dloc_dx

(torch.Size([83, 3, 3]),
 tensor([[[ 7.0401e+00, -6.4668e-01, -1.0990e-01],
          [-9.0150e+00,  3.1523e+00,  2.0071e-01],
          [ 2.5401e+02, -3.2541e+01, -4.2027e+00]],
 
         [[-1.2281e+01,  5.5989e+00, -1.1548e+00],
          [ 3.0379e+01, -1.7975e+01,  3.3918e+00],
          [ 2.3945e+02, -1.2764e+02,  2.4913e+01]],
 
         [[-7.1015e+00,  3.6663e+00, -3.9459e-01],
          [-5.5972e+00,  2.2908e+00, -4.2108e-01],
          [ 3.3749e+01, -2.1001e+01,  1.2176e+00]],
 
         [[ 2.2093e+01,  5.2995e+00, -2.4290e+00],
          [-8.2092e+00, -4.5888e+00,  7.9899e-01],
          [ 2.3451e+02,  4.9855e+01, -2.6037e+01]],
 
         [[-1.5411e+01,  3.5093e+01, -3.4874e+00],
          [ 1.3775e+00, -4.4802e+00,  3.3421e-01],
          [ 7.2342e+01, -1.7745e+02,  1.6584e+01]],
 
         [[ 6.2211e+00,  1.5844e+01,  9.5286e-01],
          [ 8.9713e+00,  1.3800e+01,  1.0005e+00],
          [-2.2369e+02, -4.0642e+02, -2.7520e+01]],
 
         [[-9.5713e+00, -1.1150e-01, -6

In [62]:
hit_unc.shape

torch.Size([83, 4, 3])

In [63]:
loc_unc = dloc_dx

In [64]:
dloc_dx.shape, hit_unc[:,0][:,:,None].shape

(torch.Size([83, 3, 3]), torch.Size([83, 3, 1]))

In [65]:
loc_unc = ((dloc_dx*hit_unc[:,0][:,:,None])**2).sum(1)

In [66]:
loc_unc = torch.sqrt(loc_unc)

In [67]:
jacobian(loc_unc, dets[0].resolution).sum((-1, -2))[:10]

tensor([[-1.1438e-05, -3.2179e-06, -2.2882e-07],
        [-3.2768e-05, -1.8826e-05, -3.5830e-06],
        [-9.0421e-06, -4.3231e-06, -5.7707e-07],
        [-2.3569e-05, -7.0102e-06, -2.5570e-06],
        [-1.5472e-05, -3.5377e-05, -3.5034e-06],
        [-1.0917e-05, -2.1011e-05, -1.3816e-06],
        [-1.4951e-05, -1.7285e-06, -1.2190e-06],
        [-1.4833e-05, -1.3840e-05, -2.9136e-06],
        [-2.3187e-05, -2.6574e-06, -4.1735e-06],
        [-8.0298e-07, -1.7680e-06, -4.5084e-08]])

## Check cross-terms

In [89]:
unc2_sum = None
for i, xi in enumerate([xa0, xa1, xb0, xb1]):
    for j, xj in enumerate([xa0, xa1, xb0, xb1]):
        if j < i: continue
        dloc_dx_2 = jacobian(loc, xi).sum((2))*jacobian(loc, xj).sum((2)) if i != j else jacobian(loc, xi).sum((2))**2
        unc_2 = (dloc_dx_2*hit_unc[:,i][:,:,None]*hit_unc[:,j][:,:,None]).sum(1)  # Muons, (x,y,z)
        if unc2_sum is None:
            unc2_sum = unc_2
        else:
            unc2_sum = unc2_sum + unc_2
loc_unc = torch.sqrt(unc2_sum)

In [91]:
loc_unc[:10]

tensor([[0.0191, 0.0080, 0.0007],
        [0.0408, 0.0398, 0.0050],
        [0.0179, 0.0158, 0.0009],
        [0.0256, 0.0172, 0.0029],
        [0.0603, 0.0407, 0.0122],
        [0.0168, 0.0283, 0.0015],
        [0.0250, 0.0064, 0.0018],
        [0.0232, 0.0152, 0.0052],
        [0.0270, 0.0031, 0.0048],
        [0.0018, 0.0128, 0.0002]], grad_fn=<SliceBackward>)

In [92]:
jacobian(loc_unc, dets[0].resolution).sum((-1, -2))[:10]

tensor([[-3.6153e-06, -8.4130e-07, -3.5390e-08],
        [-1.3012e-05, -4.2293e-06, -1.2901e-06],
        [-2.0833e-06, -6.6369e-07, -1.8301e-07],
        [-1.1279e-05, -1.2953e-06, -1.1254e-06],
        [-1.8583e-06, -1.5323e-05, -5.0206e-07],
        [-3.7271e-06, -8.0469e-06, -6.3470e-07],
        [-4.2748e-06, -9.8556e-08, -4.1889e-07],
        [-4.4316e-06, -6.2592e-06, -8.0927e-07],
        [-1.0209e-05, -7.0036e-07, -1.8227e-06],
        [ 4.0789e-08, -1.4331e-07, -4.6976e-09]])