In [1]:
import torch
from modules import SConv2d, SLinear, SModel, SReLU
from models import SCrossEntropyLoss

In [2]:
fc1 = SLinear(3,3)
fc1.op.weight.data = torch.Tensor(list(range(1,10))).view(fc1.op.weight.shape) * 0.1
fc1.op.bias.data = torch.zeros_like(fc1.op.bias)
fc2 = SLinear(3,3)
fc2.op.weight.data = torch.Tensor(list(range(9,0,-1))).view(fc2.op.weight.shape) * 0.1
fc2.op.weight.data[0,0] = 2
fc2.op.bias.data = torch.zeros_like(fc2.op.bias)
relu = SReLU()
criteria = SCrossEntropyLoss()
optimizer = torch.optim.SGD(list(fc1.parameters()) + list(fc2.parameters()), 0)

optimizer.zero_grad()
x = torch.Tensor([[1,2,3]])
xS = torch.zeros_like(x)
labels = torch.LongTensor([0])

x1, xS1 = fc1(x, xS)
x1.retain_grad()
xS1.retain_grad()
x2, xS2 = relu(x1, xS1)
x2.retain_grad()
xS2.retain_grad()
x3, xS3 = fc2(x2, xS2)
x3.retain_grad()
xS3.retain_grad()

loss = criteria(x3, xS3, labels)
loss.backward()

In [3]:
# x3SG = torch.Tensor([[0.0123928,-0.0117341,-0.000658689,-0.0117341,0.011742,-7.92681E-6,-0.000658689,-7.92681E-6,0.000666616]]).view(1,3,3)
# x1SG = torch.Tensor([[0.0249031,0.00560088,0.00560088,0.00560088,0.00129391,0.00129391,0.00560088,0.00129391,0.00129391]]).view(1,3,3)

In [4]:
def CECross(output):
    BS = output.shape[0]
    e3 = (output - output.max()).exp()
    e3_sum = e3.sum()
    ratio = (e3 / e3_sum).view(BS,1,-1)

    return (torch.diag_embed(ratio.view(BS,-1),0,1) - ratio.swapaxes(1,2).bmm(ratio))
x3SG = CECross(x3)

In [6]:
x1SG = torch.zeros(1,3,3)
for i in range(3):
    for j in range(3):
        for k in range(3):
            for m in range(3):
                x1SG[:,i,j] += x3SG[:,k,m] * fc2.op.weight[k,i] * fc2.op.weight[m,j]

In [7]:
def crossSecond(IN, outS):
    BS = IN.shape[0]
    IN = IN.view(BS,1,-1)
    return IN.swapaxes(1,2).bmm(IN).view(BS,1,-1).swapaxes(1,2).bmm(outS.view(BS,1,-1))
fc1WSG = crossSecond(x,  x1SG)
fc2WSG = crossSecond(x2, x3SG)

In [5]:
fc1WSG

tensor([[[0.0249, 0.0056, 0.0056, 0.0056, 0.0013, 0.0013, 0.0056, 0.0013,
          0.0013],
         [0.0498, 0.0112, 0.0112, 0.0112, 0.0026, 0.0026, 0.0112, 0.0026,
          0.0026],
         [0.0747, 0.0168, 0.0168, 0.0168, 0.0039, 0.0039, 0.0168, 0.0039,
          0.0039],
         [0.0498, 0.0112, 0.0112, 0.0112, 0.0026, 0.0026, 0.0112, 0.0026,
          0.0026],
         [0.0996, 0.0224, 0.0224, 0.0224, 0.0052, 0.0052, 0.0224, 0.0052,
          0.0052],
         [0.1494, 0.0336, 0.0336, 0.0336, 0.0078, 0.0078, 0.0336, 0.0078,
          0.0078],
         [0.0747, 0.0168, 0.0168, 0.0168, 0.0039, 0.0039, 0.0168, 0.0039,
          0.0039],
         [0.1494, 0.0336, 0.0336, 0.0336, 0.0078, 0.0078, 0.0336, 0.0078,
          0.0078],
         [0.2241, 0.0504, 0.0504, 0.0504, 0.0116, 0.0116, 0.0504, 0.0116,
          0.0116]]])

In [6]:
fc2WSG

tensor([[[ 2.4290e-02, -2.2999e-02, -1.2910e-03, -2.2999e-02,  2.3014e-02,
          -1.5537e-05, -1.2910e-03, -1.5537e-05,  1.3066e-03],
         [ 5.5520e-02, -5.2569e-02, -2.9509e-03, -5.2569e-02,  5.2604e-02,
          -3.5512e-05, -2.9509e-03, -3.5512e-05,  2.9864e-03],
         [ 8.6750e-02, -8.2139e-02, -4.6108e-03, -8.2139e-02,  8.2194e-02,
          -5.5488e-05, -4.6108e-03, -5.5488e-05,  4.6663e-03],
         [ 5.5520e-02, -5.2569e-02, -2.9509e-03, -5.2569e-02,  5.2604e-02,
          -3.5512e-05, -2.9509e-03, -3.5512e-05,  2.9864e-03],
         [ 1.2690e-01, -1.2016e-01, -6.7450e-03, -1.2016e-01,  1.2024e-01,
          -8.1171e-05, -6.7450e-03, -8.1171e-05,  6.8261e-03],
         [ 1.9828e-01, -1.8775e-01, -1.0539e-02, -1.8775e-01,  1.8787e-01,
          -1.2683e-04, -1.0539e-02, -1.2683e-04,  1.0666e-02],
         [ 8.6750e-02, -8.2139e-02, -4.6108e-03, -8.2139e-02,  8.2194e-02,
          -5.5488e-05, -4.6108e-03, -5.5488e-05,  4.6663e-03],
         [ 1.9828e-01, -1.8775e-01