In [3]:
import torch

In [48]:

class LinearFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weights):
        output = input.mm(weights)
        ctx.save_for_backward(input, weights)
        return output
    
    @staticmethod
    def backward(ctx, grad_outputs):
        input, weights = ctx.saved_tensors
        grad_weights = grad_outputs.t().mm(input).t()
        grad_input = grad_outputs.mm(weights.t())
        return grad_input, grad_weights

class LinearLayer(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearLayer, self).__init__()
        self.weights = torch.nn.Parameter(torch.rand((input_dim, output_dim)))
        self.layer = LinearFunction.apply

    def forward(self, x):
        return self.layer(x, self.weights)

In [57]:
x = torch.rand((8,10),requires_grad=True)
ll = LinearLayer(10, 1)
out = ll(x)
loss_fn = torch.nn.MSELoss()
loss = loss_fn(out, torch.rand(out.shape))
loss.backward()
print(ll.weights.grad)
print(x.grad)
print(x.is_leaf)

tensor([[2.7965],
        [2.2336],
        [1.9711],
        [2.8940],
        [1.7923],
        [1.7834],
        [1.8158],
        [2.2971],
        [1.6177],
        [2.0345]])
tensor([[0.3557, 0.2679, 0.4707, 0.2759, 0.2922, 0.2862, 0.4475, 0.0061, 0.0175,
         0.3846],
        [0.5588, 0.4208, 0.7394, 0.4335, 0.4590, 0.4496, 0.7029, 0.0096, 0.0274,
         0.6042],
        [0.2565, 0.1932, 0.3394, 0.1990, 0.2107, 0.2064, 0.3227, 0.0044, 0.0126,
         0.2774],
        [0.3264, 0.2458, 0.4319, 0.2532, 0.2681, 0.2626, 0.4106, 0.0056, 0.0160,
         0.3530],
        [0.4176, 0.3145, 0.5526, 0.3239, 0.3430, 0.3360, 0.5253, 0.0071, 0.0205,
         0.4515],
        [0.3182, 0.2397, 0.4211, 0.2469, 0.2614, 0.2560, 0.4003, 0.0054, 0.0156,
         0.3441],
        [0.3219, 0.2424, 0.4259, 0.2497, 0.2644, 0.2589, 0.4049, 0.0055, 0.0158,
         0.3480],
        [0.2949, 0.2221, 0.3902, 0.2288, 0.2422, 0.2373, 0.3710, 0.0050, 0.0145,
         0.3189]])
True
