In [13]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [14]:
from tomopt.muon import generate_batch

In [15]:
import matplotlib.pyplot as plt
import seaborn as sns

In [16]:
from tomopt.muon import MuonBatch

In [17]:
import torch
from torch import Tensor

In [18]:
from tomopt.core import X0

In [19]:
def arb_rad_length(*,z:float, lw:Tensor, size:float) -> float:
    rad_length = torch.ones(list((lw/size).long()))*1e5
    if z >= 0.5 and z <= 0.5: rad_length[...] = X0['lead']#X0['beryllium']
#     if z == 0.6 : rad_length[...] = X0['beryllium']
        
    return rad_length

In [20]:
from tomopt.volume import PassiveLayer, DetectorLayer

In [21]:
import torch.nn.functional as F

In [22]:
def eff_cost(x:Tensor) -> Tensor:
    return torch.expm1(3*F.relu(x))

In [23]:
def res_cost(x:Tensor) -> Tensor:
    return F.relu(x/100)**2

In [24]:
def get_layers():
    layers = []
    lwh = Tensor([1,1,1])
    size = 0.1
    init_eff = 0.5
    init_res = 100000
    pos = 'above'
    for z,d in zip(np.arange(lwh[2],0,-size), [1,1,0,0,0,0,0,0,1,1]):
        if d:
            layers.append(DetectorLayer(pos=pos, init_eff=init_eff, init_res=init_res,
                                        lw=lwh[:2], z=z, size=size, eff_cost_func=eff_cost, res_cost_func=res_cost))
        else:
            pos = 'below'
            layers.append(PassiveLayer(rad_length_func=arb_rad_length, lw=lwh[:2], z=z, size=size))

    return nn.ModuleList(layers) 

In [25]:
import numpy as np
from torch import nn

In [26]:
from tomopt.volume import Volume

In [27]:
batch = MuonBatch(generate_batch(1000), init_z=1)

In [28]:
volume = Volume(get_layers())

In [29]:
volume(batch)

# VMAP

In [30]:
pip show torch

Name: torch
Version: 1.8.1
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: /Users/giles/anaconda3/envs/tomopt/lib/python3.8/site-packages
Requires: typing-extensions, numpy
Required-by: torchvision, torchaudio, tomopt
Note: you may need to restart the kernel to use updated packages.


In [31]:
from torch._vmap_internals import _vmap as vmap

## Dot product

In [32]:
torch.dot                            # [D], [D] -> []
batched_dot = vmap(torch.dot)  # [N, D], [N, D] -> [N]
x, y = torch.randn(2, 5), torch.randn(2, 5)
batched_dot(x, y)

tensor([ 0.8373, -1.5334])

## Vector model

In [33]:
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)

def model(feature_vec):
    # Very simple linear model with activation
    return feature_vec.dot(weights).relu()

examples = torch.randn(batch_size, feature_size)
result = vmap(model)(examples)
result

tensor([0.0000, 0.0000, 3.7810], grad_fn=<ReluBackward0>)

## Jacobian

In [34]:
# Setup
N = 5
f = lambda x: x ** 2
x = torch.randn(N, requires_grad=True)
y = f(x)
I_N = torch.eye(N)

# Sequential approach
jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
                 for v in I_N.unbind()]
jacobian = torch.stack(jacobian_rows)
jacobian

tensor([[ 0.3702,  0.0000,  0.0000,  0.0000, -0.0000],
        [ 0.0000,  0.2037,  0.0000,  0.0000, -0.0000],
        [ 0.0000,  0.0000,  2.1707,  0.0000, -0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0080, -0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -1.6036]])

In [35]:
df_dx = lambda x: 2*x

In [36]:
df_dx(x)

tensor([ 0.3702,  0.2037,  2.1707,  0.0080, -1.6036], grad_fn=<MulBackward0>)

In [37]:
# vectorized gradient computation
def get_vjp(v):
    return torch.autograd.grad(y, x, v, retain_graph=True)[0]
jacobian = vmap(get_vjp)(I_N)
jacobian

tensor([[ 0.3702,  0.0000,  0.0000,  0.0000, -0.0000],
        [ 0.0000,  0.2037,  0.0000,  0.0000, -0.0000],
        [ 0.0000,  0.0000,  2.1707,  0.0000, -0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0080, -0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -1.6036]])

In [38]:
I_N.unbind()

(tensor([1., 0., 0., 0., 0.]),
 tensor([0., 1., 0., 0., 0.]),
 tensor([0., 0., 1., 0., 0.]),
 tensor([0., 0., 0., 1., 0.]),
 tensor([0., 0., 0., 0., 1.]))

In [39]:
x, y

(tensor([ 0.1851,  0.1018,  1.0853,  0.0040, -0.8018], requires_grad=True),
 tensor([3.4271e-02, 1.0372e-02, 1.1780e+00, 1.6123e-05, 6.4291e-01],
        grad_fn=<PowBackward0>))

## Batch-wise grad

In [40]:
N = 5
f = lambda x: x ** 2
x = torch.randn(N, requires_grad=True)
y = f(x)
I_N = torch.eye(N)

%timeit torch.stack([torch.autograd.grad(y,x,v, retain_graph=True)[0] for v in I_N.unbind()]).sum(1)

132 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [41]:
def get_vjp(v): return torch.autograd.grad(y,x,v, retain_graph=True)[0].sum()
vmap(get_vjp)(I_N)

%timeit vmap(get_vjp)(I_N)

86.4 µs ± 3.44 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [42]:
def batchwise_grad(y: Tensor, x: Tensor, create_graph: bool = False, allow_unused: bool = True) -> Tensor:
    def get_vjp(v): return torch.autograd.grad(y, x, v, retain_graph=True, create_graph=create_graph, allow_unused=allow_unused)[0].sum()
    return vmap(get_vjp)(torch.eye(len(x)))

In [43]:
batchwise_grad(y, x)

tensor([ 0.6276,  2.7733, -3.3634, -0.5657,  0.0976])

## Batch-wise jacobian

In [49]:
N = 5
f = lambda x: x ** 2
x = torch.randn(N, requires_grad=True)
y = f(x)
I_N = torch.eye(N)

yb = torch.stack([y for _ in range(5)])

torch.stack([torch.autograd.grad(y,x,v, retain_graph=True)[0] for v in I_N.unbind()])

tensor([[ 4.5066,  0.0000, -0.0000, -0.0000,  0.0000],
        [ 0.0000,  1.5566, -0.0000, -0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.5869, -0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0000, -1.1729,  0.0000],
        [ 0.0000,  0.0000, -0.0000, -0.0000,  5.8094]])

In [50]:
flat_y = yb.reshape(-1)
I_N = torch.eye(len(flat_y))

In [51]:
%%timeit
jac = []
for grad_y in I_N.unbind():
    (grad_x,) = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True)
    jac.append(grad_x.reshape(x.shape))
torch.stack(jac).reshape(yb.shape + x.shape)

1.29 ms ± 14.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [52]:
def get_vjp(v): return torch.autograd.grad(flat_y, x, v, retain_graph=True)[0].reshape(x.shape)
vmap(get_vjp)(I_N).reshape(yb.shape + x.shape)

tensor([[[ 4.5066,  0.0000, -0.0000, -0.0000,  0.0000],
         [ 0.0000,  1.5566, -0.0000, -0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.5869, -0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0000, -1.1729,  0.0000],
         [ 0.0000,  0.0000, -0.0000, -0.0000,  5.8094]],

        [[ 4.5066,  0.0000, -0.0000, -0.0000,  0.0000],
         [ 0.0000,  1.5566, -0.0000, -0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.5869, -0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0000, -1.1729,  0.0000],
         [ 0.0000,  0.0000, -0.0000, -0.0000,  5.8094]],

        [[ 4.5066,  0.0000, -0.0000, -0.0000,  0.0000],
         [ 0.0000,  1.5566, -0.0000, -0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.5869, -0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0000, -1.1729,  0.0000],
         [ 0.0000,  0.0000, -0.0000, -0.0000,  5.8094]],

        [[ 4.5066,  0.0000, -0.0000, -0.0000,  0.0000],
         [ 0.0000,  1.5566, -0.0000, -0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.5869, -0.00

In [53]:
def batchwise_jacobian(y: Tensor, x: Tensor, create_graph: bool = False, allow_unused: bool = True) -> Tensor:
    flat_y = y.reshape(-1)

    def get_vjp(v): return torch.autograd.grad(flat_y, x, v, retain_graph=True, create_graph=create_graph, allow_unused=allow_unused)[0].reshape(x.shape)
    
    return vmap(get_vjp)(torch.eye(len(flat_y))).reshape(y.shape + x.shape)

In [54]:
batchwise_jacobian(y, x)

tensor([[ 4.5066,  0.0000, -0.0000, -0.0000,  0.0000],
        [ 0.0000,  1.5566, -0.0000, -0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.5869, -0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0000, -1.1729,  0.0000],
        [ 0.0000,  0.0000, -0.0000, -0.0000,  5.8094]])

In [55]:
%timeit batchwise_jacobian(yb, x)

107 µs ± 1.89 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


# Scatter

In [63]:
from tomopt.inference import ScatterBatch

In [84]:
def for_jacobian(y: Tensor, x: Tensor, create_graph: bool = False, allow_unused: bool = True) -> Tensor:
    jac = []
    flat_y = y.reshape(-1)
    for grad_y in torch.eye(len(flat_y)).unbind():
        (grad_x,) = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, allow_unused=allow_unused)
        jac.append(grad_x.reshape(x.shape))
    return torch.stack(jac).reshape(y.shape + x.shape)

In [85]:
class ForScatterBatch(ScatterBatch):
    def compute_scatters(self) -> None:
        r"""
        Currently only handles 2 detectors above and below passive volume
        Scatter locations adapted from:
        @MISC {3334866,
            TITLE = {Closest points between two lines},
            AUTHOR = {Brian (https://math.stackexchange.com/users/72614/brian)},
            HOWPUBLISHED = {Mathematics Stack Exchange},
            NOTE = {URL:https://math.stackexchange.com/q/3334866 (version: 2019-08-26)},
            EPRINT = {https://math.stackexchange.com/q/3334866},
            URL = {https://math.stackexchange.com/q/3334866}
        }
        """

        # self.hits in layers
        xa0 = torch.cat([self.hits["above"]["xy"][:, 0], self.hits["above"]["z"][:, 0]], dim=-1)  # reco x, reco y, gen z
        xa1 = torch.cat([self.hits["above"]["xy"][:, 1], self.hits["above"]["z"][:, 1]], dim=-1)
        xb0 = torch.cat([self.hits["below"]["xy"][:, 1], self.hits["below"]["z"][:, 1]], dim=-1)
        xb1 = torch.cat([self.hits["below"]["xy"][:, 0], self.hits["below"]["z"][:, 0]], dim=-1)

        dets = self.volume.get_detectors()
        res = []
        for p, l, i in zip(("above", "above", "below", "below"), dets, (0, 1, 0, 1)):
            x = l.abs2idx(self.hits[p]["xy"][:, i])
            res.append(l.resolution[x[:, 0], x[:, 1]])
        res2 = torch.stack(res, dim=1)[:, :, None] ** 2

        # Extrapolate muon-path vectors from self.hits
        v1 = xa1 - xa0
        v2 = xb1 - xb0

        # scatter locations
        v3 = torch.cross(v1, v2, dim=1)  # connecting vector perpendicular to both lines
        rhs = xb0 - xa0
        lhs = torch.stack([v1, -v2, v3], dim=1).transpose(2, 1)
        coefs = torch.linalg.solve(lhs, rhs)  # solve p1+t1*v1 + t3*v3 = p2+t2*v2 => p2-p1 = t1*v1 - t2*v2 + t3*v3

        q1 = xa0 + (coefs[:, 0:1] * v1)  # closest point on v1
        self._loc = q1 + (coefs[:, 2:3] * v3 / 2)  # Move halfway along v3 from q1

        # Theta deviations
        self._theta_in = torch.arctan(v1[:, :2] / v1[:, 2:3])
        self._theta_out = torch.arctan(v2[:, :2] / v2[:, 2:3])
        self._dtheta = torch.abs(self._theta_in - self._theta_out)

        # xy deviations
        self._dxy = coefs[:, 2:3] * v3[:, :2]

        # loc uncertainty
        dloc_dres = torch.stack([for_jacobian(self._loc, l.resolution).sum((-1, -2)) for l in dets], dim=1)
        self._loc_unc = torch.sqrt((dloc_dres.pow(2) * res2).sum(1))

        # dtheta uncertainty
        ddtheta_dres = torch.stack([for_jacobian(self._dtheta, l.resolution).sum((-1, -2)) for l in dets], dim=1)
        self._dtheta_unc = torch.sqrt((ddtheta_dres.pow(2) * res2).sum(1))

        # dxy uncertainty
        ddxy_dres = torch.stack([for_jacobian(self._dxy, l.resolution).sum((-1, -2)) for l in dets], dim=1)
        self._dxy_unc = torch.sqrt((ddxy_dres.pow(2) * res2).sum(1))

        # theta_in uncertainty
        dtheta_in_dres = torch.stack([for_jacobian(self._theta_in, l.resolution).sum((-1, -2)) for l in dets[:2]], dim=1)
        self._theta_in_unc = torch.sqrt((dtheta_in_dres.pow(2) * res2[:, :2]).sum(1))

        # theta_out uncertainty
        dtheta_out_dres = torch.stack([for_jacobian(self._theta_out, l.resolution).sum((-1, -2)) for l in dets[2:]], dim=1)
        self._theta_out_unc = torch.sqrt((dtheta_out_dres.pow(2) * res2[:, 2:]).sum(1))

In [86]:
%time scatters = ForScatterBatch(batch, volume)

CPU times: user 19.9 s, sys: 768 ms, total: 20.6 s
Wall time: 20 s


## Batchwise grad scatter

In [87]:
class BWScatterBatch(ScatterBatch):
    def compute_scatters(self) -> None:
        r"""
        Currently only handles 2 detectors above and below passive volume

        Scatter locations adapted from:
        @MISC {3334866,
            TITLE = {Closest points between two lines},
            AUTHOR = {Brian (https://math.stackexchange.com/users/72614/brian)},
            HOWPUBLISHED = {Mathematics Stack Exchange},
            NOTE = {URL:https://math.stackexchange.com/q/3334866 (version: 2019-08-26)},
            EPRINT = {https://math.stackexchange.com/q/3334866},
            URL = {https://math.stackexchange.com/q/3334866}
        }
        """

        # self.hits in layers
        xa0 = torch.cat([self.hits["above"]["xy"][:, 0], self.hits["above"]["z"][:, 0]], dim=-1)  # reco x, reco y, gen z
        xa1 = torch.cat([self.hits["above"]["xy"][:, 1], self.hits["above"]["z"][:, 1]], dim=-1)
        xb0 = torch.cat([self.hits["below"]["xy"][:, 1], self.hits["below"]["z"][:, 1]], dim=-1)
        xb1 = torch.cat([self.hits["below"]["xy"][:, 0], self.hits["below"]["z"][:, 0]], dim=-1)

        dets = self.volume.get_detectors()
        res = []
        for p, l, i in zip(("above", "above", "below", "below"), dets, (0, 1, 0, 1)):
            x = l.abs2idx(self.hits[p]["xy"][:, i])
            res.append(l.resolution[x[:, 0], x[:, 1]])
        res2 = torch.stack(res, dim=1)[:, :, None] ** 2

        # Extrapolate muon-path vectors from self.hits
        v1 = xa1 - xa0
        v2 = xb1 - xb0

        # scatter locations
        v3 = torch.cross(v1, v2, dim=1)  # connecting vector perpendicular to both lines
        rhs = xb0 - xa0
        lhs = torch.stack([v1, -v2, v3], dim=1).transpose(2, 1)
        coefs = torch.linalg.solve(lhs, rhs)  # solve p1+t1*v1 + t3*v3 = p2+t2*v2 => p2-p1 = t1*v1 - t2*v2 + t3*v3

        q1 = xa0 + (coefs[:, 0:1] * v1)  # closest point on v1
        self._loc = q1 + (coefs[:, 2:3] * v3 / 2)  # Move halfway along v3 from q1
        
        
        # Theta deviations
        self._theta_in = torch.arctan(v1[:, :2] / v1[:, 2:3])
        self._theta_out = torch.arctan(v2[:, :2] / v2[:, 2:3])
        self._dtheta = torch.abs(self._theta_in - self._theta_out)

        # xy deviations
        self._dxy = coefs[:, 2:3] * v3[:, :2]
        
        # loc uncertainty
        dloc_dres = torch.stack([batchwise_jacobian(self._loc, l.resolution).sum((-1, -2)) for l in dets], dim=1)
        self._loc_unc = torch.sqrt((dloc_dres.pow(2) * res2).sum(1))

        # dtheta uncertainty
        ddtheta_dres = torch.stack([batchwise_jacobian(self._dtheta, l.resolution).sum((-1, -2)) for l in dets], dim=1)
        self._dtheta_unc = torch.sqrt((ddtheta_dres.pow(2) * res2).sum(1))

        # dxy uncertainty
        ddxy_dres = torch.stack([batchwise_jacobian(self._dxy, l.resolution).sum((-1, -2)) for l in dets], dim=1)
        self._dxy_unc = torch.sqrt((ddxy_dres.pow(2) * res2).sum(1))

        # theta_in uncertainty
        dtheta_in_dres = torch.stack([batchwise_jacobian(self._theta_in, l.resolution).sum((-1, -2)) for l in dets[:2]], dim=1)
        self._theta_in_unc = torch.sqrt((dtheta_in_dres.pow(2) * res2[:, :2]).sum(1))

        # theta_out uncertainty
        dtheta_out_dres = torch.stack([batchwise_jacobian(self._theta_out, l.resolution).sum((-1, -2)) for l in dets[2:]], dim=1)
        self._theta_out_unc = torch.sqrt((dtheta_out_dres.pow(2) * res2[:, 2:]).sum(1))

In [88]:
%time scatters = BWScatterBatch(batch, volume)

CPU times: user 12.3 s, sys: 1.5 s, total: 13.8 s
Wall time: 9.45 s
