In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

In [2]:
x = torch.Tensor(2)
x.requires_grad = True
loss = 13 * x[0] + 7 * x[1] + 1
loss.backward()
print(x)
print(x.grad)

tensor([-1.6047e-01,  4.5818e-41], requires_grad=True)
tensor([13.,  7.])


In [3]:
x = torch.randn((2,3), requires_grad=True)
y = torch.Tensor([[1,2],[3,4]])
loss = y.mm(x)
loss = loss.sum()
loss.backward()
print(x.grad)

tensor([[4., 4., 4.],
        [6., 6., 6.]])


In [113]:
x = torch.Tensor([
    [7,6,5, 9],
    [4,3,2, 8]
])
x.requires_grad = True
y = torch.Tensor([[1,2],[3,4], [5,6]])
y.requires_grad = True
loss = y.mm(x)
loss = loss.sum()
loss.backward()
print(y.grad)

tensor([[27., 17.],
        [27., 17.],
        [27., 17.]])


In [109]:
x = torch.Tensor([
    [7,6,5, 9],
    [4,3,2, 8]
])
x.requires_grad = True
y = torch.Tensor([[1,2],[3,4], [5,6]])
y.requires_grad = True
loss = y.mm(x)
loss = loss.sum()
loss.backward()
True if (y @ x).sum() == (x.grad * x).sum() else False

True

In [123]:
x_grad = (torch.ones([x.shape[1], 1]) @ y.sum(dim=0, keepdim=True)).T
x_grad

tensor([[ 9.,  9.,  9.,  9.],
        [12., 12., 12., 12.]], grad_fn=<PermuteBackward0>)

In [124]:
y_grad = (x.sum(dim=1, keepdim=True) @ torch.ones([1, y.shape[0]])).T
y_grad

tensor([[27., 17.],
        [27., 17.],
        [27., 17.]], grad_fn=<PermuteBackward0>)

In [128]:
if x.grad.equal(x_grad):
    print(True)
if y.grad.equal(y_grad):
    print(True)

True
True


In [196]:
with torch.set_grad_enabled(True):
    x = torch.Tensor([[2, 3],[1, 4]])
    x.requires_grad = True
    y = torch.Tensor([[3,4],[5,6],[7,8]])
    y.requires_grad = True
    z = torch.Tensor([[10,20,30], [40,60,70], [80,90,5], [15,30,45]])
    z.requires_grad = True
    n = torch.Tensor([[11,22,33,44]
    , [1,2,3,4]])
    n.requires_grad = True

    p = y.mm(x)
    p = z.mm(p)
    p = n.mm(p)

    loss = p.sum()
    loss = loss.backward()

    print(n.grad)
    print(z.grad)
    print(y.grad)
    print(x.grad)


tensor([[3700., 9950., 8125., 5550.],
        [3700., 9950., 8125., 5550.]])
tensor([[ 420.,  660.,  900.],
        [ 840., 1320., 1800.],
        [1260., 1980., 2700.],
        [1680., 2640., 3600.]])
tensor([[23400., 23400.],
        [31800., 31800.],
        [21900., 21900.]])
tensor([[76500., 76500.],
        [91920., 91920.]])


In [200]:
grad_n = ((z @ y @ x).sum(dim=1, keepdim=True) @ torch.ones([1, n.shape[0]])).T
grad_n

tensor([[3700., 9950., 8125., 5550.],
        [3700., 9950., 8125., 5550.]], grad_fn=<PermuteBackward0>)

In [253]:
n.sum(dim=0).diag()

tensor([[12.,  0.,  0.,  0.],
        [ 0., 24.,  0.,  0.],
        [ 0.,  0., 36.,  0.],
        [ 0.,  0.,  0., 48.]], grad_fn=<DiagBackward0>)

In [255]:
grad_z = (((y @ x).sum(dim=1, keepdim=True) @ torch.ones([1, z.shape[0]])) @ n.sum(dim=0).diag()).T
grad_z

tensor([[ 420.,  660.,  900.],
        [ 840., 1320., 1800.],
        [1260., 1980., 2700.],
        [1680., 2640., 3600.]], grad_fn=<PermuteBackward0>)

In [286]:
# grad_y = ((x).sum(dim=1, keepdim=True) @ (n @ z)).T
grad_y = ((x).sum(dim=1, keepdim=True) @ torch.ones([1, y.shape[0]]))
grad_y

tensor([[5., 5., 5.],
        [5., 5., 5.]], grad_fn=<MmBackward0>)

In [287]:
grad_x = torch.ones([1, x.shape[0]])
grad_x

tensor([[1., 1.]])

In [304]:
grad_n = ( (z @ y @ x).sum(dim=1, keepdim=True) @ torch.ones([1, n.shape[0]]) ).T
grad_n

tensor([[3700., 9950., 8125., 5550.],
        [3700., 9950., 8125., 5550.]], grad_fn=<PermuteBackward0>)

In [303]:
grad_z = ( ( (y @ x).sum(dim=1, keepdim=True) @ torch.ones([1, z.shape[0]]) ) @ n.sum(dim=0).diag() ).T
grad_z

tensor([[ 420.,  660.,  900.],
        [ 840., 1320., 1800.],
        [1260., 1980., 2700.],
        [1680., 2640., 3600.]], grad_fn=<PermuteBackward0>)

In [324]:
grad_y = ( ( (x).sum(dim=1, keepdim=True) @ torch.ones([1, y.shape[0]]) ) @ (n @ z).sum(dim=0).diag() ).T
grad_y

tensor([[23400., 23400.],
        [31800., 31800.],
        [21900., 21900.]], grad_fn=<PermuteBackward0>)

In [322]:
grad_x = ( torch.ones([x.shape[1], x.shape[0]] ) @ (n @ z @ y).sum(dim=0).diag() ).T
grad_x

tensor([[76500., 76500.],
        [91920., 91920.]], grad_fn=<PermuteBackward0>)

In [326]:
assert grad_n.equal(n.grad)
assert grad_z.equal(z.grad)
assert grad_y.equal(y.grad)
assert grad_x.equal(x.grad)
print("Todos coinciden")

Todos coinciden
