# Testing Grounds - Fit Loop
Creating a library that helps with the pytorch looping.

In [9]:
import time
import math

from uuid import uuid4
from pathlib import Path
from copy import deepcopy
from typing import Union, List, Callable, Optional, Any, Dict, Tuple

from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

In [22]:
class A():
    def __init__(self, f):
        self.__f = f
        self.a = 22
    
    def __getitem__(self, name):
        return name
        
    def __getattr__(self,name):
        return getattr(self.__f,name)
class B():
    def __init__(self):
        self.b = 33

LoopState
- ✅should cast the batch to device before passing it using `state.batch()`
- ✅should get the batch num `state.batch_num` and epoch num `state.epoch_num`
- ✅the model, optimizer, loss_function, lr_scheduler should be available
    `state.model`, `state.optimizer`, `state.loss_function`, `state.lr_scheduler`
- ✅should return the batch metrics as float tensors using square bracket indexing
    `state['loss']` 
    - every step function hook receives the LoopState object.
- ✅The loop state object should have a copy of all the values returned from the function hook
- ✅example the below returned dict values should be avialable in the LoopState object

```python
def train_step(state):
    X,y = state() # should device cast automatically
    y_ = state.model(X)
    loss = state.loss_function(y_, y)
    
    state.optimizer.zero_grad()
    loss.backward()
    state.optimizer.step()
    state.lr_scheduler.step() 
    
    loss = loss.item()
    batch_loss = loss * y.size()
    batch_corr = (y_.argmax(dim=0) == y).sum().float().item()
    
    return {'loss':loss,'batch_loss':batch_loss:'batch_corr'}
```
- The LoopState object should be cleared of the above values at the start 
  of the next epoch.
- The returned values should be available through the FitLoop object
  Eg: `FitLoop.train.batch['loss']`
- The (above statement) returned value should be optionally available by setting the flag 
  `store_batch_metrics`

```python
def train_epoch_end(state):
    loss = state['loss']
    batch_loss = state['batch_loss']
    batch_corr = state['batch_corr']
    
    size = state.size
    
    epoch_loss = batch_loss.sum().item()/size
    epoch_accu = batch_corr.sum().item()/size
    
    return {"loss":epoch_loss,"accu":epoch_accu}
```

- The returned values should be available through the FitLoop object
  Eg: `FitLoop.train.epoch['loss']`
- ✅For each phase a different LoopState obect is maintained.



In [14]:
class FitLoop:
    # Dummy
    pass


In [237]:
a = []

a.append(2)
a.append(3)

a[-1]

3

In [245]:
class LoopState:
    """
    Maintains train/valid/test loop state for a single run of 
    a certain number of epochs, does not used to preserve state 
    between runs.
    """
    _stages = ['batch','epoch_start','epoch_end']
    _batch_step, _epoch_start, _epoch_end = _stages
    def __init__(self, phase:str, floop:FitLoop, no_cast:bool, no_float:bool, is_train:bool, is_test:bool):
        """
        phase : phase name 'train', 'valid' or 'test'
        floop : the calling FitLoop object
        """
        self.__batch = ()
        self.__floop = floop
        self._no_cast = no_cast
        self._no_float = no_float
        self.phase = phase
        self.batch_num = 0
        self.epoch_num = 0
        self.metrics = {s:{} for s in self._stages}
        self.is_train = is_train
        self.is_test = is_test
        
        # For easy access
        dl = getattr(floop, f'{phase}_dl')
        bs = dl.batch_size
        dr = dl.drop_last
        sz = len(dl.dataset)
        bt = sz / bs
        
        # Gives dataset size and batch count
        self.size = sz
        self.batches = math.floor(bt) if dr else math.ceil(bt)
        self.batch_size = 0
    
    def __getattr__(self, name:str) -> Any:
        # To get attributes from the FitLoop object 
        # for use in the stage functions.
        return getattr(self.__floop, name)
    
    def __getitem__(self, metric_name:str):
        # To get the metrics stored in the batch step stage
        metric_value = self.metrics[self._batch_step][metric_name]
        try:
            return torch.tensor(metric_value).float()
        except:
            return metric_value
    
    """
    Getter and setter for the current batch
    """
    @property
    def batch(self) -> Tuple[Tensor,...]:
        if self._no_cast:
            return self.__batch
        
        return (
            d.to(device=self.device,dtype=self.dtype) 
            if d.is_floating_point() 
            else d.to(device=self.device,dtype=torch.long) 
            for d in self.__batch
        )
    
    @batch.setter
    def batch(self, current_batch:Tuple[Tensor,...]) -> None:
        self.__batch = current_batch
        
    """
    Functions to append rdict values to self.metrics
    """
    def _append(self, rdict:Dict[str, float], stage:str) -> None:
        #  Append metrics to the specific stage.
        if rdict is None:
            if stage == self._epoch_end:
                print(f"no rdict returned from: f{self.phase}_{stage}")
            """
            TODO: Add warning if rdict of stage is None
            """
            return
        
        for key in rdict:
            if key not in self.metrics[stage]:
                self.metrics[stage][key] = []
            self.metrics[stage][key].append(rdict[key])
            
    def _append_batch_step(self, rdict:Dict[str, float]) -> None:
        # Called after batch step rdict is returned
        self._append(rdict, self._batch_step)
        
    def _append_epoch_start(self, rdict:Dict[str, float]) -> None:
        # Called before epoch start
        self._append(rdict, self._epoch_start)
        
    def _append_epoch_end(self, rdict:Dict[str, float]) -> None:
        # Called after epoch end step rdict is returned
        self._append(rdict, self._epoch_end)
    
        
    """
    Functions to clear rdict values from self.metrics
    """
    def _clear(self, stage:str) -> None:
        # Clear the batch metrics at the end of the batch.
        for mlist in self.metrics[stage]:
            self.metrics[stage][mlist].clear()
            
    def _clear_batch_step(self) -> None:
        # Called before epoch start
        self._clear(self._batch_step)
        
    def _clear_epoch_start(self) -> None:
        # Called ??
        self._clear(self._epoch_start)
        
    def _clear_epoch_end(self) -> None:
        # Called after loop end
        self._clear(self._epoch_end)
    
    """
    State updates before epoch start and batch step stages
    """
    def _pre_epoch_start_update(self, epoch_num:int) -> None:
        self._clear_batch_step()
        self.batch_num = 0
        self.epoch_num = epoch_num
    
    def _pre_batch_step_update(self, current_batch):
        self.batch_size = current_batch[0].size(0)
        self.batch_num += 1
        self.batch = current_batch
    
    """
    Functions to get various metrics at different stages 
    """
    def _get_epoch_metric(self, criteria:str) -> float:
        # Last added metric that is to be used as a model 
        # selection criteria
        metric = self.metrics[self._epoch_end][criteria][-1]
        if self._no_float:
            return metric
        else:
            return float(metric)
    
    def _get_epoch_metrics(self, 
                display_metrics:Optional[Union[str,List[str]]]=None
                ) -> Dict[str,float]:
        # Return the last saved epoch metrics
        if isinstance(display_metrics, str):
            return {display_metricss:self._get_epoch_metric(display_metrics)}
        elif isinstance(display_metrics, list):
            return {
                metric:self._get_epoch_metric(metric)
                for metric in display_metrics
            }
        else:
            return {
                metric: self._get_epoch_metric(metric)
                for metric in self.metrics[self._epoch_end]
            }

FitLoop
- all stage functions except for `train_step` should be optional that's the only one that is required for training the model, rest all are for metric keeping.
- ✅if `FitLoop.fit(define_all:bool=False)` the zero_grad and the context manager are not auto set.
- ✅should keep track of epochs that have been completed
- ✅epoch_number can be reset 
- metrics can be cleared
- `FitLoop.set_name.loop_stage['metric_name']` to access the metric
- `FitLoop.store_pretrained:bool` arg to store the pretrained weights before training
    if path then store at given path else store in memory.
- ✅`FitLoop.reset(reset_model:bool)` to clear metrics, epoch_num and to reset the model, to pretrained state
    will load the weight from passed path else from memory.
- `FitLoop.save(path:str)` to save the model and training state somehow even the fitloop state.
- `FitLoop.load(path:str)` to load the FitLoop state from given path.
- Some basic step should be used such that one can use it without defining step functions.
- `FitLoop.fit(continue_loop:int=0)` ask after `int` whether to continue training or to end.
- `FitLoop.fit(profiler:bool=False)` mode to capture all stage timings and maybe even CPU, GPU, RAM usage to check for bottlenecks and usage spikes, to be used with timed_test.
- Make it easy to train,validate, test with some other DataLoader that is not attached to the object.
- `FitLoop.fit(no_print:bool=False)` mode to capture all stage timings and maybe even CPU, GPU, RAM usage to check for bottlenecks and usage spikes, to be used with timed_test.
- Use a loading bar for epoch and a disappearing one for batch.
- Functionality to view the metrics.
- ✅Model score should be a loop instance so that the best model may not be erased.
- Time keeping/ metric keeping:
    - Profiler:
        - Individual Stage Timings
        - Individual Stage CPU Usage
        - Individual Stage GPU Usage
        - Individual Stage RAM Usage
        - For these Stages:
            - Batch step for each phase
            - Epoch start step for each phase
            - Epoch end step for each phase
    - General
        - Metrics returned in the batch step
        - Metrics returned in the end step
        - Progress bar for epoch
        - Progress bar for batch that disappears after complete
        - ✅ Epoch timing (for both phases when training)
        - ✅ Total timing 
- Default Trai

In [37]:
class FitLoopDefault:
    def train_step(state):
        print("default train_step")
        return {}

    def valid_step(state):
        print("default valid_step")
        return {}

    def test_step(state):
        print("default test_step")
        return {}
        
    def train_epoch_end(state):
        print("default train_epoch_end")
        return {}

    def valid_epoch_end(state):
        print("default valid_epoch_end")
        return {}

    def test_epoch_end(state):
        print("default test_epoch_end")
        return {}

In [42]:
class Metrics:
    """
    Class to keep track of all the metrics and should have
    visualization for the metrics.
    """

## FitLoop

### Helpers

In [228]:
def ftime(t1:float,t2:float)->str:
    t = t2-t1
    s, ms = str(t).split('.')
    ms = ms[:3]
    s = str(int(s)%60).rjust(2)
    m = int(t//60)
    h = str(m//60)
    m = str(m % 60).rjust(2)
    u = ['h','m','s','ms']
    v = [h,m,s,ms]
    t = filter(lambda x:int(x[0]) > 0,zip(v,u))
    return ' '.join([' '.join(x) for x in t]).ljust(20)

### The Class

In [270]:
class FitLoop:
    """
    class that helps in training pytorch models.
    """
    
    # ---------------------------------------------------------------------
    """
    SECTION: 0 
    
    Initialization
    """
    _sets = ['train','valid','test']
    _TR, _VA, _TE = _sets
    
    _model_type = ['pretrained','best']
    _PR, _BS = _model_type
    def __init__(self, 
                 # Basic Blocks
                 model: Module, 
                 optimizer: Union[Optimizer,List[Optimizer]], 
                 loss_function: Callable[[Tensor,Tensor],Tensor], 
                 
                 # DataLoader
                 train_dl: DataLoader, 
                 valid_dl: Optional[DataLoader]=None, 
                 test_dl: Optional[DataLoader]=None, 
                 
                 # Batch Step
                 train_step: Callable[[LoopState],Dict[str, Any]]=FitLoopDefault.train_step,
                 valid_step: Optional[Callable[[LoopState],Dict[str, Any]]]=FitLoopDefault.valid_step,
                 test_step: Optional[Callable[[LoopState],Dict[str, Any]]]=FitLoopDefault.test_step,
                 
                 # Epoch Start Step
                 train_epoch_start: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 valid_epoch_start: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 test_epoch_start: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 
                 # Epoch End Step
                 train_epoch_end: Callable[[LoopState],Dict[str, Any]]=FitLoopDefault.train_epoch_end,
                 valid_epoch_end: Optional[Callable[[LoopState],Dict[str, Any]]]=FitLoopDefault.valid_epoch_end,
                 test_epoch_end: Optional[Callable[[LoopState],Dict[str, Any]]]=FitLoopDefault.test_epoch_end,
                 
                 # Other Args
                 lr_scheduler: Optional[Union[_LRScheduler, Any, List[Union[_LRScheduler,Any]]]]=None,
                 device: torch.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 
                 configure_optimizer:Callable[[FitLoop],None]=None,
                 dtype: torch.dtype=torch.float32,
                 
                 # Model Evaluation
                 criteria: Optional[str]=None,
                 criteria_direction: int=1,
                 
                 # Preserving Model State
                 save_to_disk: bool=False,
                 save_path: str="models",
                 pretrained_model_name: Optional[str]=None,
                 best_model_name: Optional[str]=None,
                ) -> None:
        """
        FitLoop constructor
        ----
        Parameters:
        # Basic Blocks
            The bare minimum required along with train_dl.
            - model : nn.Module model that has to be trained
            - optimizer : an optimizer from torch.optim
            - loss_function : function to compute loss
         
        # DataLoader
            - train_dl : training DataLoader
            - valid_dl : validation DataLoader, if None validation will be ignored
            - test_dl : testing DataLoader, if None `.test()` will not run
         
        # Batch Step
            Functions that take in a LoopState object to perform 
            required calculations, functions should return a dict with values
            to be used in the epoch end step.
            - train_step : portion of the loop where forward and backward 
                passes take place.
            - valid_step : validation portion of the loop.
            - test_step : called when `FitLoop.test()` is called.
        
        # Epoch Start Step
            TODO: NEED TO IMPLEMENT
            - train_epoch_start :
            - valid_epoch_start :
            - test_epoch_start : 
        
        # Epoch End Step
            Functions that take in a LoopState object to perform 
            required calculations, functions should return a dict with values
            that are to be returned when the loop is over.
            - train_epoch_end : after training epoch has ended.
            - valid_epoch_end : after validation epoch has ended.
            - test_epoch_end : called when the test loop is done, one iteration
                over all batches in the test dataloader.
        
        # Other Args
            - lr_scheduler : scheduler from torch.optim.lr_scheduler
            - device : torch.device model will be cast to device this prior to the loop
            - configure_optimizer : function that configures the optimizer, will be called
                whenever the model weights have to be restored.
            - dtype : floating point dtype to cast model and data to
            
        # Model Evaluation
            - criteria : model evaluation metric that is returned in the dict of the
                `valid_epoch_end` stage function if None (default) best model and 
                best score are not tracked.
            - criteria_direction : whether more is better (1) or less is better (-1) 
                for model score criteria.
        
        # Preserving Model State
            - save_to_disk : True then save pretrained and best_model to the disk, else it is 
                stored as an attribute.
            - save_path : location where the initial and pretrained models are to be saved
            - pretrained_model_name : Name to save the pretrained model by
            - best_model_name : Name to save the best model by
        """
        # Basic Blocks
        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_function
        
        # DataLoaders
        self.train_dl = train_dl
        self.valid_dl = valid_dl
        self.test_dl = test_dl
        
        # Batch Step
        self.train_step = train_step
        self.valid_step = valid_step
        self.test_step = test_step
        
        # Epoch Start Step
        self.train_epoch_start = train_epoch_start
        self.valid_epoch_start = valid_epoch_start
        self.test_epoch_start = test_epoch_start
        
        # Epoch End Step
        self.train_epoch_end = train_epoch_end
        self.valid_epoch_end = valid_epoch_end
        self.test_epoch_end = test_epoch_end
        
        # Other Args
        self.lr_scheduler = lr_scheduler
        self.device = device
        self.configure_optimizer = configure_optimizer
        self.dtype = dtype
        
        # Model Evaluation
        self.criteria = criteria
        self.criteria_direction = criteria_direction
        
        # Preserving Model State
        if pretrained_model_name is None:
            u = str(uuid4()).split('-')[1]
            pretrained_model_name = f"pretrained_{u}.pt"
        if best_model_name is None:
            u = str(uuid4()).split('-')[1]
            best_model_name = f"best_{u}.pt"
        self.pretrained_model_name = pretrained_model_name
        self.best_model_name = best_model_name
        self.save_to_disk = save_to_disk
        self.save_path = Path(save_path)
        
        
        # INITIALIZE NON ARGS
        self.__save_model(self._PR)
        self.epoch_num = 0
        self.best_score = self.criteria_direction * float('-inf')
        self.profiler = {}
        self.pretrained_model_state_dict = None
        self.best_model_state_dict = None
            
            
    # ---------------------------------------------------------------------
    """
    SECTION: 1
    
    Helper functions used in `__loop`
    """
        
    
    def __call_batch_step(self, state:LoopState) -> None:
        step_funcs = [self.train_step, self.valid_step, self.test_step]
        step_funcs = {s:f for s,f in zip(self._sets, step_funcs)}
        step_func = step_funcs[state.phase]
        
        if step_func is None:
            raise AttributeError(f"{phase}_step not assigned")
        rdict = step_func(state)
        if isinstance(rdict,dict):
            state._append_batch_step(rdict)
        
    def __call_epoch_start_step(self, state:LoopState) -> None:
        step_funcs = [self.train_epoch_start,self.valid_epoch_start,self.test_epoch_start]
        step_funcs = {s:f for s,f in zip(self._sets, step_funcs)}
        step_func = step_funcs[state.phase]
        
        if step_func is None:
            return None
        rdict = step_func(state)
        if isinstance(rdict,dict):
            state._append_epoch_start(rdict)
        
    def __call_epoch_end_step(self, state:LoopState) -> None:
        step_funcs = [self.train_epoch_end,self.valid_epoch_end,self.test_epoch_end]
        step_funcs = {s:f for s,f in zip(self._sets, step_funcs)}
        step_func = step_funcs[state.phase]
        
        if step_func is None:
            raise AttributeError(f"{phase}_end_step not assigned")
        rdict = step_func(state)
        if isinstance(rdict,dict):
            state._append_epoch_end(rdict)
        
    def __get_dl(self, is_test:bool)-> Dict[str,DataLoader]:
        if is_test:
            if self.test_dl is None:
                raise AttributeError("test_dl not assigned")
            return {self._TE:self.test_dl}
        
        va_dl = self.valid_dl is not None
        if  va_dl:
            return {self._TR:self.train_dl, self._VA:self.valid_dl}
        else:
            return {self._TR:self.train_dl}
    
    def __profile_time(self,t1,t2,name):
        # TODO: adds to profiler
        if t1 is None or t2 is None:
            return
    
    def __profile_other(self,val,name):
        # TODO: Profiler for other metrics
        pass
        
        
    # ---------------------------------------------------------------------
    """
    SECTION: 2
    
    The main loop function __loop 
    """
    
    def __loop(self, 
            epochs:int=1,  print_every:int=1, 
            steps: Optional[int]=None, load_best:bool=False, 
            profiler:bool=False, is_test:bool=False,
            track_batch_metrics:bool=True, define_all:bool=False,
            continue_loop:int=0, no_print:bool=False, no_cast:bool=False,
            display_metrics:Optional[Union[str,List[str]]]=None, no_float:bool=False,
            is_sanity_check:bool=False 
           ) -> None:
        """
        Runs the training loop for `epochs`
        ----
        PARAMETERS
         - epochs : should be a non negative integer
         - print_every : if 0 will not print, else will print at given epoch
         - steps : number of batches to run in each phase [train,valid] 
             for check if everything is working.
         - load_best : whether to load the best model after training, works only if validation
             parameters are defined `valid_dl`, `valid_step`, `valid_epoch_end`
         - profiler : whether to keep track of time taken by various sections
         - is_test : whether it is a model testing loop or training/validation loop
         - track_batch_metrics : whether to store the values returned in the batch steps
         - define_all : If True then `torch.set_grad_enabled`, `optimizer.zero_grad` and model mode 
             ie [train,eval] have to be called where required (usually in the `train_step` function).
         -  continue_loop : Will ask whether to continue training after `continue` epochs, should
             be a positive integer.
         - no_print : If True will suppress all print statements, can be used when custom logging is
             used in the stage functions.
         - no_cast : True, if model and data casting has to be manually set in the stage functions
         - display_metrics : List of metrics returned in the epoch_end stage rdict that has to be 
             displayed, if None (default) all the returned metrics are displayed.
         - no_float : True don't apply float conversion to returned metrics.
         - is_sanity_check : For sanity check mode.
        
        """
        time_ = lambda p : time.time() if p else None
        tpe = lambda : time_(print_every != 0) # Returns the time 
        tpr = lambda : time_(profiler) # Times keeping used by profiler
        total_time_start = tpe()
        
        # INITILIZING VARIABLES -----
        is_train = not(is_test or is_sanity_check or profiler)
        
        # Storage
        dl = self.__get_dl(is_test)
        sz = { k : len(dl[k].dataset) for k in dl }
        phases = [ph for ph in dl]
        state = {ph: LoopState(ph,self,no_cast,no_float,is_train, is_test) for ph in phases}

        # Markers
        self.__save_model(self._BS)

        # ----------------------------
        
        
        # CONVENIENCE FUNCTIONS ------
        
        # Function to get formatted epochs (from 1 not 0)
        r_just_val = len(str(epochs))*2 + 3
        estr = lambda e: f"[{e + 1}/{epochs}]".rjust(r_just_val)

        # Function to print every `print_every` epochs.
        def eprint(e,st):
            if (e == 0) and (print_every != 0):
                print(st,end="")
            elif (e + 1) % print_every == 0:
                print(st,end="")

        def grad_times(t):
            t = t*1000
            return f"{t:0.3f} ms".rjust(10)

        # Function for phase strings.
        def statstr(phase, loss, accu, time, infr, rjust=True):
            infr = grad_times(infr)
            st =  f"{phase} :: loss: {loss:0.4f} | accu: {accu:0.4f} | infr: {infr}  | time: {time} \n"
            if rjust:
                return st.rjust(r_just_val + len(st) + 3)
            else:
                return st
            
        # ----------------------------


        # THE LOOP - START -----------
        
        """
        TODO : Preloop section, to initilize parameters in whichever way, add profiling.
            Manual cast for below line
        """
        if not no_cast:
            model = self.model.to(device=self.device, dtype=self.dtype)
        for e in range(epochs):
            epoch_time_start = tpe()
            
            # Update FitLoop epoch_num
            if not is_sanity_check and not profiler and not is_test:
                self.epoch_num += 1
            
            for phase in phases:
                
                # Update LoopState: batch_num, metrics['batch'], epoch_num
                state[phase]._pre_epoch_start_update(e)
                
                # EPOCH START STEP - START 
                self.__call_epoch_start_step(state[phase])
                # EPOCH START STEP - END 
                
                is_tr = phase == self._TR
                if not define_all:
                    if is_tr:
                          model.train()
                    else:
                          model.eval()
                            
                # if is_tr:
                #    eprint(e,estr(e)+f" - ")
                    
                    
                # BATCH LOOP - START 
                for batch in dl[phase]:
                    
                    # Update LoopState: batch_num, batch and batch_size
                    state[phase]._pre_batch_step_update(batch)
                    
                    # BATCH STEP - START 
                    if define_all:
                        self.__call_batch_step(state[phase])
                    else:
                        if isinstance(self.optimizer,list):
                            for opt in self.optimizer:opt.zero_grad()
                        else:
                            self.optimizer.zero_grad()
                        with torch.set_grad_enabled(is_tr):
                            self.__call_batch_step(state[phase])
                    # BATCH STEP - END 
                # BATCH LOOP - END 
                
                # EPOCH END STEP - START 
                self.__call_epoch_end_step(state[phase])
                # EPOCH END STEP - END 
                
                # UPDATE MARKERS
                if not (is_tr or is_test or profiler or is_sanity_check) and self.criteria is not None:
                    score = state[phase]._get_epoch_metric(self.criteria)
                    direc = self.criteria_direction > 0
                    is_better = (score > self.best_score) if direc else (score < self.best_score)
                    if is_better:
                        self.best_score = score
                        self.__save_model(self._BS)

#                 if is_tr:
#                     eprint(e,statstr(phase, phase_loss,phase_accu, phase_time, mean_inf, False))
#                 else:
#                     eprint(e,statstr(phase, phase_loss,phase_accu, phase_time, mean_inf))

                # Display the epoch metrics after the phase end
                epoch_metrics = state[phase]._get_epoch_metrics(display_metrics)
                # Replace later
                print(phase, epoch_metrics)
            
            # Logging epoch times
            epoch_time_end = tpe()
            epoch_time = ftime(epoch_time_start, epoch_time_end)
            
            # Replace later
            print('epoch_time:',epoch_time)
#             epoch_time = f"epoch time: {ftime(epoch_time_start, epoch_time_end)}\n"
#             eprint(e,epoch_time.rjust(len(epoch_time)+r_just_val+3)+"\n")

        # THE LOOP - END -------------
    
        # ----------------------------
        total_time_end = tpe()
        total_time = ftime(total_time_start,total_time_end)
        
        # Replace later
        print('total_time:',total_time)
        
#         print(f"least loss: {least_loss:0.4f} | best accu: {best_accuracy:0.4f} | total time taken: {total_time}")

        if load_best or profiler or is_sanity_check:
            self.__load_model(self._BS)


        # __loop - END ---------------

    
    # ---------------------------------------------------------------------
    """
    SECTION: 3
    
    Loop methods that are called by the FitLoop user.
    """
    
    def fit(self, 
            epochs:int=1, print_every:int=1,
            display_metrics:Optional[Union[str,List[str]]]=None,
            track_batch_metrics:bool=True, load_best:bool=True,
            continue_loop:int=0, define_all:bool=False,  
            no_print:bool=False, no_cast:bool=False,
            no_float:bool=False
           ) -> None:
        """
        Runs the training loop for `epochs`
        ----
        PARAMETERS
         - epochs : should be a non negative integer
         - print_every : if 0 will not print, else will print at given epoch
         - display_metrics : List of metrics returned in the epoch_end stage rdict that has to be 
             displayed, if None (default) all the returned metrics are displayed.
         - track_batch_metrics : whether to store the values returned in the batch steps
         - load_best : whether to load the best model after training, works only if validation
             parameters are defined `valid_dl`, `valid_step`, `valid_epoch_end`
         -  continue_loop : Will ask whether to continue training after `continue` epochs; should
             be a positive integer.
         - define_all : If True then `torch.set_grad_enabled`, `optimizer.zero_grad` and model mode 
             ie [train,eval] have to be called where required (usually in the `train_step` function).
         - no_print : If True will suppress all print statements, can be used when custom logging is
             used in the stage functions.
         - no_cast : True, if model and data casting has to be manually set in the stage functions
         - no_float : True don't apply float conversion to returned metrics.
        
        """
        self.__loop(epochs=epochs, print_every=print_every,
                   display_metrics=display_metrics, track_batch_metrics=track_batch_metrics,
                   load_best=load_best, continue_loop=continue_loop, define_all=define_all,
                   no_print=no_print, no_cast=no_cast, no_float=no_float)
    
    def train(self, *args, **kwargs):
        """
        Alias for FitLoop.fit
        """
        self.fit(*args, **kwargs)
    
    def test(self, no_print:bool=False, no_cast:bool=False, no_float:bool=False) -> None:
        """
        For model testing. Runs loop for one epoch using test DataLoader and test stage functions.
        ----
        PARAMETERS
         - no_print : If True will suppress all print statements, can be used when custom logging is
             used in the stage functions.
         - no_cast : True, if model and data casting has to be manually set in the stage functions
         - no_float : True don't apply float conversion to returned metrics.
        """
        self.__loop(is_test=True, no_print=no_print, no_cast=no_cast, no_float=no_float)
    
        
    def run_profiler(self,
            steps: Optional[int]=None, define_all:bool=False,
            no_cast:bool=False, no_float:bool=False
            ) -> None:
        """
        Runs the loop in profiler mode, ie run all three (train, valid, test) stages 
        (if set) for a single epoch and given number of steps and profile the time taken 
        at different stages along with resource utilization.
        
        Model state is not altered (it's reloaded) if the profiler is not interrupted.
        ----
        PARAMETERS
         - steps : number of batches to iterate over in each phase [train,valid,test] 
             to check if everything is working as expected, if None then all batches are
             iterated over.
         - define_all : If True then `torch.set_grad_enabled`, `optimizer.zero_grad` and model mode 
             ie [train,eval] have to be called where required (usually in the `train_step` function).
         - no_cast : True, if model and data casting has to be manually set in the stage functions
         - no_float : True don't apply float conversion to returned metrics.
        """
        
        # TODO : Implement this in __loop
        print("NOT IMPLEMENTED")
        return
        
        self.__loop(steps=steps, define_all=define_all, no_cast=no_cast, 
                    no_float=no_float, no_print=True, profiler=True)
        self.__loop(steps=steps, define_all=define_all, no_cast=no_cast, 
                    no_float=no_float, no_print=True, profiler=True, is_test=True)
    
    def sanity_check(self, use_test_dl=False,
            epochs:int=1, steps:int=5, print_every:int=1,
            display_metrics:Optional[Union[str,List[str]]]=None,
            continue_loop:int=0, define_all:bool=False,  
            no_print:bool=False, no_cast:bool=False,
            no_float:bool=False
           ) -> None:
        """
        Runs the loop in sanity check mode, ie all three (train, valid, test) stages 
        (if set) for given number of epochs and steps.
        Model state is not altered (it's reloaded) if the sanity check is not interrupted.
        ----
        PARAMETERS
         - use_test_dl : If False will use the validation DataLoader for the test phase,
             else will use the test DataLoader.
         - epochs : should be a non negative integer
         - steps : number of batches to run in each phase [train,valid] 
             for check if everything is working.
         - print_every : if 0 will not print, else will print at given epoch
         - display_metrics : List of metrics returned in the epoch_end stage rdict that has to be 
             displayed, if None (default) all the returned metrics are displayed.
         -  continue_loop : Will ask whether to continue training after `continue` epochs, should
             be a positive integer.
         - define_all : If True then `torch.set_grad_enabled`, `optimizer.zero_grad` and model mode 
             ie [train,eval] have to be called where required (usually in the `train_step` function).
         - no_print : If True will suppress all print statements, can be used when custom logging is
             used in the stage functions.
         - no_cast : True, if model and data casting has to be manually set in the stage functions
         - no_float : True don't apply float conversion to returned metrics.
        """
        
        # TODO : Implement this in __loop
        print("NOT IMPLEMENTED")
        return
    
        self.__loop(epochs=epochs, steps=steps, print_every=print_every, 
                    display_metrics=display_metrics, continue_loop=continue_loop,
                    define_all=define_all, no_print=no_print, no_cast=no_cast, 
                    no_float=no_float, is_sanity_check=True)
        self.__loop(use_test_dl=use_test_dl, epochs=epochs, steps=steps, print_every=print_every, 
                    display_metrics=display_metrics, continue_loop=continue_loop,
                    define_all=define_all, no_print=no_print, no_cast=no_cast, 
                    no_float=no_float, is_sanity_check=True, is_test=True)
    
    
    # ---------------------------------------------------------------------
    """
    SECTION: 4
    
    Functions to preserve the model state.
    """
    
    def __save_model(self, typ:str) -> None:
        """
        Save model to object or to the disk.
        """
        name = self.best_model_name if typ == self._BS else self.pretrained_model_name
        path = self.save_path/ name
        state_dict = self.model.state_dict()
        if self.save_to_disk:
            torch.save(state_dict,path)
        elif typ == self._BS:
            self.best_model_state_dict = deepcopy(state_dict)
        else:
            self.pretrained_model_state_dict = deepcopy(state_dict)
        
    def __load_model(self, typ:str):
        """
        Load model from the object or from the disk.
        """
        name = self.best_model_name if typ == self._BS else self.pretrained_model_name
        path = self.save_path/ name
        if self.save_to_disk:
            state_dict = torch.load(path, map_location=self.device)
        elif typ == self._BS:
            state_dict = self.best_model_state_dict
        else:
            state_dict = self.pretrained_model_state_dict
        self.model.load_state_dict(state_dict)
        if self.configure_optimizer is None:
            print("please reconfigure FitLoop.optimizer before training")
        else:
            self.configure_optimizer(self)
    
    def reset(self, reset_model:bool=True) -> None:
        """
        Resets FitLoop to initial state.
        Parameters reset:
            - model, to pretrained state if `reset_model`
            - epoch_num, to 0
            - best_score to ∓inf
        FitLoop.optimizer param groups will have to be set again
        """
        if reset_model:
            self.__load_model(self, self._PR)
        self.epoch_num = 0
        self.best_score = self.criteria_direction * float('-inf')
        
        
    # ---------------------------------------------------------------------
    """
    SECTION: 5
    
    Functions to preserve the FitLoop object state so that training can be resumed.
    """
    
    def save(self, path, only_model=False):
        """
        TODO : save the FitLoop state, if only_model then save only model.
        """
        pass
    
    def load(self, path):
        """
        TODO : load the FitLoop state, if only model then load the model 
            state dict.
        """
        pass
    
    
    # ---------------------------------------------------------------------
    """
    SECTION: 6
    
    Functions to delete stored model weights.
    """
    
    def del_pretrained(self) -> None:
        """
        Deletes the pretrianed model state dict from the disk if 
        `save_to_disk` else states attribute to None
        """
        if self.save_to_disk:
            (self.save_path/self.pretrained_model_name).unlink()
        else:
            self.pretrained_model_state_dict = None
        
    def del_best_model(self) -> None:
        """
        Deletes the best model state dict from the disk if 
        `save_to_disk` else states attribute to None
        """
        if self.save_to_disk:
            (self.save_path/self.best_model_name).unlink()
        else:
            self.best_model_state_dict = None


## Testing

### Model Setup

In [200]:
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor

device = torch.device('cpu')

In [271]:
"""
Basic setup with FakeData for testing.
"""

sets = ['train','valid','test']
TR, VA, TE = sets
class_names = ['a','b','c','d']
num_classes = 4
batch_size = 4

sz = {s:z*batch_size for s, z in zip(sets,[5,2,3])} # a multiple of batch size
ds = {s:FakeData(size=sz[s], transform=ToTensor(), num_classes=num_classes) for s in sets}
dl = {s:DataLoader(ds[s],batch_size=batch_size) for s in ds}

model = resnet18()
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)

loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())



#### Test Calculations

In [112]:
print("get batch")
%time X,y = next(iter(dl[TR]))
print(X.size(),y.size())

print("\ninference")
%time y_ = model(X)
print(y_.size())

print("\nloss")
%time loss = loss_function(y_,y)
print(loss)

print("\naccu")
accu = (y_.argmax(dim=0)==y).sum().item()
print(accu)

get batch
CPU times: user 18.3 ms, sys: 7.25 ms, total: 25.5 ms
Wall time: 34.6 ms
torch.Size([4, 3, 224, 224]) torch.Size([4])

inference
CPU times: user 335 ms, sys: 98.4 ms, total: 434 ms
Wall time: 609 ms
torch.Size([4, 4])

loss
CPU times: user 1.06 ms, sys: 5.85 ms, total: 6.91 ms
Wall time: 7.97 ms
tensor(1.4268, grad_fn=<NllLossBackward>)

accu
0


### Stage Functions

#### Batch Steps

In [216]:
def common_step(state):
    print(f"{state.phase}_step, bn: {state.batch_num} en: {state.epoch_num}, ",end="")
    X, y = state.batch
    y_ = state.model(X)
    loss = state.loss_function(y_,y)
    print("loss",loss.item())
    r_loss = loss.item() * state.batch_size
    r_corr = (y_.argmax(dim=1) == y).sum().item()
    return loss, r_loss, r_corr

def train_step(state):
    loss, r_loss, r_corr = common_step(state)
    loss.backward()
    state.optimizer.step()
    return {'r_loss':r_loss,'r_corr':r_corr}

def valid_step(state):
    loss, r_loss, r_corr = common_step(state)
    return {'r_loss':r_loss,'r_corr':r_corr}

def test_step(state):
    loss, r_loss, r_corr = common_step(state)
    return {'r_loss':r_loss,'r_corr':r_corr}

#### Epoch Start Step

In [241]:
def common_epoch_start(state):
    print(f"\n{state.phase}_epoch_start, # {state.epoch_num}")
    return {'dummy':'dict'}
    
def train_epoch_start(state):
    return common_epoch_start(state)

def valid_epoch_start(state):
    return common_epoch_start(state)

def test_epoch_start(state):
    return common_epoch_start(state)

#### Epoch End Steps

In [231]:
def common_epoch_end(state):
    print(f"{state.phase}_epoch_end, # {state.epoch_num}")
    r_loss = state['r_loss']
    r_corr = state['r_corr']
    
    print('r_loss len',len(r_loss))
    print('r_corr len',len(r_corr))
    
    e_loss = r_loss.sum()/state.size
    e_accu = r_corr.sum()/state.size
    
    print('loss',e_loss)
    print('accu',e_accu)
    
    return {'loss':e_loss, 'accu':e_accu}
    
def train_epoch_end(state):
    return common_epoch_end(state)

def valid_epoch_end(state):
    return common_epoch_end(state)

def test_epoch_end(state):
    return common_epoch_end(state)

### FitLoop - Usage

#### Setup

In [272]:
def configure_optimizer(self):
    parameters = self.model.parameters()
    self.optimizer.param_groups.clear()
    self.optimizer.add_param_group({'params': parameters})
    
fl_dict = {
    "model": model,
    "optimizer": optimizer,
    "loss_function": loss_function,
    "train_dl":dl[TR],
    "valid_dl":dl[VA],
    "test_dl":dl[TE],
    "train_step":train_step,
    "valid_step":valid_step,
    "test_step":test_step,
    "train_epoch_start":train_epoch_start,
    "valid_epoch_start":valid_epoch_start,
    "test_epoch_start":test_epoch_start,
    "train_epoch_end":train_epoch_end,
    "valid_epoch_end":valid_epoch_end,
    "test_epoch_end":test_epoch_end,
    "configure_optimizer":configure_optimizer,
    "criteria": "accu"
}

trainer = FitLoop(**fl_dict)

#### Execution

In [275]:
trainer.fit(epochs=3)


train_epoch_start, # 0
train_step, bn: 1 en: 0, loss 1.457542896270752
train_step, bn: 2 en: 0, loss 1.6427596807479858
train_step, bn: 3 en: 0, loss 1.4610310792922974
train_step, bn: 4 en: 0, loss 1.7224218845367432
train_step, bn: 5 en: 0, loss 1.293718695640564
train_epoch_end, # 0
r_loss len 5
r_corr len 5
loss tensor(1.5155)
accu tensor(0.2500)
train {'loss': 1.5154948234558105, 'accu': 0.25}

valid_epoch_start, # 0
valid_step, bn: 1 en: 0, loss 2.981490135192871
valid_step, bn: 2 en: 0, loss 4.458968162536621
valid_epoch_end, # 0
r_loss len 2
r_corr len 2
loss tensor(3.7202)
accu tensor(0.1250)
valid {'loss': 3.720229148864746, 'accu': 0.125}
epoch_time:  5 s 377 ms         

train_epoch_start, # 1
train_step, bn: 1 en: 1, loss 0.15699230134487152
train_step, bn: 2 en: 1, loss 0.9917818903923035
train_step, bn: 3 en: 1, loss 0.024288427084684372
train_step, bn: 4 en: 1, loss 1.3614912033081055
train_step, bn: 5 en: 1, loss 0.00760624697431922
train_epoch_end, # 1
r_loss len 5
r

## Portions of the loop

- **Single step** Portion of the loop that receives the batch and will cal forward pass on it.
    - *Set* zero grad
    - **Train Step** 
        - *Set* 
            - enable gradients
            - model.train
        - *Define*: 
            - compute loss
            - call backward 
            - update gradients and maybe scheduler step
            - loss and whatever metrics will be saved in a list.
    - **Valid Step**
        - *Set*
            - disable gradients
            - model.eval
        - *Define*: 
            - compute loss
            - loss and whatever metrics will be saved in a list.
    - **Test Step** Only in test loop
        - *Set*
            - disable gradients
            - model.eval
        - *Define*: 
            - compute loss
            - loss and whatever metrics will be saved in a list.
- **Pre Epoch**
- **Post Epoch**

## Reference

In [None]:
# Function to calculate a time string from
# 2 instances of time.time
def ftime(t1,t2):
    t = t2-t1
    s, ms = str(t).split('.')
    ms = ms[:3]
    s = str(int(s)%60).rjust(2)
    m = int(t//60)
    h = str(m//60)
    m = str(m % 60).rjust(2)
    u = ['h','m','s','ms']
    v = [h,m,s,ms]
    t = filter(lambda x:int(x[0]) > 0,zip(v,u))
    return ' '.join([' '.join(x) for x in t]).ljust(20)

In [None]:
def fit(model, dl_train, dl_valid, optim, loss_function, epochs, sched=None, should_pass=False, print_every=1, load_best=True):
    """
    Args:
    # model
      - pytorch model to be trained
      
    # dl_train
      - Dataloader for the train set
      
    # dl_valid
      - Dataloader for the valid set
      
    # optimizer
      - optimizer attached to model params from torch.optim
      
    # loss_function
      - function to calculate the loss; loss_function(model(X),y)
      
    # epochs
      - number of epochs to train for
      
    # device
      - torch.device to load the model to 
      
    # sched
      - scheduler for the optimizer learning rate
      
    # should_pass
      - if loss should be passed to the scheduler.
      
    # print_every
      - to print the loss every n epochs
    
    # load_best 
      - whether to load the best model as determined by accuracy
    """
    
    total_time_start = time.time()
    # ----------------------------
    phases = [TR,VA]
    dl = {TR:dl_train, VA:dl_valid}
    sizes = {k:len(dl[k].dataset) for k in dl}
    
    # To contain metrics to plot later
    losses_all = {TR:[],VA:[]}
    losses_epo = {TR:[],VA:[]}
    accuracy_epo = {TR:[],VA:[]}
    
    # Markers
    least_loss = float('inf')
    best_accuracy = 0
    best_model = deepcopy(model.state_dict())
    
    # ----------------------------
    # Convenience functions (that won't be used elsewhere)
    
    # Function to get formatted epochs (from 1 not 0)
    r_just_val = len(str(epochs))*2 + 3
    estr = lambda e: f"[{e + 1}/{epochs}]".rjust(r_just_val)
    
    # Convenience function to print every `print_every` epochs.
    def eprint(e,st):
        if ((e + 1) % print_every == 0):
            print(st,end="")

    def grad_times(t):
        t = t*1000
        return f"{t:0.3f} ms".rjust(10)
    
    # Convenience function for phase strings.
    def statstr(phase,loss,accu,time,infr, rjust=True):
        infr = grad_times(infr)
        st =  f"{phase} :: loss: {loss:0.4f} | accu: {accu:0.4f} | infr: {infr}  | time: {time} \n"
        if rjust:
            return st.rjust(r_just_val + len(st) + 3)
        else:
            return st

    
    # ----------------------------
    # Training 
    model.to(device)
    for e in range(epochs):
        epoch_time_start = time.time()
        
        # Training or validation phase
        for phase in phases:
            phase_time_start = time.time()

            # Running metrics
            running_loss = 0
            running_corr = 0

            is_tr = phase == TR
            if is_tr:
                  model.train()
            else:
                  model.eval()
                  
            # Keeping track of computation times.
            inference_times = []
            dl_p = dl[phase]
            dl_l = len(dl_p)
            if is_tr:
                eprint(e,estr(e)+f" - ")
            for b,batch in enumerate(dl_p):
                X,y = batch
                bs = y.size(0)
                X = X.to(device)
                y = y.to(device)
                
                optim.zero_grad()
                
                with torch.set_grad_enabled(is_tr):
                    # Forward pass and log times 
                    inference_time_start = time.time()
                    y_ = model(X)
                    inference_time_end = time.time()
                    inference_times.append((inference_time_end-inference_time_start)/bs)
                    
                    loss = loss_function(y_,y)
                    if is_tr:
                        # Backprop (grad calc and param update)
                        loss.backward()
                        optim.step()
                        
                    elif not is_tr and sched is not None:
                        # Scheduler step
                        if should_pass:
                            sched.step(loss.item())
                        else:
                            sched.step()
                    
                running_loss += loss.item() * bs
                running_corr += (y_.argmax(dim=1) == y).sum().item()
                losses_all[phase].append(loss.item())
             
            # Calculate phase losses and scores
            phase_loss = running_loss / sizes[phase]
            phase_accu = running_corr / sizes[phase]
                  
            # Calculate timings
            mean_inf = np.array(inference_times).mean()
            # Log scores and losses
            losses_epo[phase].append(phase_loss)
            accuracy_epo[phase].append(phase_accu)
                
            # Update markers save best model
            if not is_tr:
                if phase_accu > best_accuracy:
                    best_accuracy = phase_accu
                    best_model = deepcopy(model.state_dict())
                if phase_loss < least_loss:
                    least_loss = phase_loss
            
            # Logging phase time
            phase_time_end = time.time()
            phase_time = ftime(phase_time_start,phase_time_end)
            if is_tr:
                eprint(e,statstr(phase, phase_loss,phase_accu, phase_time, mean_inf, False))
            else:
                eprint(e,statstr(phase, phase_loss,phase_accu, phase_time, mean_inf))
        
        # Logging epoch times
        epoch_time_end = time.time()
        epoch_time = f"epoch time: {ftime(epoch_time_start, epoch_time_end)}\n"
        eprint(e,epoch_time.rjust(len(epoch_time)+r_just_val+3)+"\n")
    
    if load_best:
        model.load_state_dict(best_model)
    
    # ----------------------------
    total_time_end = time.time()
    total_time = ftime(total_time_start,total_time_end)
    print(f"least loss: {least_loss:0.4f} | best accu: {best_accuracy:0.4f} | total time taken: {total_time}")
    
    return {"losses_epo":losses_epo,"accuracy_epo":accuracy_epo,"losses_all":losses_all}