# BM3 GNN Pred

In [1]:
from typing import *
import numpy as np
from functools import partial

import torch
from torch import Tensor, nn
import torch.nn.functional as F

from tomopt.volume import *
from tomopt.muon import *
from tomopt.inference import *
from tomopt.optimisation import *
from tomopt.core import *
from tomopt.utils import *

In [2]:
DEVICE = torch.device("cpu")

In [3]:
from tomopt.volume.layer import Layer

def get_volume(size: float = 0.2, lwh: Tensor = Tensor([1.0, 1.0, 1.4]), device: torch.device = torch.device("cpu")) -> Volume:
    def area_cost(x: Tensor) -> Tensor:
        return F.relu(x)

    layers: List[Layer] = []
    n_panels = 4
    layers.append(
        PanelDetectorLayer(
            pos="above",
            lw=lwh[:2],
            z=lwh[2].item(),
            size=size,
            panels=[
                DetectorPanel(
                    res=1e3, eff=1, init_xyz=(0.5, 0.5, 1 - (i * (size) / n_panels)), init_xy_span=(1.0, 1.0), area_cost_func=area_cost, device=device
                )
                for i in range(n_panels)
            ],
        )
    )
    for z in np.round(np.arange(lwh[2] - size, size, -size), decimals=1):
        layers.append(PassiveLayer(lw=lwh[:2], z=z, size=size, device=device))
    layers.append(
        PanelDetectorLayer(
            pos="below",
            lw=lwh[:2],
            z=size,
            size=size,
            panels=[
                DetectorPanel(
                    res=1e3, eff=1, init_xyz=(0.5, 0.5, 0.2 - (i * (size) / n_panels)), init_xy_span=(1.0, 1.0), area_cost_func=area_cost, device=device
                )
                for i in range(n_panels)
            ],
        )
    )

    return Volume(nn.ModuleList(layers))

In [4]:
volume = get_volume(device=DEVICE)

In [5]:
muons = MuonBatch(generate_batch(250), init_z=volume.h, device=DEVICE)

In [6]:
gen = RandomBlockPassiveGenerator(
        block_size=None, volume=volume, sort_x0=False, enforce_diff_mat=True, materials=["beryllium", "carbon", "silicon", "iron", "lead"], block_size_max_half=False
    )

In [7]:
volume.load_rad_length(*gen.get_data())

In [8]:
volume(muons)

In [9]:
volume.target

tensor([0.0056])

In [10]:
sb = PanelScatterBatch(muons, volume)

## GNN inferer

In [11]:
model = torch.jit.load('../../mode_muon_tomo_inference/dev/exported_models/bm3_traced.pt')

In [12]:
type(model)

torch.jit._script.RecursiveScriptModule

In [13]:
dvi = DeepVolumeInferer(model=model, base_inferer=PanelX0Inferer(volume), volume=volume)

In [14]:
dvi.in_vars, dvi.in_var_uncs

([], [])

In [15]:
dvi.add_scatters(sb)

  idxs = torch.combinations(torch.arange(0, unc.shape[-1]), with_replacement=True)


In [16]:
dvi.in_vars[0].shape, dvi.in_var_uncs[0].shape

(torch.Size([250, 8]), torch.Size([250, 8]))

In [17]:
%%time
p,w = dvi.get_prediction()

CPU times: user 3.65 s, sys: 307 ms, total: 3.96 s
Wall time: 1.52 s


In [18]:
torch.exp(p).shape

torch.Size([1, 5, 125])

In [19]:
p

tensor([[[-5.9334, -5.2389, -4.9669, -5.2601, -5.9614, -5.1603, -4.2639,
          -3.9750, -4.2559, -5.1431, -4.8196, -3.9492, -3.6674, -4.0147,
          -4.8703, -5.1372, -4.3484, -4.0858, -4.4049, -5.1515, -5.8525,
          -5.3695, -5.0904, -5.3317, -5.7418, -4.7668, -3.8202, -3.4970,
          -3.8370, -4.8650, -3.7280, -2.7794, -2.5027, -2.8293, -3.7898,
          -3.4185, -2.4889, -2.2501, -2.5520, -3.4895, -3.7135, -2.8166,
          -2.5475, -2.8892, -3.8609, -4.6608, -3.8556, -3.5590, -3.8882,
          -4.8094, -4.1684, -3.1392, -2.8092, -3.1621, -4.2987, -3.0351,
          -2.1529, -1.9421, -2.1903, -3.0990, -2.7062, -1.9377, -1.7589,
          -1.9830, -2.7976, -2.9534, -2.1705, -1.9696, -2.2111, -3.1589,
          -3.8886, -2.9935, -2.7427, -3.1113, -4.1781, -4.4804, -3.3793,
          -2.9945, -3.3840, -4.6269, -3.2293, -2.2146, -1.9553, -2.2281,
          -3.3129, -2.7035, -1.9336, -1.7279, -1.9526, -2.9055, -2.9598,
          -2.0890, -1.9312, -2.1880, -3.2784, -4.17

In [20]:
%%time
# jacobian(p, dvi.in_var)  # crashes

CPU times: user 3 µs, sys: 1e+03 ns, total: 4 µs
Wall time: 5.01 µs


## Loss

In [21]:
x02id = {X0[m]: i for i, m in enumerate(gen.materials)}; x02id

{0.3528: 0, 0.1932: 1, 0.0937: 2, 0.01757: 3, 0.005612: 4}

In [22]:
loss = VoxelClassLoss(x02id=x02id, target_budget=None)

In [23]:
l = loss(p, 1, volume); l

Automatically setting cost coefficient to 0.7592974305152893


tensor([0.7593], grad_fn=<AddBackward0>)

In [24]:
jacobian(l, volume.get_detectors()[0].panels[0].xy)

tensor([[0.0550, 0.0775]])

In [25]:
jacobian(l, volume.get_detectors()[0].panels[0].xy_span)

tensor([[-0.0265,  0.1094]])

In [26]:
jacobian(l, volume.get_detectors()[0].panels[0].z)

tensor([[-13.3966]])