In [None]:
# default_exp train
%load_ext autoreload
%autoreload 2

# train


> API details.

In [None]:
#hide
from nbdev.showdoc import *

![pipeline](images/graphic10.PNG)

In [None]:
#export
from fastai.vision.all import *
from timm import create_model
from fastai.vision.learner import _update_first_layer

### Using timm models

**>> With `fastai` version 2.6.0, `timm` models are integrated into `fastai` so you do not need to use this anymore**


The following can be originally found in this [notebook](https://github.com/muellerzr/Practical-Deep-Learning-for-Coders-2.0/blob/master/Computer%20Vision/05_EfficientNet_and_Custom_Weights.ipynb)

In [None]:
#export
def create_timm_body(arch:str, pretrained=True, cut=None, n_in=3):
    "Creates a body from any model in the `timm` library."
    model = create_model(arch, pretrained=pretrained, num_classes=0, global_pool='')
    _update_first_layer(model, n_in, pretrained)
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut): return cut(model)
    else: raise NamedError("cut must be either integer or function")

In [None]:
#export
def create_timm_model(arch:str, n_out, cut=None, pretrained=True, n_in=3, init=nn.init.kaiming_normal_, custom_head=None,
                     concat_pool=True, **kwargs):
    "Create custom architecture using `arch`, `n_in` and `n_out` from the `timm` library"
    body = create_timm_body(arch, pretrained, None, n_in)
    if custom_head is None:
        nf = num_features_model(nn.Sequential(*body.children()))
        head = create_head(nf, n_out, concat_pool=concat_pool, **kwargs)
    else: head = custom_head
    model = nn.Sequential(body, head)
    if init is not None: apply_init(model[1], init)
    return model

In [None]:
#export
def timm_learner(dls, arch:str, loss_func=None, pretrained=True, cut=None, splitter=None,
                y_range=None, config=None, n_out=None, normalize=True, **kwargs):
    "Build a convnet style learner from `dls` and `arch` using the `timm` library"
    if config is None: config = {}
    if n_out is None: n_out = get_c(dls)
    assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
    if y_range is None and 'y_range' in config: y_range = config.pop('y_range')
    model = create_timm_model(arch, n_out, default_split, pretrained, y_range=y_range, **config)
    learn = Learner(dls, model, loss_func=loss_func, splitter=default_split, **kwargs)
    if pretrained: learn.freeze()
    return learn

### Callbacks

In [None]:
#export
class EpochIteration(Callback):
    "Display Epoch and Iteration"
    def __init__(self, show_img=False):
        self.show_img = show_img
    def before_batch(self):
        if self.training is not False:
            b = f'Training: Epoch: {self.epoch} Iter: {self.iter} Loss:{self.loss}'
        else:
            b = f'Validation: Epoch: {self.epoch} Iter: {self.iter} Loss:{self.loss}'
        
        if self.show_img is not False: show_images(self.learn.xb[0], suptitle=b)

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 02_explore.ipynb.
Converted 03_preprocessing.ipynb.
Converted 04_pipeline.ipynb.
Converted 05_train.ipynb.
Converted 06_examine.ipynb.
Converted 10_wearable.ipynb.
Converted 20_retinanet.ipynb.
Converted 90_tutorial.ipynb.
Converted index.ipynb.
