In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from tomopt.muon import *
from tomopt.inference import *
from tomopt.volume import *
from tomopt.core import *
from tomopt.optimisation import *

import matplotlib.pyplot as plt
import seaborn as sns
from typing import *
import numpy as np

import torch
from torch import Tensor, nn
import torch.nn.functional as F

# Basics

In [3]:
def arb_rad_length(*,z:float, lw:Tensor, size:float) -> float:
    rad_length = torch.ones(list((lw/size).long()))*X0['beryllium']
    if z >= 0.4 and z <= 0.5: rad_length[5:,5:] = X0['lead']
    return rad_length

In [4]:
def eff_cost(x:Tensor) -> Tensor:
    return torch.expm1(3*F.relu(x))

In [5]:
def res_cost(x:Tensor) -> Tensor:
    return F.relu(x/100)**2

In [67]:
def get_layers():
    layers = []
    lwh = Tensor([1,1,1])
    size = 0.1
    init_eff = 0.5
    init_res = 1000
    pos = 'above'
    for z,d in zip(np.arange(lwh[2],0,-size), [1,1,0,0,0,0,0,0,1,1]):
        if d:
            layers.append(DetectorLayer(pos=pos, init_eff=init_eff, init_res=z*init_res,
                                        lw=lwh[:2], z=z, size=size, eff_cost_func=eff_cost, res_cost_func=res_cost))
        else:
            pos = 'below'
            layers.append(PassiveLayer(rad_length_func=arb_rad_length, lw=lwh[:2], z=z, size=size))

    return nn.ModuleList(layers) 

In [68]:
volume = Volume(get_layers())

# Scatter inference

In [69]:
muons = MuonBatch(generate_batch(100), init_z=1.0)

In [70]:
volume(muons)

In [71]:
class GenScatterBatch(ScatterBatch):
    @staticmethod
    def get_muon_trajectory(hits:Tensor, uncs=Tensor) -> Tensor:
        r'''
        hits = (muons,detector,(x,y,z))
        uncs = (muons,detector,(unc,unc,0))
        
        Assume same unceratinty for x and y
        '''
        
        inv_unc2 = uncs[:,:,0:1]**-2
        sum_inv_unc2 = inv_unc2.sum(dim=1)
        mean_xyz = torch.sum(hits*inv_unc2, dim=1)/sum_inv_unc2
        mean_xyz_z = torch.sum(hits*hits[:,:,2:3]*inv_unc2, dim=1)/sum_inv_unc2
        mean_xy = mean_xyz[:,:2]
        mean_z = mean_xyz[:,2:3]
        mean_xy_z = mean_xyz_z[:,:2]
        mean_z2 = mean_xyz_z[:,2:3]
        
        xy_star = (mean_xy-((mean_z*mean_xy_z)/mean_z2))/(1-(mean_z.square()/mean_z2))
        angles = (mean_xy_z-(xy_star*mean_z))/mean_z2
        
        calc_xyz = lambda z: torch.cat([xy_star+(angles*z), z], dim=-1)
        return calc_xyz(hits[:,1,2:3])-calc_xyz(hits[:,0,2:3])
        
    def compute_scatters(self) -> None:
        r"""
        Currently only handles 2 detectors above and below passive volume
        """

        self.xa0 = torch.cat([self.hits["above"]["xy"][:, 0], self.hits["above"]["z"][:, 0]], dim=-1)  # reco x, reco y, gen z
        self.xa1 = torch.cat([self.hits["above"]["xy"][:, 1], self.hits["above"]["z"][:, 1]], dim=-1)
        self.xb0 = torch.cat([self.hits["below"]["xy"][:, 1], self.hits["below"]["z"][:, 1]], dim=-1)
        self.xb1 = torch.cat([self.hits["below"]["xy"][:, 0], self.hits["below"]["z"][:, 0]], dim=-1)

        dets = self.volume.get_detectors()
        res = []
        for p, l, i in zip(("above", "above", "below", "below"), dets, (0, 1, 1, 0)):
            x = l.abs2idx(self.hits[p]["xy"][:, i])
            r = 1 / l.resolution[x[:, 0], x[:, 1]]
            res.append(torch.stack([r, r, torch.zeros_like(r)], dim=-1))
        self._hit_unc = torch.stack(res, dim=1)
        
        v1 = self.get_muon_trajectory(torch.stack([self.xa0, self.xa1], dim=1), uncs=self._hit_unc[:,:2])
        v2 = self.get_muon_trajectory(torch.stack([self.xb0, self.xb1], dim=1), uncs=self._hit_unc[:,2:])
        
        # scatter locations
        v3 = torch.cross(v1, v2, dim=1)  # connecting vector perpendicular to both lines
        rhs = self.xb0 - self.xa0
        lhs = torch.stack([v1, -v2, v3], dim=1).transpose(2, 1)
        coefs = torch.linalg.solve(lhs, rhs)  # solve p1+t1*v1 + t3*v3 = p2+t2*v2 => p2-p1 = t1*v1 - t2*v2 + t3*v3

        q1 = self.xa0 + (coefs[:, 0:1] * v1)  # closest point on v1
        self._loc = q1 + (coefs[:, 2:3] * v3 / 2)  # Move halfway along v3 from q1
        self._loc_unc: Optional[Tensor] = None

        # Theta deviations
        self._theta_in = torch.arctan(v1[:, :2] / v1[:, 2:3])
        self._theta_out = torch.arctan(v2[:, :2] / v2[:, 2:3])
        self._dtheta = torch.abs(self._theta_in - self._theta_out)
        self._theta_in_unc: Optional[Tensor] = None
        self._theta_out_unc: Optional[Tensor] = None
        self._dtheta_unc: Optional[Tensor] = None

        # xy deviations
        self._dxy = coefs[:, 2:3] * v3[:, :2]
        self._dxy_unc: Optional[Tensor] = None

In [72]:
gsb = GenScatterBatch(muons, volume)

In [73]:
sb = ScatterBatch(muons, volume)

In [74]:
gsb.location[:10],sb.location[:10]

(tensor([[0.1952, 0.7033, 0.4431],
         [0.7786, 0.3759, 0.3105],
         [0.1156, 0.8766, 0.3230],
         [0.8084, 0.2616, 0.2662],
         [0.6475, 0.0482, 0.1029],
         [0.9144, 0.1843, 0.2169],
         [0.8951, 0.5828, 0.1589],
         [0.0338, 0.7231, 0.1495],
         [0.5464, 0.2514, 0.1051],
         [0.3565, 0.4568, 0.1339]], grad_fn=<SliceBackward>),
 tensor([[0.1952, 0.7033, 0.4434],
         [0.7786, 0.3759, 0.3104],
         [0.1156, 0.8766, 0.3231],
         [0.8084, 0.2616, 0.2661],
         [0.6475, 0.0482, 0.1029],
         [0.9144, 0.1843, 0.2169],
         [0.8951, 0.5828, 0.1588],
         [0.0338, 0.7231, 0.1492],
         [0.5464, 0.2514, 0.1051],
         [0.3565, 0.4568, 0.1340]], grad_fn=<SliceBackward>))

In [75]:
gsb.location_unc[:10],sb.location_unc[:10]

(tensor([[0.0206, 0.1018, 1.0822],
         [0.0117, 0.0111, 0.2057],
         [0.0134, 0.0117, 0.1649],
         [0.0110, 0.0146, 0.1119],
         [0.0096, 0.0062, 0.0399],
         [0.0118, 0.0177, 0.1483],
         [0.0136, 0.0142, 0.0998],
         [0.0066, 0.0181, 0.2260],
         [0.0139, 0.0093, 0.0905],
         [0.0182, 0.0058, 0.1185]], grad_fn=<SliceBackward>),
 tensor([[0.0206, 0.1018, 1.0818],
         [0.0117, 0.0111, 0.2057],
         [0.0134, 0.0117, 0.1649],
         [0.0110, 0.0146, 0.1119],
         [0.0096, 0.0062, 0.0399],
         [0.0118, 0.0177, 0.1483],
         [0.0136, 0.0142, 0.0997],
         [0.0066, 0.0182, 0.2260],
         [0.0139, 0.0093, 0.0905],
         [0.0181, 0.0058, 0.1185]], grad_fn=<SliceBackward>))

# VolumeWrapper

In [8]:
from functools import partial

In [9]:
volume = Volume(get_layers())

In [10]:
wrapper = VolumeWrapper(volume=volume, res_opt=partial(torch.optim.SGD, lr=2e10), eff_opt=partial(torch.optim.SGD, lr=2e5),
                        loss_func=DetectorLoss(0))

In [None]:
from tomopt.optimisation import MetricLogger

In [None]:
ml = MetricLogger(show_plots=True)

In [None]:
trn_passives = PassiveYielder([arb_rad_length])

In [None]:
from tomopt.optimisation.callbacks.callback import Callback

In [None]:
class ParamCap(Callback):
    def on_volume_batch_begin(self) -> None:
        with torch.no_grad():
            for d in self.wrapper.volume.get_detectors():
                torch.clamp_(d.resolution, min=1, max=1e7)
                torch.clamp_(d.efficiency, min=1e-7, max=1)

In [None]:
%%time
_ = wrapper.fit(25, n_mu_per_volume=1000, mu_bs=100, passive_bs=1, trn_passives=trn_passives, val_passives=trn_passives, cbs=[NoMoreNaNs(),ParamCap(),ml])

In [None]:
for d in volume.get_detectors():
    print(1, d.resolution, d.efficiency)