In [1]:
#export
import k1lib, torch.nn as nn, torch, dill

In [2]:
#export
class Learner:
    def __init__(self, *args, **kwargs):
        self.model = None; self.data = None; self.opt = None
        self.lossF = None; self._cbs = None; self.fileName = None
        self.cbs = k1lib.Callbacks().withBasics().withQOL().withAdvanced()
    @property
    def cbs(self): return self._cbs
    @cbs.setter
    def cbs(self, cbs): cbs.learner = self; self._cbs = cbs
    def __getattr__(self, attr):
        if attr == "cbs": raise AttributeError()
        return getattr(self.cbs, attr)
    def __getstate__(self): return dict(self.__dict__)
    def __setstate__(self, state):
        self.__dict__.update(state)
        self.cbs.learner = self
    def evaluate(self): pass # supposed to be overriden, to provide functionality here
    @property
    def warnings(self):
        warnings = "Warning: no model yet. Set using `l.model = ...`\n" if self.model == None else ""
        warnings += "Warning: no loss function yet. Set using `l.lossF = ...`\n" if self.lossF == None else ""
        warnings += "Warning: no data yet. Set using `l.data = ...`\n" if self.data == None else ""
        warnings += "Warning: no optimizer yet. Set using `l.opt = ...`\n" if self.opt == None else ""
        if warnings != "": warnings += "\n\n"
        return warnings
    def __repr__(self):
        return f"""{self.warnings}l.model:\n{k1lib.tab(str(self.model))}
l.opt:\n{k1lib.tab(str(self.opt))}
l.cbs:\n{k1lib.tab(self.cbs.__repr__())}
Use...
- l.configure(...): to change any configs (data, model, opt, lossF)
- l.withCbs(...): to use a custom `Callbacks` object
- l.run(epochs): to run the network
- l.Loss: to get a specific callback\n\n"""

In [3]:
#export
@k1lib.patch(Learner)
def save(self, fileName=None):
    self.fileName = fileName or self.fileName
    if self.fileName == None:
        files = [file for file in os.listdir() if file.startswith("learner") and file.endswith(".pth")]
        files = set([int(file.split(".pth")[0].split("learner-")[1]) for file in files])
        count = 0;
        while count in files: count += 1
        self.fileName = f"l-{count}.pth"
    torch.save(self, self.fileName, pickle_module=dill)
    print(f"Saved to {self.fileName}")
@k1lib.patch(Learner)
def load(fileName=None):
    if fileName==None: fileName = input("Enter learner file name to load:")
    print(f"Loaded from {fileName}"); return torch.load(fileName, pickle_module=dill)

In [4]:
#export
@k1lib.patch(Learner)
def _run1Batch(self):
    self.cbs("startBatch")
    try:
        self.cbs("startPass"); self.y = self.model(self.xb); self.cbs("endPass")
        self.cbs("startLoss"); loss = self.lossF(self.y, self.yb)
        self.loss = loss.detach().item(); self.cbs("endLoss")
        if not self.cbs("startBackward"): loss.backward()
        if not self.cbs("startStep"):  self.opt.step()
        if not self.cbs("startZeroGrad"): self.opt.zero_grad()
    except k1lib.CancelBatchException as ex:
        self.cbs("cancelBatch")
        print(f"Batch cancelled: {ex}.")
    self.cbs("endBatch")

In [5]:
#export
@k1lib.patch(Learner)
def _run1Epoch(self):
    self.cbs("startEpoch")
    try:
        self.nBatches = len(self.data.train); self.model.train()
        for self.batch, (self.xb, self.yb) in enumerate(self.data.train):
            self._run1Batch()
        self.nBatches = len(self.data.valid); self.model.eval(); self.cbs("startValidBatches")
        for self.batch, (self.xb, self.yb) in enumerate(self.data.valid):
            self._run1Batch()
    except k1lib.CancelEpochException as ex:
        self.cbs("cancelEpoch")
        print(f"Epoch cancelled: {ex}.")
    self.cbs("endEpoch")

In [6]:
#export
@k1lib.patch(Learner)
def run(self, epochs):
    self.epochs = epochs
    if self.warnings != "": raise Exception(self.warnings)
    self.cbs("startRun")
    try:
        for self.epoch in range(epochs):
            self._run1Epoch()
    except k1lib.CancelRunException as ex:
        self.cbs("cancelRun")
        print(f"Run cancelled: {ex}.")
    self.cbs("endRun"); return self

In [7]:
#export
class Recorder(k1lib.Callback):
    def __init__(self):
        super().__init__()
        self.order = 20
        self.xbs = []; self.ybs = []; self.ys = []
    def startBatch(self):
        self.xbs += [self.xb]
        self.ybs += [self.yb]
    def endPass(self):
        self.ys += [self.y]
    @property
    def values(self): return self.xbs, self.ybs, self.ys
@k1lib.patch(Learner)
def record(self, epochs=1, batchesPerEpoch=4):
    """Like run(), but:
- There's no training
- Have optional batch limiter
- Returns recorded xBatch, yBatch and answer y"""
    self.cbs.suspend(["Loss", "HookParam", "HookModule", "ParamScheduler", "Autosave", "CancelOnExplosion"])
    self.cbs.withBatchLimit(batchesPerEpoch, "_record_BatchLimit")
    self.cbs.withDontTrain("_record_DontTrain")
    self.cbs.append(Recorder(), "_record_Recorder")
    self.run(epochs)
    answer = self.cbs._record_Recorder.values
    self.cbs.removePrefix("_record_").restore(); return answer
@k1lib.patch(Learner)
def evaluate(self): raise NotImplementedError()

In [8]:
!../export.py _learner

Current dir: /home/kelvin/repos/labs/k1lib, ../export.py
rm: cannot remove '__pycache__': No such file or directory
Found existing installation: k1lib 0.1.0
Uninstalling k1lib-0.1.0:
  Successfully uninstalled k1lib-0.1.0
running install
running bdist_egg
running egg_info
creating k1lib.egg-info
writing k1lib.egg-info/PKG-INFO
writing dependency_links to k1lib.egg-info/dependency_links.txt
writing top-level names to k1lib.egg-info/top_level.txt
writing manifest file 'k1lib.egg-info/SOURCES.txt'
reading manifest file 'k1lib.egg-info/SOURCES.txt'
writing manifest file 'k1lib.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build
creating build/lib
creating build/lib/k1lib
copying k1lib/_learner.py -> build/lib/k1lib
copying k1lib/callbacks.py -> build/lib/k1lib
copying k1lib/data.py -> build/lib/k1lib
copying k1lib/imports.py -> build/lib/k1lib
copying k1lib/_basics.py -> build/lib/k1lib
copying k1lib/nn.py -> bui

In [33]:
import torch

In [28]:
l = Lambda(lambda x: x[None])

In [32]:
l(torch.rand(2, 3))

tensor([[[0.0042, 0.3004, 0.6085],
         [0.9042, 0.8013, 0.1034]]])

In [10]:
a = {"a": 3, "b": 5}
a.update({"b": 8, "c": 11})

In [11]:
a

{'a': 3, 'b': 8, 'c': 11}