This has got to be one of my favorite notebooks because of learning how callbacks work

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
import torch
import matplotlib.pyplot as ply
import ipywidgets as widgets

In [3]:
def f(o): print('hi')

In [4]:
w = widgets.Button(description = "Click me")

In [5]:
w

Button(description='Click me', style=ButtonStyle())

In [6]:
w.on_click(f)

In [7]:
from time import sleep

In [8]:
def slow_calculation():
    res = 0
    for i in range(5):
        res += i*i
        sleep(0.5)
    return res

In [9]:
slow_calculation()

30

In [10]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        res += i*i
        sleep(0.5)
        if cb: cb(i)
    return res

In [11]:
def show_progress(epoch):
    print(f'Yayayaya! Got done with epoch {epoch}!')

In [12]:
slow_calculation(show_progress)

Yayayaya! Got done with epoch 0!
Yayayaya! Got done with epoch 1!
Yayayaya! Got done with epoch 2!
Yayayaya! Got done with epoch 3!
Yayayaya! Got done with epoch 4!


30

In [13]:
slow_calculation(lambda epoch: print(f'Yayayaya! Got done with epoch {epoch}!') )

Yayayaya! Got done with epoch 0!
Yayayaya! Got done with epoch 1!
Yayayaya! Got done with epoch 2!
Yayayaya! Got done with epoch 3!
Yayayaya! Got done with epoch 4!


30

In [14]:
def make_show_progress(exclamation):
    # Leading "_" is generally understood to be "private"
    def _inner(epoch): print(f"{exclamation}! We've finished epoch {epoch}!")
    return _inner

In [15]:
slow_calculation(make_show_progress('Nice'))

Nice! We've finished epoch 0!
Nice! We've finished epoch 1!
Nice! We've finished epoch 2!
Nice! We've finished epoch 3!
Nice! We've finished epoch 4!


30

In [16]:
class ProgressShowingCallback():
    def __init__(self, exclamation = "Yayayayaya"): self.exclamation = exclamation
    def __call__(self,epoch): print(f"{self.exclamation}! We've finished epoch {epoch}!")

In [17]:
cb = ProgressShowingCallback("I cannot believe it")

In [18]:
slow_calculation(cb)

I cannot believe it! We've finished epoch 0!
I cannot believe it! We've finished epoch 1!
I cannot believe it! We've finished epoch 2!
I cannot believe it! We've finished epoch 3!
I cannot believe it! We've finished epoch 4!


30

In [19]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb: cb.before_calc(i)
        res += i*i
        sleep(0.5)
        if cb: cb.after_calc(i, val = res)
    return res

In [20]:
class PrintStepCallback():
    def __init__(self): pass
    def before_calc(self, *args, **kwargs): print(f"About to Start")
    def after_calc(self, *args, **kwargs): print(f"Done step")

In [21]:
slow_calculation(PrintStepCallback())

About to Start
Done step
About to Start
Done step
About to Start
Done step
About to Start
Done step
About to Start
Done step


30

In [22]:
class PrintStepCallback():
    def __init__(self): pass
    def before_calc(self, epoch, *args, **kwargs): print(f"About to start epoch {epoch}!")
    def after_calc(self, epoch, val, *args, **kwargs): print(f"After epoch {epoch}, val is {val}")

In [23]:
slow_calculation(PrintStepCallback())

About to start epoch 0!
After epoch 0, val is 0
About to start epoch 1!
After epoch 1, val is 1
About to start epoch 2!
After epoch 2, val is 5
About to start epoch 3!
After epoch 3, val is 14
About to start epoch 4!
After epoch 4, val is 30


30

In [24]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb and hasattr(cb,'before_calc'): cb.before_calc(i)
        res += i*i
        sleep(0.5)
        if cb and hasattr(cb,'after_calc'):
            if cb.after_calc(i, res):
                print("stopping early")
                break
    return res

In [25]:
class PrintAfterCallback():
    def after_calc(self,epoch,val):
        print(f"After {epoch}, val is {val}")
        if val > 10: return True

In [26]:
slow_calculation(PrintAfterCallback())

After 0, val is 0
After 1, val is 1
After 2, val is 5
After 3, val is 14
stopping early


14

In [27]:
class SlowCalculator():
    def __init__(self, cb=None): self.cb,self.res = cb,0
        
    def callback(self,cb_name,*args):
        if not self.cb: return
        cb = getattr(self.cb,cb_name,None)
        if cb: return cb(self,*args)
        
    def calc(self):
        for i in range(5):
            self.callback('before_calc',i)
            self.res += i*i
            sleep(0.5)
            if self.callback('after_calc',i):
                print("stopping early")
                break

In [28]:
class PrintAfterCallback():
    def after_calc(self,calc,epoch):
        print(f"After {epoch}, val is {calc.res}")
        if calc.res > 10: return True

In [29]:
calculator = SlowCalculator(PrintAfterCallback())

In [30]:
calculator.calc()

After 0, val is 0
After 1, val is 1
After 2, val is 5
After 3, val is 14
stopping early


In [31]:
calculator.res

14

### \_\_call__ magic happens here

This is the part where callback function is replaced with \_\_call__

In [32]:
class SlowCalculator():
    def __init__(self, cb=None): self.cb,self.res = cb,0
        
    def __call__(self,cb_name,*args):
        if not self.cb: return
        cb = getattr(self.cb,cb_name,None)
        if cb: return cb(self,*args)
        
    def calc(self):
        for i in range(5):
            self('before_calc',i)
            self.res += i*i
            sleep(0.5)
            if self('after_calc',i):
                print("stopping early")
                break

In [33]:
calculator = SlowCalculator(PrintAfterCallback())

In [34]:
calculator.calc()

After 0, val is 0
After 1, val is 1
After 2, val is 5
After 3, val is 14
stopping early


In [35]:
calculator.res

14

In [36]:
class ModifyingCallback():
    def after_calc (self, calc, epoch):
        print(f"After {epoch}: {calc.res}")
        if calc.res>10: return True
        if calc.res<3: calc.res = calc.res*2

In [37]:
calculator = SlowCalculator(ModifyingCallback())

In [38]:
calculator.calc()
calculator.res

After 0: 0
After 1: 1
After 2: 6
After 3: 15
stopping early


15

### `__dunder__` thingies

In [47]:
class SloppyAdder():
    def __init__(self,o): self.o = o
    def __add__(self,b): return SloppyAdder(self.o + b.o + 0.01)
    def __repr__(self): return str(self.o)

In [48]:
a = SloppyAdder(1)
b = SloppyAdder(2)
a+b

3.01