## Data Initialization

In [26]:
import numpy as np

n, k, p = 100, 8, 3
X = np.random.randn(n, k) # 100*8
Y = np.random.rand(n, p) # 100*3

W = np.random.randn(k, p) # 8*3

alpha = 1e-3
max_itr = 1000

## Forward Propagation

In [27]:
def forward(X, W):
    """
    X: (n, k)
    W: (k, p)
    return: Y_hat (n, p)
    """
    return np.matmul(X, W) # 100*8 and 8*3 return 100*3

## Loss Function

In [28]:
def loss(Y_hat, Y):
    """
    Squared error loss
    """
    E = Y_hat-Y # error 100*3
    loss_value = np.sum(E**2) # mse 
    return E, loss_value

## Backward Propagation

In [29]:
def backward(X, E):
    """
    Gradient of loss w.r.t W
    """
    grad_W = 2*np.matmul(X.T, E) # X.T: 8*100 E: 100*3 
    # X.T(Y_hat_Y) is the gradient of W in loss function
    return grad_W # 8*3

## Training Loop

In [30]:
def fit(X, Y, W, alpha, max_itr):
    for i in range(max_itr):
        # Forward
        Y_hat = forward(X, W)

        # Loss
        E, L = loss(Y_hat, Y) # error, mse

        # Backward
        grad_W = backward(X, E)

        # Update
        W-=alpha*grad_W

        if i%10==0:
            print(f"iter{i}, loss={L:.4f}, W={W}")
    return W

In [31]:
W_final = fit(X, Y, W, alpha, max_itr)

iter0, loss=2002.4465, W=[[-0.15518939 -0.76836147 -0.67083583]
 [ 0.4213657   0.05484097 -0.73123126]
 [-0.11710926 -0.60637151  0.52278888]
 [-0.30375326 -0.83198875 -0.3513228 ]
 [-1.16318125  1.10128405  0.21414826]
 [-1.0669386   0.38869412  1.06045986]
 [-0.02356459  1.50808969 -0.86481501]
 [-0.80932912 -0.28919898 -0.45939361]]
iter10, loss=127.3274, W=[[-0.01208256 -0.30009683 -0.15877269]
 [ 0.015028    0.18403768  0.0148603 ]
 [-0.06584644 -0.27150018  0.04081779]
 [-0.19053263 -0.0355051  -0.15030641]
 [-0.21685464  0.01012626 -0.00472788]
 [-0.20489613  0.23934289  0.12618012]
 [-0.02607017  0.1660768  -0.13726242]
 [-0.18540842 -0.02509796 -0.11921193]]
iter20, loss=96.3321, W=[[-0.03044611 -0.10909742 -0.09124375]
 [ 0.03758879  0.11098529  0.05428528]
 [-0.0712661  -0.15716133 -0.08252302]
 [-0.13849312  0.01692829 -0.08493214]
 [-0.09211019 -0.08304039 -0.03831792]
 [-0.04338235  0.11696363  0.02305437]
 [-0.00238781  0.00783244 -0.06761354]
 [-0.09972671 -0.04271474 -