# Experiment with tst_param and average_grad 

## What are `tst_param` and `average_grad`

In [1]:
# CLICK ME
from fastai.vision.all import *
from IPython.display import IFrame
from pdb import set_trace

In [6]:
import pdb

In [None]:
!python

Python 3.9.10 | packaged by conda-forge | (main, Feb  1 2022, 21:28:27) 
[Clang 11.1.0 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> 

In [6]:
def average_grad(p, mom, dampening=False, grad_avg=None, **kwargs):
    "Keeps track of the avg grads of `p` in `state` with `mom`."
    if grad_avg is None: grad_avg = torch.zeros_like(p.grad.data)
    damp = 1-mom if dampening else 1.
    grad_avg.mul_(mom).add_(p.grad.data, alpha=damp)
    return {'grad_avg': grad_avg}

average_grad.defaults = dict(mom=0.9)

`dampening=False` gives the classical formula for momentum in SGD: 
```
new_val = old_val * mom + grad
```
whereas `dampening=True` makes it an exponential moving average:
```
new_val = old_val * mom + grad * (1-mom)
```

In [7]:
def tst_param(val, grad=None):
    "Create a tensor with `val` and a gradient of `grad` for testing"
    res = tensor([val]).float()
    res.grad = tensor([val/10 if grad is None else grad]).float()
    return res

In [8]:
p = tst_param([1,2,3], [4,5,6]) 

state = {}
state = average_grad(p, mom=0.9, **state)
test_eq(state['grad_avg'], p.grad)
print(state)

state = average_grad(p, mom=0.9, **state)
test_eq(state['grad_avg'], p.grad * 1.9)
print(state)

#Test dampening
state = {}
state = average_grad(p,  mom=0.9, dampening=True, **state)
test_eq(state['grad_avg'], 0.1*p.grad)
print(state)

state = average_grad(p, mom=0.9, dampening=True, **state)
test_close(state['grad_avg'], (0.1*0.9+0.1)*p.grad)
print(state)

{'grad_avg': tensor([[4., 5., 6.]])}
{'grad_avg': tensor([[ 7.6000,  9.5000, 11.4000]])}
{'grad_avg': tensor([[0.4000, 0.5000, 0.6000]])}
{'grad_avg': tensor([[0.7600, 0.9500, 1.1400]])}


## Experiment tst_param with pdb

In [2]:
def average_grad(p, mom, dampening=False, grad_avg=None, **kwargs):
    "Keeps track of the avg grads of `p` in `state` with `mom`."
    set_trace()
    if grad_avg is None: grad_avg = torch.zeros_like(p.grad.data)
    damp = 1-mom if dampening else 1.
    grad_avg.mul_(mom).add_(p.grad.data, alpha=damp)
    return {'grad_avg': grad_avg}

average_grad.defaults = dict(mom=0.9)

In [5]:
def tst_param(val, grad=None):
    "Create a tensor with `val` and a gradient of `grad` for testing"
    set_trace()
    res = tensor([val]).float()
    res.grad = tensor([val/10 if grad is None else grad]).float()
    return res

In [8]:
p = tst_param([1,2,3], [4,5,6]) 

## Experiment average_grad with pdb

In [None]:
state = {}
state = average_grad(p, mom=0.9, **state)

<function save_history at 0x148662d30>
> [0;32m/var/folders/gz/ch3n2mp51m9386sytqf97s6w0000gn/T/ipykernel_43631/3504488234.py[0m(5)[0;36maverage_grad[0;34m()[0m
[0;32m      3 [0;31m    [0;34m"Keeps track of the avg grads of `p` in `state` with `mom`."[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      4 [0;31m    [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 5 [0;31m    [0;32mif[0m [0mgrad_avg[0m [0;32mis[0m [0;32mNone[0m[0;34m:[0m [0mgrad_avg[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mzeros_like[0m[0;34m([0m[0mp[0m[0;34m.[0m[0mgrad[0m[0;34m.[0m[0mdata[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      6 [0;31m    [0mdamp[0m [0;34m=[0m [0;36m1[0m[0;34m-[0m[0mmom[0m [0;32mif[0m [0mdampening[0m [0;32melse[0m [0;36m1.[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m    [0mgrad_avg[0m[0;34m.[0m[0mmul_[0m[0;34m([0m[0mmom[0m[0;34m)[0m[0;34m.[0m[0madd_[0m[0;34m([0m[0mp[

In [None]:
test_eq(state['grad_avg'], p.grad)
state = average_grad(p, mom=0.9, **state)
test_eq(state['grad_avg'], p.grad * 1.9)

<function save_history at 0x143faab80>
> [0;32m/var/folders/gz/ch3n2mp51m9386sytqf97s6w0000gn/T/ipykernel_42776/3917144823.py[0m(4)[0;36mtst_param[0;34m()[0m
[0;32m      2 [0;31m    [0;34m"Create a tensor with `val` and a gradient of `grad` for testing"[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      3 [0;31m    [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 4 [0;31m    [0mres[0m [0;34m=[0m [0mtensor[0m[0;34m([0m[0;34m[[0m[0mval[0m[0;34m][0m[0;34m)[0m[0;34m.[0m[0mfloat[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0mres[0m[0;34m.[0m[0mgrad[0m [0;34m=[0m [0mtensor[0m[0;34m([0m[0;34m[[0m[0mval[0m[0;34m/[0m[0;36m10[0m [0;32mif[0m [0mgrad[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mgrad[0m[0;34m][0m[0;34m)[0m[0;34m.[0m[0mfloat[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      6 [0;31m    [0;32mreturn[0m [0mres[0m[0;34m

In [45]:
#Test dampening
state = {}
state = average_grad(p,  mom=0.9, dampening=True, **state)
test_eq(state['grad_avg'], 0.1*p.grad)
state = average_grad(p, mom=0.9, dampening=True, **state)
test_close(state['grad_avg'], (0.1*0.9+0.1)*p.grad)