# Testing Grounds - Fit Loop
Creating a library that helps with the pytorch looping.

In [43]:
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor

<module 'typing' from '/Users/alan/opt/anaconda3/envs/data_sci/lib/python3.7/typing.py'>

In [27]:
"""
Basic setup 
"""

sets = ['train','valid','test']
TR, VA, TE = sets
class_names = ['a','b','c','d']
num_classes = 4

sz = {s:z for s, z in zip(sets,[1024,128,256])}
ds = {s:FakeData(size=sz[s], transform=ToTensor(), num_classes=num_classes) for s in sets}
dl = {s:DataLoader(ds[s],batch_size=16) for s in ds}

model = resnet18()
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)

loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

In [33]:
"""
Batch for testing purposes
"""
batch = next(iter(dl[TR]))
X, y = batch

In [154]:
import time

from typing import Union, List, Callable, Optional, Any, Dict, Tuple

from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

In [48]:
device = torch.device('cpu')

- dict passed from from one of the states should be accesible from the loop state through `__getitem__`

In [146]:
def somefunc():
    return 22
class A:
    def __init__(self):
        self.s = somefunc

In [None]:
class LoopMetrics:
    def __init__(self):
        pass

In [153]:
'asdf' in {'asdf':22}

True

- every step function hook receives the LoopState object.
- The loop state object should have a copy of all the values returned from the function hook
- example the below returned dict values should be avialable in the LoopState object

```python
def train_step(state):
    X,y = state() # should device cast automatically
    y_ = state.model(X)
    loss = state.loss_function(y_, y)
    
    state.optimizer.zero_grad()
    loss.backward()
    state.optimizer.step()
    state.lr_scheduler.step() 
    
    loss = loss.item()
    batch_loss = loss * y.size()
    batch_corr = (y_.argmax(dim=0) == y).sum().float().item()
    
    return {'loss':loss,'batch_loss':batch_loss:'batch_corr'}
```
- The LoopState object should be cleared of the above values at the start 
  of the next epoch.
- The returned values should be available through the FitLoop object
  Eg: `FitLoop.train.batch['loss']`
- The returned value should be optionally available by setting the flag 
  `store_batch_metrics`

```python
def train_epoch_end(state):
    loss = state['loss']
    batch_loss = state['batch_loss']
    batch_corr = state['batch_corr']
    
    size = state.sz
    
    epoch_loss = batch_loss.sum().item()/size
    epoch_accu = batch_corr.sum().item()/size
    
    return {"loss":epoch_loss,"accu":epoch_accu}
```

- The returned values should be available through the FitLoop object
  Eg: `FitLoop.train.epoch['loss']`
- For each phase a different LoopState obect is maintained.



In [77]:
class LoopState:
    __batch = 'batch'
    __epoch = 'epoch'
    def __init__(self, phase:str, floop:FitLoop):
        self.metrics = {}
        self._batch = ()
        
        # Passed down attributes
        self.model = floop.model
        self.optimizer = floop.optimizer
        self.loss_function = floop.loss_function
        pass
    
    @property
    def batch(self) -> Tuple[...]:
        return (d.to(self.device) for d in self._batch)
    
    @batch.setter
    def batch(self, current_batch:Tuple[...]) -> None:
        self._batch = current_batch
    
    def append(self, rdict:Dict[str, float], stage:str) -> None:
        #  Append metrics to the specific stage.
        if stage not in self.metrics:
            self.metrics[stage] = {}
            
        for key in rdict:
            if key not in self.metrics[stage]:
                self.metrics[stage][key] = []
            self.metrics[stage][key].append(rdict[key])
    
    def clear(self, stage:str) -> None:
        # Clear the batch metrics at the end of the batch.
        for m in self.metrics[stage]:
            m.clear()
            
    def append_batch(self, rdict:Dict[str, float]) -> None
        self.append(rdict, self.__batch)
        
    def append_epoch(self, rdict:Dict[str, float]) -> None
        self.append(rdict, self.__epoch)
            
    def clear_batch(self) -> None
        self.clear(self.__batch)
        
    def clear_epoch(self) -> None
        self.clear(self.__epoch)
    
    def 
    
    def __getitem__(self):
        pass

In [99]:
class FitLoop:
    sets = ['train','valid','test']
    TR, VA, TE = sets
    def __init__(self, 
                 # Basic blocks
                 model: Module, optimizer: Union[Optimizer,List[Optimizer]], 
                 loss_function: Callable[[Tensor,Tensor],Tensor], 
                 
                 # DataLoader
                 train_dl: DataLoader, 
                 valid_dl: Optional[DataLoader]=None, 
                 test_dl: Optional[DataLoader]=None, 
                 
                 # Batch Step
                 train_step: Callable[[LoopState],Dict[str, Any]],
                 valid_step: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 test_step: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 
                 # Epoch Start Step
                 train_epoch_start: Callable[[LoopState],Dict[str, Any]],
                 valid_epoch_start: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 test_epoch_start: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 
                 # Epoch End Step
                 train_epoch_end: Callable[[LoopState],Dict[str, Any]],
                 valid_epoch_end: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 test_epoch_end: Optional[Callable[[LoopState],Dict[str, Any]]]=None,
                 
                 # Other Args
                 lr_scheduler: Optional[_LRScheduler, Any, List[_LRScheduler,Any]]=None,
                 device: torch.device=torch.device('cpu'), 
                 dtype: torch.dtype=torch.float32
                ) -> None:
        """
        FitLoop constructor
        ----
        Parameters:
        # Basic Blocks
            The bare minimum required along with train_dl, train_step and train_epoch_end.
            - model : nn.Module model that has to be trained
            - optimizer : an optimizer from torch.optim
            - loss_function : function to compute loss
         
        # DataLoader
            - train_dl : training DataLoader
            - valid_dl : validation DataLoader, if None validation will be ignored
            - test_dl : testing DataLoader, if None `.test()` will not run
         
        # Batch Step
            Functions that take in a LoopState object to perform 
            required calculations, functions should return a dict with values
            to be used in the epoch end step.
            - train_step : portion of the loop where forward and backward 
                passes take place.
            - valid_step : validation portion of the loop.
            - test_step : called when `FitLoop.test()` is called.
        
        # Epoch Start Step
            TODO: NEED TO IMPLEMENT
            - train_epoch_start :
            - valid_epoch_start :
            - test_epoch_start : 
        
        # Epoch End Step
            Functions that take in a LoopState object to perform 
            required calculations, functions should return a dict with values
            that are to be returned when the loop is over.
            - train_epoch_end : after training epoch has ended.
            - valid_epoch_end : after validation epoch has ended.
            - test_epoch_end : called when the test loop is done, one iteration
                over all batches in the test dataloader.
        
        # Other Args
            - lr_scheduler : scheduler from torch.optim.lr_scheduler
            - device : torch.device model will be cast to device this prior to the loop
            - dtype : dtype to cast model and data to
        """
        # Basic Blocks
        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_function
        
        # DataLoaders
        self.train_dl = train_dl
        self.valid_dl = valid_dl
        self.test_dl = test_dl
        
        # Batch Step
        self.train_step = train_step
        self.valid_step = valid_step
        self.test_step = test_step
        
        # Epoch Start Step
        self.train_epoch_start = train_epoch_start
        self.valid_epoch_start = valid_epoch_start
        self.test_epoch_start = test_epoch_start
        
        # Epoch End Step
        self.train_epoch_end = train_epoch_end
        self.valid_epoch_end = valid_epoch_end
        self.test_epoch_end = test_epoch_end
        
        # Other Args
        self.lr_scheduler = lr_scheduler
        self.device = device
        self.dtype = dtype
        
        # Epoch Num
        self.epoch_num = 0
        
    def __ftime(t1:float,t2:float)->str:
        t = t2-t1
        s, ms = str(t).split('.')
        ms = ms[:3]
        s = str(int(s)%60).rjust(2)
        m = int(t//60)
        h = str(m//60)
        m = str(m % 60).rjust(2)
        u = ['h','m','s','ms']
        v = [h,m,s,ms]
        t = filter(lambda x:int(x[0]) > 0,zip(v,u))
        return ' '.join([' '.join(x) for x in t]).ljust(20)
    
    def __time(profiler:bool) -> Optional[float]:
        if profiler:
            return time.time()
    
    def __call_batch_step(self, state:LoopState) -> None:
        step_funcs = [self.train_step, self.valid_step, self.test_step]
        step_funcs = {s:f for s,f in zip(self.sets, step_func)}
        step_func = step_funcs[state.phase]
        
        if step_func is None:
            raise AttributeError(f"{phase}_step not assigned")
        state_dict = step_func(phase)
        # To Do Alter LoopState according to the state_dict
        
        
    def __call_epoch_start_step(self, state:LoopState) -> Dict[str,Any]:
        step_funcs = [self.train_epoch_start,self.valid_epoch_start,self.test_epoch_start]
        step_funcs = {s:f for s,f in zip(self.sets, step_func)}
        step_func = step_funcs[state.phase]
        
        if step_func is None:
            raise AttributeError(f"{phase}_step not assigned")
        state_dict = step_func(state)
        # To Do Alter LoopState according to the state_dict
        
    def __call_epoch_end_step(self, phase:str) -> Dict[str,Any]:
        step_funcs = [self.train_epoch_end,self.valid_epoch_end,self.test_epoch_end]
        step_funcs = {s:f for s,f in zip(self.sets, step_func)}
        step_func = step_funcs[state.phase]
        
        if step_func is None:
            raise AttributeError(f"{phase}_step not assigned")
        state_dict = step_func(state)
        # To Do Alter LoopState according to the state_dict
        
    def __get_dl(self, is_test:bool)-> Dict[str,DataLoader]:
        if is_test:
            if self.test_dl is None:
                raise AttributeError("test_dl not assigned")
            return {self.TE:self.test_dl}
        v = self.valid_dl is None
        if not v:
            return {TR:self.train_dl, VA:self.valid_dl}
        else:
            return {TR:self.train_dl}
    
    def __get_phase_list(self, is_test):
        if is_test:
            return [self.TE]
        
        va_dl = self.valid_dl is not None
        va_st = self.valid_step is not None
        
        if va_dl and va_st:
            return [self.TR, self.VA]
        else:
            return [self.TR]
    
    def __loop(self, 
            epochs:int=1, print_every:int=1, steps: Optional[int],
            criteria:Union[str, None]=None, load_best:bool=False, 
            profiler:bool=False, is_test:bool=False,
            track_batch_metrics:bool=True
           ):
        """
        Runs the training loop for `epochs`
        ----
        Parameters
         - epochs : should be a non negative integer
         - print_every : if 0 will not print, else will print at given epoch
         - steps : number of batches to run in each phase [train,valid] 
             for check if everything is working.
         - criteria : value in the validation epoch end return dict
             that used is used to define a better model, eg: 'accuracy'
         - load_best : whether to load the best model after training for `epochs`
         - profiler : whether to keep track of time taken by various sections
         - is_test : whether it is a model testing loop or training/validation loop
         - track_batch_metrics : whether to store the values returned in the batch steps
        
        """
        t = lambda : self.__time(print_every != 0) # Returns the time 
        
        total_time_start = t()
        
        # ----------------------------
        phases = self.__get_phase_list(is_test)
        dl = self.__get_dl(is_test)
        sz = {k:len(dl[k].dataset) for k in dl}
        state = {}

        # Markers
        least_loss = float('inf')
        best_score = float('-inf') # value used save best model
        best_model = deepcopy(model.state_dict())

        # ----------------------------
        # Convenience functions (that aren't be used elsewhere)

        # Function to get formatted epochs (from 1 not 0)
        r_just_val = len(str(epochs))*2 + 3
        estr = lambda e: f"[{e + 1}/{epochs}]".rjust(r_just_val)

        # Convenience function to print every `print_every` epochs.
        def eprint(e,st):
            if (e == 0) and (print_every != 0):
                print(st,end="")
            elif (e + 1) % print_every == 0:
                print(st,end="")

        def grad_times(t):
            t = t*1000
            return f"{t:0.3f} ms".rjust(10)

        # Convenience function for phase strings.
        def statstr(phase, loss, accu, time, infr, rjust=True):
            infr = grad_times(infr)
            st =  f"{phase} :: loss: {loss:0.4f} | accu: {accu:0.4f} | infr: {infr}  | time: {time} \n"
            if rjust:
                return st.rjust(r_just_val + len(st) + 3)
            else:
                return st


        # ----------------------------
        # Training 
        model = self.model.to(device)
        for e in range(epochs):
            epoch_time_start = t()
            
            """
            TODO : Re-initilize LoopState
            """
            # for st in state do something
            
            # Training or validation phase
            for phase in phases:
                phase_time_start = t()
                
                # EPOCH START STEP START
                epoch_start_step = self.__call_epoch_start_step(state[phase])
                # EPOCH START STEP END
                
                is_tr = phase == TR
                if is_tr:
                      model.train()
                else:
                      model.eval()

                # Keeping track of computation times.
                inference_times = []
                dl_p = dl[phase]
                dl_l = len(dl_p)
                
                if is_tr:
                    eprint(e,estr(e)+f" - ")
                    
                for b,batch in enumerate(dl_p):
                    # BATCH STEP START
                    batch_step = self.__call_batch_step(state[phase])
                    # BATCH STEP END
                    
                    """
                    Code below goes in to above block
                    """
                    
                    ## In the function
#                     X,y = batch
#                     bs = y.size(0)
#                     X = X.to(device)
#                     y = y.to(device)

                    ## Auto set
#                     optim.zero_grad()

                    ## Auto set
#                     with torch.set_grad_enabled(is_tr):
#                         # Forward pass and log times 
#                         inference_time_start = time.time()
#                         y_ = model(X)
#                         inference_time_end = time.time()
#                         inference_times.append((inference_time_end-inference_time_start)/bs)

#                         loss = loss_function(y_,y)
#                         if is_tr:
#                             # Backprop (grad calc and param update)
#                             loss.backward()
#                             optim.step()

#                         elif not is_tr and sched is not None:
#                             # Scheduler step
#                             if should_pass:
#                                 sched.step(loss.item())
#                             else:
#                                 sched.step()

#                     running_loss += loss.item() * bs
#                     running_corr += (y_.argmax(dim=1) == y).sum().item()
#                     losses_all[phase].append(loss.item())


                # EPOCH END STEP START
                self.__call_epoch_end_step(state[phase])
                # EPOCH END STEP END
            
                # Calculate phase losses and scores
                phase_loss = running_loss / sizes[phase]
                phase_accu = running_corr / sizes[phase]

                # Calculate timings
                mean_inf = np.array(inference_times).mean()
                # Log scores and losses
                losses_epo[phase].append(phase_loss)
                accuracy_epo[phase].append(phase_accu)

                # Update markers save best model
                if not is_tr:
                    if phase_accu > best_accuracy:
                        best_accuracy = phase_accu
                        best_model = deepcopy(model.state_dict())
                    if phase_loss < least_loss:
                        least_loss = phase_loss

                # Logging phase time
                phase_time_end = time.time()
                phase_time = ftime(phase_time_start,phase_time_end)
                if is_tr:
                    eprint(e,statstr(phase, phase_loss,phase_accu, phase_time, mean_inf, False))
                else:
                    eprint(e,statstr(phase, phase_loss,phase_accu, phase_time, mean_inf))

            # Logging epoch times
            epoch_time_end = time.time()
            epoch_time = f"epoch time: {ftime(epoch_time_start, epoch_time_end)}\n"
            eprint(e,epoch_time.rjust(len(epoch_time)+r_just_val+3)+"\n")

        if load_best:
            model.load_state_dict(best_model)

        # ----------------------------
        total_time_end = time.time()
        total_time = ftime(total_time_start,total_time_end)
        print(f"least loss: {least_loss:0.4f} | best accu: {best_accuracy:0.4f} | total time taken: {total_time}")

        return {"losses_epo":losses_epo,"accuracy_epo":accuracy_epo,"losses_all":losses_all}
    
    def fit(self, 
            epochs:int=1, print_every:int=1, 
            criteria:Union[str, None]=None, load_best:bool=False, 
            profiler:bool=False
           ):
        """
        Runs the training loop for `epochs`
        ----
        Parameters
         - epochs : should be a non negative integer
         
         - print_every : if 0 will not print, else will print at given epoch
         - criteria : value in the validation epoch end return dict
             that used is used to define a better model, eg: 'accuracy'
         - load_best : whether to load the best model after training for `epochs`
         - profiler : whether to keep track of time taken by various sections
        
        """
        pass

SyntaxError: non-default argument follows default argument (<ipython-input-99-b9ca058edd5d>, line 4)

FitLoop
- should keep track of epochs that have been completed
- epoch_number can be reset 
- metrics can be cleared
- `FitLoop.set_name.loop_stage['metric_name']` to access the metric
- `FitLoop.set_name.loop_stage['metric_name']` to access the metric
- `FitLoop.store_pretrained:bool` arg to store the pretrained weights before training
    if path then store at given path else store in memory.
- `FitLoop.reset(reset_model:bool)` to clear metrics, epoch_num and to reset the model, to pretrained state
    will load the weight from passed path else from memory.
- `FitLoop.save(path:str)` to save the model and training state somehow even the fitloop state.
- `FitLoop.load(path:str)` to load the FitLoop state from given path.
- Some basic step should be used such that one can use it without defining step functions.
- `FitLoop.fit(continue:[bool,int]=False)` ask after `int` whether to continue training or to end.

LoopState
- should cast the batch to device before passing it using `state.batch()`
- should get the batch num `state.batch_num` and epoch num `state.epoch_num`
- the model, optimizer, loss_function, lr_scheduler should be available
    `state.model`, `state.optimizer`, `state.loss_function`, `state.lr_scheduler`
- should return the batch metrics as float tensors using square bracket indexing
    `state['loss']` 

## Portions of the loop

- **Single step** Portion of the loop that receives the batch and will cal forward pass on it.
    - *Set* zero grad
    - **Train Step** 
        - *Set* 
            - enable gradients
            - model.train
        - *Define*: 
            - compute loss
            - call backward 
            - update gradients and maybe scheduler step
            - loss and whatever metrics will be saved in a list.
    - **Valid Step**
        - *Set*
            - disable gradients
            - model.eval
        - *Define*: 
            - compute loss
            - loss and whatever metrics will be saved in a list.
    - **Test Step** Only in test loop
        - *Set*
            - disable gradients
            - model.eval
        - *Define*: 
            - compute loss
            - loss and whatever metrics will be saved in a list.
- **Pre Epoch**
- **Post Epoch**

## Reference

In [None]:
# Function to calculate a time string from
# 2 instances of time.time
def ftime(t1,t2):
    t = t2-t1
    s, ms = str(t).split('.')
    ms = ms[:3]
    s = str(int(s)%60).rjust(2)
    m = int(t//60)
    h = str(m//60)
    m = str(m % 60).rjust(2)
    u = ['h','m','s','ms']
    v = [h,m,s,ms]
    t = filter(lambda x:int(x[0]) > 0,zip(v,u))
    return ' '.join([' '.join(x) for x in t]).ljust(20)

In [None]:
def fit(model, dl_train, dl_valid, optim, loss_function, epochs, sched=None, should_pass=False, print_every=1, load_best=True):
    """
    Args:
    # model
      - pytorch model to be trained
      
    # dl_train
      - Dataloader for the train set
      
    # dl_valid
      - Dataloader for the valid set
      
    # optimizer
      - optimizer attached to model params from torch.optim
      
    # loss_function
      - function to calculate the loss; loss_function(model(X),y)
      
    # epochs
      - number of epochs to train for
      
    # device
      - torch.device to load the model to 
      
    # sched
      - scheduler for the optimizer learning rate
      
    # should_pass
      - if loss should be passed to the scheduler.
      
    # print_every
      - to print the loss every n epochs
    
    # load_best 
      - whether to load the best model as determined by accuracy
    """
    
    total_time_start = time.time()
    # ----------------------------
    phases = [TR,VA]
    dl = {TR:dl_train, VA:dl_valid}
    sizes = {k:len(dl[k].dataset) for k in dl}
    
    # To contain metrics to plot later
    losses_all = {TR:[],VA:[]}
    losses_epo = {TR:[],VA:[]}
    accuracy_epo = {TR:[],VA:[]}
    
    # Markers
    least_loss = float('inf')
    best_accuracy = 0
    best_model = deepcopy(model.state_dict())
    
    # ----------------------------
    # Convenience functions (that won't be used elsewhere)
    
    # Function to get formatted epochs (from 1 not 0)
    r_just_val = len(str(epochs))*2 + 3
    estr = lambda e: f"[{e + 1}/{epochs}]".rjust(r_just_val)
    
    # Convenience function to print every `print_every` epochs.
    def eprint(e,st):
        if ((e + 1) % print_every == 0):
            print(st,end="")

    def grad_times(t):
        t = t*1000
        return f"{t:0.3f} ms".rjust(10)
    
    # Convenience function for phase strings.
    def statstr(phase,loss,accu,time,infr, rjust=True):
        infr = grad_times(infr)
        st =  f"{phase} :: loss: {loss:0.4f} | accu: {accu:0.4f} | infr: {infr}  | time: {time} \n"
        if rjust:
            return st.rjust(r_just_val + len(st) + 3)
        else:
            return st

    
    # ----------------------------
    # Training 
    model.to(device)
    for e in range(epochs):
        epoch_time_start = time.time()
        
        # Training or validation phase
        for phase in phases:
            phase_time_start = time.time()

            # Running metrics
            running_loss = 0
            running_corr = 0

            is_tr = phase == TR
            if is_tr:
                  model.train()
            else:
                  model.eval()
                  
            # Keeping track of computation times.
            inference_times = []
            dl_p = dl[phase]
            dl_l = len(dl_p)
            if is_tr:
                eprint(e,estr(e)+f" - ")
            for b,batch in enumerate(dl_p):
                X,y = batch
                bs = y.size(0)
                X = X.to(device)
                y = y.to(device)
                
                optim.zero_grad()
                
                with torch.set_grad_enabled(is_tr):
                    # Forward pass and log times 
                    inference_time_start = time.time()
                    y_ = model(X)
                    inference_time_end = time.time()
                    inference_times.append((inference_time_end-inference_time_start)/bs)
                    
                    loss = loss_function(y_,y)
                    if is_tr:
                        # Backprop (grad calc and param update)
                        loss.backward()
                        optim.step()
                        
                    elif not is_tr and sched is not None:
                        # Scheduler step
                        if should_pass:
                            sched.step(loss.item())
                        else:
                            sched.step()
                    
                running_loss += loss.item() * bs
                running_corr += (y_.argmax(dim=1) == y).sum().item()
                losses_all[phase].append(loss.item())
             
            # Calculate phase losses and scores
            phase_loss = running_loss / sizes[phase]
            phase_accu = running_corr / sizes[phase]
                  
            # Calculate timings
            mean_inf = np.array(inference_times).mean()
            # Log scores and losses
            losses_epo[phase].append(phase_loss)
            accuracy_epo[phase].append(phase_accu)
                
            # Update markers save best model
            if not is_tr:
                if phase_accu > best_accuracy:
                    best_accuracy = phase_accu
                    best_model = deepcopy(model.state_dict())
                if phase_loss < least_loss:
                    least_loss = phase_loss
            
            # Logging phase time
            phase_time_end = time.time()
            phase_time = ftime(phase_time_start,phase_time_end)
            if is_tr:
                eprint(e,statstr(phase, phase_loss,phase_accu, phase_time, mean_inf, False))
            else:
                eprint(e,statstr(phase, phase_loss,phase_accu, phase_time, mean_inf))
        
        # Logging epoch times
        epoch_time_end = time.time()
        epoch_time = f"epoch time: {ftime(epoch_time_start, epoch_time_end)}\n"
        eprint(e,epoch_time.rjust(len(epoch_time)+r_just_val+3)+"\n")
    
    if load_best:
        model.load_state_dict(best_model)
    
    # ----------------------------
    total_time_end = time.time()
    total_time = ftime(total_time_start,total_time_end)
    print(f"least loss: {least_loss:0.4f} | best accu: {best_accuracy:0.4f} | total time taken: {total_time}")
    
    return {"losses_epo":losses_epo,"accuracy_epo":accuracy_epo,"losses_all":losses_all}

## ToDo
- Figure out the portions of the loop.
- CPU, GPU, RAM profiling also along with time take to run.