In [None]:
%load_ext autotime
import shutil, logging, warnings, pathlib, pickle, tempfile, dataclasses, typing, contextlib, IPython
import numpy as np, scipy as sp, pandas as pd, sympy as sy, tensorflow as tf, bayesflow as bf
import matplotlib.pyplot as plt, seaborn as sns
from matplotlib.mathtext import math_to_image
from dataclasses import dataclass, field
from wolframclient.evaluation import WolframLanguageSession
from wolframclient.language import wl, wlexpr
pd.set_option('display.max_columns', None)
logging.disable(logging.WARN)
warnings.simplefilter('ignore', FutureWarning)
warnings.simplefilter('once', UserWarning)
EPS = 1e-8
for gpu in tf.config.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)
IPython.display.clear_output()

def safe_delete(target):
    """confirms before deletion to avoid accidents"""
    if target.is_file():
        input(f'press enter to delete and recreate {target}')
        target.unlink()
    elif target.is_dir():
        input(f'press enter to delete and recreate {target}')
        shutil.rmtree(target)

def read_pickle(file, overwrite=False):
    if overwrite:
        safe_delete(file)
    try:
        with open(file, 'rb') as f:
            return pickle.load(f)
    except:
        return None

def write_pickle(dct, file):
    file.parent.mkdir(exist_ok=True, parents=True)
    with open(file, 'wb') as f:
        pickle.dump(dct, f, -1)

def latex(x):
    """converts symbols into nice latex for graphics"""
    try:
        y = f'$\\{x}$'
        with tempfile.TemporaryFile() as fp:
            math_to_image(y, fp)
    except:
        y = f'${x}$'
    return y

class BaseClass():
    """enables access to attributes & methods using dictionary syntax"""
    def __getitem__(self, key):
        return getattr(self, key)
    def __delitem__(self, key):
        delattr(self, key)
    def __setitem__(self, key, val):
        setattr(self, key, val)
    def __contains__(self, key):
        return hasattr(self, key)

@dataclass
class ODE(BaseClass):
    name: str
    n_steps: int
    sys: typing.Dict
    x0: typing.Dict
    cls_ode: typing.List
    cls_obs: typing.List
    cls_dif: typing.List = field(default_factory=list)
    prior: typing.Callable = None  # not necessary (will be overwritten) if prior.pkl reads successfully
    param_scaler: str = 'minmax'
    data_scaler: str = 'robust'
    n_calibrate: int = 2**20
    num_coupling_layers: int = 6
    coupling_settings: typing.Dict = field(default_factory=dict)
    solver_kwargs: typing.Dict = field(default_factory=dict)
    root_path: str = '/home/pythonserver/BayesResearch/Cook_new'
    overwrite: bool = False

    def __post_init__(self):
        self.ses = WolframLanguageSession()
        self.param_scaler = f'{self.param_scaler}'.lower().strip()
        self.data_scaler = f'{self.data_scaler}'.lower().strip()
        self.root_path = pathlib.Path(self.root_path) / self.name
        self.checkpoint_path = self.root_path / f'{self.param_scaler}_{self.data_scaler}/checkpoints'
        self.results_path = self.root_path / f'{self.param_scaler}_{self.data_scaler}/results'
        self.results_path.mkdir(exist_ok=True, parents=True)
        self.use_jac = ('method' in self.solver_kwargs) and (self.solver_kwargs['jac'] in ['Radau','BDF','LSODA'])
        
        self.cls_all = list(self.x0.keys())
        self.x0 = list(self.x0.values())
        self.idx_ode = [k for k, key in enumerate(self.cls_all) if key in self.cls_ode]
        self.idx_obs = [k for k, key in enumerate(self.cls_all) if key in self.cls_obs]
        self.idx_dif = [k for k, key in enumerate(self.cls_all) if key in self.cls_dif]
        self.sys_all = sy.Matrix([self.sys[key] for key in self.cls_all])
        self.sys_ode = sy.Matrix([self.sys[key] for key in self.cls_ode])
        self.jac_all = self.sys_all.jacobian(self.cls_all)
        self.jac_ode = self.sys_ode.jacobian(self.cls_ode)
        self.prior = self.init_prior()
        self.generative_model = bf.simulation.GenerativeModel(prior=self.prior, simulator=bf.simulation.Simulator(simulator_fun=self.solve))
        self.init_model()
        IPython.display.clear_output()

    def init_prior(self):
        file = self.root_path / 'prior.pkl'
        pr = read_pickle(file, self.overwrite)
        if pr is None:
            print('generating simulations')
            pr = self.prior
            self.generative_model = bf.simulation.GenerativeModel(prior=pr, simulator=bf.simulation.Simulator(simulator_fun=self.solve))
            pr.__dict__ |= self.get_dfs(self.generative_model(self.n_calibrate))
            print('simulations complete')
            for key, val in {'prior_draws':'param', 'sim_data':'data'}.items():
                arr = pr[key]
                if arr.ndim <= 2:
                    ax = 0
                else:
                    ax = (0,2)
                pr[f'{val}_mean'] = np.mean(arr, axis=ax, keepdims=True)
                pr[f'{val}_std'] = np.std(arr, axis=ax, keepdims=True)
                for r in range(5):
                    pr[f'{val}_q{r}'] = np.quantile(arr, q=r/4, axis=ax, keepdims=True)
                pr[f'{val}_ran'] = pr[f'{val}_q4'] - pr[f'{val}_q0']
                pr[f'{val}_iqr'] = pr[f'{val}_q3'] - pr[f'{val}_q1']
            write_pickle(pr, file)
        return pr

    def init_model(self, overwrite=False):
        if overwrite:
            safe_delete(self.checkpoint_path)
        self.amortizer = bf.amortizers.AmortizedPosterior(
            inference_net = bf.networks.InvertibleNetwork(
                num_params=len(self.prior.names),
                num_coupling_layers=self.num_coupling_layers,
                coupling_settings=self.coupling_settings,
            ),
            summary_net = bf.networks.SequenceNetwork())
        self.trainer = bf.trainers.Trainer(
            generative_model = self.generative_model,
            configurator = self.configurator,
            amortizer = self.amortizer,
            checkpoint_path = self.checkpoint_path,
            # memory = True,
            max_to_keep = 10000,
        )

    def load_best_checkpoint(self):
        self.init_model()
        try:
            best_idx = self.trainer.loss_history.total_val_loss.argmin() + 1
            best_ckpt = self.checkpoint_path / f'ckpt-{best_idx}'
            self.trainer.checkpoint.restore(best_ckpt)
            print(f'loading {best_ckpt}')
            return True
        except:
            return False

    def solve(self, params, n_steps=None):
        if not isinstance(params, dict):
            params = dict(zip(self.prior.names, params))
        n_steps = self.n_steps if n_steps is None else n_steps
        fun = sy.lambdify([t, self.cls_all], self.sys_all.subs(params).flat())
        if self.use_jac:
            self.solver_kwargs['jac'] = sy.lambdify([t, self.cls_all], self.jac_all.subs(params))
        sol = sp.integrate.solve_ivp(fun=fun, y0=self.x0, t_span=[0,n_steps], t_eval=np.arange(n_steps)+1, **self.solver_kwargs)
        assert sol.status==0, f'{params}\n{sol}'
        sol.y[self.idx_dif] = np.diff(sol.y[self.idx_dif], axis=1, prepend=0)
        # sol.y[self.idx_dif] = np.diff(sol.y[self.idx_dif], axis=1, prepend=self.x0[self.idx_dif])
        sol.y[self.idx_obs] = self.prior.lognormal(np.log(sol.y[self.idx_obs]), params['sigma'])
        return sol.y

    def get_dfs(self, forward_dict):
        """convert arrays into dataframes"""
        dct = {
            'prior_draws':{'roll':-1, 'idx':'sample', 'col':self.prior.names},
            'post_draws' :{'roll':-1, 'idx':'sample', 'col':self.prior.names},
            'sim_data'   :{'roll': 1, 'idx':'step'  , 'col':self.cls_all},
        }
        for key, val in dct.items():
            if key in forward_dict:
                A = np.rollaxis(forward_dict[key], val['roll'])
                sh = A.shape
                B = pd.DataFrame(A.reshape(sh[0],-1).T, columns=val['col'][-sh[0]:]).rename_axis('draw')
                if len(sh) > 2:
                    B[val['idx']] = B.index % sh[-1]
                    B.index //= sh[-1]
                    B.set_index(val['idx'], append=True, inplace=True)
                forward_dict[key.split('_')[0]+'_df'] = B
        return forward_dict

    def validate(self, forward_dict):
        """remove parameter draws outside of range, remove data with negatives"""
        contract = lambda x: np.all(x, axis=tuple(range(1,x.ndim)))
        param = lambda key: (self.prior.param_q0 <= forward_dict[key]) & (forward_dict[key] <= self.prior.param_q4)
        data  = lambda key: forward_dict[key] > -EPS
        dct = {'prior_draws':param, 'post_draws':param, 'sim_data':data}
        masks = {key: contract(fcn(key)) for key, fcn in dct.items() if key in forward_dict}
        idx_keep = np.array(True)
        for mask in masks.values():
            idx_keep = idx_keep & mask
        for key in masks.keys():
            forward_dict[key] = forward_dict[key][idx_keep]
        masks['overall'] = idx_keep
        for key, mask in masks.items():
            n, d = np.sum(~mask), len(mask)
            if n > 0:
                print(f'rejecting {n} / {d} = {round(n/d*100)}% of {key}')
        return forward_dict

    def scale_param(self, x):
        if self.param_scaler in ['minmax']:
            x = (x - self.prior.param_q0) / self.prior.param_ran
        elif self.param_scaler in ['robust']:
            x = (x - self.prior.param_q2) / self.prior.param_iqr
        elif self.param_scaler in ['standard', 'z']:
            x = (x - self.prior.param_mean) / self.prior.param_std
        else:
            warnings.warn(f'unknown param_scaler "{self.param_scaler}"; no scaling attempted', UserWarning)
        return x
            
    def unscale_param(self, x):
        if self.param_scaler in ['minmax']:
            x = x * self.prior.param_ran + self.prior.param_q0
        elif self.param_scaler in ['robust']:
            x = x * self.prior.param_iqr + self.prior.param_q2
        elif self.param_scaler in ['standard', 'z']:
            x = x * self.prior.param_std + self.prior.param_mean
        else:
            warnings.warn(f'unknown param_scaler "{self.param_scaler}"; no scaling attempted', UserWarning)
        return x
        
    def scale_data(self, x):
        if self.data_scaler in ['minmax']:
            x = (x - self.prior.data_q0) / self.prior.data_ran
        elif self.data_scaler in ['robust']:
            x = (x - self.prior.data_q2) / self.prior.data_iqr
        elif self.data_scaler in ['standard', 'z']:
            x = (x - self.prior.data_mean) / self.prior.data_std
        elif self.data_scaler in ['log', 'log1p']:
            x = np.log1p(x)
        else:
            warnings.warn(f'unknown data_scaler "{self.data_scaler}"; no scaling attempted', UserWarning)
        return x

    def configurator(self, forward_dict, dfs=False):
        """prepare for input to amortizer"""
        forward_dict = self.validate(forward_dict)
        forward_dict['summary_conditions'] = self.scale_data(forward_dict['sim_data'])[:,self.idx_obs].astype(np.float32)
        if 'prior_draws' in forward_dict:
            forward_dict['parameters'] = self.scale_param(forward_dict['prior_draws']).astype(np.float32)
        return self.get_dfs(forward_dict) if dfs else forward_dict

    def defigurator(self, forward_dict, dfs=True):
        """post-process output from amortizer"""
        forward_dict['post_draws'] = self.unscale_param(forward_dict['post_samples'])
        return self.get_dfs(forward_dict) if dfs else forward_dict

    def sampler(self, forward_dict, n_samples, dfs=True):
        forward_dict = self.configurator(forward_dict)
        forward_dict['post_samples'] = self.trainer.amortizer.sample(forward_dict, n_samples=n_samples)
        return self.defigurator(forward_dict, dfs)

    def train(self, save_period=np.inf, overwrite=False, **kwargs):
        self.init_model(overwrite=overwrite)
        defaults = {
            'simulations_dict': {key:self.prior[key] for key in ['prior_draws','sim_data']},
            'epochs': 5000,
            'batch_size': 256,
            'save_checkpoint': True,
            'validation_sims': 1024,
            'reuse_optimizer': True,
            'early_stopping': False,
        }
        kwargs = defaults | kwargs  # lets kwargs overwrite defaults
        # train for save_period, then save, fix missing, & draw new validation set
        e = kwargs['epochs']
        while e > 0:
            kwargs['epochs'] = min(e, save_period)
            h = self.trainer.train_offline(**kwargs)
            self.fix_missing_loss(h['val_losses'].values[-1][0].round(7))
            e -= kwargs['epochs']

    def fix_missing_loss(self, y):
        """Fix a bug in BayesFlow 1.1.5 where last val_loss is not recorded"""
        w = {int(x.stem.split('_')[1]): x for x in self.checkpoint_path.iterdir() if 'history' in x.stem}
        for k, file in sorted(w.items(), reverse=True):
            D = read_pickle(file)
            H = D['val_history']
            H = H[max(H.keys())]
            if k > len(H):
                H[f'Epoch {k}'] = [y]
            L = D['_total_val_loss']
            if k > len(L):
                L.append(y)
            if k > 1:
                y = L[-2]
            write_pickle(D, file)

    def run_sim(self, name='sim', n_samples=50, plot=False, save=False, overwrite=False):
        file = self.results_path / f'{name}_data.pkl'
        self[name] = read_pickle(file, overwrite)
        if self[name] is None:
            self.load_best_checkpoint()
            self[name] = self.generative_model(batch_size=n_samples*20)
            self[name] = self.sampler(self[name], n_samples=n_samples)
            self[name]['name'] = name
            write_pickle(self[name], file)
        if plot:
            self.plot_ecdf(self[name], save)
            self.plot_recovery(self[name], save)

    def run_real(self, real_data, name, n_samples=500, n_subsample=100, n_steps=None, plot=False, save=False, overwrite=False):
        file = self.results_path / f'{name}_data.pkl'
        self[name] = read_pickle(file, overwrite)
        if self[name] is None:
            self.load_best_checkpoint()
            # format real_data as needed by armotizer
            R = pd.DataFrame(columns=self.cls_all).rename_axis('step')
            for key, val in real_data.items():
                R[key] = val
            S = R.fillna(0).values.T[np.newaxis]
            self[name] = {'real_data': real_data, 'sim_data': S}
            self[name] = self.sampler(self[name], n_samples=n_samples, dfs=False)
            self[name].pop('sim_data')
            self[name] = self.validate(self[name])
            self[name] |= self.generative_model.simulator(self[name]['post_draws'], n_steps=n_steps)
            self[name] = self.get_dfs(self.validate(self[name]))
            self[name]['all_df'] = pd.concat([self[name]['sim_df'], self[name]['real_data'].rename_axis('step').assign(draw=-1).reset_index().set_index(['draw','step'])])
            self[name]['name'] = name
            if n_subsample > 0:
                A = self[name]['post_df']
                A = A.sample(n=min(n_subsample, A.shape[0]))
                self[name]['post_subsample'] = A
                self[name]['sim_subsample'] = self[name]['sim_df'].loc[A.index]
            write_pickle(self[name], file)
        if plot:
            self.plot_update(self[name], save)
            self.plot_predictive(self[name], save)

    def plot_losses(self, save=True):
        self.init_model()
        H = self.trainer.loss_history.get_plottable()
        ub = H['train_losses'].quantile(0.98).values[0]
        H = {k:v.clip(upper=ub) for k,v in H.items()}
        bf.diagnostics.plot_losses(**H, moving_average=True)
        if save:
            plt.savefig(self.results_path / 'losses.png')
        plt.show()

    def plot_ecdf(self, forward_dict, save=True):
        bf.diagnostics.plot_sbc_ecdf(forward_dict['post_draws'], forward_dict['prior_draws'], param_names=self.prior.latex)
        if save:
            plt.savefig(self.results_path / f'{forward_dict["name"]}_ecdf.png')
        plt.show()

    def plot_recovery(self, forward_dict, save=True):
        bf.diagnostics.plot_recovery(forward_dict['post_draws'], forward_dict['prior_draws'], param_names=self.prior.latex)
        if save:
            plt.savefig(self.results_path / f'{forward_dict["name"]}_recovery.png')
        plt.show()

    def plot_update(self, forward_dict, save=True):
        pos = forward_dict['post_df'].assign(kind='posterior')
        pri = self.prior(2**13, 'df').assign(kind='prior')
        pri = self.prior['prior_df'].assign(kind='prior')
        Q = pd.concat([pri,pos])
        Q.columns = self.prior.latex+['kind']
        fig = sns.FacetGrid(Q.melt(id_vars='kind'), hue='kind', col='variable', col_wrap=3, sharex=False, sharey=False)
        fig.map(sns.histplot, 'value', kde=False, element='bars', alpha=0.5, stat='density')
        fig.set_titles(template = "{col_name}")
        fig.add_legend()
        if save:
            plt.savefig(self.results_path / f'{forward_dict["name"]}_update.png')
        plt.show()

    def plot_predictive(self, forward_dict, save=True, include=None, exclude=None):
        S = forward_dict['sim_df'].copy()
        R = forward_dict['real_data'].copy()
        if include is not None:
            inc = include if isinstance(include, list) else [include]
            S = S[inc]
        elif exclude is not None:
            exc = exclude if isinstance(exclude, list) else [exclude]
            S = S.drop(columns=exc)
        fig, ax = plt.subplots(S.shape[1]+0, 1, sharex=False, figsize=(20,20))
        G = S.groupby('step')
        for i, nm in enumerate(S.columns):
            for ci, clr in {90:'green', 50:'red', 10:'blue'}.items():
                x = (100 - ci) / 200
                lb = G[nm].quantile(x)
                ub = G[nm].quantile(1-x)
                ax[i].fill_between(x=lb.index, y1=lb, y2=ub, color=clr, alpha=0.3, label=f'{nm} {ci}%')
            if nm in self.cls_obs:
                ax[i].plot(R[nm], 'kx')
            ax[i].legend(loc='right')
            ax[i].set_title(nm)
            ax[i].plot()
        if save:
            plt.savefig(self.results_path / f'{forward_dict["name"]}_predictive.png')
        plt.show()

    def get_equ(self, params):
        S = self.sys_ode.subs(params)
        cmd = f"{' && '.join([f'{s} == 0' for s in S] + [f'{k} >= -{EPS}' for k in self.cls_ode])}, {{{', '.join([str(k) for k in self.cls_ode])}}}".replace('**', '^').replace('e-', '*10^-')
        cmd = f"Chop[N[Solve[{cmd}, Reals]]]"
        # cmd = f"Chop[NSolve[{cmd}, Reals, 10]]"
        with self.ses, contextlib.redirect_stderr(None):
            solutions = self.ses.evaluate(wlexpr(cmd))
        if len(solutions) == 0:
            print(params)
            print(cmd)
        else:
            return pd.DataFrame([self.get_eig(params, sol) for sol in solutions])

    def get_eig(self, params, sol):
        equ = params | {sy.symbols(str(k).replace('Global`','')): v for k, v in sol}
        J = sy.matrix2numpy(self.jac_ode.subs(equ), float)
        equ['eig'] = np.max(np.real(np.linalg.eigvals(J)))
        return equ

    def get_equilibria(self, draws):
        A = pd.concat([self.get_equ(params).assign(draw=draw) for draw, params in draws.to_dict('index').items()]).set_index('draw')
        B = A[self.cls_ode+['eig']].rename(columns=str)
        B['_stable']   = B['eig'] <= -EPS
        B['_unstable'] = B['eig'] >= EPS
        B['_unknown'] = ~(B['_stable'] | B['_unstable'])
        B = B.drop(columns='eig') > EPS
        A['label'] = (B.values * B.columns).to_series().str.join('').values

        L = pd.get_dummies(A['label'])
        targ = list(L.columns)
        G = pd.concat([A, L], axis=1).groupby('draw').agg({k:'max' for k in self.prior.names} | {k:'sum' for k in targ})
        C = G.groupby(targ).describe()
        C.insert(0,'n_equ', C[[]].reset_index().values.sum(axis=1))
        C.insert(1,'ct', C.iloc[:,1].astype(int))
        C.insert(2,'pct', (C['ct'] / C['ct'].sum() * 100).round(2))
        S = (
            C[[x for x in C.columns if 'count' not in x]]
            .sort_values(['ct','n_equ'], ascending=False)
            .set_index(['n_equ','ct','pct'], append=True)
        )
        return {'details':A, 'summary':S}

class Prior(BaseClass):
    def __init__(self, seed=None):
        self.rng = np.random.RandomState(seed)
        self.names = list(self.sample().keys())
        self.latex = [latex(key) for key in self.names]

    def __call__(self, batch_size=None, rtn='arr'):
        """Format sample - don't change"""
        s = self.sample(batch_size)
        return pd.DataFrame(s) if rtn=='df' else np.array(list(s.values())).T if rtn=='arr' else [dict(zip(s.keys(), z)) for z in zip(*s.values())]
    
    def sample(self, batch_size=None):
        """define prior"""
        return {
            r1: self.rng.uniform(0, 1, size=batch_size),  # treatment rates
            r2: self.rng.uniform(0, 1, size=batch_size),  # treatment rates
            rho: self.rng.uniform(0, 1, size=batch_size),  # probability the treatment works
            kappa: self.rng.uniform(0, 1, size=batch_size),  # time to become infectious
            beta1: self.rng.uniform(0, 5, size=batch_size),  # infection rates
            beta2: self.rng.uniform(0, 5, size=batch_size),  # infection rates
            Lamda: self.rng.uniform(500000, 1000000, size=batch_size),  # birth counts
            mu: self.rng.uniform(1/90, 1/50, size=batch_size),  # center at 1/70
            # sigma: np.abs(self.rng.standard_cauchy(size=batch_size))
        }

s, e, i, t, de, beta1, beta2, lamda, mu, kappa, rho, r1, r2 = sy.symbols('s e i t de beta_1 beta_2 lambda mu kappa rho r_1 r_2')
n = s + e + i + t
sys = {
    s: Lamda - beta1*s*i/n - mu*s,
    e: (beta1*s + beta2*t)*i/n - (mu+kappa+r1)*e + rho*r2*i,
    i: kappa*e - (r2+mu)*i,
    t: r1*e + (1-rho)*r2*i - beta2*t*i/n - mu*t,
    de: (beta1*s + beta2*t)*i/n,
}
us_data = pd.DataFrame({
    de:[25701,26283,26673,25107,24205,22727,21210,19751,18287,17500,16309,15945,15055,14835,14499,14068,13732,13286,12905,11537,11182,10528],
})

for param_scaler in ['minmax','standard','robust']:
    for data_scaler in ['minmax','standard','robust','log']:
        self = ODE(
            # overwrite=True,
            # name='TB_noisy',
            name='TB',
            prior=Prior(),
            n_steps=us_data.shape[0],
            sys=sys,
            x0={s:290000000, e:25000, i:25000, t:22000, de:0},
            cls_ode=[s,e,i,t],
            cls_obs=[de],
            cls_dif=[de],
            n_calibrate=2**20,
            param_scaler=param_scaler,
            data_scaler=data_scaler,
            
            # param_scaler='minmax',
            # param_scaler='standard',
            # param_scaler='robust',
            
            # data_scaler='minmax',
            # data_scaler='standard',
            # data_scaler='robust',
            # data_scaler='log',
        )
        print(param_scaler, data_scaler)
        try:
            self.plot_losses()
        except:
            self.train(
                overwrite=True,
                # save_period=1000,
                validation_sims=2**12,
                epochs=1000,
                batch_size=2**16,
            )
            self.plot_losses()
            self.run_sim(
                overwrite=True,
                n_samples=50,
                plot=True,
                save=True,
            )
            self.run_real(
                overwrite=True,
                name='us',
                n_samples=2**10,
                real_data=us_data,
                # n_steps=100,
                # n_subsample=10,
                plot=True,
                save=True,
        )
# E = self.get_equilibria(self.us['post_subsample'])

robust robust


Training epoch 1: 100%|█| 16/16 [00:12<00:00,  1.25it/s, Epoch: 1, Batch: 16,Loss: 6.996,W.Decay: 0.347,Avg.Loss: 7.93
Training epoch 2: 100%|█| 16/16 [00:02<00:00,  7.43it/s, Epoch: 2, Batch: 16,Loss: 6.438,W.Decay: 0.335,Avg.Loss: 6.72
Training epoch 3: 100%|█| 16/16 [00:02<00:00,  7.34it/s, Epoch: 3, Batch: 16,Loss: 5.877,W.Decay: 0.323,Avg.Loss: 6.16
Training epoch 4: 100%|█| 16/16 [00:02<00:00,  7.33it/s, Epoch: 4, Batch: 16,Loss: 5.359,W.Decay: 0.315,Avg.Loss: 5.56
Training epoch 5: 100%|█| 16/16 [00:02<00:00,  7.32it/s, Epoch: 5, Batch: 16,Loss: 5.030,W.Decay: 0.312,Avg.Loss: 5.18
Training epoch 6: 100%|█| 16/16 [00:02<00:00,  7.20it/s, Epoch: 6, Batch: 16,Loss: 4.893,W.Decay: 0.310,Avg.Loss: 5.02
Training epoch 7: 100%|█| 16/16 [00:02<00:00,  7.34it/s, Epoch: 7, Batch: 16,Loss: 4.717,W.Decay: 0.309,Avg.Loss: 4.81
Training epoch 8: 100%|█| 16/16 [00:02<00:00,  7.33it/s, Epoch: 8, Batch: 16,Loss: 4.675,W.Decay: 0.309,Avg.Loss: 4.72
Training epoch 9: 100%|█| 16/16 [00:02<00:00,  7

In [8]:
self.trainer.loss_history.get_plottable()

{}

time: 995 µs (started: 2023-11-10 12:04:37 -06:00)


In [None]:
k = 16
2**k, 28.1 / 2**11 * 2**k / 60