In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from tomopt.muon import *
from tomopt.inference import *
from tomopt.volume import *
from tomopt.core import *
from tomopt.optimisation import *

import matplotlib.pyplot as plt
import seaborn as sns
from typing import *
import numpy as np

import torch
from torch import Tensor, nn
import torch.nn.functional as F

In [3]:
import torch
from torch import nn, Tensor
from torch.autograd import grad

In [4]:
x = nn.Parameter(Tensor([1]))

In [5]:
y = 3*(x**2)

In [6]:
grad(y, x, retain_graph=True)

(tensor([6.]),)

In [7]:
z = 2*(y**3)

In [8]:
grad(z, y, retain_graph=True)

(tensor([54.]),)

In [9]:
6*(y**2)

tensor([54.], grad_fn=<MulBackward0>)

In [10]:
def arb_rad_length(*,z:float, lw:Tensor, size:float) -> float:
    rad_length = torch.ones(list((lw/size).long()))*X0['beryllium']
    if z >= 0.4 and z <= 0.5: rad_length[5:,5:] = X0['lead']
    return rad_length

In [11]:
def eff_cost(x:Tensor) -> Tensor:
    return torch.expm1(3*F.relu(x))

In [12]:
def res_cost(x:Tensor) -> Tensor:
    return F.relu(x/100)**2

In [13]:
def get_layers():
    layers = []
    lwh = Tensor([1,1,1])
    size = 0.1
    init_eff = 0.5
    init_res = 1000
    pos = 'above'
    for z,d in zip(np.arange(lwh[2],0,-size), [1,1,0,0,0,0,0,0,1,1]):
        if d:
            layers.append(DetectorLayer(pos=pos, init_eff=init_eff, init_res=init_res,
                                        lw=lwh[:2], z=z, size=size, eff_cost_func=eff_cost, res_cost_func=res_cost))
        else:
            pos = 'below'
            layers.append(PassiveLayer(rad_length_func=arb_rad_length, lw=lwh[:2], z=z, size=size))

    return nn.ModuleList(layers) 

In [14]:
volume = Volume(get_layers())

In [15]:
muons = MuonBatch(generate_batch(100), 1)

In [16]:
volume(muons)

In [17]:
hits = muons.get_hits(volume.lw)

In [18]:
xa0 = torch.cat([hits["above"]["xy"][:, 0], hits["above"]["z"][:, 0]], dim=-1)  # reco x, reco y, gen z
xa1 = torch.cat([hits["above"]["xy"][:, 1], hits["above"]["z"][:, 1]], dim=-1)
xb0 = torch.cat([hits["below"]["xy"][:, 1], hits["below"]["z"][:, 1]], dim=-1)
xb1 = torch.cat([hits["below"]["xy"][:, 0], hits["below"]["z"][:, 0]], dim=-1)

dets = volume.get_detectors()
res = []
for p, l, i in zip(("above", "above", "below", "below"), dets, (0, 1, 1, 0)):
    x = l.abs2idx(hits[p]["xy"][:, i])
    r = 1/l.resolution[x[:, 0], x[:, 1]]
    res.append(torch.stack([r,r,torch.zeros_like(r)], dim=-1))
hit_unc = torch.stack(res, dim=1)

# Extrapolate muon-path vectors from hits
v1 = xa1 - xa0
v2 = xb1 - xb0

# scatter locations
v3 = torch.cross(v1, v2, dim=1)  # connecting vector perpendicular to both lines
rhs = xb0 - xa0
lhs = torch.stack([v1, -v2, v3], dim=1).transpose(2, 1)
coefs = torch.linalg.solve(lhs, rhs)  # solve p1+t1*v1 + t3*v3 = p2+t2*v2 => p2-p1 = t1*v1 - t2*v2 + t3*v3

q1 = xa0 + (coefs[:, 0:1] * v1)  # closest point on v1
loc = q1 + (coefs[:, 2:3] * v3 / 2)  # Move halfway along v3 from q1

# Theta deviations
_theta_in = torch.arctan(v1[:, :2] / v1[:, 2:3])
_theta_out = torch.arctan(v2[:, :2] / v2[:, 2:3])
_dtheta = torch.abs(_theta_in - _theta_out)

# xy deviations
_dxy = coefs[:, 2:3] * v3[:, :2]

In [19]:
hit_unc.shape

torch.Size([80, 4, 3])

In [20]:
grad(xa0[0,0], volume.get_detectors()[0].resolution, retain_graph=True)[0].sum()

tensor(-4.5297e-07)

In [21]:
grad(loc[0,0], volume.get_detectors()[0].resolution, retain_graph=True)[0].sum()

tensor(-2.6262e-06)

In [22]:
grad(loc[0,0], xa0, retain_graph=True, allow_unused=True)[0].sum(0)

tensor([-1.1140,  4.1284,  0.5887])

In [23]:
from tomopt.utils import jacobian

In [24]:
jacobian(loc[0], xa0).sum(1)

tensor([[ -1.1140,   4.1284,   0.5887],
        [ -2.8012, -20.4058,  -2.7180],
        [ 18.4357, 122.8164,  16.3223]])

In [25]:
hit_unc.shape

torch.Size([80, 4, 3])

## Loc Unc

In [143]:
def compute_unc(var:Tensor, hits:List[Tensor], hit_uncs:List[Tensor]) -> Tensor:
    unc2_sum = None
    for i, (xi, unci) in enumerate(zip(hits, hit_uncs)):
        for j, (xj, uncj)in enumerate(zip(hits, hit_uncs)):
            if j < i: continue
            dv_dx_2 = jacobian(var, xi).sum((2))*jacobian(var, xj).sum((2)) if i != j else jacobian(var, xi).sum((2))**2  # Muons, var_xyz, hit_xyz
            unc_2 = (dv_dx_2*unci[:,None]*uncj[:,None]).sum(2)  # Muons, (x,y,z)
            if unc2_sum is None:
                unc2_sum = unc_2
            else:
                unc2_sum = unc2_sum + unc_2
    return torch.sqrt(unc2_sum)

In [148]:
loc_unc = compute_unc(loc, [xa0, xa1, xb0, xb1], [hit_unc[:,0], hit_unc[:,1], hit_unc[:,2], hit_unc[:,3]])

In [149]:
loc_unc[:10]

tensor([[0.0056, 0.0303, 0.1924],
        [0.0583, 0.0431, 2.2424],
        [0.0243, 0.0101, 0.1479],
        [0.0083, 0.0215, 0.1487],
        [0.0705, 0.0170, 0.4718],
        [0.0957, 0.1632, 1.1465],
        [0.0149, 0.0085, 0.2138],
        [0.0120, 0.0163, 0.2313],
        [0.0153, 0.0231, 0.2059],
        [0.0213, 0.0252, 0.6726]], grad_fn=<SliceBackward>)

In [150]:
jacobian(loc_unc, dets[0].resolution).sum((-1, -2))[:10]

tensor([[-1.5340e-06, -6.6594e-06, -4.0085e-05],
        [-4.7773e-06, -6.1904e-06, -2.5527e-04],
        [-6.7754e-06, -2.8755e-06, -4.4612e-05],
        [-5.6997e-08, -5.5747e-09, -4.3923e-07],
        [-2.2656e-05, -5.4023e-06, -1.5331e-04],
        [-3.8480e-05, -6.5496e-05, -4.6314e-04],
        [-1.3767e-06, -4.1200e-07, -1.3547e-05],
        [-3.7930e-06, -4.5509e-06, -6.5607e-05],
        [-1.8023e-06, -2.9229e-06, -3.1267e-05],
        [-6.0952e-06, -7.5207e-06, -1.9299e-04]])

# $\Delta\theta$ unc

In [151]:
xa0.shape

torch.Size([80, 3])

In [152]:
dtheta_unc = compute_unc(_dtheta, [xa0, xa1, xb0, xb1], [hit_unc[:,0], hit_unc[:,1], hit_unc[:,2], hit_unc[:,3]])

In [153]:
dtheta_unc[:10]

tensor([[0.0141, 0.0138],
        [0.0141, 0.0141],
        [0.0137, 0.0141],
        [0.0141, 0.0139],
        [0.0138, 0.0141],
        [0.0140, 0.0139],
        [0.0141, 0.0141],
        [0.0141, 0.0141],
        [0.0141, 0.0140],
        [0.0141, 0.0141]], grad_fn=<SliceBackward>)

In [154]:
jacobian(dtheta_unc, dets[0].resolution).sum((-1, -2))[:10]

tensor([[-3.5361e-06, -3.4871e-06],
        [-3.5330e-06, -3.5349e-06],
        [-3.3855e-06, -3.5283e-06],
        [-3.5259e-06, -3.4416e-06],
        [-3.4357e-06, -3.5327e-06],
        [-3.5033e-06, -3.4599e-06],
        [-3.5290e-06, -3.5298e-06],
        [-3.5359e-06, -3.5176e-06],
        [-3.5079e-06, -3.4778e-06],
        [-3.5342e-06, -3.5339e-06]])

# $\theta_1$ unc

In [155]:
theta_unc = compute_unc(_theta_in, [xa0, xa1], [hit_unc[:,0], hit_unc[:,1]])

In [156]:
theta_unc[:10]

tensor([[0.0100, 0.0098],
        [0.0100, 0.0100],
        [0.0096, 0.0100],
        [0.0100, 0.0098],
        [0.0097, 0.0100],
        [0.0099, 0.0098],
        [0.0100, 0.0100],
        [0.0100, 0.0099],
        [0.0099, 0.0099],
        [0.0100, 0.0100]], grad_fn=<SliceBackward>)

In [157]:
jacobian(theta_unc, dets[0].resolution).sum((-1, -2))[:10]

tensor([[-4.9973e-06, -4.9087e-06],
        [-4.9967e-06, -4.9985e-06],
        [-4.8220e-06, -4.9870e-06],
        [-4.9878e-06, -4.8979e-06],
        [-4.8721e-06, -4.9956e-06],
        [-4.9585e-06, -4.8955e-06],
        [-4.9802e-06, -4.9927e-06],
        [-4.9948e-06, -4.9749e-06],
        [-4.9674e-06, -4.9295e-06],
        [-4.9971e-06, -4.9953e-06]])

# $\Delta xy$ unc

In [158]:
dxy_unc = compute_unc(_dxy, [xa0, xa1, xb0, xb1], [hit_unc[:,0], hit_unc[:,1], hit_unc[:,2], hit_unc[:,3]])

In [159]:
dxy_unc[:10]

tensor([[0.0053, 0.0049],
        [0.0094, 0.0116],
        [0.0031, 0.0056],
        [0.0080, 0.0007],
        [0.0061, 0.0060],
        [0.0073, 0.0089],
        [0.0009, 0.0059],
        [0.0006, 0.0058],
        [0.0043, 0.0037],
        [0.0053, 0.0085]], grad_fn=<SliceBackward>)

In [160]:
jacobian(dxy_unc, dets[0].resolution).sum((-1, -2))[:10]

tensor([[-2.0146e-06, -4.3391e-08],
        [-2.6889e-07, -1.9700e-06],
        [-1.2431e-06, -1.4905e-06],
        [-2.0203e-08, -4.0787e-08],
        [-1.8718e-06, -2.0502e-06],
        [-2.6135e-06, -3.8774e-06],
        [-9.7845e-08, -3.6993e-07],
        [-2.4772e-07, -1.6365e-06],
        [-6.0109e-07, -6.3328e-07],
        [-2.6101e-07, -3.2245e-06]])