# Testing Grounds - Fit Loop
Creating a library that helps with the pytorch looping.

In [1]:
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor

In [33]:
"""
Basic setup with FakeData for testing.
"""

sets = ['train','valid','test']
TR, VA, TE = sets
class_names = ['a','b','c','d']
num_classes = 4

sz = {s:z for s, z in zip(sets,[1024,128,256])}
ds = {s:FakeData(size=sz[s], transform=ToTensor(), num_classes=num_classes) for s in sets}
dl = {s:DataLoader(ds[s],batch_size=16) for s in ds}

model = resnet18()
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)

loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

"""
Batch for testing purposes
"""
X, y = next(iter(dl[TR]))

In [223]:
import time
import math

from uuid import uuid4
from pathlib import Path
from copy import deepcopy
from typing import Union, List, Callable, Optional, Any, Dict, Tuple

from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

In [5]:
device = torch.device('cpu')

In [22]:
class A():
    def __init__(self, f):
        self.__f = f
        self.a = 22
    
    def __getitem__(self, name):
        return name
        
    def __getattr__(self,name):
        return getattr(self.__f,name)
class B():
    def __init__(self):
        self.b = 33

LoopState
- ✅should cast the batch to device before passing it using `state.batch()`
- ✅should get the batch num `state.batch_num` and epoch num `state.epoch_num`
- ✅the model, optimizer, loss_function, lr_scheduler should be available
    `state.model`, `state.optimizer`, `state.loss_function`, `state.lr_scheduler`
- ✅should return the batch metrics as float tensors using square bracket indexing
    `state['loss']` 
    - every step function hook receives the LoopState object.
- ✅The loop state object should have a copy of all the values returned from the function hook
- ✅example the below returned dict values should be avialable in the LoopState object

```python
def train_step(state):
    X,y = state() # should device cast automatically
    y_ = state.model(X)
    loss = state.loss_function(y_, y)
    
    state.optimizer.zero_grad()
    loss.backward()
    state.optimizer.step()
    state.lr_scheduler.step() 
    
    loss = loss.item()
    batch_loss = loss * y.size()
    batch_corr = (y_.argmax(dim=0) == y).sum().float().item()
    
    return {'loss':loss,'batch_loss':batch_loss:'batch_corr'}
```
- The LoopState object should be cleared of the above values at the start 
  of the next epoch.
- The returned values should be available through the FitLoop object
  Eg: `FitLoop.train.batch['loss']`
- The (above statement) returned value should be optionally available by setting the flag 
  `store_batch_metrics`

```python
def train_epoch_end(state):
    loss = state['loss']
    batch_loss = state['batch_loss']
    batch_corr = state['batch_corr']
    
    size = state.size
    
    epoch_loss = batch_loss.sum().item()/size
    epoch_accu = batch_corr.sum().item()/size
    
    return {"loss":epoch_loss,"accu":epoch_accu}
```

- The returned values should be available through the FitLoop object
  Eg: `FitLoop.train.epoch['loss']`
- ✅For each phase a different LoopState obect is maintained.



In [228]:
class FitLoop:
    # Dummy
    pass

class LoopState:
    """
    Maintains train/valid/test loop state for a single run of 
    a certain number of epochs, does not used to preserve state 
    between runs.
    """
    _batch = 'batch'
    _epoch = 'epoch'
    def __init__(self, phase:str, floop:FitLoop, no_cast:bool):
        """
        phase : phase name 'train', 'valid' or 'test'
        floop : the calling FitLoop object
        """
        self.metrics = {}
        self.__batch = ()
        self.__floop = floop
        self.batch_num = 0
        self.epoch_num = 0
        self.no_cast = no_cast
        
        # For easy access
        dl = getattr(floop, f'{phase}_dl')
        bs = dl.batch_size
        dr = dl.drop_last
        sz = len(dl.dataset)
        bt = sz / bs
        
        # Gives dataset size and batch count
        self.size = sz
        self.batches = math.floor(bt) if dr else math.ceil(bt)
    
    def __getattr__(self, name:str) -> Any:
        return getattr(self.__floop, name)
    
    @property
    def batch(self) -> Tuple[Tensor,...]:
        if no_cast:
            return self.__batch
        
        return (
            d.to(device=self.device,dtype=self.dtype) 
            if d.is_floating_point() 
            else d.to(device=self.device,dtype=torch.long) 
            for d in self.__batch
        )
    
    @batch.setter
    def batch(self, current_batch:Tuple[Tensor,...]) -> None:
        self.__batch = current_batch
        
    def _get_score(self, criteria:str) -> float:
        # Last added metric that is to be used as a model 
        # selection criteria
        return float(self.metrics[self._epoch][-1])
    
    def _append(self, rdict:Dict[str, float], stage:str) -> None:
        #  Append metrics to the specific stage.
        if stage not in self.metrics:
            self.metrics[stage] = {}
            
        for key in rdict:
            if key not in self.metrics[stage]:
                self.metrics[stage][key] = []
            self.metrics[stage][key].append(rdict[key])
    
    def _clear(self, stage:str) -> None:
        # Clear the batch metrics at the end of the batch.
        for m in self.metrics[stage]:
            m.clear()
            
    def _append_batch(self, rdict:Dict[str, float]) -> None:
        self._append(rdict, self._batch)
        
    def _append_epoch(self, rdict:Dict[str, float]) -> None:
        self._append(rdict, self._epoch)
            
    def _clear_batch(self) -> None:
        self._clear(self._batch)
        
    def _clear_epoch(self) -> None:
        self._clear(self._epoch)
    
    def __getitem__(self, metric_name:str):
        metric_value = self.metrics[self._batch][metric]
        try:
            return torch.tensor(metric_value).float()
        except:
            return metric_value
        

FitLoop
- ✅if `FitLoop.fit(define_all:bool=False)` the zero_grad and the context manager are not auto set.
- ✅should keep track of epochs that have been completed
- ✅epoch_number can be reset 
- metrics can be cleared
- `FitLoop.set_name.loop_stage['metric_name']` to access the metric
- `FitLoop.store_pretrained:bool` arg to store the pretrained weights before training
    if path then store at given path else store in memory.
- ✅`FitLoop.reset(reset_model:bool)` to clear metrics, epoch_num and to reset the model, to pretrained state
    will load the weight from passed path else from memory.
- `FitLoop.save(path:str)` to save the model and training state somehow even the fitloop state.
- `FitLoop.load(path:str)` to load the FitLoop state from given path.
- Some basic step should be used such that one can use it without defining step functions.
- `FitLoop.fit(continue_loop:int=0)` ask after `int` whether to continue training or to end.
- `FitLoop.fit(profiler:bool=False)` mode to capture all stage timings and maybe even CPU, GPU, RAM usage to check for bottlenecks and usage spikes, to be used with timed_test.
- Make it easy to train,validate, test with some other DataLoader that is not attached to the object.
- `FitLoop.fit(no_print:bool=False)` mode to capture all stage timings and maybe even CPU, GPU, RAM usage to check for bottlenecks and usage spikes, to be used with timed_test.
- Use a loading bar for epoch and a disappearing one for batch.
- Functionality to view the metrics.
- ✅Model score should be a loop instance so that the best model may not be erased.

In [None]:
class FitLoop:
    sets = ['train','valid','test']
    TR, VA, TE = sets
    
    model_typ = ['pretrained','best']
    PR, BS = model_type
    def __init__(self, 
                 # Basic blocks
                 model: Module, optimizer: Union[Optimizer,List[Optimizer]], 
                 loss_function: Callable[[Tensor,Tensor],Tensor], 
                 
                 # DataLoader
                 train_dl: DataLoader, 
                 valid_dl: Optional[DataLoader]=None, 
                 test_dl: Optional[DataLoader]=None, 
                 
                 # Batch Step
                 train_step: Callable[[LoopState],Dict[str, Any]],
                 valid_step: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 test_step: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 
                 # Epoch Start Step
                 train_epoch_start: Callable[[LoopState],Dict[str, Any]],
                 valid_epoch_start: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 test_epoch_start: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 
                 # Epoch End Step
                 train_epoch_end: Callable[[LoopState],Dict[str, Any]],
                 valid_epoch_end: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 test_epoch_end: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 
                 # Other Args
                 lr_scheduler: Optional[_LRScheduler, Any, List[_LRScheduler,Any]]=None,
                 device: torch.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 
                 dtype: torch.dtype=torch.float32,
                 criteria_direction: int=1
                 
                 # Preserving Model State
                 save_to_disk: bool=False
                 save_path: str="models"
                 pretrained_name: Optional[str]=None,
                 best_model_name: Optional[str]=None,
                ) -> None:
        """
        FitLoop constructor
        ----
        Parameters:
        # Basic Blocks
            The bare minimum required along with train_dl, train_step and train_epoch_end.
            - model : nn.Module model that has to be trained
            - optimizer : an optimizer from torch.optim
            - loss_function : function to compute loss
         
        # DataLoader
            - train_dl : training DataLoader
            - valid_dl : validation DataLoader, if None validation will be ignored
            - test_dl : testing DataLoader, if None `.test()` will not run
         
        # Batch Step
            Functions that take in a LoopState object to perform 
            required calculations, functions should return a dict with values
            to be used in the epoch end step.
            - train_step : portion of the loop where forward and backward 
                passes take place.
            - valid_step : validation portion of the loop.
            - test_step : called when `FitLoop.test()` is called.
        
        # Epoch Start Step
            TODO: NEED TO IMPLEMENT
            - train_epoch_start :
            - valid_epoch_start :
            - test_epoch_start : 
        
        # Epoch End Step
            Functions that take in a LoopState object to perform 
            required calculations, functions should return a dict with values
            that are to be returned when the loop is over.
            - train_epoch_end : after training epoch has ended.
            - valid_epoch_end : after validation epoch has ended.
            - test_epoch_end : called when the test loop is done, one iteration
                over all batches in the test dataloader.
        
        # Other Args
            - lr_scheduler : scheduler from torch.optim.lr_scheduler
            - device : torch.device model will be cast to device this prior to the loop
            - dtype : floating point dtype to cast model and data to
            - criteria_direction : whether more is better (1) or less is better (-1) 
                for model score criteria.
        
        # Preserving Model State
            - save_to_disk : True then save pretrained and best_model to the disk, else it is 
                stored as an attribute.
            - save_path : location where the initial and pretrained models are to be saved
            - pretrained_name : Name to save the pretrained model by
            - best_model_name : Name to save the best model by
        """
        # Basic Blocks
        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_function
        
        # DataLoaders
        self.train_dl = train_dl
        self.valid_dl = valid_dl
        self.test_dl = test_dl
        
        # Batch Step
        self.train_step = train_step
        self.valid_step = valid_step
        self.test_step = test_step
        
        # Epoch Start Step
        self.train_epoch_start = train_epoch_start
        self.valid_epoch_start = valid_epoch_start
        self.test_epoch_start = test_epoch_start
        
        # Epoch End Step
        self.train_epoch_end = train_epoch_end
        self.valid_epoch_end = valid_epoch_end
        self.test_epoch_end = test_epoch_end
        
        # Other Args
        self.lr_scheduler = lr_scheduler
        self.device = device
        self.dtype = dtype
        self.criteria_direction = criteria_direction
        
        # Preserving Model State
        if save_disk and pretrained_name is None:
            u = str(uuid4()).split('-')[1]
            pretrained_name = f"pretrained_{u}.pt"
        if save_disk and best_model_name is None:
            u = str(uuid4()).split('-')[1]
            best_model_name = f"best_{u}.pt"
        self.pretrained_name = pretrained_name
        self.best_model_name = best_model_name
        self.save_to_disk = save_to_disk
        self.save_path = Path(save_path)
        
        
        # INITIALIZE NON ARGS
        self.__save_model(self.PR)
        self.epoch_num = 0
        self.best_score = self.criteria_direction * float('-inf')
            
        
    def __ftime(t1:float,t2:float)->str:
        t = t2-t1
        s, ms = str(t).split('.')
        ms = ms[:3]
        s = str(int(s)%60).rjust(2)
        m = int(t//60)
        h = str(m//60)
        m = str(m % 60).rjust(2)
        u = ['h','m','s','ms']
        v = [h,m,s,ms]
        t = filter(lambda x:int(x[0]) > 0,zip(v,u))
        return ' '.join([' '.join(x) for x in t]).ljust(20)
    
    def __time(profiler:bool) -> Optional[float]:
        if profiler:
            return time.time()
    
    def __call_batch_step(self, state:LoopState) -> None:
        step_funcs = [self.train_step, self.valid_step, self.test_step]
        step_funcs = {s:f for s,f in zip(self.sets, step_func)}
        step_func = step_funcs[state.phase]
        
        if step_func is None:
            raise AttributeError(f"{phase}_step not assigned")
        state_dict = step_func(phase)
        # To Do Alter LoopState according to the state_dict
        
        
    def __call_epoch_start_step(self, state:LoopState) -> Dict[str,Any]:
        step_funcs = [self.train_epoch_start,self.valid_epoch_start,self.test_epoch_start]
        step_funcs = {s:f for s,f in zip(self.sets, step_func)}
        step_func = step_funcs[state.phase]
        
        if step_func is None:
            raise AttributeError(f"{phase}_step not assigned")
        state_dict = step_func(state)
        # To Do Alter LoopState according to the state_dict
        
    def __call_epoch_end_step(self, phase:str) -> Dict[str,Any]:
        step_funcs = [self.train_epoch_end,self.valid_epoch_end,self.test_epoch_end]
        step_funcs = {s:f for s,f in zip(self.sets, step_func)}
        step_func = step_funcs[state.phase]
        
        if step_func is None:
            raise AttributeError(f"{phase}_step not assigned")
        state_dict = step_func(state)
        # To Do Alter LoopState according to the state_dict
        
    def __get_dl(self, is_test:bool)-> Dict[str,DataLoader]:
        if is_test:
            if self.test_dl is None:
                raise AttributeError("test_dl not assigned")
            return {self.TE:self.test_dl}
        
        va_dl = self.valid_dl is not None
        va_st = self.valid_step is not None
        if  va_dl and va_st:
            return {TR:self.train_dl, VA:self.valid_dl}
        else:
            return {TR:self.train_dl}
    
    def __loop(self, 
            epochs:int=1, print_every:int=1, steps: Optional[int],
            criteria:Union[str, None]=None, load_best:bool=False, 
            profiler:bool=False, is_test:bool=False,
            track_batch_metrics:bool=True, define_all:bool=False,
            continue_loop:int=0, no_print:bool=False, no_cast:bool=False
           ):
        """
        Runs the training loop for `epochs`
        ----
        Parameters
         - epochs : should be a non negative integer
         - print_every : if 0 will not print, else will print at given epoch
         - steps : number of batches to run in each phase [train,valid] 
             for check if everything is working.
         - criteria : value in the `valid_epoch_end` return dict
             that used is used to define a better model, eg: 'accuracy'
         - load_best : whether to load the best model after training, works only if validation
             parameters are defined `valid_dl`, `valid_step`, `valid_epoch_end`
         - profiler : whether to keep track of time taken by various sections
         - is_test : whether it is a model testing loop or training/validation loop
         - track_batch_metrics : whether to store the values returned in the batch steps
         - define_all : If True then `torch.set_grad_enabled`, `optimizer.zero_grad` and model mode 
             ie [train,eval] have to be called where required (usually in the `train_step` function).
         -  continue_loop : Will ask whether to continue training after `continue` epochs, should
             be a positive integer.
         - no_print : If True will suppress all print statements, can be used when custom logging is
             used in the stage functions.
         - no_cast : True, if model and data casting has to be manually set in the stage functions
        
        """
        t = lambda : self.__time(print_every != 0) # Returns the time 
        
        total_time_start = t()
        
        # INITILIZING VARIABLES -----
        
        # Storage
        dl = self.__get_dl(is_test)
        sz = { k : len(dl[k].dataset) for k in dl }
        phases = [ph for ph in dl]
        state = {ph: LoopState() for ph in phases}

        # Markers
        self.__save_model(self.BS)

        # ----------------------------
        
        
        # CONVENIENCE FUNCTIONS ------
        
        # Function to get formatted epochs (from 1 not 0)
        r_just_val = len(str(epochs))*2 + 3
        estr = lambda e: f"[{e + 1}/{epochs}]".rjust(r_just_val)

        # Function to print every `print_every` epochs.
        def eprint(e,st):
            if (e == 0) and (print_every != 0):
                print(st,end="")
            elif (e + 1) % print_every == 0:
                print(st,end="")

        def grad_times(t):
            t = t*1000
            return f"{t:0.3f} ms".rjust(10)

        # Function for phase strings.
        def statstr(phase, loss, accu, time, infr, rjust=True):
            infr = grad_times(infr)
            st =  f"{phase} :: loss: {loss:0.4f} | accu: {accu:0.4f} | infr: {infr}  | time: {time} \n"
            if rjust:
                return st.rjust(r_just_val + len(st) + 3)
            else:
                return st
            
        # ----------------------------


        # THE LOOP - START -----------
        
        """
        TODO : Preloop section, to initilize parameters in whichever way, add profiling.
            Manual cast for below line
        """
        if not no_cast:
            model = self.model.to(device=self.device, dtype=self.dtype)
        for e in range(epochs):
            epoch_time_start = t()
            
            # Update FitLoop epoch_num
            self.epoch_num += 1
            
            """
            TODO : Re-initilize LoopState, add profiling
            """
            # for st in state do something
            
            # 
            for phase in phases:
                # TODO: set the epoch_num for LoopState and self
                phase_time_start = t()
                
                # Update LoopState epoch_num
                state[phase].epoch_num = e
                
                # EPOCH START STEP - START ---
                epoch_start_step = self.__call_epoch_start_step(state[phase])
                # EPOCH START STEP - END -----
                
                is_tr = phase == TR
                if not define_all:
                    if is_tr:
                          model.train()
                    else:
                          model.eval()
                            
                # if is_tr:
                #    eprint(e,estr(e)+f" - ")
                    
                    
                # BATCH LOOP - START ---------
                for batch in dl[phase]:
                    
                    # Update LoopState batch_num and set batch
                    state[phase].batch_num += 1
                    state[phase].batch = batch
                    
                    # BATCH STEP - START ---------
                    if define_all:
                        batch_step = self.__call_batch_step(state[phase])
                    else:
                        optim.zero_grad()
                        with torch.set_grad_enabled(is_tr):
                            batch_step = self.__call_batch_step(state[phase])
                    # BATCH STEP - END -----------
                    
                # BATCH LOOP - END -----------
                
                # EPOCH END STEP - START -----
                self.__call_epoch_end_step(state[phase])
                # EPOCH END STEP - END -------
                
                # Update markers save best model
                if not is_tr and not is_test:
                    score = state[phase]._get_score()
                    direc = self.criteria_direction > 0
                    is_better = score > self.best_score if direc else score < self.best_score:
                    if is_better:
                        self.best_score = score
                        self.__save_model(self.BS)

#                 # Logging phase time
#                 phase_time_end = time.time()
#                 phase_time = ftime(phase_time_start,phase_time_end)
#                 if is_tr:
#                     eprint(e,statstr(phase, phase_loss,phase_accu, phase_time, mean_inf, False))
#                 else:
#                     eprint(e,statstr(phase, phase_loss,phase_accu, phase_time, mean_inf))

            # Logging epoch times
#             epoch_time_end = time.time()
#             epoch_time = f"epoch time: {ftime(epoch_time_start, epoch_time_end)}\n"
#             eprint(e,epoch_time.rjust(len(epoch_time)+r_just_val+3)+"\n")

        # THE LOOP - END -------------
    
        if load_best:
            self.__load_model(self.BS)

        # ----------------------------
        total_time_end = time.time()
#         total_time = ftime(total_time_start,total_time_end)
#         print(f"least loss: {least_loss:0.4f} | best accu: {best_accuracy:0.4f} | total time taken: {total_time}")

#         return {"losses_epo":losses_epo,"accuracy_epo":accuracy_epo,"losses_all":losses_all}
        return
    
    def fit(self, 
            epochs:int=1, print_every:int=1, 
            criteria:Union[str, None]=None, load_best:bool=False, 
            profiler:bool=False
           ):
        """
        Runs the training loop for `epochs`
        ----
        Parameters
         - epochs : should be a non negative integer
         
         - print_every : if 0 will not print, else will print at given epoch
         - criteria : value in the validation epoch end return dict
             that used is used to define a better model, eg: 'accuracy'
         - load_best : whether to load the best model after training for `epochs`
         - profiler : whether to keep track of time taken by various sections
        
        """
        pass
    
    def train(self,*args,**kwargs):
        """
        An alias for FitLoop.fit
        """
        self.fit(*args, **kwargs)
    
    def test(self):
        """
        TODO : Run __loop in test mode for model testing
        """
        pass
    
    def timed_test(self):
        """
        TODO: Runs the model in training mode for one epoch to get
            to get the runtime and other stats, will restore model
            state after.
        """
        pass
    
    def sanity_check(self, use_test_dl=False):
        """
        use_test_dl whether to use test or validation data for the 
        test loop.
        TODO : Run train and test loops for 1 epoch and fewer steps
            ie smaller dataset to check if everything is working,
            will restore model state after.
        """
        pass
    
    def __save_model(self, typ:str) -> None:
        """
        Save model to object or to the disk.
        """
        name = self.best_model_name if typ == self.BS else self.pretrained_name
        path = self.save_path/ name
        state_dict = self.model.state_dict()
        if self.save_to_disk:
            torch.save(state_dict,path)
        elif typ == self.BS:
            self.best_model_state_dict = deepcopy(state_dict)
        else:
            self.pretrained_state_dict = deepcopy(state_dict)
        
    def __load_model(self, typ:str):
        """
        Load model from the object or from the disk.
        """
        name = self.best_model_name if typ == self.BS else self.pretrained_name
        path = self.save_path/ name
        if self.save_to_disk:
            state_dict = torch.load(path, map_location=self.device)
        elif typ == self.BS:
            state_dict = self.best_model_state_dict
        else:
            state_dict = self.pretrained_state_dict
        self.model.load_state_dict(state_dict)
        print("please reinitilize FitLoop.optimizer param groups before training")
    
    def reset(self, reset_model:bool=True) -> None:
        """
        Resets FitLoop to initial state.
        Parameters reset:
            - model, to pretrained state if rese_mode==True
            - epoch_num, to 0
            - best_score to ∓inf
        FitLoop.optimizer param groups will have to be set again
        """
        if reset_model:
            self.__load_model(self, self.PR)
        self.epoch_num = 0
        self.best_score = self.criteria_direction * float('-inf')
    
    def save(self, path, only_model=False):
        """
        TODO : save the FitLoop state, if only_model then save only model.
        """
        pass
    
    def load(self, path):
        """
        TODO : load the FitLoop state, if only model then load the model 
            state dict.
        """
        pass
    
    def del_pretrained(self):
        """
        TODO : function to delete pretrained model from save path
        """
        
    def del_best_model(self):
        """
        TODO : function to delete best model from save path
        """

## Portions of the loop

- **Single step** Portion of the loop that receives the batch and will cal forward pass on it.
    - *Set* zero grad
    - **Train Step** 
        - *Set* 
            - enable gradients
            - model.train
        - *Define*: 
            - compute loss
            - call backward 
            - update gradients and maybe scheduler step
            - loss and whatever metrics will be saved in a list.
    - **Valid Step**
        - *Set*
            - disable gradients
            - model.eval
        - *Define*: 
            - compute loss
            - loss and whatever metrics will be saved in a list.
    - **Test Step** Only in test loop
        - *Set*
            - disable gradients
            - model.eval
        - *Define*: 
            - compute loss
            - loss and whatever metrics will be saved in a list.
- **Pre Epoch**
- **Post Epoch**

## Reference

In [None]:
# Function to calculate a time string from
# 2 instances of time.time
def ftime(t1,t2):
    t = t2-t1
    s, ms = str(t).split('.')
    ms = ms[:3]
    s = str(int(s)%60).rjust(2)
    m = int(t//60)
    h = str(m//60)
    m = str(m % 60).rjust(2)
    u = ['h','m','s','ms']
    v = [h,m,s,ms]
    t = filter(lambda x:int(x[0]) > 0,zip(v,u))
    return ' '.join([' '.join(x) for x in t]).ljust(20)

In [None]:
def fit(model, dl_train, dl_valid, optim, loss_function, epochs, sched=None, should_pass=False, print_every=1, load_best=True):
    """
    Args:
    # model
      - pytorch model to be trained
      
    # dl_train
      - Dataloader for the train set
      
    # dl_valid
      - Dataloader for the valid set
      
    # optimizer
      - optimizer attached to model params from torch.optim
      
    # loss_function
      - function to calculate the loss; loss_function(model(X),y)
      
    # epochs
      - number of epochs to train for
      
    # device
      - torch.device to load the model to 
      
    # sched
      - scheduler for the optimizer learning rate
      
    # should_pass
      - if loss should be passed to the scheduler.
      
    # print_every
      - to print the loss every n epochs
    
    # load_best 
      - whether to load the best model as determined by accuracy
    """
    
    total_time_start = time.time()
    # ----------------------------
    phases = [TR,VA]
    dl = {TR:dl_train, VA:dl_valid}
    sizes = {k:len(dl[k].dataset) for k in dl}
    
    # To contain metrics to plot later
    losses_all = {TR:[],VA:[]}
    losses_epo = {TR:[],VA:[]}
    accuracy_epo = {TR:[],VA:[]}
    
    # Markers
    least_loss = float('inf')
    best_accuracy = 0
    best_model = deepcopy(model.state_dict())
    
    # ----------------------------
    # Convenience functions (that won't be used elsewhere)
    
    # Function to get formatted epochs (from 1 not 0)
    r_just_val = len(str(epochs))*2 + 3
    estr = lambda e: f"[{e + 1}/{epochs}]".rjust(r_just_val)
    
    # Convenience function to print every `print_every` epochs.
    def eprint(e,st):
        if ((e + 1) % print_every == 0):
            print(st,end="")

    def grad_times(t):
        t = t*1000
        return f"{t:0.3f} ms".rjust(10)
    
    # Convenience function for phase strings.
    def statstr(phase,loss,accu,time,infr, rjust=True):
        infr = grad_times(infr)
        st =  f"{phase} :: loss: {loss:0.4f} | accu: {accu:0.4f} | infr: {infr}  | time: {time} \n"
        if rjust:
            return st.rjust(r_just_val + len(st) + 3)
        else:
            return st

    
    # ----------------------------
    # Training 
    model.to(device)
    for e in range(epochs):
        epoch_time_start = time.time()
        
        # Training or validation phase
        for phase in phases:
            phase_time_start = time.time()

            # Running metrics
            running_loss = 0
            running_corr = 0

            is_tr = phase == TR
            if is_tr:
                  model.train()
            else:
                  model.eval()
                  
            # Keeping track of computation times.
            inference_times = []
            dl_p = dl[phase]
            dl_l = len(dl_p)
            if is_tr:
                eprint(e,estr(e)+f" - ")
            for b,batch in enumerate(dl_p):
                X,y = batch
                bs = y.size(0)
                X = X.to(device)
                y = y.to(device)
                
                optim.zero_grad()
                
                with torch.set_grad_enabled(is_tr):
                    # Forward pass and log times 
                    inference_time_start = time.time()
                    y_ = model(X)
                    inference_time_end = time.time()
                    inference_times.append((inference_time_end-inference_time_start)/bs)
                    
                    loss = loss_function(y_,y)
                    if is_tr:
                        # Backprop (grad calc and param update)
                        loss.backward()
                        optim.step()
                        
                    elif not is_tr and sched is not None:
                        # Scheduler step
                        if should_pass:
                            sched.step(loss.item())
                        else:
                            sched.step()
                    
                running_loss += loss.item() * bs
                running_corr += (y_.argmax(dim=1) == y).sum().item()
                losses_all[phase].append(loss.item())
             
            # Calculate phase losses and scores
            phase_loss = running_loss / sizes[phase]
            phase_accu = running_corr / sizes[phase]
                  
            # Calculate timings
            mean_inf = np.array(inference_times).mean()
            # Log scores and losses
            losses_epo[phase].append(phase_loss)
            accuracy_epo[phase].append(phase_accu)
                
            # Update markers save best model
            if not is_tr:
                if phase_accu > best_accuracy:
                    best_accuracy = phase_accu
                    best_model = deepcopy(model.state_dict())
                if phase_loss < least_loss:
                    least_loss = phase_loss
            
            # Logging phase time
            phase_time_end = time.time()
            phase_time = ftime(phase_time_start,phase_time_end)
            if is_tr:
                eprint(e,statstr(phase, phase_loss,phase_accu, phase_time, mean_inf, False))
            else:
                eprint(e,statstr(phase, phase_loss,phase_accu, phase_time, mean_inf))
        
        # Logging epoch times
        epoch_time_end = time.time()
        epoch_time = f"epoch time: {ftime(epoch_time_start, epoch_time_end)}\n"
        eprint(e,epoch_time.rjust(len(epoch_time)+r_just_val+3)+"\n")
    
    if load_best:
        model.load_state_dict(best_model)
    
    # ----------------------------
    total_time_end = time.time()
    total_time = ftime(total_time_start,total_time_end)
    print(f"least loss: {least_loss:0.4f} | best accu: {best_accuracy:0.4f} | total time taken: {total_time}")
    
    return {"losses_epo":losses_epo,"accuracy_epo":accuracy_epo,"losses_all":losses_all}