In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from typing import *
import numpy as np
from functools import partial
from fastprogress import progress_bar

import torch
from torch import Tensor, nn
import torch.nn.functional as F

from tomopt.volume import *
from tomopt.muon import *
from tomopt.inference import *
from tomopt.optimisation import *
from tomopt.core import *
from tomopt.utils import *
from tomopt.plotting import *
from tomopt.benchmarks.phi_detector import *

import seaborn as sns

In [3]:
DEVICE = torch.device("cpu")

In [4]:
def get_layers():
    layers = []
    lwh = Tensor([1,1,1])
    size = 0.1
    init_eff = 0.8
    init_res = 1000
    n_panels = 4
    layers.append(PanelDetectorLayer(pos='above', lw=lwh[:2], z=1, size=2*size,
                                     panels=[PhiDetectorPanel(init_phi=torch.rand(1)*2*torch.pi, res=init_res, eff=init_eff,
                                                              init_z=1-(i*(2*size)/n_panels), device=DEVICE) for i in range(n_panels)]))
    for z in [0.8,0.7,0.6,0.5,0.4,0.3]:
        layers.append(PassiveLayer(lw=lwh[:2], z=z, size=size, device=DEVICE))
    layers.append(PanelDetectorLayer(pos='below', lw=lwh[:2], z=0.2, size=2*size,
                                     panels=[PhiDetectorPanel(init_phi=torch.rand(1)*2*torch.pi, res=init_res, eff=init_eff,
                                                      init_z=0.2-(i*(2*size)/n_panels), device=DEVICE) for i in range(n_panels)]))

    return nn.ModuleList(layers) 

In [5]:
volume = Volume(get_layers())

In [6]:
gen = MuonGenerator2016.from_volume(volume, fixed_mom=None)

In [7]:
mu = MuonBatch(MuonResampler().resample(gen(10), volume, gen), init_z=1)

In [8]:
h = volume.get_detectors()[0].panels[0].get_hits(mu)

In [9]:
jacobian(h['reco_h'], volume.get_detectors()[0].panels[0].phi)

tensor([[ 0.1744],
        [ 0.1790],
        [ 0.1156],
        [ 0.0766],
        [-0.0966],
        [ 0.2000],
        [ 0.7836],
        [-0.1940],
        [ 0.5907],
        [ 0.5078]])

In [10]:
mu._hits

defaultdict(<function tomopt.muon.muon_batch.MuonBatch.__init__.<locals>.<lambda>()>,
            {})

In [11]:
mu.append_hits(volume.get_detectors()[0].panels[0].get_hits(mu), 'above')

In [12]:
h = mu._hits['above']['reco_h'];h

[tensor([[-0.2639],
         [-0.6542],
         [-0.8418],
         [-0.2200],
         [-0.3870],
         [-1.1488],
         [-0.7932],
         [-0.8816],
         [-0.7929],
         [-0.0874]], grad_fn=<AddBackward0>)]

In [13]:
h[0]

tensor([[-0.2639],
        [-0.6542],
        [-0.8418],
        [-0.2200],
        [-0.3870],
        [-1.1488],
        [-0.7932],
        [-0.8816],
        [-0.7929],
        [-0.0874]], grad_fn=<AddBackward0>)

In [14]:
jacobian(h[0], volume.get_detectors()[0].panels[0].phi)

tensor([[ 0.1744],
        [ 0.1790],
        [ 0.1156],
        [ 0.0766],
        [-0.0966],
        [ 0.2000],
        [ 0.7836],
        [-0.1940],
        [ 0.5907],
        [ 0.5078]])

In [15]:
volume.get_detectors()[0].scatter_and_propagate(mu, 0.1)

In [16]:
jacobian(h[0], volume.get_detectors()[0].panels[0].phi)

tensor([[ 0.1744],
        [ 0.1790],
        [ 0.1156],
        [ 0.0766],
        [-0.0966],
        [ 0.2000],
        [ 0.7836],
        [-0.1940],
        [ 0.5907],
        [ 0.5078]])

In [17]:
mu.append_hits(volume.get_detectors()[0].panels[1].get_hits(mu), 'above')

In [18]:
h = mu._hits['above']['reco_h'];h

[tensor([[-0.2639],
         [-0.6542],
         [-0.8418],
         [-0.2200],
         [-0.3870],
         [-1.1488],
         [-0.7932],
         [-0.8816],
         [-0.7929],
         [-0.0874]], grad_fn=<AddBackward0>),
 tensor([[-0.2238],
         [-0.1954],
         [-0.1823],
         [-0.0754],
         [ 0.0678],
         [-0.2371],
         [-0.7031],
         [ 0.1537],
         [-0.5588],
         [-0.4786]], grad_fn=<AddBackward0>)]

In [19]:
jacobian(h[0], volume.get_detectors()[0].panels[0].phi)

tensor([[ 0.1744],
        [ 0.1790],
        [ 0.1156],
        [ 0.0766],
        [-0.0966],
        [ 0.2000],
        [ 0.7836],
        [-0.1940],
        [ 0.5907],
        [ 0.5078]])

In [20]:
mu = MuonBatch(MuonResampler().resample(gen(10), volume, gen), init_z=1)

In [21]:
mu._hits

defaultdict(<function tomopt.muon.muon_batch.MuonBatch.__init__.<locals>.<lambda>()>,
            {})

In [22]:
volume.get_detectors()[0](mu)

In [23]:
mu._hits

defaultdict(<function tomopt.muon.muon_batch.MuonBatch.__init__.<locals>.<lambda>()>,
            {'above': defaultdict(list,
                         {'reco_h': [tensor([[-0.3784],
                                   [-0.8630],
                                   [-0.1300],
                                   [-0.4498],
                                   [-0.1434],
                                   [ 0.0579],
                                   [-0.4163],
                                   [-0.0141],
                                   [-0.1421],
                                   [-1.1306]], grad_fn=<AddBackward0>),
                           tensor([[-0.7346],
                                   [-0.9334],
                                   [-0.4070],
                                   [-0.8207],
                                   [-0.6310],
                                   [-0.2539],
                                   [-0.0518],
                                   [-0.9322],
          

In [24]:
h = mu._hits['above']['reco_h'];h

[tensor([[-0.3784],
         [-0.8630],
         [-0.1300],
         [-0.4498],
         [-0.1434],
         [ 0.0579],
         [-0.4163],
         [-0.0141],
         [-0.1421],
         [-1.1306]], grad_fn=<AddBackward0>),
 tensor([[-0.7346],
         [-0.9334],
         [-0.4070],
         [-0.8207],
         [-0.6310],
         [-0.2539],
         [-0.0518],
         [-0.9322],
         [-0.4203],
         [-0.5154]], grad_fn=<AddBackward0>),
 tensor([[-0.7719],
         [-0.9949],
         [-0.4090],
         [-0.7875],
         [-0.6817],
         [-0.2599],
         [-0.0902],
         [-0.8872],
         [-0.4068],
         [-0.5904]], grad_fn=<AddBackward0>),
 tensor([[-0.5202],
         [-0.5840],
         [-0.3481],
         [-0.5349],
         [-0.4805],
         [-0.2569],
         [ 0.0742],
         [-0.7514],
         [-0.2839],
         [-0.1480]], grad_fn=<AddBackward0>)]

In [25]:
jacobian(h[0], volume.get_detectors()[0].panels[0].phi)

tensor([[0.7350],
        [0.9321],
        [0.4116],
        [0.8822],
        [0.6060],
        [0.2480],
        [0.0434],
        [0.9785],
        [0.4458],
        [0.5156]])

In [26]:
def arb_rad_length(*,z:float, lw:Tensor, size:float) -> Tensor:
    rad_length = torch.ones(list((lw/size).long()))*X0['beryllium']
    if z >= 0.4 and z <= 0.5: rad_length[5:,5:] = X0['lead']
    return rad_length

In [27]:
volume.load_rad_length(arb_rad_length)

In [65]:
mu = MuonBatch(MuonResampler().resample(gen(100), volume, gen), init_z=1)
volume(mu)

In [66]:
hits = mu.get_hits()

In [67]:
mu.x.shape

torch.Size([100])

In [68]:
hits['above']['reco_h'].shape

torch.Size([100, 4, 1])

In [69]:
from tomopt.inference.scattering import AbsScatterBatch

class PhiDetScatterBatch(AbsScatterBatch):
    def _extract_hits(self) -> None:
        # reco x, reco y, gen z, must be a list to allow computation of uncertainty
        above_hits = torch.stack(
            [
                torch.cat([self.hits["above"]["reco_h"][:, i], self.hits["above"]["phi"][:, i], self.hits["above"]["z"][:, i]], dim=-1)
                for i in range(self.hits["above"]["reco_h"].shape[1])
            ],
            dim=1,
        )  # muons, panels, xyz
        below_hits = torch.stack(
            [
                torch.cat([self.hits["below"]["reco_h"][:, i], self.hits["below"]["phi"][:, i], self.hits["below"]["z"][:, i]], dim=-1)
                for i in range(self.hits["below"]["reco_h"].shape[1])
            ],
            dim=1,
        )
        _above_gen_hits = torch.stack(
            [torch.cat([self.hits["above"]["gen_xy"][:, i], self.hits["above"]["z"][:, i]], dim=-1) for i in range(self.hits["above"]["gen_xy"].shape[1])],
            dim=1,
        )  # muons, panels, xyz
        _below_gen_hits = torch.stack(
            [torch.cat([self.hits["below"]["gen_xy"][:, i], self.hits["below"]["z"][:, i]], dim=-1) for i in range(self.hits["below"]["gen_xy"].shape[1])],
            dim=1,
        )
        self._n_hits_above = above_hits.shape[1]
        self._n_hits_below = below_hits.shape[1]

        # Combine all input vars into single tensor, NB ideally would stack to new dim but can't assume same number of panels above & below
        self._reco_hits = torch.cat((above_hits, below_hits), dim=1)  # muons, all panels, reco h,phi,z
        self._gen_hits = torch.cat((_above_gen_hits, _below_gen_hits), dim=1)  # muons, all panels, true xyz

    def plot_scatter(self, idx: int) -> None:
        raise NotImplementedError("Ah, I see you've just volunteered to implement this!")

    @staticmethod
    def _get_hit_uncs(zordered_panels: List[PhiDetectorPanel], hits: Tensor) -> Tensor:
        uncs: List[Tensor] = []
        for l, h in zip(zordered_panels, hits.unbind(1)):
            r = 1 / l.resolution.expand(len(h), 1)
            uncs.append(torch.cat([r, torch.zeros((len(h), 2), device=r.device)], dim=-1))
        return torch.stack(uncs, dim=1)  # muons, panels, unc h,phi,z, zero unc for phi and z
    

    def _compute_tracks(self) -> None:
        def _get_panels() -> List[PhiDetectorPanel]:
            panels = []
            for det in self.volume.get_detectors():
                if not isinstance(det, PanelDetectorLayer):
                    raise ValueError(f"Detector {det} is not a PanelDetectorLayer")
                panels += [det.panels[j] for j in det.get_panel_zorder()]
            return panels

        self._hit_uncs = self._get_hit_uncs(_get_panels(), self.gen_hits)
#         self._track_in, self._track_start_in = self.get_muon_trajectory(self.above_hits, self.above_hit_uncs, self.volume.lw)
#         self._track_out, self._track_start_out = self.get_muon_trajectory(self.below_hits, self.below_hit_uncs, self.volume.lw)

    def _compute_scatters(self) -> None:
        r"""
        Currently only handles detectors above and below passive volume

        Scatter locations adapted from:
        @MISC {3334866,
            TITLE = {Closest points between two lines},
            AUTHOR = {Brian (https://math.stackexchange.com/users/72614/brian)},
            HOWPUBLISHED = {Mathematics Stack Exchange},
            NOTE = {URL:https://math.stackexchange.com/q/3334866 (version: 2019-08-26)},
            EPRINT = {https://math.stackexchange.com/q/3334866},
            URL = {https://math.stackexchange.com/q/3334866}
        }
        """

        self._extract_hits()
        self._compute_tracks()
#         self._filter_scatters()

#         # Track computations
#         self._cross_track = torch.cross(self.track_in, self.track_out, dim=1)  # connecting vector perpendicular to both lines

#         rhs = self.track_start_out - self.track_start_in
#         lhs = torch.stack([self.track_in, -self.track_out, self._cross_track], dim=1).transpose(2, 1)
#         # coefs = torch.linalg.solve(lhs, rhs)  # solve p1+t1*v1 + t3*v3 = p2+t2*v2 => p2-p1 = t1*v1 - t2*v2 + t3*v3
#         self._track_coefs = (lhs.inverse() @ rhs[:, :, None]).squeeze(-1)

#     @staticmethod
#     def get_muon_trajectory(hits: Tensor, uncs: Tensor, lw: Tensor) -> Tuple[Tensor, Tensor]:
#         r"""
#         hits = (muons,panels,(h,phi,z))
#         uncs = (muons,panels,(unc,0,0))

#         Assumes no uncertainty for z and phi

#         Uses an analytic likelihood-maximisation: L = \prod_i G[h_i - h_{i,opt}(x0,y0,theta,phi)]
#         where:
#         h_{i,opt} = x_{i,opt} cos(phi_i) + y_{i,opt} sin(phi_i)
#         x_{i,opt} = x0 + z_i tan(theta) cos(phi)
#         y_{i,opt} = y0 + z_i tan(theta) sin(phi)

#         x0, y0, z0 = track coordinates at z=0
#         theta, phi = track angles

#         x0 = sum_i[ ((cos(phi_i)/(unc_i^2))*(h_i-(z_i*tan(theta)*cos(phi)*cos(phi_i))-(sin(phi_i)*(y0+(z_i*tan(theta)*sin(phi))))) ] / sum_i[ ((cos(phi_i)^2)/(unc_i^2)) ]
#         y0 = sum_i[ ((sin(phi_i)/(unc_i^2))*(h_i-(z_i*tan(theta)*sin(phi)*cos(phi_i))-(cos(phi_i)*(x0+(z_i*tan(theta)*sin(phi))))) ] / sun_i[ ((sin(phi_i)^2)/(unc_i^2)) ]
#         theta = tan^-1[ sum_i[ (z_i/(unc_i^2))*((h_i*((cos(phi)*cos(phi_i))+(sin(phi)*sin(phi_i))))-(x0*((cos(phi_i)*sin(phi_i)*sin(phi))+(cos(phi)*cos(phi_i)*cos(phi_i))))-(y0((cos(phi)*cos(phi_i)*sin(phi_i))+(sin(phi)*sin(phi_i)*sin(phi_i))))) ] / sum_i[ ((z_i^2)/(unc_i^2))*((cos(phi)*cos(phi_i))+(sin(phi)*sin(sin(phi_i)))) ] ]
#         phi = sin^-1[ ((z_i*sin(phi_i))/(unc_i^2))*(h_i-(y0*sin(phi_i))-(x0*cos(phi_i))-(z_i*tan(theta)*cos(phi)*cos(phi_i))) ] / sum_i[ (((z_i*sin(phi_i))^2)/(unc_i^2))*tan(theta) ]

#         In eval mode:
#             Muons with <2 hits within panels have NaN trajectory.
#             Muons with >=2 hits in panels have valid trajectories
#         """

#         hits = torch.where(torch.isinf(hits), lw.mean().type(hits.type()) / 2, hits)
#         uncs = torch.nan_to_num(uncs)  # Set Infs to large number

#         x0: Tensor
#         y0: Tensor  # Track positions at Z0=0
#         phi: Tensor
#         theta: Tensor  # Track angles

In [70]:
sb = PhiDetScatterBatch(mu, volume)

In [71]:
sb.above_hits.shape, sb.above_hit_uncs.shape

(torch.Size([100, 4, 3]), torch.Size([100, 4, 3]))

In [72]:
volume.get_detectors()[0].panels[0].phi

Parameter containing:
tensor(4.6192, requires_grad=True)

In [73]:
jacobian(sb.above_hits, volume.get_detectors()[0].panels[0].phi)

tensor([[[0.5373, 1.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]],

        [[1.1489, 1.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]],

        [[0.5100, 1.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]],

        ...,

        [[0.1922, 1.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]],

        [[0.8471, 1.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]],

        [[0.7548, 1.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]]])

In [74]:
hits = sb.above_hits
uncs = sb.above_hit_uncs

In [75]:
hits = torch.where(torch.isinf(hits), volume.lw.mean().type(hits.type()) / 2, hits)
uncs = torch.nan_to_num(uncs)  # Set Infs to large number

In [76]:
uncs2 = uncs[:,:,0:1]**2

In [77]:
track = torch.ones(len(hits),1,4, device=hits.device, requires_grad=True)/2

In [78]:
def get_nll(hits:Tensor, uncs2:Tensor, track:Tensor) -> Tensor:
    x0 = track[:,:,0:1]
    y0 = track[:,:,1:2]
    phi = track[:,:,2:3]
    theta = track[:,:,3:4]
    h_i = hits[:,:,0:1]
    phi_i = hits[:,:,1:2]
    z_i = hits[:,:,2:3]

    x_i_opt = x0 + (z_i*theta.tan()*phi.cos())
    y_i_opt = y0 + (z_i*theta.tan()*phi.sin())
    h_i_opt = (x_i_opt*phi_i.cos()) + (y_i_opt*phi_i.sin())  

    return (0.5*((h_i-h_i_opt).square()/uncs2).sum(1))-(1/(2*torch.pi*uncs2).sqrt().sum(1))

In [79]:
nll = get_nll(hits, uncs2, track)
nll.shape

torch.Size([100, 1])

In [80]:
grad = jacobian(nll, track, create_graph=True).sum((2,3))

In [46]:
grad.shape

torch.Size([10, 1, 4])

In [47]:
grad[0]

tensor([[  45554.0938, -106952.0938,  -60050.1172,  -20507.5391]],
       grad_fn=<SelectBackward0>)

In [48]:
hesse = jacobian(grad, track, create_graph=True).sum((3,4))

In [49]:
hesse.shape

torch.Size([10, 1, 4, 4])

In [50]:
(grad@hesse.inverse()).shape

torch.Size([10, 10, 1, 4])

In [51]:
(grad.unsqueeze(1)@hesse.inverse()).squeeze(1).shape

torch.Size([10, 1, 4])

In [52]:
step = (grad.unsqueeze(1)@hesse.inverse()).squeeze(1)

In [53]:
step.mean(0)

tensor([[4.0454e-01, 2.7361e-01, 2.1547e-04, 6.3728e-04]],
       grad_fn=<MeanBackward1>)

In [54]:
lr = Tensor([[[0.1,0.1,100,10]]])

In [55]:
lr*step.mean(0)

tensor([[[0.0405, 0.0274, 0.0215, 0.0064]]], grad_fn=<MulBackward0>)

In [81]:
lr = Tensor([[[0.1,0.1,100,10]]])
track = torch.ones(len(hits),1,4, device=hits.device, requires_grad=True)/2

for i in progress_bar(range(100)):  # Newton optimise nuisances
    nll = get_nll(hits, uncs2, track)
    grad = jacobian(nll, track, create_graph=True).sum((2,3))
    hesse = jacobian(grad, track, create_graph=True).sum((3,4))
    step = lr*(grad.unsqueeze(1)@hesse.inverse()).squeeze(1)
    track = track-step
    if i > 0 and i % 10 == 0:
            lr = lr / 2
    print(nll.mean(0).data, track.mean(0).data)

tensor([552313.7500]) tensor([[0.4526, 0.4739, 0.4703, 0.4935]])
tensor([459080.0938]) tensor([[0.4125, 0.4534, 0.4413, 0.4841]])
tensor([366024.0938]) tensor([[0.3781, 0.4370, 0.3915, 0.4724]])
tensor([299159.9688]) tensor([[0.3505, 0.4242, 0.3703, 0.4530]])
tensor([254436.2812]) tensor([[0.3306, 0.4144, 0.3774, 0.4344]])
tensor([206660.5156]) tensor([[0.3156, 0.4069, 0.3972, 0.4153]])
tensor([175511.4062]) tensor([[0.3051, 0.4011, 0.4483, 0.3922]])
tensor([132246.6562]) tensor([[0.2990, 0.3972, 0.5605, 0.3667]])
tensor([103370.5234]) tensor([[0.2959, 0.3939, 0.5222, 0.3501]])
tensor([87922.6016]) tensor([[0.2956, 0.3934, 0.6164, 0.3302]])
tensor([89197.1016]) tensor([[0.2984, 0.3923, 0.7041, 0.3203]])
tensor([79813.7188]) tensor([[0.3012, 0.3917, 0.7114, 0.3158]])
tensor([73537.2500]) tensor([[0.3044, 0.3911, 0.7800, 0.3090]])
tensor([63532.7539]) tensor([[0.3075, 0.3907, 0.8599, 0.2989]])
tensor([63752.5508]) tensor([[0.3112, 0.3904, 0.9420, 0.2865]])
tensor([71545.9375]) tensor([[0

KeyboardInterrupt: 

In [64]:
%%time
jac = jacobian(track, volume.get_detectors()[0].panels[0].phi)

CPU times: user 7.71 s, sys: 931 ms, total: 8.65 s
Wall time: 4.38 s


In [62]:
jac.shape

torch.Size([10, 1, 4])