In [None]:
# default_exp model_wrapper

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

# model wrapper

> API details.

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# export
from pytorch_inferno.callback import AbsCallback, PredHandler
from pytorch_inferno.utils import to_device, device
from pytorch_inferno.data import DataPair, WeightedDataLoader, DataSet

from typing import Optional, Union, List, Generator, Callable
from fastcore.all import store_attr, is_listy, typedispatch, Path
from fastprogress import master_bar, progress_bar
import numpy as np

from torch.tensor import Tensor
import torch
import torch.nn as nn
from torch import optim

In [None]:
# export
class ModelWrapper():
    def __init__(self, model:nn.Module, device:torch.device=device):
        self.model,self.device = to_device(model, device),device
        
    def _fit_batch(self, x:Tensor, y:Tensor, w:Tensor) -> None:
        self.x,self.y,self.w = to_device(x,self.device),to_device(y,self.device),to_device(w,self.device)
        for c in self.cbs: c.on_batch_begin()
        self.y_pred = self.model(self.x)
        if self.state != 'test' and self.loss_func is not None:
            self.loss_func.weights = self.w
            self.loss_val = self.loss_func(self.y_pred, self.y)
        for c in self.cbs: c.on_forwards_end()
        if self.state != 'train': return

        self.opt.zero_grad()
        for c in self.cbs: c.on_backwards_begin()
        self.loss_val.backward()
        for c in self.cbs: c.on_backwards_end()
        self.opt.step()
        for c in self.cbs: c.on_batch_end()
            
    def fit(self, n_epochs:int, data:DataPair, opt:Callable[[Generator],optim.Optimizer],
            loss:Optional[Callable[[Tensor,Tensor],Tensor]], cbs:Optional[Union[AbsCallback,List[AbsCallback]]]=None) -> None:
        def fit_epoch(epoch:int) -> None:
            self.model.train()
            self.state = 'train'
            self.epoch = epoch
            for c in self.cbs: c.on_epoch_begin()
            for b in progress_bar(self.data.trn_dl, parent=self.mb): self._fit_batch(*b)
            for c in self.cbs: c.on_epoch_end()

            self.model.eval()
            self.state = 'valid'
            for c in self.cbs: c.on_epoch_begin()
            for b in progress_bar(self.data.val_dl, parent=self.mb): self._fit_batch(*b)
            for c in self.cbs: c.on_epoch_end()
            
        if cbs is None: cbs = []
        elif not is_listy(cbs): cbs = [cbs]
        self.cbs,self.stop,self.n_epochs = cbs,False,n_epochs
        self.data,self.loss_func,self.opt = data,loss,opt(self.model.parameters())
        for c in self.cbs: c.set_wrapper(self)
        for c in self.cbs: c.on_train_begin()
        self.mb = master_bar(range(self.n_epochs))
        for e in self.mb:
            fit_epoch(e)
            if self.stop: break
        for c in self.cbs: c.on_train_end()
    
    def _predict_dl(self, x:WeightedDataLoader, pred_cb:PredHandler=PredHandler(),
                cbs:Optional[Union[AbsCallback,List[AbsCallback]]]=None) -> np.ndarray:            
        if cbs is None: cbs = []
        elif not is_listy(cbs): cbs = [cbs]
        cbs.append(pred_cb)
        self.cbs,self.data = cbs,x
        self.state = 'test'
        for c in self.cbs: c.set_wrapper(self)
        self.model.eval()
        for c in self.cbs: c.on_pred_begin()
        for b in progress_bar(self.data): self._fit_batch(*b)
        for c in self.cbs: c.on_pred_end()
        return pred_cb.get_preds()
    
    def _predict_array(self, x:Union[Tensor,np.ndarray], pred_cb:PredHandler=PredHandler(),
                   cbs:Optional[Union[AbsCallback,List[AbsCallback]]]=None) -> np.ndarray:
        return self._predict_dl(WeightedDataLoader(DataSet(x), batch_size=len(x)), pred_cb, cbs)
    
    def predict(self, x:Union[Tensor,np.ndarray], pred_cb:PredHandler=PredHandler(),
                cbs:Optional[Union[AbsCallback,List[AbsCallback]]]=None) -> np.ndarray:
        if isinstance(x, WeightedDataLoader): return self._predict_dl(x, pred_cb, cbs)
        else:                                 return self._predict_array(x, pred_cb, cbs)
        
    def save(self, fname:Union[Path,str]) -> None: torch.save({'model':self.model.state_dict()}, fname)
        
    def load(self, fname:Union[Path,str]) -> None:
        state = torch.load(fname, map_location='cpu')
        self.model.load_state_dict(state['model'])
        self.model = to_device(self.model, device)

# Testing

In [None]:
from pytorch_inferno.callback import LossTracker, EarlyStopping
from pytorch_inferno.data import get_paper_data

from fastcore.all import partialler

In [None]:
n = 1000
model = ModelWrapper(nn.Sequential(nn.Linear(3,50),nn.ReLU(),nn.Linear(50,1),nn.Sigmoid()))
data, test = get_paper_data(n, bm=0, bs=64, n_test=n)

In [None]:
model.fit(10, data=data, opt=partialler(optim.SGD,lr=2e-3), loss=nn.BCELoss(),
          cbs=[LossTracker(),EarlyStopping(5)])

Train: 0.7777950843175252 Valid: 0.7710541779994965
Train: 0.739490024248759 Valid: 0.7274892117977142
Train: 0.6959971030553181 Valid: 0.6971303887367248
Train: 0.6743967692057292 Valid: 0.6751605319976807
Train: 0.6553388277689616 Valid: 0.6592334063053131
Train: 0.641307270526886 Valid: 0.6469841578006744
Train: 0.6313625574111938 Valid: 0.6365882155895233
Train: 0.625999140739441 Valid: 0.627690737247467
Train: 0.6166404604911804 Valid: 0.6199122014045715
Train: 0.6116449395815532 Valid: 0.6127951893806457


In [None]:
preds = model.predict(test)

In [None]:
assert len(preds) == n

In [None]:
preds.shape

(1000, 1)

In [None]:
preds = model.predict(test.dataset.x)

In [None]:
assert len(preds) == n

In [None]:
preds.shape

(1000, 1)