In [35]:
import torch
from torch import Tensor

m = 100

def sigmoid(z):
    return 1 / (1 + torch.exp(-z))


def train_perceptron(X: Tensor, W: Tensor, b: Tensor, y_true: Tensor, mu: float):
    # Forward propagation
    z = torch.matmul(X, W) + b
    y_pred = sigmoid(z)
    # Calculate the error
    error = y_true - y_pred
    # Backward propagation
    dW = torch.matmul(X.T, error)
    db = torch.sum(error, dim=0)

    # Update weights and biases using gradient descent
    W += mu * dW
    b += mu * db

    return W, b


# Initialize input, weights, biases, true labels, and learning rate
X = torch.rand((m, 784))
W = torch.rand((784, 10))
b = torch.rand((10,))
y_true = torch.rand((m, 10))
mu = 0.01

print("X: ", X, '\n')
print(X.shape, '\n')
print("W: ", W, '\n')
print(W.shape, '\n')
print("b: ", b, '\n')
print(b.shape, '\n')


updated_W, updated_b = train_perceptron(X, W, b, y_true, mu)

print("updated_W: ", updated_W, '\n')
print(updated_W.shape, '\n')

print("updated_b: ", updated_b, '\n')
print(updated_b.shape)

X:  tensor([[0.3366, 0.3124, 0.3810,  ..., 0.0558, 0.9330, 0.6770],
        [0.0875, 0.6805, 0.3735,  ..., 0.7977, 0.7724, 0.3012],
        [0.4883, 0.2073, 0.1630,  ..., 0.6077, 0.4553, 0.9855],
        ...,
        [0.9454, 0.4944, 0.3806,  ..., 0.9955, 0.9127, 0.4340],
        [0.2280, 0.6718, 0.3783,  ..., 0.0566, 0.7488, 0.2949],
        [0.4562, 0.8017, 0.7782,  ..., 0.3436, 0.8678, 0.7288]]) 

torch.Size([100, 784]) 

W:  tensor([[0.5231, 0.0832, 0.7700,  ..., 0.5502, 0.2256, 0.9167],
        [0.5098, 0.2364, 0.6868,  ..., 0.2076, 0.7761, 0.8734],
        [0.5385, 0.0797, 0.1596,  ..., 0.0655, 0.0620, 0.6531],
        ...,
        [0.3015, 0.6900, 0.6663,  ..., 0.7197, 0.9410, 0.0754],
        [0.9334, 0.8395, 0.2179,  ..., 0.2604, 0.6276, 0.5574],
        [0.0152, 0.7496, 0.8418,  ..., 0.1860, 0.2444, 0.2846]]) 

torch.Size([784, 10]) 

b:  tensor([0.8683, 0.7428, 0.4973, 0.8501, 0.5833, 0.9285, 0.5001, 0.6046, 0.3676,
        0.7242]) 

torch.Size([10]) 

updated_W:  tensor([[