# BM3 GNN Pred

In [1]:
from typing import *
import numpy as np
from functools import partial

import torch
from torch import Tensor, nn
import torch.nn.functional as F

from tomopt.volume import *
from tomopt.muon import *
from tomopt.inference import *
from tomopt.optimisation import *

In [2]:
DEVICE = torch.device("cpu")

In [3]:
from tomopt.volume.layer import Layer

def get_volume(size: float = 0.2, lwh: Tensor = Tensor([1.0, 1.0, 1.4]), device: torch.device = torch.device("cpu")) -> Volume:
    def area_cost(x: Tensor) -> Tensor:
        return F.relu(x)

    layers: List[Layer] = []
    n_panels = 4
    layers.append(
        PanelDetectorLayer(
            pos="above",
            lw=lwh[:2],
            z=lwh[2].item(),
            size=size,
            panels=[
                DetectorPanel(
                    res=1e3, eff=1, init_xyz=(0.5, 0.5, 1 - (i * (size) / n_panels)), init_xy_span=(1.0, 1.0), area_cost_func=area_cost, device=device
                )
                for i in range(n_panels)
            ],
        )
    )
    for z in np.round(np.arange(lwh[2] - size, size, -size), decimals=1):
        layers.append(PassiveLayer(lw=lwh[:2], z=z, size=size, device=device))
    layers.append(
        PanelDetectorLayer(
            pos="below",
            lw=lwh[:2],
            z=size,
            size=size,
            panels=[
                DetectorPanel(
                    res=1e3, eff=1, init_xyz=(0.5, 0.5, 0.2 - (i * (size) / n_panels)), init_xy_span=(1.0, 1.0), area_cost_func=area_cost, device=device
                )
                for i in range(n_panels)
            ],
        )
    )

    return Volume(nn.ModuleList(layers))

In [4]:
volume = get_volume(device=DEVICE)

In [5]:
muons = MuonBatch(generate_batch(250), init_z=volume.h, device=DEVICE)

In [6]:
volume(muons)

In [7]:
sb = PanelScatterBatch(muons, volume)

## GNN inferer

In [8]:
model = torch.jit.load('../../mode_muon_tomo_inference/dev/exported_models/bm3_traced.pt')

In [9]:
type(model)

torch.jit._script.RecursiveScriptModule

In [58]:
from tomopt.inference.volume import AbsVolumeInferer, AbsX0Inferer
from tomopt.inference.scattering import AbsScatterBatch


class DeepVolumeInferer(AbsVolumeInferer):
    def __init__(self, model: Union[torch.jit._script.RecursiveScriptModule, nn.Module], base_inferer:AbsX0Inferer, volume: Volume):
        super().__init__(volume=volume)
        self.model,self.base_inferer = model,base_inferer
        self.voxel_centres = self._build_centres()
        
        self.in_vars: List[Tensor] = []
        self.in_var_uncs: List[Tensor] = []
        self.efficiencies: List[Tensor] = []
        self.in_var: Optional[Tensor] = None
        self.in_var_unc: Optional[Tensor] = None
        self.efficiency: Optional[Tensor] = None
        
    def compute_efficiency(self, scatters:AbsScatterBatch) -> Tensor:
        return self.base_inferer.compute_efficiency(scatters=scatters)
    
    def get_base_predictions(self, scatters:AbsScatterBatch) -> Tensor:
        x, u = self.base_inferer.x0_from_dtheta(scatters=scatters)
        return x[:, None], u[:, None]
    
    def add_scatters(self, scatters: AbsScatterBatch) -> None:
        super().add_scatters(scatters=scatters)
        x0, x0_unc = self.get_base_predictions(scatters)
        self.in_vars.append(torch.cat((sb.dtheta, sb.dxy, x0, sb.location), dim=-1))
        self.in_var_uncs.append(torch.cat((sb.dtheta_unc, sb.dxy_unc, x0_unc, sb.location_unc), dim=-1))
        self.efficiencies.append(self.compute_efficiency(scatters=scatters))
        
    def _build_centres(self) -> Tensor:
        bounds = (
            self.volume.passive_size
            * np.mgrid[
                round(self.volume.get_passive_z_range()[0].detach().cpu().numpy()[0] / self.volume.passive_size) : round(
                    self.volume.get_passive_z_range()[1].detach().cpu().numpy()[0] / self.volume.passive_size
                ) : 1,
                0 : round(self.volume.lw.detach().cpu().numpy()[0] / self.volume.passive_size) : 1,
                0 : round(self.volume.lw.detach().cpu().numpy()[1] / self.volume.passive_size) : 1,
            ]
        )
        #         bounds[0] = np.flip(bounds[0])  # z is reversed
        return torch.tensor(bounds.reshape(3, -1).transpose(-1, -2), dtype=torch.float32) + (self.volume.passive_size / 2)
        
    def _build_inputs(self, in_var:Tensor) -> Tensor:
        data = in_var[None, :].repeat_interleave(len(self.voxel_centres), dim=0)
        data[:, :, -3:] -= self.voxel_centres[:, None].repeat_interleave(len(in_var), dim=1)
        data = torch.cat((data, torch.norm(data[:, :, -3:], dim=-1, keepdim=True)), dim=-1)  # dR
        return data
    
    def _get_weight(self) -> Tensor:
        '''Maybe alter this to include resoltuon/pred uncertainties'''
        return self.efficiency
        
    def get_prediction(self) -> Tuple[Optional[Tensor], Optional[Tensor]]:
        self.in_var = torch.cat(self.in_vars, dim=0)
        self.in_var_unc = torch.cat(self.in_var_uncs, dim=0)
        self.efficiency = torch.cat(self.efficiencies, dim=0)
        
        inputs = self._build_inputs(self.in_var)
        pred = self.model(inputs[None])
        weight = self._get_weight()
        return pred, weight

In [73]:
dvi = DeepVolumeInferer(model=model, base_inferer=PanelX0Inferer(volume), volume=volume)

In [74]:
dvi.in_vars, dvi.in_var_uncs

([], [])

In [75]:
dvi.add_scatters(sb)

In [76]:
dvi.in_vars[0].shape, dvi.in_var_uncs[0].shape

(torch.Size([250, 8]), torch.Size([250, 8]))

In [77]:
%%time
p,w = dvi.get_prediction()

CPU times: user 53.4 s, sys: 623 ms, total: 54 s
Wall time: 1.98 s


In [78]:
torch.exp(p).shape

torch.Size([1, 5, 125])