In [1]:
import torch
import matplotlib.pyplot as plt
import random

#### widgets

In [3]:
import ipywidgets as widgets

In [9]:
w = widgets.Button(description='click')
w.on_click(lambda x: print('hi'))
w

Button(description='click', style=ButtonStyle())

hi
hi


#### callback basics

In [10]:
import time

In [11]:
def slow_calc():
    res = 0
    for i in range(5):
        res += i**2
        time.sleep(1)
    return res

In [12]:
slow_calc()

30

In [18]:
def slow_calc(cb=None):
    res = 0
    for i in range(5):
        res += i**2
        time.sleep(1)
        if cb:
            cb(i+1)
    return res

In [19]:
def show_progress(i):
    print(f"completed: {i}")

In [20]:
slow_calc(cb=show_progress)

completed: 1
completed: 2
completed: 3
completed: 4
completed: 5


30

#### partials

In [21]:
def show_progress(msg, i):
    print(f"{msg}: {i}")

In [22]:
slow_calc(lambda i: show_progress('hey completed!', i))

hey completed!: 1
hey completed!: 2
hey completed!: 3
hey completed!: 4
hey completed!: 5


30

In [23]:
def make_show_progress(msg):
    def _f(i):
        print(f"{msg}: {i}")
    return _f

In [24]:
slow_calc(make_show_progress('Nice'))

Nice: 1
Nice: 2
Nice: 3
Nice: 4
Nice: 5


30

In [25]:
from functools import partial

In [27]:
slow_calc(partial(show_progress, 'oha'))

oha: 1
oha: 2
oha: 3
oha: 4
oha: 5


30

#### callback as callable classes

In [31]:
class ProgressShowClb():
    def __init__(self, msg):
        self.msg = msg
    
    def __call__(self, i):
        print(f"{self.msg}: {i}")

In [32]:
cb = ProgressShowClb('ola ola')

In [33]:
slow_calc(cb)

ola ola: 1
ola ola: 2
ola ola: 3
ola ola: 4
ola ola: 5


30

#### multiple callback funcs; *args and **kwargs

In [34]:
def f(*args, **kwargs):
     print(f"args: {args}, kwargs:{kwargs}")

In [37]:
f('h', '0', t='v', k='n')

args: ('h', '0'), kwargs:{'t': 'v', 'k': 'n'}


In [38]:
def g(a, b, c=0):
    print(a, b, c)

In [40]:
args = ['h', 'e']
kwargs = {'c': 'o'}
g(*args, **kwargs)

h e o


In [47]:
def slow_calc(cb=None):
    res = 0
    for i in range(5):
        if cb:
            cb.before_clb(i)
        res += i**2
        if cb:
            cb.after_clb(i, val=res)
    return res

In [48]:
class ProgressCB():
    def before_clb(self, *args, **kwargs):
        print('sted started')
    
    def after_clb(self, *args, **kwargs):
        print('step done')

In [49]:
slow_calc(ProgressCB())

sted started
step done
sted started
step done
sted started
step done
sted started
step done
sted started
step done


30

In [50]:
class ProgressCB():
    def before_clb(self, *args, **kwargs):
        print('sted started')
    
    def after_clb(self, epoch, val, **kwargs):
        print(f'step:{epoch}, val:{val}')

In [51]:
slow_calc(ProgressCB())

sted started
step:0, val:0
sted started
step:1, val:1
sted started
step:2, val:5
sted started
step:3, val:14
sted started
step:4, val:30


30

#### modifying func behaviour through callbacks

In [53]:
def slow_calc(cb=None):
    res = 0
    for i in range(5):
        if cb:
            cb.before_clb(i)
        res += i**2
        if cb:
            if cb.after_clb(i, res):
                break
    return res

In [54]:
class ProgressCB():
    def before_clb(self, *args, **kwargs):
        print('sted started')
    
    def after_clb(self, epoch, val, **kwargs):
        if val == 5:
            return True
        print(f'step:{epoch}, val:{val}')

In [55]:
slow_calc(cb=ProgressCB())

sted started
step:0, val:0
sted started
step:1, val:1
sted started


5

In [73]:
class SlowCalc():
    def __init__(self, cb=None):
        self.res, self.cb = 0, cb
    
    def callback(self, cb_name, *args):
        if not self.cb:
            return
        cb = getattr(self.cb, cb_name, None)
        if cb:
            return cb(self, *args)
    
    def calc(self):
        for i in range(5):
            self.callback('before_cb', i)
            self.res += i**2
            time.sleep(1)
            if self.callback('after_cb', i):
                print('early stopping')
                break

In [74]:
class ProgressCB():
    def after_cb(self, calc, epoch):
        if calc.res >= 10:
            return True
        print(f'step:{epoch}, val:{calc.res}')

In [76]:
calc = SlowCalc(cb=ProgressCB())
calc.calc()
calc.res

step:0, val:0
step:1, val:1
step:2, val:5
early stopping


14