# training

> Training loop

In [None]:
#|default_exp tracking

In [None]:
#|hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from isaacai.utils import *
from isaacai.dataloaders import *
from isaacai.models import *
from isaacai.training import *
import inspect

from datetime import datetime
import torchvision.transforms.functional as TF,torch.nn.functional as F

import matplotlib.pyplot as plt,matplotlib as mpl
import fastcore.all as fc
import torch
from torch import nn, Tensor
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
import pandas as pd , numpy as np
from torcheval.metrics import MulticlassAccuracy,Mean

import sqlite3
from pathlib import Path
import pandas as pd

In [None]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'

import logging
logging.disable(logging.WARNING)

set_seed(42)
from IPython.display import clear_output

In [None]:
xmean,xstd = 0.28, 0.35
@inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]

_dataset = load_dataset('fashion_mnist').with_transform(transformi)
_dataset = sample_dataset_dict(_dataset)
dls = DataLoaders.from_dataset_dict(_dataset, 64, num_workers=4)
clear_output()

In [None]:
#| export
def sql_connect(path): return sqlite3.connect(path)

In [None]:
#| export
class TrackingCB:
    def __init__(self, db_path,exp_name):
        self._exp_name = exp_name
        self._con = sql_connect(db_path)
        self._log_epoch,self._log_batch, self._log_fit = [pd.DataFrame()]*3
        
    def before_fit(self,trainer):
        if getattr(trainer,'MetricsCB',None) is not None: setattr(getattr(trainer,'MetricsCB'),'_log',self._log)
        log = {'model':str(trainer.model),'model_type':str(type(trainer.model)),'model_source':inspect.getsource(trainer.model.__class__)}
        log['callbacks'] = str(trainer.callbacks)

        for callback in trainer.callbacks:
            cb = getattr(trainer,callback)
            cb_attrs = dir(cb)
            for cb_attr in cb_attrs:
                if cb_attr.startswith('_'): continue
                if not callable(getattr(cb,cb_attr)): log[cb_attr] = str(getattr(cb,cb_attr))
        self._log(log)
        
    def _log(self,x):
        x['insert_timestamp'] = datetime.now()
        x['exp_name'] = self._exp_name
        log = pd.DataFrame(x,index=[""])
        if 'batch' in x.keys(): self._log_batch = pd.concat([self._log_batch,log])
        elif 'epoch' in x.keys(): self._log_epoch = pd.concat([self._log_epoch,log])
        else: self._log_fit = pd.concat([self._log_fit,log])
        
    
    def after_fit(self, trainer):
        self._log_batch.to_sql('batch_stats', self._con, if_exists='append')
        self._log_epoch.to_sql('epoch_stats', self._con, if_exists='append')
        self._log_fit.to_sql('fit_stats', self._con, if_exists='append')


In [None]:
trainer = Trainer(dls,
                  nn.CrossEntropyLoss(), 
                  torch.optim.Adam, 
                  SimpleNet(28*28,64,10), 
                  callbacks=[MomentumTrainCB(.85),MetricsCB(accuracy=MulticlassAccuracy()), DeviceCB(),
                             TrackingCB(Path('../exp_tracker'),'test_exp1')])

In [None]:
trainer.fit()

In [None]:
trainer.TrackingCB._log_epoch

Unnamed: 0,accuracy,train_loss,valid_loss,epoch,elapsed,insert_timestamp,exp_name
,0.596,1.795199,1.267658,0,0 days 00:00:00.388058,2023-02-14 01:24:38.283150,test_exp1
,0.592,1.018322,1.117246,1,0 days 00:00:00.366744,2023-02-14 01:24:38.651088,test_exp1
,0.656,0.861313,0.956027,2,0 days 00:00:00.359158,2023-02-14 01:24:39.011897,test_exp1


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()