In [32]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [33]:
from typing import *
import numpy as np
from functools import partial
from fastprogress import progress_bar
import pandas as pd
import h5py

from lumin.plotting.results import plot_roc

import torch
from torch import Tensor, nn
import torch.nn.functional as F
from torch._vmap_internals import _vmap as vmap

from tomopt.volume import *
from tomopt.muon import *
from tomopt.inference import *
from tomopt.optimisation import *
from tomopt.core import *
from tomopt.utils import *
from tomopt.plotting import *

import seaborn as sns
import matplotlib.pyplot as plt

In [34]:
DEVICE = torch.device("cpu")

In [35]:
def area_cost(x:Tensor) -> Tensor:
    return F.relu(x)

In [36]:
def get_layers():
    layers = []
    lwh = Tensor([1,1,1])
    size = 0.1
    init_eff = 0.9
    init_res = 1e4
    n_panels = 4
    layers.append(PanelDetectorLayer(pos='above', lw=lwh[:2], z=1, size=2*size,
                                     panels=[DetectorPanel(res=init_res, eff=init_eff,
                                                      init_xyz=[0.5,0.5,1-(i*(2*size)/n_panels)], init_xy_span=[2.,2.],
                                                      area_cost_func=area_cost, device=DEVICE) for i in range(n_panels)]))
    for z in [0.8,0.7,0.6,0.5,0.4,0.3]:
        layers.append(PassiveLayer(lw=lwh[:2], z=z, size=size, device=DEVICE))
    layers.append(PanelDetectorLayer(pos='below', lw=lwh[:2], z=0.2, size=2*size,
                                     panels=[DetectorPanel(res=init_res, eff=init_eff,
                                                      init_xyz=[0.5,0.5,0.2-(i*(2*size)/n_panels)], init_xy_span=[2.,2.],
                                                      area_cost_func=area_cost, device=DEVICE) for i in range(n_panels)]))

    return nn.ModuleList(layers) 

In [37]:
volume = Volume(get_layers())

In [38]:
volume

Volume(
  (layers): ModuleList(
    (0): PanelDetectorLayer(
      (panels): ModuleList(
        (0): <class 'tomopt.volume.panel.DetectorPanel'> located at xy=tensor([0.5000, 0.5000]), z=tensor([1.]), and xy span tensor([2., 2.])
        (1): <class 'tomopt.volume.panel.DetectorPanel'> located at xy=tensor([0.5000, 0.5000]), z=tensor([0.9500]), and xy span tensor([2., 2.])
        (2): <class 'tomopt.volume.panel.DetectorPanel'> located at xy=tensor([0.5000, 0.5000]), z=tensor([0.9000]), and xy span tensor([2., 2.])
        (3): <class 'tomopt.volume.panel.DetectorPanel'> located at xy=tensor([0.5000, 0.5000]), z=tensor([0.8500]), and xy span tensor([2., 2.])
      )
    )
    (1): PassiveLayer()
    (2): PassiveLayer()
    (3): PassiveLayer()
    (4): PassiveLayer()
    (5): PassiveLayer()
    (6): PassiveLayer()
    (7): PanelDetectorLayer(
      (panels): ModuleList(
        (0): <class 'tomopt.volume.panel.DetectorPanel'> located at xy=tensor([0.5000, 0.5000]), z=tensor([0.2000]),

In [39]:
from tomopt.optimisation import ULorryPassiveGenerator

In [40]:
u_volume = 12
passive_gen = ULorryPassiveGenerator(volume, u_volume=u_volume, u_prob=0.5, fill_frac=0.8, x0_lorry=X0['iron'], bkg_materials=['air', 'iron'])

In [41]:
test_passives = PassiveYielder(passive_gen, n_passives=200)

In [42]:
from tomopt.optimisation import NoMoreNaNs, PanelMetricLogger, CostCoefWarmup, PanelOptConfig, MuonResampler

In [43]:
wrapper = PanelVolumeWrapper(volume,
                             xy_pos_opt=partial(torch.optim.SGD, lr=5e4),
                             z_pos_opt=partial(torch.optim.SGD, lr=5e3),
                             xy_span_opt=partial(torch.optim.SGD, lr=1e4),
                             loss_func=VolumeClassLoss(x02id={0:0, 1:1}, target_budget=5),
                             partial_volume_inferer=partial(DenseBlockClassifierFromX0s, n_block_voxels=u_volume, partial_x0_inferer=PanelX0Inferer, volume=volume, ratio_offset=-1, ratio_coef=1))

In [44]:
def set_detectors(volume) -> None:  
    for i,p in enumerate(volume.get_detectors()[0].panels):
        p.z.data = 1-(torch.rand(1)*0.2)

    for i,p in enumerate(volume.get_detectors()[1].panels):
        p.z.data = 0.2-(torch.rand(1)*0.2)

In [45]:
set_detectors(volume)

In [46]:
volume

Volume(
  (layers): ModuleList(
    (0): PanelDetectorLayer(
      (panels): ModuleList(
        (0): <class 'tomopt.volume.panel.DetectorPanel'> located at xy=tensor([0.5000, 0.5000]), z=tensor([0.8587]), and xy span tensor([2., 2.])
        (1): <class 'tomopt.volume.panel.DetectorPanel'> located at xy=tensor([0.5000, 0.5000]), z=tensor([0.8190]), and xy span tensor([2., 2.])
        (2): <class 'tomopt.volume.panel.DetectorPanel'> located at xy=tensor([0.5000, 0.5000]), z=tensor([0.8347]), and xy span tensor([2., 2.])
        (3): <class 'tomopt.volume.panel.DetectorPanel'> located at xy=tensor([0.5000, 0.5000]), z=tensor([0.8592]), and xy span tensor([2., 2.])
      )
    )
    (1): PassiveLayer()
    (2): PassiveLayer()
    (3): PassiveLayer()
    (4): PassiveLayer()
    (5): PassiveLayer()
    (6): PassiveLayer()
    (7): PanelDetectorLayer(
      (panels): ModuleList(
        (0): <class 'tomopt.volume.panel.DetectorPanel'> located at xy=tensor([0.5000, 0.5000]), z=tensor([0.057

In [47]:
preds = wrapper.predict(test_passives,
                n_mu_per_volume=250,
                mu_bs=250,
                pred_cb=ClassPredHandler(x02id={0:0, 1:1}),
                cbs=[MuonResampler()])

  idxs = torch.combinations(torch.arange(0, unc.shape[-1]), with_replacement=True)


In [96]:
from sklearn.metrics import roc_auc_score

def get_roc_auc(preds:List[Tuple[np.ndarray,np.ndarray]], n=1000) -> Tuple[float, float]:
    scores = []
    preds = np.array(preds, dtype=object).flatten().reshape(-1,2)
    for _ in range(n):
        rpreds = preds[np.random.choice(np.arange(len(preds)), len(preds), replace=True)]
        if len(np.unique(rpreds[:,1])) == 2:  # Avoid single class in sampling
            scores.append(roc_auc_score(y_true=rpreds[:,1].astype(int), y_score=rpreds[:,0]))
    return np.mean(scores), np.std(scores, ddof=1)

In [99]:
get_roc_auc(preds)

(0.9019387403059502, 0.023147841100060174)

In [104]:
def get_det_data(volume:Volume) -> np.ndarray:
    z = []
    for d in volume.get_detectors():
        for p in d.panels:
            z.append(p.z.detach().cpu().item())
    z = sorted(z)[::-1]
    return np.array(z)

In [105]:
get_det_data(volume)

array([0.85917819, 0.85871094, 0.83474422, 0.81897539, 0.17303939,
       0.1658545 , 0.13859355, 0.05764951])

In [None]:
df = pd.DataFrame()
for _ in progress_bar(range(575)):
    set_detectors(volume)
    preds = wrapper.predict(test_passives,
                n_mu_per_volume=250,
                mu_bs=250,
                pred_cb=ClassPredHandler(x02id={0:0, 1:1}),
                cbs=[MuonResampler()])
    df = df.append([(*get_det_data(volume), *get_roc_auc(preds))])
    df.reset_index(inplace=True, drop=True)
#     df.columns = ['az0','az1','az2','az3','bz0','bz1','bz2','bz3', 'auc', 'auc_unc']
    df.to_csv('z_sep_data2.csv', index=False)

In [149]:
df.reset_index(inplace=True, drop=True)

In [160]:
df.columns = ['az0','az1','az2','az3','bz0','bz1','bz2','bz3', 'auc', 'auc_unc']

In [161]:
df

Unnamed: 0,az0,az1,az2,az3,bz0,bz1,bz2,bz3,auc,auc_unc
0,0.914012,0.877306,0.825831,0.805421,0.132964,0.107893,0.076997,0.018282,0.922617,0.020164
1,0.944217,0.922763,0.887347,0.835888,0.151046,0.113089,0.060073,0.005641,0.881782,0.023326
2,0.973404,0.906292,0.851596,0.829492,0.145443,0.089395,0.087707,0.034052,0.917431,0.019655
3,0.960263,0.861912,0.858426,0.838359,0.164943,0.141219,0.053146,0.010399,0.909469,0.021475
4,0.953248,0.952990,0.938438,0.904456,0.113036,0.078161,0.070929,0.027537,0.895464,0.022554
...,...,...,...,...,...,...,...,...,...,...
445,0.893151,0.877606,0.855265,0.843976,0.127677,0.030037,0.028675,0.004430,0.935859,0.017659
446,0.958221,0.927571,0.879100,0.870160,0.160257,0.127979,0.059070,0.029672,0.885195,0.023082
447,0.958787,0.884533,0.873391,0.867100,0.183160,0.077336,0.007078,0.006845,0.922029,0.018496
448,0.984093,0.964606,0.935248,0.825166,0.166229,0.151129,0.105423,0.029118,0.893990,0.023395


In [163]:
df.to_csv('z_sep_data.csv', index=False)