### TEST SET 1: Parameter Semantics

Test 1.1: Parameter is a PopulationTensor
Purpose: Verify abstraction separation.

In [45]:

w = Parameter([1.0, 2.0, 3.0])

assert isinstance(w, PopulationNode)
assert w.data == [1.0, 2.0, 3.0]
assert w.grad == [0.0, 0.0, 0.0]


What this demonstrates:
- Parameters are signals plus learning capability
- No special casing hidden in Tensor

Test 1.2: zero_grad works

In [46]:
w.grad = [0.5, -0.3, 1.2]
w.zero_grad()
assert w.grad == [0.0, 0.0, 0.0]


Conceptual check:
- Gradient accumulation is explicit
- No silent resets during backprop

Test 1.3: step() performs GD update

In [47]:
w = Parameter([1.0, -2.0])
w.grad = [0.1, -0.2]
w.step(lr=0.5)

assert w.data == [0.95, -1.9]


This is critical:
- This is the learning rule
- You can point to it and say: “This is learning.”

### TEST SET 2: GradientDescent Class


Test 2.1 — Optimizer updates all parameters

In [48]:

w1 = Parameter([1.0])
w2 = Parameter([2.0])

w1.grad = [1.0]
w2.grad = [2.0]

opt = GD([w1, w2], lr=0.1)
opt.step()

assert w1.data == [0.9]
assert w2.data == [1.8]


Test 2.2: Optimizer zero_grad

Interpretation:
- Optimizer controls time
- Parameters control state

### Experiment 1: quadratic_1d.py
Setup

Parameter:
- w

Loss:
$θ_t = r^tθ_0→0$

Run with:
- r = -0.5, it oscillates but shrinks.
- r = 1.2, it explodes.
- r = -1.2, it oscillates and explodes.

You must be able to say
- Why oscillations happen
- Why divergence happens
- Why gradient is correct but learning fails

In [73]:
# @title
def stability(lr, a=1.0):
    w = Parameter([5.0])          # θ₀ = 5
    target = 1.0
    opt = GD([w], lr=lr)

    for t in range(20):
        w.zero_grad()

        # L = 1/2 * a * θ^2
        loss = 0.5 * a * (w.data[0] ** 2)

        # dL/dθ = aθ
        w.grad[0] = a * w.data[0]

        opt.step()
        print(t, w.data[0])


#### r < 1, Vanishing gradient

In [87]:
stability(lr=0.3, a=1.0)

0 3.5
1 2.45
2 1.7150000000000003
3 1.2005000000000003
4 0.8403500000000003
5 0.5882450000000001
6 0.41177150000000007
7 0.2882400500000001
8 0.20176803500000007
9 0.14123762450000005
10 0.09886633715000004
11 0.06920643600500002
12 0.04844450520350002
13 0.033911153642450016
14 0.02373780754971501
15 0.016616465284800506
16 0.011631525699360355
17 0.008142067989552249
18 0.005699447592686574
19 0.003989613314880602


#### r = -0.5, it oscillates but shrinks.

In [80]:
stability(lr=1.5, a=1.0)

0 -2.5
1 1.25
2 -0.625
3 0.3125
4 -0.15625
5 0.078125
6 -0.0390625
7 0.01953125
8 -0.009765625
9 0.0048828125
10 -0.00244140625
11 0.001220703125
12 -0.0006103515625
13 0.00030517578125
14 -0.000152587890625
15 7.62939453125e-05
16 -3.814697265625e-05
17 1.9073486328125e-05
18 -9.5367431640625e-06
19 4.76837158203125e-06


#### r = 1.2, it explodes.


In [81]:
stability(lr=-0.2, a=1.0)

0 6.0
1 7.2
2 8.64
3 10.368
4 12.441600000000001
5 14.929920000000001
6 17.915904
7 21.499084800000002
8 25.798901760000003
9 30.958682112000005
10 37.1504185344
11 44.58050224128
12 53.496602689536005
13 64.1959232274432
14 77.03510787293185
15 92.44212944751821
16 110.93055533702186
17 133.11666640442624
18 159.7399996853115
19 191.68799962237378


#### r = -1.2, it oscillates and explodes.

In [83]:
stability(lr =2.2, a =1.0)

0 -6.0
1 7.200000000000001
2 -8.640000000000002
3 10.368000000000004
4 -12.441600000000006
5 14.929920000000012
6 -17.91590400000002
7 21.499084800000027
8 -25.798901760000035
9 30.958682112000048
10 -37.15041853440006
11 44.58050224128007
12 -53.49660268953609
13 64.19592322744332
14 -77.03510787293199
15 92.44212944751841
16 -110.93055533702211
17 133.11666640442655
18 -159.7399996853119
19 191.68799962237426


### Experiment 2 on population

In [90]:
# @title
def stability_population(lr, a=1.0):
    w = Parameter([5.0, 5.0, 5.0])     # population of 3 units
    opt = GD([w], lr=lr)

    for t in range(20):
        w.zero_grad()

        # L = 1/2 * a * sum_i w_i^2
        loss = 0.5 * a * sum(x*x for x in w.data)

        # dL/dw_i = a * w_i
        for i in range(len(w.data)):
            w.grad[i] = a * w.data[i]

        opt.step()
        print(f"{t:02d} | w = {w.data}")


#### r < 1, Vanishing gradient

In [91]:
stability_population(lr = 0.3)

00 | w = [3.5, 3.5, 3.5]
01 | w = [2.45, 2.45, 2.45]
02 | w = [1.7150000000000003, 1.7150000000000003, 1.7150000000000003]
03 | w = [1.2005000000000003, 1.2005000000000003, 1.2005000000000003]
04 | w = [0.8403500000000003, 0.8403500000000003, 0.8403500000000003]
05 | w = [0.5882450000000001, 0.5882450000000001, 0.5882450000000001]
06 | w = [0.41177150000000007, 0.41177150000000007, 0.41177150000000007]
07 | w = [0.2882400500000001, 0.2882400500000001, 0.2882400500000001]
08 | w = [0.20176803500000007, 0.20176803500000007, 0.20176803500000007]
09 | w = [0.14123762450000005, 0.14123762450000005, 0.14123762450000005]
10 | w = [0.09886633715000004, 0.09886633715000004, 0.09886633715000004]
11 | w = [0.06920643600500002, 0.06920643600500002, 0.06920643600500002]
12 | w = [0.04844450520350002, 0.04844450520350002, 0.04844450520350002]
13 | w = [0.033911153642450016, 0.033911153642450016, 0.033911153642450016]
14 | w = [0.02373780754971501, 0.02373780754971501, 0.02373780754971501]
15 | w = [

#### r = -0.5, it oscillates but shrinks.

In [95]:
stability_population(lr = 1.5)

00 | w = [-2.5, -2.5, -2.5]
01 | w = [1.25, 1.25, 1.25]
02 | w = [-0.625, -0.625, -0.625]
03 | w = [0.3125, 0.3125, 0.3125]
04 | w = [-0.15625, -0.15625, -0.15625]
05 | w = [0.078125, 0.078125, 0.078125]
06 | w = [-0.0390625, -0.0390625, -0.0390625]
07 | w = [0.01953125, 0.01953125, 0.01953125]
08 | w = [-0.009765625, -0.009765625, -0.009765625]
09 | w = [0.0048828125, 0.0048828125, 0.0048828125]
10 | w = [-0.00244140625, -0.00244140625, -0.00244140625]
11 | w = [0.001220703125, 0.001220703125, 0.001220703125]
12 | w = [-0.0006103515625, -0.0006103515625, -0.0006103515625]
13 | w = [0.00030517578125, 0.00030517578125, 0.00030517578125]
14 | w = [-0.000152587890625, -0.000152587890625, -0.000152587890625]
15 | w = [7.62939453125e-05, 7.62939453125e-05, 7.62939453125e-05]
16 | w = [-3.814697265625e-05, -3.814697265625e-05, -3.814697265625e-05]
17 | w = [1.9073486328125e-05, 1.9073486328125e-05, 1.9073486328125e-05]
18 | w = [-9.5367431640625e-06, -9.5367431640625e-06, -9.5367431640625e-0

#### r = 1.2, it explodes.

In [96]:
stability_population(lr=-0.2, a=1.0)

00 | w = [6.0, 6.0, 6.0]
01 | w = [7.2, 7.2, 7.2]
02 | w = [8.64, 8.64, 8.64]
03 | w = [10.368, 10.368, 10.368]
04 | w = [12.441600000000001, 12.441600000000001, 12.441600000000001]
05 | w = [14.929920000000001, 14.929920000000001, 14.929920000000001]
06 | w = [17.915904, 17.915904, 17.915904]
07 | w = [21.499084800000002, 21.499084800000002, 21.499084800000002]
08 | w = [25.798901760000003, 25.798901760000003, 25.798901760000003]
09 | w = [30.958682112000005, 30.958682112000005, 30.958682112000005]
10 | w = [37.1504185344, 37.1504185344, 37.1504185344]
11 | w = [44.58050224128, 44.58050224128, 44.58050224128]
12 | w = [53.496602689536005, 53.496602689536005, 53.496602689536005]
13 | w = [64.1959232274432, 64.1959232274432, 64.1959232274432]
14 | w = [77.03510787293185, 77.03510787293185, 77.03510787293185]
15 | w = [92.44212944751821, 92.44212944751821, 92.44212944751821]
16 | w = [110.93055533702186, 110.93055533702186, 110.93055533702186]
17 | w = [133.11666640442624, 133.1166664044

#### r = -1.2, it oscillates and explodes.


In [97]:
stability_population(lr =2.2, a =1.0)

00 | w = [-6.0, -6.0, -6.0]
01 | w = [7.200000000000001, 7.200000000000001, 7.200000000000001]
02 | w = [-8.640000000000002, -8.640000000000002, -8.640000000000002]
03 | w = [10.368000000000004, 10.368000000000004, 10.368000000000004]
04 | w = [-12.441600000000006, -12.441600000000006, -12.441600000000006]
05 | w = [14.929920000000012, 14.929920000000012, 14.929920000000012]
06 | w = [-17.91590400000002, -17.91590400000002, -17.91590400000002]
07 | w = [21.499084800000027, 21.499084800000027, 21.499084800000027]
08 | w = [-25.798901760000035, -25.798901760000035, -25.798901760000035]
09 | w = [30.958682112000048, 30.958682112000048, 30.958682112000048]
10 | w = [-37.15041853440006, -37.15041853440006, -37.15041853440006]
11 | w = [44.58050224128007, 44.58050224128007, 44.58050224128007]
12 | w = [-53.49660268953609, -53.49660268953609, -53.49660268953609]
13 | w = [64.19592322744332, 64.19592322744332, 64.19592322744332]
14 | w = [-77.03510787293199, -77.03510787293199, -77.03510787293