In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from tomopt.muon import *
from tomopt.inference import *
from tomopt.loss import *
from tomopt.volume import *
from tomopt.core import *

import matplotlib.pyplot as plt
import seaborn as sns
from typing import *
import numpy as np

import torch
from torch import Tensor, nn
import torch.nn.functional as F

# Basics

In [3]:
def arb_rad_length(*,z:float, lw:Tensor, size:float) -> float:
    rad_length = torch.ones(list((lw/size).long()))*X0['beryllium']
    if z >= 0.4 and z <= 0.5: rad_length[5:,5:] = X0['lead']
    return rad_length

In [4]:
def eff_cost(x:Tensor) -> Tensor:
    return torch.expm1(3*F.relu(x))

In [5]:
def res_cost(x:Tensor) -> Tensor:
    return F.relu(x/100)**2

In [6]:
def get_layers():
    layers = []
    lwh = Tensor([1,1,1])
    size = 0.1
    init_eff = 0.5
    init_res = 1000
    pos = 'above'
    for z,d in zip(np.arange(lwh[2],0,-size), [1,1,0,0,0,0,0,0,1,1]):
        if d:
            layers.append(DetectorLayer(pos=pos, init_eff=init_eff, init_res=init_res,
                                        lw=lwh[:2], z=z, size=size, eff_cost_func=eff_cost, res_cost_func=res_cost))
        else:
            pos = 'below'
            layers.append(PassiveLayer(rad_length_func=arb_rad_length, lw=lwh[:2], z=z, size=size))

    return nn.ModuleList(layers) 

In [7]:
volume = Volume(get_layers())

# VolumeWrapper

In [41]:
class Callback():
    wrapper: Optional['VolumeWrapper'] = None
        
    def __init__(self): pass
    def set_wrapper(self, wrapper:'VolumeWrapper') -> None: self.wrapper = wrapper
    def set_plot_settings(self): pass

    def on_train_begin(self) -> None:
        if self.wrapper is None:
            raise AttributeError(f"The wrapper for {type(self).__name__} callback has not been set. Please call set_wrapper before on_train_begin.")
            
    def on_train_end(self) -> None:   pass

    def on_epoch_begin(self) -> None: pass
    def on_epoch_end(self) -> None:   pass
    
    def on_volume_begin(self) -> None: pass
    def on_volume_end(self) -> None:   pass

    def on_mu_batch_begin(self) -> None: pass
    def on_mu_batch_end(self) -> None:   pass

    def on_scatter_end(self) -> None: pass

    def on_backwards_begin(self) -> None: pass
    def on_backwards_end(self) -> None:   pass
    
    def on_x0_pred_begin(self) -> None: pass
    def on_x0_pred_end(self) -> None:   pass

    def on_pred_begin(self) -> None: pass
    def on_pred_end(self) -> None:   pass

class CyclicCallback(Callback):
    pass

class MetricLogger(Callback):
    pass

In [42]:
# lumin imports
import inspect

def is_partially(var:Any) -> bool:
    r'''
    Retuns true if var is partial, function, or class, else false.

    Arguments:
        var: variable to inspect

    Return:
        true if var is partial or partialler, else false
    '''

    return isinstance(var, (partial,types.FunctionType)) or inspect.isclass(var)

class FitParams():
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)
        self.epoch = 0

In [43]:
class PassiveGenerator():
    pass

    def generate(self) -> Callable[...,Tensor]:
        pass

In [44]:
from random import shuffle

class PassiveYielder():
    def __init__(self, passives:Union[List[Callable[...,Tensor]],PassiveGenerator], n_passives:Optional[int]=None, shuffle:bool=True):
        self.passives,self.n_passives,self.shuffle = passives,n_passives,shuffle
        if isinstance(self.passives, PassiveGenerator):
            self.generator = True
            if self.n_passives is None:
                raise ValueError('If a PassiveGenerator class is used, n_passives must be specified')
        else:
            self.generator = False
            self.n_passives = len(self.passives)
        
    def __len__(self) -> int:
        return self.n_passives
    
    def __iter__(self) -> Callable[...,Tensor]:
        if self.generator:
            for _ in range(self.n_passives): yield self.passives.generate()
        else:
            if self.shuffle: shuffle(self.passives)
            for p in self.passives: yield p

In [45]:
from __future__ import annotations
from fastcore.all import is_listy, Path
from fastprogress import progress_bar

class VolumeWrapper():
    def __init__(self, volume:Volume, res_opt:Callable[[Iterator[nn.Parameter]],torch.optim.Optimizer], eff_opt:Callable[[Iterator[nn.Parameter]],torch.optim.Optimizer],
                 loss_func:Optional[nn.Module]=DetectorLoss, default_pred: Optional[float] = X0["beryllium"]):
        self.volume,self.loss_func,self.default_pred = volume,loss_func,default_pred
        self._build_opt(res_opt, eff_opt)
        self.parameters = self.volume.parameters
        
    def _build_opt(self, res_opt:Callable[[Iterator[nn.Parameter]],torch.optim.Optimizer], eff_opt:Callable[[Iterator[nn.Parameter]],torch.optim.Optimizer]) -> None:
        self.res_opt = res_opt(((l.resolution for l in volume.get_detectors())))
        self.eff_opt = eff_opt(((l.efficiency for l in volume.get_detectors())))
        
    def get_detectors(self) -> List[DetectorLayer]: return self.volume.get_detectors()
    
    def save(self, name:str) -> None:
        torch.save({'volume':self.volume.state_dict(), 'res_opt':self.res_opt.state_dict(), 'eff_opt':self.eff_opt.state_dict()}, str(name))
        
    def load(self, name:str) -> None:
        state = torch.load(name, map_location='cuda' if torch.cuda.is_available() else 'cpu')
        self.volume.load_state_dict(state['volume'])
        self.res_opt.load_state_dict(state['res_opt'])
        self.eff_opt.load_state_dict(state['eff_opt'])
    
    @classmethod
    def from_save(cls, name:str, volume:Volume, res_opt:Callable[[Iterator[nn.Parameter]],torch.optim.Optimizer], eff_opt:Callable[[Iterator[nn.Parameter]],torch.optim.Optimizer],
                  loss_func:Optional[DetectorLoss], default_pred: Optional[float] = X0["beryllium"]) -> VolumeWrapper:
        vw = cls(volume=volume, res_opt=res_opt, eff_opt=eff_opt, loss_func=loss_func, default_pred=default_pred)
        vw.load(name)
        return vw
    
    def get_param_count(self, trainable:bool=True) -> int:
        r'''
        Return number of parameters in detector.

        Arguments:
            trainable: if true (default) only count trainable parameters

        Returns:
            Number of (trainable) parameters in detector
        '''
        
        return sum(p.numel() for p in self.parameters() if p.requires_grad or not trainable)
    
    def _scan_volume(self) -> None:
        # Scan volume with muon batches
        self.fit_params.wpreds, self.fit_params.weights = [], []
        for _ in range(self.fit_params.n_mu_per_batch//self.fit_params.mu_bs):
            self.fit_params.mu = MuonBatch(self.fit_params.mu_generator(self.fit_params.mu_bs), init_z=self.volume.h)
            for c in self.fit_params.cbs: c.on_mu_batch_begin()
            self.volume(self.fit_params.mu)
            self.fit_params.sb = ScatterBatch(self.fit_params.mu, self.volume)
            for c in self.fit_params.cbs: c.on_scatter_end()
            inferer = X0Inferer(self.fit_params.sb, self.fit_params.default_pred)
            pred, wgt = inferer.pred_x0(inc_default=False)
            pred = torch.nan_to_num(pred)
            self.fit_params.wpreds.append(pred*wgt)
            self.fit_params.weights.append(wgt)
            for c in self.fit_params.cbs: c.on_mu_batch_end()
        
        # Predict volme based on all muon batches
        for c in self.fit_params.cbs: c.on_x0_pred_begin()
        wgt = torch.stack(self.fit_params.weights, dim=0).sum(0)
        pred = torch.stack(self.fit_params.wpreds, dim=0).sum(0)/wgt
        pred, wgt = inferer.add_default_pred(pred, wgt)
        self.fit_params.weight = wgt
        self.fit_params.pred = pred
        
        for c in self.fit_params.cbs: c.on_x0_pred_end()
        
        # Compute loss for volume
        if self.fit_params.state != 'test' and self.fit_params.loss_func is not None:
            loss = self.fit_params.loss_func(pred_x0=self.fit_params.pred, pred_weight=self.fit_params.weight, volume=volume)
            if self.fit_params.loss_val is None:
                self.fit_params.loss_val = loss
            else:
                self.fit_params.loss_val = self.fit_params.loss_val+loss
                
    def _scan_volumes(self, passives:PassiveYielder) -> None:
        self.fit_params.loss_val = None
        for i, passive in enumerate(passives):
            self.fit_params.volume_id = i
            self.volume.load_rad_length(passive)
            for c in self.fit_params.cbs: c.on_volume_begin()
            self._scan_volume()
            for c in self.fit_params.cbs: c.on_volume_end()
        self.fit_params.mean_loss = self.fit_params.loss_val/len(passives)
    
    def fit(self, n_epochs:int, n_mu_per_batch:int, passive_bs:int, mu_bs:int, trn_passives:PassiveYielder, val_passives:Optional[PassiveYielder], mu_generator:Callable[[int],Tensor]=generate_batch,
            cbs:Optional[Union[Callback,List[Callback]]]=None, cb_savepath:Path=Path('train_weights'),
            visible_bar:bool=True) -> List[Callback]:
        
        if cbs is None: cbs = []
        elif not is_listy(cbs): cbs = [cbs]
        cyclic_cbs,loss_cbs,metric_log = [],[],None
        for c in cbs:
            if isinstance(c, CyclicCallback): cyclic_cbs.append(c)  # CBs that might prevent a wrapper from stopping training due to a hyper-param cycle
            if hasattr(c, "get_loss"): loss_cbs.append(c)  # CBs that produce alternative losses that should be considered
            if isinstance(c, MetricLogger): metric_log = c  # CB that logs losses and eval_metrics
                
        self.fit_params = FitParams(cbs=cbs, cyclic_cbs=cyclic_cbs, loss_cbs=loss_cbs, metric_log=metric_log, stop=False, n_epochs=n_epochs,
                                    passive_bs=passive_bs, mu_bs=mu_bs, n_mu_per_batch=n_mu_per_batch, cb_savepath=Path(cb_savepath), trn_passives=trn_passives, val_passives=val_passives, mu_generator=mu_generator,
                                    loss_func=self.loss_func, res_opt=self.res_opt, eff_opt=self.eff_opt, default_pred=self.default_pred)
        self.fit_params.cb_savepath.mkdir(parents=True, exist_ok=True)
        
        def fit_epoch() -> None:
            self.fit_params.epoch += 1
            
            # Training
            self.volume.train()
            self.fit_params.state = 'train'
            for c in self.fit_params.cbs: c.on_epoch_begin()
            self._scan_volumes(self.fit_params.trn_passives)  # Gain losses for all volumes
            # Compute update step
            self.fit_params.res_opt.zero_grad()
            self.fit_params.eff_opt.zero_grad()
            for c in self.fit_params.cbs: c.on_backwards_begin()
            self.fit_params.loss_val.backward()
            for c in self.fit_params.cbs: c.on_backwards_end()
            self.fit_params.res_opt.step()     
            self.fit_params.eff_opt.step()        
            for c in self.fit_params.cbs: c.on_epoch_end()
            
            # Validation
            if self.fit_params.val_passives is not None:
                self.volume.eval()
                for c in self.fit_params.cbs: c.on_epoch_begin()
                self._scan_volumes(self.fit_params.val_passives)
                for c in self.fit_params.cbs: c.on_epoch_end()
                    
        try:
            for c in self.fit_params.cbs: c.set_wrapper(self)
            for c in self.fit_params.cbs: c.on_train_begin()
            for e in progress_bar(range(self.fit_params.n_epochs), display=visible_bar):
                fit_epoch()
                if self.fit_params.stop: break
            for c in self.fit_params.cbs: c.on_train_end()
        finally:
            self.fit_params = None
            torch.cuda.empty_cache()
        return cbs

In [46]:
class NoMoreNaN(Callback):
    def on_backwards_end(self) -> None:
        for l in self.wrapper.volume.get_detectors():
            torch.nan_to_num_(l.resolution.grad, 0)
            torch.nan_to_num_(l.efficiency.grad, 0)

In [47]:
from functools import partial

In [48]:
volume = Volume(get_layers())

In [49]:
wrapper = VolumeWrapper(volume=volume, res_opt=partial(torch.optim.SGD, lr=2e1), eff_opt=partial(torch.optim.SGD, lr=2e-5), loss_func=DetectorLoss(0.15))

In [50]:
trn_passives = PassiveYielder([arb_rad_length])

In [51]:
for p in trn_passives: print(p)

<function arb_rad_length at 0x7fccb63edca0>


In [52]:
wrapper.fit(10, n_mu_per_batch=1000, passive_bs=1, mu_bs=1000, trn_passives=trn_passives, val_passives=None, cbs=[NoMoreNaN()])

X0 comp: 8563.88, cost comp: 6209
X0 comp: 6342.17, cost comp: 6233
X0 comp: 5576.31, cost comp: 6251
X0 comp: 5460.25, cost comp: 6265
X0 comp: 5391.18, cost comp: 6280
X0 comp: 6661.94, cost comp: 6294
X0 comp: 5537.45, cost comp: 6314
X0 comp: 6883.91, cost comp: 6328
X0 comp: 4103.39, cost comp: 6347
X0 comp: 5390.42, cost comp: 6357


[<__main__.NoMoreNaN at 0x7fcc8cb9a250>]

In [None]:
wrapper.volume