# Pipeline

Basic PyTorch pipeline for general training.

In [None]:
#| default_exp pipeline.pipeline

In [None]:
#| export
from genQC.imports import *
from genQC.util import virtual, number_of_paramters, DataLoaders
from genQC.metrics import *
from genQC.config_loader import *

## Helper

In [None]:
#| export
nn.Loss = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]

## Pipeline

Note, uses functions that require: python>=3.9

In [None]:
#| export
class Pipeline_IO:   
    """A class providing basic IO functionality."""
    def get_config(self, save_path: str, without_metadata=False):       
        params_config = self.params_config(save_path)  
              
        if not without_metadata:       
            config = {}
            config["target"]         = class_to_str(type(self))
            config["save_datetime"]  = datetime.now().strftime("%m/%d/%Y %H:%M:%S")
            config["params"]         = params_config             
            
            fit = {}
            if hasattr(self, "num_epochs"):         fit["num_epochs"]         = self.num_epochs
            if hasattr(self, "batch_size"):         fit["batch_size"]         = self.batch_size
            if hasattr(self, "lr"):                 fit["lr"]                 = self.lr
            if hasattr(self, "lr_sched"):           fit["lr_sched"]           = class_to_str(type(self.lr_sched))  
            if hasattr(self, "optimizer"):          fit["optimizer"]          = class_to_str(type(self.optimizer))   
            if hasattr(self, "dataset_size_train"): fit["dataset_size_train"] = self.dataset_size_train
            if hasattr(self, "dataset_size_valid"): fit["dataset_size_valid"] = self.dataset_size_valid
            config["fit"] = fit 
            
        else:
            config = params_config
            
        self.config = config        
        return config
    
    @virtual
    def params_config(self, save_path: str): return None
    
    def store_pipeline(self, config_path: str, save_path: str): 
        if exists(config_path): os.makedirs(config_path, exist_ok=True)       
        if exists(save_path): 
            os.makedirs(save_path, exist_ok=True)
            if hasattr(self, "fit_losses"):       np.savetxt(save_path + "fit_losses.txt", self.fit_losses)            
            if hasattr(self, "fit_valid_losses"): np.savetxt(save_path + "fit_valid_losses.txt", self.fit_valid_losses) 
                             
    @virtual
    @staticmethod
    def from_config_file(config_path, device: torch.device, save_path: str=None): return None 

In [None]:
#| export
class Pipeline(Pipeline_IO):
    """A `Pipeline_IO` class providing basic pytorch model training functionality."""
    def __init__(self, 
                 model: nn.Module,
                 device: torch.device):
        self.model  = model
        self.device = device
      
    #------------------------------------
             
    @virtual
    def __call__(self, inp): pass
        
    @virtual
    def train_step(self, data, train=True, **kwargs): pass

    #------------------------------------
     
    def compile(self, optim_fn: type(torch.optim.Optimizer), loss_fn: nn.Loss, metrics: Union[Metric, list[Metric]]=None, lr=None, **kwargs):       
        self.loss_fn    = loss_fn()
        self.optim_fn   = optim_fn 
        self.optimizer  = optim_fn(self.model.parameters(), lr=lr, **kwargs) if lr else None
        
        metrics = {m.name:m for m in metrics} if metrics else {}
        #metrics |= {f"{m.name}_valid":m for m in metrics.values()}
        metrics["loss"]       = Mean("loss", self.device)  
        metrics["loss_valid"] = Mean("loss_valid", self.device)  
                          
        self.metrics = metrics    
                          
    def _reset_opt(self, lr, **kwargs): self.optimizer = self.optim_fn(self.model.parameters(), lr, **kwargs)   
    
    def _set_opt_param(self, lr, **kwargs):
        '''at least lr: Does not reset existing optimizer, only changes learn rate.'''
        self.lr = lr
        if lr:  
            if self.optimizer:           
                for g in self.optimizer.param_groups: 
                    g['lr'] = lr
                    for k,v in kwargs.items(): g[k] = v                  
            else: self._reset_opt(lr, **kwargs)
         
    #------------------------------------

    def train_on_batch(self, data, train=True):
        loss = self.train_step(data, train=train)   
          
        if train:            
            #zero grads
            self.optimizer.zero_grad()

            #backprob
            loss.backward()

            #update weights
            self.optimizer.step()
    
        return loss.detach()
    
    def train_on_epoch(self, data_loader: DataLoader, train=True):   
        self.model.train(train)
        
        mode = "" if train else "_valid"
        
        with self.progress_bar(total=len(data_loader), epoch=self.epoch, unit=" batch") as batch_prgb:                   
            for batch, data in enumerate(data_loader):    
                loss = self.train_on_batch(data, train=train)                                                                       
                self.metrics["loss"+mode].update_state(loss)            
                
                if train:
                    self.fit_losses.append(loss.item())                
                    if self.lr_sched: self.lr_sched.step()
                
                #pack up metrics
                self.out_metric_dict = {m.name:m.result().tolist() for m in self.metrics.values() if not m.empty}               
                self.end_batch_metrics(batch_prgb, **self.out_metric_dict)   
        
    #run on train and one on valid
    def fit(self, num_epochs: int, data_loaders: DataLoaders, lr: float=None, lr_sched=None, log_summary=True):
        if not hasattr(self, "loss_fn"): raise RuntimeError("'compile' has to be called first")       
       
        self._set_opt_param(lr=lr)    
        if lr_sched: self.lr_sched = lr_sched(self.optimizer)
        else: self.lr_sched = None
            
        self.num_epochs = num_epochs
        self.epochs     = range(num_epochs)
        self.fit_losses = []
        self.fit_valid_losses = []                 
        self.batch_size = data_loaders.train.batch_size 
        self.dataset_size_train = len(data_loaders.train)
        if data_loaders.valid: self.dataset_size_valid = len(data_loaders.valid)
        
            
        with self.progress_bar(total=num_epochs, desc="Fit", unit=" epoch") as epoch_prgb:       
            for self.epoch in self.epochs:   
                        
                #reset all metrics
                for m in self.metrics.values(): m.reset_state()       
                
                #train set
                self.train_on_epoch(data_loaders.train, train=True) 
          
                #valid set
                if data_loaders.valid: 
                    self.train_on_epoch(data_loaders.valid, train=False) 
                    self.fit_valid_losses.append([(self.epoch+1)*len(data_loaders.train), 
                                                   self.out_metric_dict["loss_valid"] ]) 
                    
                self.end_epoch_metrics(epoch_prgb, **self.out_metric_dict)

        self.fit_summary(log_summary=log_summary)
                
    #------------------------------------
                       
    def summary(self): print("Number of model parameters:", number_of_paramters(self.model))

    def fit_summary(self, figsize=(12,2), log_summary=True, return_fig=False):
        fig = plt.figure(figsize=figsize, constrained_layout=True)                
        plt.xlabel("Batches")
        plt.ylabel("Loss")
        if log_summary: plt.yscale('log') 
        plt.plot(self.fit_losses, label="train")
        if len(self.fit_valid_losses) > 0: 
            data = np.array(self.fit_valid_losses)
            plt.plot(data[:,0],data[:,1], label="valid", color="tab:orange")
            plt.plot(data[:,0],data[:,1], ".", color="tab:orange")
        plt.legend()
        if return_fig: return fig
        plt.show()
             
    #------------------------------------
        
    def progress_bar(self, iterable=None, total=None, epoch: int=None, **progress_bar_config): 
        if not hasattr(self, "_progress_bar_config"):
            self._progress_bar_config = {}
        elif not isinstance(self._progress_bar_config, dict):
            raise ValueError(f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}.")
        
        prgb_conf = self._progress_bar_config 
        
        if epoch is not None: prgb_conf |= {"desc":f"Epoch {epoch}"} 
        if progress_bar_config is not None: prgb_conf |= progress_bar_config  

        if iterable is not None: return tqdm(iterable, **self._progress_bar_config)
        elif total is not None:
            self._n_total = total
            return tqdm(total=total, **self._progress_bar_config)
        else: raise ValueError("Either `total` or `iterable` has to be defined.")
        
    def end_progress_bar_iteration(self, prgb:tqdm, print_lines=False, name="", index=None, **metrics):
        if metrics is not None: prgb.set_postfix(**metrics)        
        prgb.update()
        
        if not print_lines: return
        
        n_total = f"/{self._n_total}: " if hasattr(self, "_n_total") else ": "

        if index is not None: prgb.write(f"{name} {index:03}" + n_total, end='') 
        else: prgb.write(f"{name} {(prgb.n):03}" + n_total, end='') #(prgb.n+1)-1 bcs we update first

        if metrics is not None: prgb.write(str(metrics))            
        else: prgb.write("")  
                            
    def end_epoch_metrics(self, prgb:tqdm, epoch: int=None, **metrics): self.end_progress_bar_iteration(prgb, False, "Epoch", epoch, **metrics)             
    def end_batch_metrics(self, prgb:tqdm, batch: int=None, **metrics): self.end_progress_bar_iteration(prgb, False, "Batch", batch, **metrics)
    
    #------------------------------------
        

# Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()