In [5]:
import numpy as np
import torch

In [6]:
def matmul_forward(A, B):
    """
    Forward pass for matrix multiplication.
    Args:
        A: np.ndarray of shape (m, n)
        B: np.ndarray of shape (n, p)
    Returns:
        out: np.ndarray of shape (m, p)
        cache: (A, B) for use in backward pass
    """
    out = A @ B
    cache = (A, B)
    return out, cache

def matmul_backward(dout, cache):
    """
    Backward pass for matrix multiplication.
    Given dout = dL/d(out), compute:
        dA = dL/dA
        dB = dL/dB
    Args:
        dout: np.ndarray of shape (m, p)
        cache: tuple (A, B) from forward pass
    Returns:
        dA: np.ndarray of shape (m, n)
        dB: np.ndarray of shape (n, p)
    """
    A, B = cache
    dA = dout @ B.T
    dB = A.T @ dout
    return dA, dB


In [40]:
# Random input matrices
A = np.random.randn(3, 4)
B = np.random.randn(4, 1)
# Suppose loss gradient w.r.t. output
#dout = np.random.randn(3, 2)
dout = np.ones((3, 1))

In [39]:
np.random.randn(4, 1)

array([[ 1.01445776],
       [ 1.27996473],
       [ 0.56717823],
       [-1.04598744]])

In [41]:
# Forward pass
out, cache = matmul_forward(A, B)
print("Forward output:\n", out)


# Backward pass
dA, dB = matmul_backward(dout, cache)
print("\nGradient wrt A:\n", dA)
print("\nGradient wrt B:\n", dB)

Forward output:
 [[-4.10983445]
 [ 6.88080842]
 [-2.53810412]]

Gradient wrt A:
 [[-1.0298927   1.6012201   1.77571994 -0.9131129 ]
 [-1.0298927   1.6012201   1.77571994 -0.9131129 ]
 [-1.0298927   1.6012201   1.77571994 -0.9131129 ]]

Gradient wrt B:
 [[-3.36131468]
 [ 1.19869705]
 [-2.77575884]
 [ 0.24020148]]


In [42]:
A_torch = torch.tensor(A, requires_grad=True)
B_torch = torch.tensor(B, requires_grad=True)
out = A_torch @ B_torch
out

tensor([[-4.1098],
        [ 6.8808],
        [-2.5381]], dtype=torch.float64, grad_fn=<MmBackward0>)

In [43]:
out.sum().backward()
print("\nGradient wrt A:\n", A_torch.grad)
print("\nGradient wrt B:\n", B_torch.grad)


Gradient wrt A:
 tensor([[-1.0299,  1.6012,  1.7757, -0.9131],
        [-1.0299,  1.6012,  1.7757, -0.9131],
        [-1.0299,  1.6012,  1.7757, -0.9131]], dtype=torch.float64)

Gradient wrt B:
 tensor([[-3.3613],
        [ 1.1987],
        [-2.7758],
        [ 0.2402]], dtype=torch.float64)
