In [146]:
from dev.data_utils import seed_everything
from fastai.vision import *
import torch

In [147]:
seed_everything(42)
w1 = torch.rand((3,1), requires_grad=True)
seed_everything(42)
w2 = torch.rand((3,1), requires_grad=True)
assert torch.equal(w1, w2)

In [148]:
xs = torch.rand(size=(4,3)).float()
ys = torch.randint(high=2, size=(4,))

In [149]:
xs, ys

(tensor([[0.9593, 0.3904, 0.6009],
         [0.2566, 0.7936, 0.9408],
         [0.1332, 0.9346, 0.5936],
         [0.8694, 0.5677, 0.7411]]), tensor([0, 1, 1, 1]))

#### No accumulate `reduction=sum` 

In [189]:
seed_everything(42)
w1 = torch.rand((3,1), requires_grad=True)
def model(x): return torch.sigmoid(x@w1 + 0.5)
crit = BCEFlat()
opt = torch.optim.SGD([w1], lr=0.1)

In [190]:
crit = BCEFlat(reduction='sum')
opt = torch.optim.SGD([w1], lr=0.1)

In [191]:
# Full batch
pred = model(xs)
loss = crit(pred, ys)
# one graph is created here
opt.zero_grad()
loss.backward()
# graph is cleared here
opt.step()

In [192]:
w1

tensor([[0.8139],
        [0.9128],
        [0.3610]], requires_grad=True)

In [193]:
w1.grad

tensor([[0.6841],
        [0.0218],
        [0.2183]])

In [194]:
opt.zero_grad()

#### Accumulate `reduction=sum`

In [195]:
seed_everything(42)
w1 = torch.rand((3,1), requires_grad=True)
def model(x): return torch.sigmoid(x@w1 + 0.5)
crit = BCEFlat(reduction='sum')
opt = torch.optim.SGD([w1], lr=0.1)

In [196]:
opt.zero_grad()
pred = model(xs[:2])
loss = crit(pred, ys[:2])
# accumulate grads
loss.backward()
pred = model(xs[2:])
loss = crit(pred, ys[2:])
# accumulate grads
loss.backward()
opt.step()

In [197]:
w1

tensor([[0.8139],
        [0.9128],
        [0.3610]], requires_grad=True)

In [198]:
w1.grad

tensor([[0.6841],
        [0.0218],
        [0.2183]])

In [199]:
opt.zero_grad()

#### No Accumulate with `reduction=mean`

In [201]:
seed_everything(42)
w1 = torch.rand((3,1), requires_grad=True)
def model(x): return torch.sigmoid(x@w1 + 0.5)
crit = BCEFlat()
opt = torch.optim.SGD([w1], lr=0.1)

In [203]:
crit = BCEFlat(reduction='mean')
opt = torch.optim.SGD([w1], lr=0.1)

In [204]:
# Full batch
pred = model(xs)
loss = crit(pred, ys)
# one graph is created here
opt.zero_grad()
loss.backward()
# graph is cleared here
opt.step()

In [205]:
w1

tensor([[0.8652],
        [0.9145],
        [0.3774]], requires_grad=True)

In [206]:
w1.grad

tensor([[0.1710],
        [0.0054],
        [0.0546]])

In [207]:
opt.zero_grad()

#### Accumulate equal to `reduction=mean` using `reduction=sum`

In [208]:
seed_everything(42)
w1 = torch.rand((3,1), requires_grad=True)
def model(x): return torch.sigmoid(x@w1 + 0.5)
crit = BCEFlat(reduction='sum')
opt = torch.optim.SGD([w1], lr=0.1)

In [209]:
opt.zero_grad()
pred = model(xs[:2])
loss = crit(pred, ys[:2])
# accumulate grads
loss.backward()
pred = model(xs[2:])
loss = crit(pred, ys[2:])
# accumulate grads
loss.backward()

In [211]:
w1.grad.div_(4)

tensor([[0.1710],
        [0.0054],
        [0.0546]])

In [213]:
opt.step()

In [214]:
w1

tensor([[0.8652],
        [0.9145],
        [0.3774]], requires_grad=True)

In [215]:
w1.grad

tensor([[0.1710],
        [0.0054],
        [0.0546]])

In [216]:
opt.zero_grad()