In [None]:
#default_exp learner

In [None]:
#export
from mantisshrimp.imports import *
from mantisshrimp.core import *
from mantisshrimp.models import *

<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Learner" data-toc-modified-id="Learner-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Learner</a></span><ul class="toc-item"><li><span><a href="#Core" data-toc-modified-id="Core-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Core</a></span></li><li><span><a href="#Visualize" data-toc-modified-id="Visualize-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>Visualize</a></span></li></ul></li><li><span><a href="#Export" data-toc-modified-id="Export-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Export</a></span></li></ul></div>

# Learner
> Definition of the learner class.

## Core

In [None]:
#export
class Learner:
    def __init__(self, m, train_dl, valid_dl, opt_fn, logger=None):
        store_attr(self, 'm,train_dl,valid_dl,opt_fn')
        self.logger = logger or True
        self.gpus = get_all_available_gpus()
        
    @delegates(Trainer.__init__)
    def fit(self, max_epochs, lr, lr_sched_fn=None, gpus=None, callbacks=None, **kwargs):
        self.m.prepare_optimizers(self.opt_fn, lr, sched_fn=lr_sched_fn)
        gpus = ifnone(gpus, self.gpus)
        cbs = L(LearningRateLogger()) + L(callbacks)
        trainer = Trainer(max_epochs=max_epochs, logger=self.logger, callbacks=cbs,  gpus=gpus, **kwargs)
        trainer.fit(self.m, self.train_dl, self.valid_dl)
        
    @delegates(Trainer.__init__)
    def fit_one_cycle(self, max_epochs, lr_max, pct_start=.25, **kwargs):
        def lr_sched_fn(opt):
            lrs = self.m.get_lrs(lr_max)
            sched = OneCycleLR(opt, lrs, len(self.train_dl)*max_epochs, pct_start=pct_start)
            return {'scheduler':sched, 'interval':'step'}
        return self.fit(max_epochs=max_epochs, lr=lr_max, lr_sched_fn=lr_sched_fn, **kwargs)
        
    @delegates(Trainer.__init__)
    def lr_find(self, gpus=None, **kwargs):
        self.m.configure_optimizers = self._configure_optimizers(0, None)
        gpus = ifnone(gpus, self.gpus)
        return Trainer(gpus=gpus, **kwargs).lr_find(self.m, self.train_dl, self.valid_dl)

## Visualize

In [None]:
#export
@patch
def show_results(self:Learner, k=5):
    rs = random.choices(self.valid_dl.dataset.records, k=k)
    show_preds(*self.m.predict(rs=rs))

# Export

In [None]:
from nbdev.export import notebook2script
notebook2script()

Converted 00a_core.ipynb.
Converted 00b_lightning_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_data.core.ipynb.
Converted 04_data.annotations.ipynb.
Converted 06_data.load.ipynb.
Converted 07_transforms.ipynb.
Converted 08_models.ipynb.
Converted 09_learner.ipynb.
Converted 11_metrics.core.ipynb.
Converted Untitled.ipynb.
Converted Untitled1.ipynb.
Converted data_refactor.ipynb.
Converted index.ipynb.
