In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from tomopt.muon import *
from tomopt.inference import *
from tomopt.loss import *
from tomopt.volume import *
from tomopt.core import *
from tomopt.optimisation import *

import matplotlib.pyplot as plt
import seaborn as sns
from typing import *
import numpy as np

import torch
from torch import Tensor, nn
import torch.nn.functional as F

# Basics

In [3]:
def arb_rad_length(*,z:float, lw:Tensor, size:float) -> float:
    rad_length = torch.ones(list((lw/size).long()))*X0['beryllium']
    if z >= 0.4 and z <= 0.5: rad_length[5:,5:] = X0['lead']
    return rad_length

In [4]:
def eff_cost(x:Tensor) -> Tensor:
    return torch.expm1(3*F.relu(x))

In [5]:
def res_cost(x:Tensor) -> Tensor:
    return F.relu(x/100)**2

In [6]:
def get_layers():
    layers = []
    lwh = Tensor([1,1,1])
    size = 0.1
    init_eff = 0.5
    init_res = 1000
    pos = 'above'
    for z,d in zip(np.arange(lwh[2],0,-size), [1,1,0,0,0,0,0,0,1,1]):
        if d:
            layers.append(DetectorLayer(pos=pos, init_eff=init_eff, init_res=init_res,
                                        lw=lwh[:2], z=z, size=size, eff_cost_func=eff_cost, res_cost_func=res_cost))
        else:
            pos = 'below'
            layers.append(PassiveLayer(rad_length_func=arb_rad_length, lw=lwh[:2], z=z, size=size))

    return nn.ModuleList(layers) 

In [7]:
volume = Volume(get_layers())

# VolumeWrapper

In [8]:
from functools import partial

In [9]:
volume = Volume(get_layers())

In [10]:
wrapper = VolumeWrapper(volume=volume, res_opt=partial(torch.optim.SGD, lr=2e1), eff_opt=partial(torch.optim.SGD, lr=2e-5), loss_func=DetectorLoss(0.15))

In [11]:
trn_passives = PassiveYielder([arb_rad_length])

In [12]:
for p in trn_passives: print(p)

<function arb_rad_length at 0x7fe6af8210d0>


In [33]:
for i in range(10):
    print(i%5)

0
1
2
3
4
0
1
2
3
4


In [24]:
from tomopt.optimisation.callbacks.callback import Callback

r'''
This MetricLogger is a modified version of the MetricLogger in LUMIN (https://github.com/GilesStrong/lumin/blob/v0.7.2/lumin/nn/callbacks/monitors.py#L125), distributed under the following lincence:
    Copyright 2018 onwards Giles Strong

    Licensed under the Apache License, Version 2.0 (the "License");
    you may not use this file except in compliance with the License.
    You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

    Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS,
    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    See the License for the specific language governing permissions and
    limitations under the License.

Usage is compatible with the AGPL licence underwhich TomOpt is distributed.
Stated changes: adaption to work with `VolumeWrapper` class, replacement of the telemtry plots with task specific information.
'''
                                                               
class MetricLogger(Callback):
    r'''
    Provides live feedback during training showing a variety of metrics to help highlight problems or test hyper-parameters without completing a full training.
    If `show_plots` is false, will instead print training and validation losses at the end of each epoch.
    The full history is available as a dictionary by calling `MetricLogger.get_loss_history`.
    '''

    def __init__(self, show_plots:bool=IN_NOTEBOOK, loss_is_meaned:bool=True):
        self.show_plots,self.loss_is_meaned = show_plots,loss_is_meaned

    def on_train_begin(self) -> None:
        r'''
        Prepare for new training
        '''

        super().on_train_begin()
        self._reset()
        for c in self.wrapper.fit_params.loss_cbs: self._add_loss_name(type(c).__name__)

    def on_epoch_begin(self) -> None:
        r'''
        Prepare to track new loss
        '''

        self.loss,self.cnt = 0,0

    def on_epoch_end(self) -> None:
        r'''
        If validation epoch finished, record validation losses, compute info and update plots
        '''

        if self.model.fit_params.state == 'train':
            self.loss_vals[0].append(self.wrapper.loss_val)
        elif self.model.fit_params.state == 'valid':
            self.epochs.append(self.epochs[-1]+1)
            self.loss_vals[1].append(self.wrapper.loss_val)
            for i,c in enumerate(self.model.fit_params.loss_cbs): self.loss_vals[i+2].append(c.get_loss())
            for i,c in enumerate(self.metric_cbs): self.metric_vals[i].append(c.get_metric())
            if self.show_plots:
                for i, v in enumerate(self.loss_vals[1:]):
                    if len(self.loss_vals[1]) > 1 and self.extra_detail:
                        self.vel_vals[i].append(v[-1]-v[-2])
                        self.gen_vals[i].append(v[-1]/self.loss_vals[0][-1])
                    if self.loss_vals[i+1][-1] <= self.best_loss: self.best_loss = self.loss_vals[i+1][-1]
                self.update_plot()
            else:
                self.print_losses()

            ls = np.array(self.loss_vals[1:])[:,-1]
            m = None
            if self.lock_to_metric:
                m = self.metric_vals[self.main_metric_idx][-1]
                if not self.metric_cbs[self.main_metric_idx].lower_metric_better: m *= -1
            self.val_epoch_results = ls,m

    def _add_loss_name(self, name:str) -> None:
        self.loss_names.append(name)
        self.loss_vals.append([0 for _ in self.loss_vals[1]])
        self.vel_vals.append([0 for _ in self.vel_vals[0]])
        self.gen_vals.append([0 for _ in self.gen_vals[0]])

    def print_losses(self) -> None:
        r'''
        Print training and validation losses for the last epoch
        '''

        p = f'Epoch {len(self.loss_vals[1])}: Training = {np.mean(self.loss_vals[0][-self.n_trn_flds:]):.2E}'
        for v,m in zip(self.loss_vals[1:],self.loss_names[1:]): p += f' {m} = {v[-1]:.2E}'
        for m,v in zip(self.metric_cbs, self.metric_vals): p += f' {m.name} = {v[-1]:.2E}'
        print(p)

    def update_plot(self) -> None:
        r'''
        Updates the plot(s).

        # TODO: make this faster
        '''

        # Loss
        self.loss_ax.clear()
        with sns.axes_style(**self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette):
            self.loss_ax.plot(range(1,len(self.loss_vals[0])+1), self.loss_vals[0], label=self.loss_names[0])
            x = range(self.n_trn_flds, self.n_trn_flds*len(self.loss_vals[1])+1, self.n_trn_flds)
            for v,m in zip(self.loss_vals[1:],self.loss_names[1:]):
                self.loss_ax.plot(x, v, label=m)
            self.loss_ax.plot([1,x[-1]], [self.best_loss,self.best_loss], label=f'Best = {self.best_loss:.3E}', linestyle='--')
            if self.log:
                self.loss_ax.set_yscale('log', nonposy='clip')
                self.loss_ax.tick_params(axis='y', labelsize=0.8*self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col, which='both')
            self.loss_ax.grid(True, which="both")
            self.loss_ax.legend(loc='upper right', fontsize=0.8*self.plot_settings.leg_sz)
            self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)

        if self.extra_detail and len(self.loss_vals[1]) > 1:
            # Velocity
            self.vel_ax.clear()
            self.vel_ax.tick_params(axis='y', labelsize=0.8*self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col, which='both')
            self.vel_ax.grid(True, which="both")
            with sns.axes_style(**self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette) as palette:
                for i, (v,m) in enumerate(zip(self.vel_vals,self.loss_names[1:])):
                    self.vel_ax.plot(self.epochs[2:], v, label=f'{m} {v[-1]:.2E}', color=palette[i+1])
                self.vel_ax.legend(loc='lower right', fontsize=0.8*self.plot_settings.leg_sz)
                self.vel_ax.set_ylabel(r'$\Delta \bar{L}\ /$ Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)

            # Generalisation
            self.gen_ax.clear()
            self.gen_ax.grid(True, which="both")
            with sns.axes_style(**self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette) as palette:
                for i, (v,m) in enumerate(zip(self.gen_vals,self.loss_names[1:])):
                    self.gen_ax.plot(self.epochs[2:], v, label=f'{m} {v[-1]:.2f}', color=palette[i+1])
                self.gen_ax.legend(loc='upper left', fontsize=0.8*self.plot_settings.leg_sz)
                self.gen_ax.set_xlabel('Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
                self.gen_ax.set_ylabel('Validation / Train', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
                if len(self.epochs) > 8:  # For some reason this needs to be 2+number epochs to display...
                    self.epochs = self.epochs[1:]
                    for i in range(len(self.vel_vals)): self.vel_vals[i],self.gen_vals[i] = self.vel_vals[i][1:],self.gen_vals[i][1:]

            # Metrics
            if self.main_metric_idx is not None:
                self.metric_ax.clear()
                with sns.axes_style(**self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette) as palette:
                    x = range(self.n_trn_flds, self.n_trn_flds*len(self.loss_vals[1])+1, self.n_trn_flds)
                    y = self.metric_vals[self.main_metric_idx]
                    self.metric_ax.plot(x, y, color=palette[1])
                    best = np.nanmin(y) if self.metric_cbs[self.main_metric_idx].lower_metric_better else np.nanmax(y)
                    self.metric_ax.plot([1,x[-1]], [best,best], label=f'Best = {best:.3E}', linestyle='--', color=palette[2])
                    self.metric_ax.legend(loc='upper left', fontsize=0.8*self.plot_settings.leg_sz)
                    self.metric_ax.grid(True, which="both")
                    self.metric_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
                    self.metric_ax.set_ylabel(self.metric_cbs[self.main_metric_idx].name, fontsize=0.8*self.plot_settings.lbl_sz,
                                              color=self.plot_settings.lbl_col)
            
            self.display.update(self.fig)
        else:
            self.display.update(self.loss_ax.figure)

    def _reset(self) -> None:
        self.loss_names = ['Training', 'Validation']
        self.loss_vals = [[] for _ in self.loss_names]
        self.vel_vals, self.gen_vals = [[] for _ in range(len(self.loss_names)-1)], [[] for _ in range(len(self.loss_names)-1)]
        self.n_trn_flds = len(self.model.fit_params.trn_idxs)
        self.log = 'regress' in self.model.objective.lower()
        self.best_loss,self.epochs = math.inf,[0]

        self.metric_cbs = []
        for c in self.model.fit_params.cbs:
            if hasattr(c, 'get_metric'):
                self.metric_cbs.append(c)
        self.metric_vals = [[] for _ in self.metric_cbs]
        self.main_metric_idx = None
        self.lock_to_metric = False
        if len(self.metric_cbs) > 0:
            self.main_metric_idx = 0
            for i,c in enumerate(self.metric_cbs):
                if c.main_metric:
                    self.main_metric_idx = i
                    self.lock_to_metric = True
                    break

        if self.show_plots:
            with sns.axes_style(**self.plot_settings.style):
                if self.extra_detail:
                    self.fig = plt.figure(figsize=(self.plot_settings.w_mid, self.plot_settings.h_mid), constrained_layout=True)
                    gs = self.fig.add_gridspec(2, 3)
                    if self.main_metric_idx is None:
                        self.loss_ax = self.fig.add_subplot(gs[:,:-1])
                    else:
                        self.loss_ax = self.fig.add_subplot(gs[:1,:-1])
                        self.metric_ax = self.fig.add_subplot(gs[1:2,:-1])
                    self.vel_ax  = self.fig.add_subplot(gs[:1,2:])
                    self.gen_ax  = self.fig.add_subplot(gs[1:2,2:])
                    for ax in [self.loss_ax, self.vel_ax, self.gen_ax]:
                        ax.tick_params(axis='x', labelsize=0.8*self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col)
                        ax.tick_params(axis='y', labelsize=0.8*self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col)
                    self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
                    self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
                    if self.main_metric_idx is not None:
                        self.metric_ax.tick_params(axis='x', labelsize=0.8*self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col)
                        self.metric_ax.tick_params(axis='y', labelsize=0.8*self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col)
                        self.metric_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
                        self.metric_ax.set_ylabel(self.metric_cbs[self.main_metric_idx].name, fontsize=0.8*self.plot_settings.lbl_sz,
                                                  color=self.plot_settings.lbl_col)
                    self.vel_ax.set_ylabel(r'$\Delta \bar{L}\ /$ Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
                    self.gen_ax.set_xlabel('Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
                    self.gen_ax.set_ylabel('Validation / Train', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
                    self.display = display(self.fig, display_id=True)
                else:
                    self.fig, self.loss_ax = plt.subplots(1, figsize=(self.plot_settings.w_mid, self.plot_settings.h_mid))
                    self.loss_ax.tick_params(axis='x', labelsize=0.8*self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col)
                    self.loss_ax.tick_params(axis='y', labelsize=0.8*self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col)
                    self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
                    self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
                    self.display = display(self.loss_ax.figure, display_id=True)

    def get_loss_history(self) -> Tuple[OrderedDict,OrderedDict]:
        r'''
        Get the current history of losses and metrics

        Returns:
            history: tuple of ordered dictionaries: first with losses, second with validation metrics
        '''

        history = (OrderedDict(),OrderedDict())
        for v,m in zip(self.loss_vals,self.loss_names): history[0][m] = v
        for v,c in zip(self.metric_vals,self.metric_cbs): history[1][c.name] = v
        return history

    def get_results(self, save_best:bool) -> Dict[str,float]:
        r'''
        Returns losses and metrics of the (loaded) model

        #TODO: extend this to load at specified index

        Arguments:
            save_best: if the training used :class:`~lumin.nn.callbacks.monitors.SaveBest` return results at best point else return the latest values

        Returns:
            dictionary of validation loss and metrics
        '''

        losses = np.array(self.loss_vals[1:])
        metrics = np.array(self.metric_vals)
        results = {}
        
        if save_best:
            if self.main_metric_idx is None or not self.lock_to_metric or len(losses) > 1:  # Tracking SWA only supported for loss
                idx = np.unravel_index(np.nanargmin(losses), losses.shape)[-1]
                results['loss'] = np.nanmin(losses)
            else:
                idx = np.nanargmin(self.metric_vals[self.main_metric_idx]) if self.metric_cbs[self.main_metric_idx].lower_metric_better else \
                    np.nanargmax(self.metric_vals[self.main_metric_idx])
                results['loss'] = losses[0][idx]
        else:
            results['loss'] = np.nanmin(losses[:,-1:])
            idx = -1
        if len(self.metric_cbs) > 0:
            for c,v in zip(self.metric_cbs,metrics[:,idx]): results[c.name] = v
        return results


In [None]:
ml = MetricLogger()

In [None]:
wrapper.fit(10, n_mu_per_volume=1000, mu_bs=100, trn_passives=trn_passives, val_passives=None, cbs=[NoMoreNaNs(),ml])