In [20]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [21]:
from tomopt.muon import *
from tomopt.inference import *
from tomopt.loss import *
from tomopt.volume import *
from tomopt.core import *
from tomopt.optimisation import *

import matplotlib.pyplot as plt
import seaborn as sns
from typing import *
import numpy as np

import torch
from torch import Tensor, nn
import torch.nn.functional as F

# Basics

In [22]:
def arb_rad_length(*,z:float, lw:Tensor, size:float) -> float:
    rad_length = torch.ones(list((lw/size).long()))*X0['beryllium']
    if z >= 0.4 and z <= 0.5: rad_length[5:,5:] = X0['lead']
    return rad_length

In [23]:
def eff_cost(x:Tensor) -> Tensor:
    return torch.expm1(3*F.relu(x))

In [24]:
def res_cost(x:Tensor) -> Tensor:
    return F.relu(x/100)**2

In [25]:
def get_layers():
    layers = []
    lwh = Tensor([1,1,1])
    size = 0.1
    init_eff = 0.5
    init_res = 1000
    pos = 'above'
    for z,d in zip(np.arange(lwh[2],0,-size), [1,1,0,0,0,0,0,0,1,1]):
        if d:
            layers.append(DetectorLayer(pos=pos, init_eff=init_eff, init_res=init_res,
                                        lw=lwh[:2], z=z, size=size, eff_cost_func=eff_cost, res_cost_func=res_cost))
        else:
            pos = 'below'
            layers.append(PassiveLayer(rad_length_func=arb_rad_length, lw=lwh[:2], z=z, size=size))

    return nn.ModuleList(layers) 

In [26]:
volume = Volume(get_layers())

In [28]:
for p in volume.parameters(): print(p.grad)

None
None
None
None
None
None
None
None


# VolumeWrapper

In [10]:
from functools import partial

In [11]:
volume = Volume(get_layers())

In [12]:
wrapper = VolumeWrapper(volume=volume, res_opt=partial(torch.optim.SGD, lr=2e1), eff_opt=partial(torch.optim.SGD, lr=2e-5), loss_func=DetectorLoss(0.15))

In [13]:
trn_passives = PassiveYielder([arb_rad_length])

In [14]:
for p in trn_passives: print(p)

<function arb_rad_length at 0x7fb2cbf21430>


In [16]:
wrapper.fit(10, n_mu_per_volume=1000, passive_bs=1, mu_bs=1000, trn_passives=trn_passives, val_passives=None, cbs=[NoMoreNaNs()])

[<tomopt.optimisation.callback.grad_callbacks.NoMoreNaNs at 0x7fb2ce1d6910>]

In [17]:
p = wrapper.predict(n_mu_per_volume=1000, mu_bs=1000, passives=trn_passives)

In [18]:
p[0].shape

(6, 10, 10)