In [46]:
from dev.data_utils import seed_everything
from fastai.vision import *
import torch

In [47]:
seed_everything(42)
w1 = torch.rand((3,1), requires_grad=True)
seed_everything(42)
w2 = torch.rand((3,1), requires_grad=True)
assert torch.equal(w1, w2)

In [48]:
xs = torch.rand(size=(4,3)).float()
ys = torch.randint(high=2, size=(4,))

In [49]:
xs, ys

(tensor([[0.9593, 0.3904, 0.6009],
         [0.2566, 0.7936, 0.9408],
         [0.1332, 0.9346, 0.5936],
         [0.8694, 0.5677, 0.7411]]), tensor([0, 1, 1, 1]))

#### No accumulate `reduction=sum` 

In [67]:
seed_everything(42)
w1 = torch.rand((3,3), requires_grad=True)
w2 = torch.rand((3,1), requires_grad=True)
def model(x): return torch.sigmoid((F.relu(x@w1))@w2 + 0.5)
crit = BCEFlat()
opt = torch.optim.SGD([w1, w2], lr=0.1)

In [68]:
crit = BCEFlat(reduction='sum')
opt = torch.optim.SGD([w1, w2], lr=0.1)

In [69]:
# Full batch
pred = model(xs)
loss = crit(pred, ys)
# one graph is created here
opt.zero_grad()
loss.backward()
# graph is cleared here
opt.step()

In [70]:
w1

tensor([[0.8712, 0.8370, 0.3333],
        [0.9566, 0.3714, 0.5888],
        [0.2510, 0.7546, 0.9160]], requires_grad=True)

In [71]:
w1.grad

tensor([[0.1111, 0.7799, 0.4953],
        [0.0271, 0.1901, 0.1208],
        [0.0556, 0.3900, 0.2477]])

In [72]:
opt.zero_grad()

#### Accumulate `reduction=sum`

In [73]:
seed_everything(42)
w1 = torch.rand((3,3), requires_grad=True)
w2 = torch.rand((3,1), requires_grad=True)
def model(x): return torch.sigmoid((F.relu(x@w1))@w2 + 0.5)
crit = BCEFlat(reduction='sum')
opt = torch.optim.SGD([w1, w2], lr=0.1)

In [74]:
opt.zero_grad()
pred = model(xs[:2])
loss = crit(pred, ys[:2])
# accumulate grads
loss.backward()
pred = model(xs[2:])
loss = crit(pred, ys[2:])
# accumulate grads
loss.backward()
opt.step()

In [75]:
w1

tensor([[0.8712, 0.8370, 0.3333],
        [0.9566, 0.3714, 0.5888],
        [0.2510, 0.7546, 0.9160]], requires_grad=True)

In [76]:
w1.grad

tensor([[0.1111, 0.7799, 0.4953],
        [0.0271, 0.1901, 0.1208],
        [0.0556, 0.3900, 0.2477]])

In [77]:
opt.zero_grad()

#### No Accumulate with `reduction=mean`

In [78]:
seed_everything(42)
w1 = torch.rand((3,3), requires_grad=True)
w2 = torch.rand((3,1), requires_grad=True)
def model(x): return torch.sigmoid((F.relu(x@w1))@w2 + 0.5)
crit = BCEFlat(reduction='mean')
opt = torch.optim.SGD([w1, w2], lr=0.1)

In [79]:
# Full batch
pred = model(xs)
loss = crit(pred, ys)
# one graph is created here
opt.zero_grad()
loss.backward()
# graph is cleared here
opt.step()

In [80]:
w1

tensor([[0.8795, 0.8955, 0.3705],
        [0.9586, 0.3857, 0.5979],
        [0.2552, 0.7839, 0.9346]], requires_grad=True)

In [81]:
w1.grad

tensor([[0.0278, 0.1950, 0.1238],
        [0.0068, 0.0475, 0.0302],
        [0.0139, 0.0975, 0.0619]])

In [21]:
opt.zero_grad()

#### Accumulate equal to `reduction=mean` using `reduction=sum`

In [82]:
seed_everything(42)
w1 = torch.rand((3,3), requires_grad=True)
w2 = torch.rand((3,1), requires_grad=True)
def model(x): return torch.sigmoid((F.relu(x@w1))@w2 + 0.5)
crit = BCEFlat(reduction='sum')
opt = torch.optim.SGD([w1, w2], lr=0.1)

In [83]:
opt.zero_grad()
pred = model(xs[:2])
loss = crit(pred, ys[:2])
# accumulate grads
loss.backward()
pred = model(xs[2:])
loss = crit(pred, ys[2:])
# accumulate grads
loss.backward()

In [84]:
w1.grad.div_(4)

tensor([[0.0278, 0.1950, 0.1238],
        [0.0068, 0.0475, 0.0302],
        [0.0139, 0.0975, 0.0619]])

In [85]:
opt.step()

In [86]:
w1

tensor([[0.8795, 0.8955, 0.3705],
        [0.9586, 0.3857, 0.5979],
        [0.2552, 0.7839, 0.9346]], requires_grad=True)

In [87]:
w1.grad

tensor([[0.0278, 0.1950, 0.1238],
        [0.0068, 0.0475, 0.0302],
        [0.0139, 0.0975, 0.0619]])

In [88]:
opt.zero_grad()