In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from tomopt.muon import *
from tomopt.inference import *
from tomopt.volume import *
from tomopt.core import *
from tomopt.optimisation import *

import matplotlib.pyplot as plt
import seaborn as sns
from typing import *
import numpy as np

import torch
from torch import Tensor, nn
import torch.nn.functional as F

# Basics

In [3]:
def arb_rad_length(*,z:float, lw:Tensor, size:float) -> float:
    rad_length = torch.ones(list((lw/size).long()))*X0['beryllium']
    if z >= 0.4 and z <= 0.5: rad_length[5:,5:] = X0['lead']
    return rad_length

In [4]:
def eff_cost(x:Tensor) -> Tensor:
    return torch.expm1(3*F.relu(x))

In [5]:
def res_cost(x:Tensor) -> Tensor:
    return F.relu(x/100)**2

In [6]:
def get_layers():
    layers = []
    lwh = Tensor([1,1,1])
    size = 0.1
    init_eff = 0.5
    init_res = 1000
    pos = 'above'
    for z,d in zip(np.arange(lwh[2],0,-size), [1,1,0,0,0,0,0,0,1,1]):
        if d:
            layers.append(DetectorLayer(pos=pos, init_eff=init_eff, init_res=z*init_res,
                                        lw=lwh[:2], z=z, size=size, eff_cost_func=eff_cost, res_cost_func=res_cost))
        else:
            pos = 'below'
            layers.append(PassiveLayer(rad_length_func=arb_rad_length, lw=lwh[:2], z=z, size=size))

    return nn.ModuleList(layers) 

In [7]:
volume = Volume(get_layers())

# Scatter inference

In [8]:
muons = MuonBatch(generate_batch(100), init_z=1.0)

In [9]:
volume(muons)

In [38]:
for i, l in enumerate(zip(volume.get_detectors(), range(5))): print(i,l)

0 (DetectorLayer(), 0)
1 (DetectorLayer(), 1)
2 (DetectorLayer(), 2)
3 (DetectorLayer(), 3)


In [139]:
list(range(10))

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [143]:
list(range(9,-1, -1))

[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]

In [162]:
from tomopt.utils import jacobian

class GenScatterBatch(ScatterBatch):    
    @staticmethod
    def get_muon_trajectory(hits: List[Tensor], uncs: List[Tensor]) -> Tensor:
        r"""
        hits = [muons,detector,(x,y,z)]
        uncs = [muons,detector,(unc,unc,0)]

        Assume same unceratinty for x and y
        """
        
        hits, uncs = torch.stack(hits, dim=1), torch.stack(uncs, dim=1)

        inv_unc2 = uncs[:, :, 0:1] ** -2
        sum_inv_unc2 = inv_unc2.sum(dim=1)
        mean_xyz = torch.sum(hits * inv_unc2, dim=1) / sum_inv_unc2
        mean_xyz_z = torch.sum(hits * hits[:, :, 2:3] * inv_unc2, dim=1) / sum_inv_unc2
        mean_xy = mean_xyz[:, :2]
        mean_z = mean_xyz[:, 2:3]
        mean_xy_z = mean_xyz_z[:, :2]
        mean_z2 = mean_xyz_z[:, 2:3]

        xy_star = (mean_xy - ((mean_z * mean_xy_z) / mean_z2)) / (1 - (mean_z.square() / mean_z2))
        angles = (mean_xy_z - (xy_star * mean_z)) / mean_z2

        def _calc_xyz(z: Tensor) -> Tensor:
            return torch.cat([xy_star + (angles * z), z], dim=-1)
        
        return _calc_xyz(hits[:, 1, 2:3]) - _calc_xyz(hits[:, 0, 2:3])
        
    def compute_scatters(self) -> None:
        r"""
        Currently only handles 2 detectors above and below passive volume
        """
        
        # reco x, reco y, gen z, must be a list to allow computation of uncertainty
        self.above_hits = [torch.cat([self.hits["above"]["xy"][:,i], self.hits["above"]["z"][:,i]], dim=-1) for i in range(self.hits["above"]["xy"].shape[1])]  
        self.below_hits = [torch.cat([self.hits["below"]["xy"][:,i], self.hits["below"]["z"][:,i]], dim=-1) for i in range(self.hits["below"]["xy"].shape[1])]
        self.n_hits_above = len(self.above_hits)
        
        def _get_hit_uncs(dets:List[DetectorLayer], hits:List[Tensor]) -> List[Tensor]:
            res = []
            for i, (l, h) in enumerate(zip(dets,hits)):
                x = l.abs2idx(h)
                r = 1 / l.resolution[x[:, 0], x[:, 1]]
                res.append(torch.stack([r, r, torch.zeros_like(r)], dim=-1))
            return res
        
        self.above_hit_uncs = _get_hit_uncs(self.volume.get_detectors()[:self.n_hits_above], self.above_hits)
        self.below_hit_uncs = _get_hit_uncs(self.volume.get_detectors()[self.n_hits_above:], self.below_hits)
        
        v1 = self.get_muon_trajectory(self.above_hits, uncs=self.above_hit_uncs)
        v2 = self.get_muon_trajectory(self.below_hits, uncs=self.below_hit_uncs)
        
        # scatter locations
        v3 = torch.cross(v1, v2, dim=1)  # connecting vector perpendicular to both lines
        rhs = self.below_hits[0] - self.above_hits[0]
        lhs = torch.stack([v1, -v2, v3], dim=1).transpose(2, 1)
        coefs = torch.linalg.solve(lhs, rhs)  # solve p1+t1*v1 + t3*v3 = p2+t2*v2 => p2-p1 = t1*v1 - t2*v2 + t3*v3

        q1 = self.above_hits[0] + (coefs[:, 0:1] * v1)  # closest point on v1
        self._loc = q1 + (coefs[:, 2:3] * v3 / 2)  # Move halfway along v3 from q1
        self._loc_unc: Optional[Tensor] = None

        # Theta deviations
        self._theta_in = torch.arctan(v1[:, :2] / v1[:, 2:3])
        self._theta_out = torch.arctan(v2[:, :2] / v2[:, 2:3])
        self._dtheta = torch.abs(self._theta_in - self._theta_out)
        self._theta_in_unc: Optional[Tensor] = None
        self._theta_out_unc: Optional[Tensor] = None
        self._dtheta_unc: Optional[Tensor] = None

        # xy deviations
        self._dxy = coefs[:, 2:3] * v3[:, :2]
        self._dxy_unc: Optional[Tensor] = None
            
    def _compute_unc(self, var: Tensor, hits: List[Tensor], hit_uncs: List[Tensor]) -> Tensor:
        unc2_sum = None
        for i, (xi, unci) in enumerate(zip(hits, hit_uncs)):
            for j, (xj, uncj) in enumerate(zip(hits, hit_uncs)):
                if j < i:
                    continue
                dv_dx_2 = jacobian(var, xi).sum((2)) * jacobian(var, xj).sum((2)) if i != j else jacobian(var, xi).sum((2)) ** 2  # Muons, var_xyz, hit_xyz
                
                unc_2 = (dv_dx_2 * unci[:, None] * uncj[:, None]).sum(2)  # Muons, (x,y,z)
                if unc2_sum is None:
                    unc2_sum = unc_2
                else:
                    unc2_sum = unc2_sum + unc_2
        return torch.sqrt(unc2_sum)

    @property
    def location(self) -> Tensor:
        return self._loc

    @property
    def location_unc(self) -> Tensor:
        if self._loc_unc is None:
            self._loc_unc = self._compute_unc(
                var=self._loc,
                hits=self.above_hits+self.below_hits,
                hit_uncs=self.above_hit_uncs+self.below_hit_uncs,
            )
        return self._loc_unc

    @property
    def dtheta(self) -> Tensor:
        return self._dtheta

    @property
    def dtheta_unc(self) -> Tensor:
        if self._dtheta_unc is None:
            self._dtheta_unc = self._compute_unc(
                var=self._dtheta,
                hits=self.above_hits+self.below_hits,
                hit_uncs=self.above_hit_uncs+self.below_hit_uncs,
            )
        return self._dtheta_unc

    @property
    def dxy(self) -> Tensor:
        return self._dxy

    @property
    def dxy_unc(self) -> Tensor:
        if self._dxy_unc is None:
            self._dxy_unc = self._compute_unc(
                var=self._dxy,
                hits=self.above_hits+self.below_hits,
                hit_uncs=self.above_hit_uncs+self.below_hit_uncs,
            )
        return self._dxy_unc

    @property
    def theta_in(self) -> Tensor:
        return self._theta_in

    @property
    def theta_in_unc(self) -> Tensor:
        if self._theta_in_unc is None:
            self._theta_in_unc = self._compute_unc(var=self._theta_in, hits=self.above_hits, hit_uncs=self.above_hit_uncs)
        return self._theta_in_unc

    @property
    def theta_out(self) -> Tensor:
        return self._theta_out

    @property
    def theta_out_unc(self) -> Tensor:
        if self._theta_out_unc is None:
            self._theta_out_unc = self._compute_unc(var=self._theta_out, hits=self.below_hits, hit_uncs=self.below_hit_uncs)
        return self._theta_out_unc

In [169]:
gsb = GenScatterBatch(muons, volume)
sb = ScatterBatch(muons, volume)

In [170]:
gsb.location[:10], sb.location[:10]

(tensor([[ 0.9271,  0.8857,  0.1793],
         [ 0.2274,  0.5313,  0.1083],
         [ 0.7479,  0.2227,  0.0494],
         [ 0.9296,  0.3082,  0.3317],
         [ 0.9193,  0.6987,  0.0650],
         [ 0.6436,  0.3648, -0.1672],
         [ 0.9118,  0.3332,  0.0391],
         [ 0.5644,  0.0162, -0.0211],
         [ 0.2386,  0.0617,  0.0289],
         [ 0.5692,  0.4057,  0.1498]], grad_fn=<SliceBackward>),
 tensor([[ 0.9271,  0.8857,  0.1793],
         [ 0.2274,  0.5313,  0.1083],
         [ 0.7479,  0.2227,  0.0494],
         [ 0.9296,  0.3082,  0.3317],
         [ 0.9193,  0.6987,  0.0650],
         [ 0.6436,  0.3648, -0.1672],
         [ 0.9118,  0.3332,  0.0391],
         [ 0.5644,  0.0162, -0.0211],
         [ 0.2386,  0.0617,  0.0289],
         [ 0.5692,  0.4057,  0.1498]], grad_fn=<SliceBackward>))

In [171]:
gsb.location_unc[:10], sb.location_unc[:10]

(tensor([[0.0095, 0.0086, 0.0761],
         [0.0058, 0.0058, 0.0351],
         [0.0081, 0.0083, 0.0778],
         [0.0134, 0.0083, 0.1844],
         [0.0098, 0.0075, 0.0512],
         [0.0580, 0.0420, 0.3434],
         [0.0072, 0.0261, 0.1366],
         [0.0122, 0.0088, 0.1608],
         [0.0186, 0.0083, 0.0733],
         [0.0054, 0.0050, 0.0405]], grad_fn=<SliceBackward>),
 tensor([[0.0095, 0.0086, 0.0761],
         [0.0058, 0.0058, 0.0351],
         [0.0081, 0.0083, 0.0778],
         [0.0134, 0.0083, 0.1844],
         [0.0098, 0.0075, 0.0512],
         [0.0580, 0.0420, 0.3434],
         [0.0072, 0.0261, 0.1366],
         [0.0122, 0.0088, 0.1608],
         [0.0186, 0.0083, 0.0733],
         [0.0054, 0.0050, 0.0405]], grad_fn=<SliceBackward>))

In [172]:
gsb.theta_in_unc[:10], sb.theta_in_unc[:10]

(tensor([[0.0103, 0.0106],
         [0.0104, 0.0106],
         [0.0106, 0.0106],
         [0.0106, 0.0106],
         [0.0106, 0.0106],
         [0.0103, 0.0105],
         [0.0106, 0.0105],
         [0.0106, 0.0106],
         [0.0103, 0.0106],
         [0.0105, 0.0106]], grad_fn=<SliceBackward>),
 tensor([[0.0103, 0.0106],
         [0.0104, 0.0106],
         [0.0106, 0.0106],
         [0.0106, 0.0106],
         [0.0106, 0.0106],
         [0.0103, 0.0105],
         [0.0106, 0.0105],
         [0.0106, 0.0106],
         [0.0103, 0.0106],
         [0.0105, 0.0106]], grad_fn=<SliceBackward>))

In [173]:
gsb.theta_out_unc[:10], sb.theta_out_unc[:10]

(tensor([[0.0862, 0.0857],
         [0.0845, 0.0862],
         [0.0866, 0.0848],
         [0.0848, 0.0865],
         [0.0814, 0.0859],
         [0.0854, 0.0839],
         [0.0865, 0.0810],
         [0.0848, 0.0866],
         [0.0807, 0.0842],
         [0.0849, 0.0866]], grad_fn=<SliceBackward>),
 tensor([[0.0862, 0.0857],
         [0.0845, 0.0862],
         [0.0866, 0.0848],
         [0.0848, 0.0865],
         [0.0814, 0.0859],
         [0.0854, 0.0839],
         [0.0865, 0.0810],
         [0.0848, 0.0866],
         [0.0807, 0.0842],
         [0.0849, 0.0866]], grad_fn=<SliceBackward>))

In [174]:
gsb.dtheta_unc[:10], sb.dtheta_unc[:10]

(tensor([[0.0865, 0.0860],
         [0.0848, 0.0866],
         [0.0869, 0.0852],
         [0.0851, 0.0868],
         [0.0817, 0.0863],
         [0.0857, 0.0843],
         [0.0868, 0.0813],
         [0.0852, 0.0869],
         [0.0810, 0.0846],
         [0.0852, 0.0869]], grad_fn=<SliceBackward>),
 tensor([[0.0865, 0.0860],
         [0.0848, 0.0866],
         [0.0869, 0.0852],
         [0.0851, 0.0868],
         [0.0817, 0.0863],
         [0.0857, 0.0843],
         [0.0868, 0.0813],
         [0.0852, 0.0869],
         [0.0810, 0.0846],
         [0.0852, 0.0869]], grad_fn=<SliceBackward>))

# VolumeWrapper

In [None]:
from functools import partial

In [None]:
volume = Volume(get_layers())

In [None]:
wrapper = VolumeWrapper(volume=volume, res_opt=partial(torch.optim.SGD, lr=2e10), eff_opt=partial(torch.optim.SGD, lr=2e5),
                        loss_func=DetectorLoss(0))

In [None]:
from tomopt.optimisation import MetricLogger

In [None]:
ml = MetricLogger(show_plots=True)

In [None]:
trn_passives = PassiveYielder([arb_rad_length])

In [None]:
from tomopt.optimisation.callbacks.callback import Callback

In [None]:
class ParamCap(Callback):
    def on_volume_batch_begin(self) -> None:
        with torch.no_grad():
            for d in self.wrapper.volume.get_detectors():
                torch.clamp_(d.resolution, min=1, max=1e7)
                torch.clamp_(d.efficiency, min=1e-7, max=1)

In [None]:
%%time
_ = wrapper.fit(25, n_mu_per_volume=1000, mu_bs=100, passive_bs=1, trn_passives=trn_passives, val_passives=trn_passives, cbs=[NoMoreNaNs(),ParamCap(),ml])

In [None]:
for d in volume.get_detectors():
    print(1, d.resolution, d.efficiency)