# BM1 GNN Pred

In [1]:
from typing import *
import numpy as np
from functools import partial

import torch
from torch import Tensor, nn
import torch.nn.functional as F

from tomopt.volume import *
from tomopt.muon import *
from tomopt.inference import *
from tomopt.optimisation import *
from tomopt.core import *
from tomopt.utils import *

In [2]:
DEVICE = torch.device("cpu")

In [3]:
from tomopt.volume.layer import Layer

def get_volume(size: float = 0.2, lwh: Tensor = Tensor([1.0, 1.0, 1.4]), device: torch.device = torch.device("cpu")) -> Volume:
    def area_cost(x: Tensor) -> Tensor:
        return F.relu(x)

    layers: List[Layer] = []
    n_panels = 4
    layers.append(
        PanelDetectorLayer(
            pos="above",
            lw=lwh[:2],
            z=lwh[2].item(),
            size=size,
            panels=[
                DetectorPanel(
                    res=1e3, eff=1, init_xyz=(lwh[0].item()/2, lwh[1].item()/2, lwh[2].item() - (i * (size) / n_panels)), init_xy_span=(lwh[0].item(), lwh[1].item()), area_cost_func=area_cost, device=device
                )
                for i in range(n_panels)
            ],
        )
    )
    for z in np.round(np.arange(lwh[2] - size, size, -size), decimals=1):
        layers.append(PassiveLayer(lw=lwh[:2], z=z, size=size, device=device))
    layers.append(
        PanelDetectorLayer(
            pos="below",
            lw=lwh[:2],
            z=size,
            size=size,
            panels=[
                DetectorPanel(
                    res=1e3, eff=1, init_xyz=(lwh[0].item()/2, lwh[1].item()/2, size - (i * (size) / n_panels)), init_xy_span=(lwh[0].item(), lwh[1].item()), area_cost_func=area_cost, device=device
                )
                for i in range(n_panels)
            ],
        )
    )

    return Volume(nn.ModuleList(layers))

In [4]:
volume = get_volume(device=DEVICE)

In [5]:
muons = MuonBatch(generate_batch(250), init_z=volume.h, device=DEVICE)

In [6]:
gen = RandomBlockPassiveGenerator(
        block_size=None, volume=volume, sort_x0=False, enforce_diff_mat=True, materials=["beryllium", "lead"], block_size_max_half=False
    )

In [7]:
volume.load_rad_length(*gen.get_data())

In [8]:
volume(muons)

In [9]:
volume.target

tensor([0.3528])

In [10]:
sb = ScatterBatch(muons, volume)

## GNN inferer

In [11]:
model = torch.jit.load('../../mode_muon_tomo_inference/dev/exported_models/bm1_traced.pt')

In [12]:
type(model)

torch.jit._script.RecursiveScriptModule

In [13]:
dvi = DeepVolumeInferer(model=model, base_inferer=PanelX0Inferer(volume), volume=volume)

In [14]:
dvi.in_vars, dvi.in_var_uncs

([], [])

In [15]:
dvi.add_scatters(sb)

  idxs = torch.combinations(torch.arange(0, unc.shape[-1]), with_replacement=True)


In [16]:
dvi.in_vars[0].shape, dvi.in_var_uncs[0].shape

(torch.Size([250, 8]), torch.Size([250, 8]))

In [17]:
%%time
p,w = dvi.get_prediction()

CPU times: user 4.23 s, sys: 354 ms, total: 4.59 s
Wall time: 1.94 s


In [18]:
p.shape

torch.Size([1, 1])

In [19]:
p

tensor([[1.4020e-07]], grad_fn=<SigmoidBackward0>)

In [20]:
%%time
# jacobian(p, dvi.in_var)  # crashes

CPU times: user 2 µs, sys: 1e+03 ns, total: 3 µs
Wall time: 5.96 µs


## Loss

In [21]:
x02id = {X0[m]: i for i, m in enumerate(gen.materials)}; x02id

{0.3528: 0, 0.005612: 1}

In [22]:
p.shape

torch.Size([1, 1])

In [23]:
volume.target.shape

torch.Size([1])

In [24]:
loss = VolumeClassLoss(x02id=x02id, target_budget=None)

In [25]:
l = loss(p, 1, volume); l

Automatically setting cost coefficient to 1.1920928955078125e-07


tensor([1.1921e-07], grad_fn=<AddBackward0>)

In [26]:
jacobian(l, volume.get_detectors()[0].panels[0].xy)

tensor([[-1.0332e-07,  6.9785e-08]])

In [27]:
jacobian(l, volume.get_detectors()[0].panels[0].xy_span)

tensor([[-3.6483e-08, -1.2915e-07]])

In [28]:
jacobian(l, volume.get_detectors()[0].panels[0].z)

tensor([[-0.0001]])