In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
from tomopt.core import *
import torch
from torch import Tensor
from torch.nn import functional as F

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from torch import nn
from tomopt.volume import PassiveLayer, PanelDetectorLayer, DetectorPanel, DetectorHeatMap

from tomopt.volume import Volume
from functools import partial
from tomopt.optimisation import PanelVolumeWrapper, VoxelX0Loss, HeatMapVolumeWrapper
from tomopt.core import X0
from tomopt.optimisation import PassiveYielder
from tomopt.optimisation import NoMoreNaNs, PanelMetricLogger, CostCoefWarmup

DEVICE = torch.device('cpu')
def area_cost(x:Tensor) -> Tensor:
    return F.relu(x)

x = torch.linspace(-1, 1, 100)
plt.plot(x, area_cost(x))

# def arb_rad_length(*, z:float, lw:Tensor, size:float) -> Tensor:
#     rad_length = torch.ones(list((lw/size).long()))*X0['beryllium']
#     if z >= 0.4 and z <= 0.5: rad_length[5:,5:] = X0['lead']
#     return rad_length

def arb_rad_length(*, z:float, lw:Tensor, size:float) -> Tensor:
    rad_length = torch.ones(list((lw/size).long()))*X0['beryllium']
    if z >= 0.4 and z <= 0.5: rad_length[5:,5:] = X0['lead']
    if z >= 0.6 and z <= 0.7: rad_length[:4,:4] = X0['lead']
    return rad_length

In [None]:
def get_layers():
    layers = []
    lwh = Tensor([1, 1, 1])
    size = 0.1
    init_eff = 1.0
    init_res = 1000
    n_panels = 4
    
    
    # set detectors above volume
    panels_up = [
        DetectorHeatMap(
            init_xyz=[-0.5, -0.5, 1 - (i*(2*size)/n_panels)],
            init_xy_span=[-0.5, 0.5],
            area_cost_func=area_cost,
            device=DEVICE,
            res=init_res,
            eff=init_eff,
        ) 
        for i in range(n_panels)
    ]
    # set detectors below volume
    panels_down = [
        DetectorHeatMap(
            init_xyz=[-0.5, -0.5, 0.2 - (i*(2*size)/n_panels)],
            init_xy_span=[-0.5, 0.5],
            area_cost_func=area_cost,
            device=DEVICE,
            res=init_res,
            eff=init_eff,
        )
        for i in range(n_panels)
    ]
    
    
    layers.append(
        PanelDetectorLayer(
            pos='above',
            lw=lwh[:2],
            z=1,
            size=2*size,
            panels=panels_up,
            type_label="heatmap",
        )
    )
    
    # set passive volume
    for z in [0.8, 0.7, 0.6, 0.5, 0.4, 0.3]:
        layers.append(
            PassiveLayer(
                lw=lwh[:2],
                z=z,
                size=size,
                device=DEVICE,
            )
        )
    
    layers.append(
        PanelDetectorLayer(
            pos='below',
            lw=lwh[:2],
            z=0.2,
            size=2*size,
            panels=panels_down,
            type_label="heatmap",
        )
    )

    return nn.ModuleList(layers) 

In [None]:
volume = Volume(get_layers())
print(volume.get_cost())
wrapper = HeatMapVolumeWrapper(volume,
                             mu_opt=partial(torch.optim.SGD, lr=1e5),
                             norm_opt=partial(torch.optim.SGD, lr=1e5),
                             sig_opt=partial(torch.optim.SGD, lr=1e4),
                             loss_func=VoxelX0Loss(target_budget=4, cost_coef=None))  # Loss is precision + budget_coef*cost_coef, balance coef as required or leave as None to automatically balance on first batch
passives = PassiveYielder([arb_rad_length])

In [None]:
for p in volume:
    for d in p.panels:
        print(type(d))
        print(d)
        for par in d.named_parameters():
            print(par)
    break

In [None]:
_ = wrapper.fit(
    n_epochs=50,
    passive_bs=1,
    n_mu_per_volume=1000,
    mu_bs=1000,
    trn_passives=passives,
    val_passives=passives,
    cbs=[CostCoefWarmup(n_warmup=5), NoMoreNaNs(), PanelMetricLogger()],
)

In [None]:
for p in volume:
    
    if isinstance(p, PanelDetectorLayer):
        for d in p.panels:
            print(type(d))
            print(d)
            for par in d.named_parameters():
                print(par)

In [None]:
preds = wrapper.predict(passives,
                        n_mu_per_volume=10000,
                        mu_bs=100)

In [None]:
from tomopt.plotting import plot_pred_true_x0

In [None]:
plot_pred_true_x0(*preds[0])