In [1]:
__name__ = "k1lib.callbacks"

In [2]:
#export
from .callbacks import Callback, Callbacks, Cbs
import k1lib, numpy as np, time
import matplotlib.pyplot as plt
from typing import Callable
try: import torch; hasTorch = True
except: hasTorch = False
__all__ = ["Landscape"]

In [2]:
#export
spacing = 0.35 # orders of magnitude
offset = -2 # orders of magnitude shift
res = 20 # resolution
scales = 10**(np.array(range(8))*spacing + offset)
scales = [round(scale, 3) for scale in scales]
scales

[0.01, 0.022, 0.05, 0.112, 0.251, 0.562, 1.259, 2.818]

In [3]:
#export
F = Callable[["k1lib.Learner"], float]
@k1lib.patch(Cbs)
class Landscape(Callback):
    " "
    def __init__(self, propertyF:F, name:str=None):
        """Plots the landscape of the network.

:param propertyF: a function that takes in :class:`k1lib.Learner` and outputs the
    desired float property

.. warning::

    Remember to detach anything you get from :class:`k1lib.Learner` in your
    function, or else you're gonna cause a huge memory leak.
"""
        super().__init__(); self.propertyF = propertyF; self.suspended = True
        self.name = name or self.name; self.order = 23; self.parent:Callback = None
    def startRun(self): self.originalParams = self.l.model.exportParams()
    def endRun(self): self.l.model.importParams(self.originalParams)
    def startPass(self):
        next(self.iter)
        for param, og, v1, v2 in zip(self.l.model.parameters(), self.originalParams, *self.vs):
            param.data = og + self.x * v1 + self.y * v2
    def endLoss(self):
        prop = self.propertyF(self.l)
        self.zs[self.ix, self.iy] = prop if prop == prop else 0 # check for nan
        if self.l.batch % 10: print(f"\rProgress: {round(100*(self.ix+self.iy/res)/res)}%, {round(time.time()-self.beginTime)}s      ", end="")
    def startBackward(self): return True
    def startStep(self): return True
    def startZeroGrad(self): return True
    def __iter__(self):
        """This one is the "core running loop", if you'd like to say so. Because
this needs to be sort of event-triggered (by checkpoint "startPass"), so kinda have
to put this into an iterator so that it's not the driving thread."""
        self.zss = [] # debug data
        for i, (scale, ax) in enumerate(zip(scales, self.axes)):
            a = torch.linspace(-scale, scale, res)
            xs, ys = np.meshgrid(a, a); self.zs = np.empty((res, res))
            xs = torch.tensor(xs); ys = torch.tensor(ys)
            for ix in range(res):
                for iy in range(res):
                    self.x = xs[ix, iy]; self.y = ys[ix, iy]
                    self.ix, self.iy = ix, iy; yield True
            self.zs[self.zs == float("inf")] = 0
            ax.plot_surface(xs, ys, self.zs, cmap=plt.cm.coolwarm)
            self.zss.append(self.zs)
            print(f"     {i+1}/8 Finished [{-scale}, {scale}] range              ", end="")
        raise k1lib.CancelRunException("Landscape finished")
    def plot(self):
        """Creates the landscapes and show plots"""
        self.suspended = False; self.iter = iter(self); self.beginTime = time.time()
        def inner():
            self.vs = [self.l.model.getParamsVector(), self.l.model.getParamsVector()]
            fig, axes = plt.subplots(2, 4, subplot_kw={"projection": "3d"}, figsize=(16, 8), dpi=120)
            self.axes = axes.flatten(); self.l.run(1000000)
        try:
            with self.cbs.suspendEval(), torch.no_grad(): inner()
        except: pass
        self.suspended = True; self.iter = None
    def __repr__(self):
        return f"""{super()._reprHead}, use...
- l.plot(): to plot everything
{super()._reprCan}"""

In [4]:
import random
l = k1lib.Learner.sample(); c = k1lib.viz.Carousel()
l.cbs.add(Landscape(lambda l: random.random()));
l.run(1); l.Landscape.plot(); c.savePlt()
assert len(c[0][1]) > 10000; #c

Progress: 100%, epoch: 0/1, batch: 299/300, elapsed:   0.12s, loss: 20.62818717956543              

<Figure size 1920x960 with 0 Axes>

In [1]:
!../../export.py callbacks/landscape

Current dir: /home/kelvin/repos/labs/k1lib, ../../export.py
rm: cannot remove '__pycache__': No such file or directory
Found existing installation: k1lib 1.1
Uninstalling k1lib-1.1:
  Successfully uninstalled k1lib-1.1
running install
running bdist_egg
running egg_info
creating k1lib.egg-info
writing k1lib.egg-info/PKG-INFO
writing dependency_links to k1lib.egg-info/dependency_links.txt
writing requirements to k1lib.egg-info/requires.txt
writing top-level names to k1lib.egg-info/top_level.txt
writing manifest file 'k1lib.egg-info/SOURCES.txt'
reading manifest file 'k1lib.egg-info/SOURCES.txt'
adding license file 'LICENSE'
writing manifest file 'k1lib.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build
creating build/lib
creating build/lib/k1lib
copying k1lib/_learner.py -> build/lib/k1lib
copying k1lib/fmt.py -> build/lib/k1lib
copying k1lib/_k1a.py -> build/lib/k1lib
copying k1lib/_context.py -> build/lib/k1