# BM3 GNN Pred

In [1]:
from typing import *
import numpy as np
from functools import partial

import torch
from torch import Tensor, nn
import torch.nn.functional as F

from tomopt.volume import *
from tomopt.muon import *
from tomopt.inference import *
from tomopt.optimisation import *

In [2]:
DEVICE = torch.device("cpu")

In [3]:
from tomopt.volume.layer import Layer

def get_volume(size: float = 0.2, lwh: Tensor = Tensor([1.0, 1.0, 1.4]), device: torch.device = torch.device("cpu")) -> Volume:
    def area_cost(x: Tensor) -> Tensor:
        return F.relu(x)

    layers: List[Layer] = []
    n_panels = 4
    layers.append(
        PanelDetectorLayer(
            pos="above",
            lw=lwh[:2],
            z=lwh[2].item(),
            size=size,
            panels=[
                DetectorPanel(
                    res=1e3, eff=1, init_xyz=(0.5, 0.5, 1 - (i * (size) / n_panels)), init_xy_span=(1.0, 1.0), area_cost_func=area_cost, device=device
                )
                for i in range(n_panels)
            ],
        )
    )
    for z in np.round(np.arange(lwh[2] - size, size, -size), decimals=1):
        layers.append(PassiveLayer(lw=lwh[:2], z=z, size=size, device=device))
    layers.append(
        PanelDetectorLayer(
            pos="below",
            lw=lwh[:2],
            z=size,
            size=size,
            panels=[
                DetectorPanel(
                    res=1e3, eff=1, init_xyz=(0.5, 0.5, 0.2 - (i * (size) / n_panels)), init_xy_span=(1.0, 1.0), area_cost_func=area_cost, device=device
                )
                for i in range(n_panels)
            ],
        )
    )

    return Volume(nn.ModuleList(layers))

In [4]:
volume = get_volume(device=DEVICE)

In [5]:
muons = MuonBatch(generate_batch(250), init_z=volume.h, device=DEVICE)

In [6]:
volume(muons)

In [7]:
sb = PanelScatterBatch(muons, volume)

## GNN inferer

In [8]:
model = torch.jit.load('../../mode_muon_tomo_inference/dev/exported_models/bm3_traced.pt')

In [9]:
type(model)

torch.jit._script.RecursiveScriptModule

In [58]:
from tomopt.inference.volume import AbsVolumeInferer, AbsX0Inferer
from tomopt.inference.scattering import AbsScatterBatch


class DeepVolumeInferer(AbsVolumeInferer):
    def __init__(self, model: Union[torch.jit._script.RecursiveScriptModule, nn.Module], base_inferer:AbsX0Inferer, volume: Volume):
        super().__init__(volume=volume)
        self.model,self.base_inferer = model,base_inferer
        self.voxel_centres = self._build_centres()
        
        self.in_vars: List[Tensor] = []
        self.in_var_uncs: List[Tensor] = []
        self.efficiencies: List[Tensor] = []
        self.in_var: Optional[Tensor] = None
        self.in_var_unc: Optional[Tensor] = None
        self.efficiency: Optional[Tensor] = None
        
    def compute_efficiency(self, scatters:AbsScatterBatch) -> Tensor:
        return self.base_inferer.compute_efficiency(scatters=scatters)
    
    def get_base_predictions(self, scatters:AbsScatterBatch) -> Tensor:
        x, u = self.base_inferer.x0_from_dtheta(scatters=scatters)
        return x[:, None], u[:, None]
    
    def add_scatters(self, scatters: AbsScatterBatch) -> None:
        super().add_scatters(scatters=scatters)
        x0, x0_unc = self.get_base_predictions(scatters)
        self.in_vars.append(torch.cat((sb.dtheta, sb.dxy, x0, sb.location), dim=-1))
        self.in_var_uncs.append(torch.cat((sb.dtheta_unc, sb.dxy_unc, x0_unc, sb.location_unc), dim=-1))
        self.efficiencies.append(self.compute_efficiency(scatters=scatters))
        
    def _build_centres(self) -> Tensor:
        bounds = (
            self.volume.passive_size
            * np.mgrid[
                round(self.volume.get_passive_z_range()[0].detach().cpu().numpy()[0] / self.volume.passive_size) : round(
                    self.volume.get_passive_z_range()[1].detach().cpu().numpy()[0] / self.volume.passive_size
                ) : 1,
                0 : round(self.volume.lw.detach().cpu().numpy()[0] / self.volume.passive_size) : 1,
                0 : round(self.volume.lw.detach().cpu().numpy()[1] / self.volume.passive_size) : 1,
            ]
        )
        #         bounds[0] = np.flip(bounds[0])  # z is reversed
        return torch.tensor(bounds.reshape(3, -1).transpose(-1, -2), dtype=torch.float32) + (self.volume.passive_size / 2)
        
    def _build_inputs(self, in_var:Tensor) -> Tensor:
        data = in_var[None, :].repeat_interleave(len(self.voxel_centres), dim=0)
        data[:, :, -3:] -= self.voxel_centres[:, None].repeat_interleave(len(in_var), dim=1)
        data = torch.cat((data, torch.norm(data[:, :, -3:], dim=-1, keepdim=True)), dim=-1)  # dR
        return data
    
    def _get_weight(self) -> Tensor:
        '''Maybe alter this to include resolution/pred uncertainties'''
        return self.efficiency
        
    def get_prediction(self) -> Tuple[Optional[Tensor], Optional[Tensor]]:
        self.in_var = torch.cat(self.in_vars, dim=0)
        self.in_var_unc = torch.cat(self.in_var_uncs, dim=0)
        self.efficiency = torch.cat(self.efficiencies, dim=0)
        
        inputs = self._build_inputs(self.in_var)
        pred = self.model(inputs[None])
        weight = self._get_weight()
        return pred, weight

In [73]:
dvi = DeepVolumeInferer(model=model, base_inferer=PanelX0Inferer(volume), volume=volume)

In [74]:
dvi.in_vars, dvi.in_var_uncs

([], [])

In [75]:
dvi.add_scatters(sb)

In [76]:
dvi.in_vars[0].shape, dvi.in_var_uncs[0].shape

(torch.Size([250, 8]), torch.Size([250, 8]))

In [77]:
%%time
p,w = dvi.get_prediction()

CPU times: user 53.4 s, sys: 623 ms, total: 54 s
Wall time: 1.98 s


In [78]:
torch.exp(p).shape

torch.Size([1, 5, 125])

In [79]:
p

tensor([[[-4.7199, -4.6612, -4.8523, -5.3027, -5.9038, -4.4587, -4.3215,
          -4.4367, -4.8904, -5.5355, -4.4805, -4.3712, -4.4131, -4.7845,
          -5.3617, -4.7846, -4.6841, -4.7248, -5.0203, -5.5507, -5.3168,
          -5.2353, -5.3210, -5.5242, -5.9422, -3.1803, -3.0079, -3.1728,
          -3.7733, -4.6853, -2.7979, -2.5219, -2.6782, -3.2331, -4.2915,
          -2.7714, -2.4809, -2.5953, -3.0754, -4.0568, -3.0980, -2.7760,
          -2.8489, -3.3317, -4.1621, -3.7012, -3.4231, -3.5473, -4.0163,
          -4.6220, -2.0405, -1.6804, -1.8288, -2.4105, -3.3542, -1.5385,
          -1.4254, -1.5604, -1.9171, -2.8633, -1.4701, -1.3884, -1.5540,
          -1.8216, -2.5975, -1.6655, -1.5257, -1.6327, -2.0116, -2.8083,
          -2.3029, -2.1007, -2.1466, -2.6024, -3.3770, -1.5095, -1.1738,
          -1.2140, -1.6123, -2.3049, -1.1016, -0.8568, -0.8933, -1.1963,
          -2.0857, -0.9639, -0.8060, -0.9017, -1.1481, -1.9760, -1.0776,
          -0.9097, -1.0306, -1.3157, -2.0047, -1.69

In [81]:
from abc import abstractmethod, ABCMeta

class AbsDetectorLoss(nn.Module, metaclass=ABCMeta):
    def __init__(
        self,
        *,
        target_budget: float,
        budget_smoothing: float = 10,
        cost_coef: Optional[Union[Tensor, float]] = None,
        steep_budget: bool = True,
        debug: bool = False,
    ):
        super().__init__()
        self.target_budget, self.budget_smoothing, self.cost_coef, self.steep_budget, self.debug = (
            target_budget,
            budget_smoothing,
            cost_coef,
            steep_budget,
            debug,
        )
        self.sub_losses: Dict[str, Tensor] = {}  # Store subcomponents in dict for telemetry

    def _get_budget_coef(self, cost: Tensor) -> Tensor:
        r"""Switch-on near target budget, plus linear increase above budget"""

        if self.target_budget is None:
            return cost.new_zeros(1)

        if self.steep_budget:
            d = self.budget_smoothing * (cost - self.target_budget) / self.target_budget
            if d <= 0:
                return 2 * torch.sigmoid(d)
            else:
                return 1 + (d / 2)
        else:
            d = cost - self.target_budget
            return (2 * torch.sigmoid(self.budget_smoothing * d / self.target_budget)) + (F.relu(d) / self.target_budget)

    def _compute_cost_coef(self, inference: Tensor) -> None:
        self.cost_coef = inference.detach().clone()
        print(f"Automatically setting cost coefficient to {self.cost_coef}")
        
    @abstractmethod
    def _get_inference_loss(self, pred:Tensor, pred_weight: Tensor, volume: Volume) -> Tensor:
        pass
    
    def _get_cost_loss(self, volume: Volume) -> Tensor:
        if self.cost_coef is None:
            self._compute_cost_coef(self.sub_losses["error"])
        cost = volume.get_cost()
        cost_loss = self._get_budget_coef(cost) * self.cost_coef
        if self.debug:
            print(
                f'cost {cost}, cost coef {self.cost_coef}, budget coef {self._get_budget_coef(cost)}. error loss {self.sub_losses["error"]}, cost loss {self.sub_losses["cost"]}'
            )
        return cost_loss

    def forward(self, pred: Tensor, pred_weight: Tensor, volume: Volume) -> Tensor:
        self.sub_losses = {}
        self.sub_losses["error"] = self._get_inference_loss(pred, pred_weight, volume)
        self.sub_losses["cost"] = self._get_cost_loss(volume)
        return self.sub_losses["error"] + self.sub_losses["cost"]

In [88]:
class VoxelX0Loss(AbsDetectorLoss):
    def _get_inference_loss(self, pred:Tensor, pred_weight: Tensor, volume: Volume) -> Tensor:
        true_x0 = volume.get_rad_cube()
        return torch.mean(F.mse_loss(pred, true_x0, reduction='none') / pred_weight)

In [89]:
class AbsMaterialClassLoss(AbsDetectorLoss):
    def __init__(
        self,
        *,
        x02id: Dict[float, int],
        target_budget: float,
        budget_smoothing: float = 10,
        cost_coef: Optional[Union[Tensor, float]] = None,
        steep_budget: bool = True,
        debug: bool = False,
    ):
        super().__init__(target_budget=target_budget, budget_smoothing=budget_smoothing, cost_coef=cost_coef, steep_budget=steep_budget, debug=debug)
        self.x02id = x02id

In [None]:
class VoxelClassLoss(AbsMaterialClassLoss):
    def _get_inference_loss(self, pred:Tensor, pred_weight: Tensor, volume: Volume) -> Tensor:
        true_x0 = volume.get_rad_cube()
        for x0 in true_x0.unique():
            true_x0[true_x0 == x0] = self.x02id[x0]
        true_x0 = true_x0.long()        
        return torch.mean(F.nll_loss(pred, true_x0, reduction='none') / pred_weight)

In [99]:
class VolumeClassLoss(AbsMaterialClassLoss):
    def _get_inference_loss(self, pred:Tensor, pred_weight: Tensor, volume: Volume) -> Tensor:
        targ = volume.target.clone()
        for x0 in targ.unique():
            targ[targ == x0] = self.x02id[x0]
        targ = true_targx0.long()
        loss = F.nll_loss(pred, true_x0, reduction='none') if pred.shape[1] > 1 else F.binary_cross_entropy(pred, true_x0, reduction='none')
        return torch.mean(loss / pred_weight)