In [None]:
import sympy as sp


class sigmoid(sp.Function):
    @classmethod
    def eval(cls, x):
        return 1 / (1 + sp.exp(-x))


class MSE(sp.Function):
    @classmethod
    def eval(cls, A, B):
        diff_squared = (A - B).applyfunc(lambda x: x**2)
        sum_dff = sum(diff_squared[i, j] for i in range(A.rows) for j in range(A.cols))
        return (1 / 2) * sp.simplify(sum_dff / A.shape[0] / A.shape[1])


class ReLU(sp.Function):
    @classmethod
    def eval(cls, x):
        return sp.Piecewise((0, x < 0), (x, x >= 0))

In [None]:
MSE(sp.Matrix([1, 4]), sp.Matrix([1, 5]))

In [None]:
x1 = sp.symbols("x1")
x = sp.Matrix([[1], [x1]])
display(x)


y_true = sp.symbols("y_t")
display(y_true)


lr = sp.symbols("alpha")
display(lr)

In [None]:
w11, w12, w21, w22 = sp.symbols("w11 w12 w21 w22")
W = sp.Matrix([[w11, w12], [w21, w22]])

u11, u12, u13 = sp.symbols("u11 u12 u13")
u = sp.Matrix([[u11], [u12], [u13]])

h1 = ReLU(W.row(0).dot(x))
h2 = ReLU(W.row(1).dot(x))
h = sp.Matrix([[1], [h1], [h2]])

y = u.dot(h)

display(W, u, h, y)

In [35]:
y_calculated = sp.symbols("y_c")

loss = MSE(sp.Matrix([[y_true]]), sp.Matrix([[y]]))

display(sp.simplify(loss.diff(u)))

display(sp.simplify(loss.diff(W)))
W_new = W - lr * loss.diff(W)
display(sp.simplify(W_new.subs(w11 + w22 * x1,y_calculated ).doit()))

u_new = u - lr * loss.diff(u)
display(sp.simplify(u_new))

# Automatically detect common subexpressions
substitutions, simplified_expr = sp.cse(W_new)

# Print results
print("Substitutions:")
for var, subexpr in substitutions:
    print(f"{var} = {subexpr}")

print("\nSimplified Expression:")
display(simplified_expr[0])

Matrix([
[Piecewise((1.0*(u11 - y_t), (w11 + w12*x1 < 0) & (w21 + w22*x1 < 0)), (1.0*(u11 + u13*(w21 + w22*x1) - y_t), w11 + w12*x1 < 0), (1.0*(u11 + u12*(w11 + w12*x1) - y_t), w21 + w22*x1 < 0), (1.0*(u11 + u12*(w11 + w12*x1) + u13*(w21 + w22*x1) - y_t), True))],
[                                                                 Piecewise((0, w11 + w12*x1 < 0), (1.0*(w11 + w12*x1)*(u11 + u12*(w11 + w12*x1) - y_t), w21 + w22*x1 < 0), (1.0*(w11 + w12*x1)*(u11 + u12*(w11 + w12*x1) + u13*(w21 + w22*x1) - y_t), True))],
[                   Piecewise((0, (w11 + w12*x1 < 0) & (w21 + w22*x1 < 0)), (1.0*(w21 + w22*x1)*(u11 + u13*(w21 + w22*x1) - y_t), w11 + w12*x1 < 0), (0, w21 + w22*x1 < 0), (1.0*(w21 + w22*x1)*(u11 + u12*(w11 + w12*x1) + u13*(w21 + w22*x1) - y_t), True))]])

Matrix([
[                                              Piecewise((0, w11 + w12*x1 < 0), (1.0*u12*(u11 + u12*(w11 + w12*x1) - y_t), w21 + w22*x1 < 0), (1.0*u12*(u11 + u12*(w11 + w12*x1) + u13*(w21 + w22*x1) - y_t), True)),                                               Piecewise((0, w11 + w12*x1 < 0), (1.0*u12*x1*(u11 + u12*(w11 + w12*x1) - y_t), w21 + w22*x1 < 0), (1.0*u12*x1*(u11 + u12*(w11 + w12*x1) + u13*(w21 + w22*x1) - y_t), True))],
[Piecewise((0, (w11 + w12*x1 < 0) & (w21 + w22*x1 < 0)), (1.0*u13*(u11 + u13*(w21 + w22*x1) - y_t), w11 + w12*x1 < 0), (0, w21 + w22*x1 < 0), (1.0*u13*(u11 + u12*(w11 + w12*x1) + u13*(w21 + w22*x1) - y_t), True)), Piecewise((0, (w11 + w12*x1 < 0) & (w21 + w22*x1 < 0)), (1.0*u13*x1*(u11 + u13*(w21 + w22*x1) - y_t), w11 + w12*x1 < 0), (0, w21 + w22*x1 < 0), (1.0*u13*x1*(u11 + u12*(w11 + w12*x1) + u13*(w21 + w22*x1) - y_t), True))]])

Matrix([
[                                                Piecewise((w11, w11 + w12*x1 < 0), (-1.0*alpha*u12*(u11 + u12*(w11 + w12*x1) - y_t) + w11, w21 + w22*x1 < 0), (-1.0*alpha*u12*(u11 + u12*(w11 + w12*x1) + u13*(w21 + w22*x1) - y_t) + w11, True)),                                                 Piecewise((w12, w11 + w12*x1 < 0), (-1.0*alpha*u12*x1*(u11 + u12*(w11 + w12*x1) - y_t) + w12, w21 + w22*x1 < 0), (-1.0*alpha*u12*x1*(u11 + u12*(w11 + w12*x1) + u13*(w21 + w22*x1) - y_t) + w12, True))],
[Piecewise((w21, (w11 + w12*x1 < 0) & (w21 + w22*x1 < 0)), (-1.0*alpha*u13*(u11 + u13*(w21 + w22*x1) - y_t) + w21, w11 + w12*x1 < 0), (w21, w21 + w22*x1 < 0), (-1.0*alpha*u13*(u11 + u12*(w11 + w12*x1) + u13*(w21 + w22*x1) - y_t) + w21, True)), Piecewise((w22, (w11 + w12*x1 < 0) & (w21 + w22*x1 < 0)), (-1.0*alpha*u13*x1*(u11 + u13*(w21 + w22*x1) - y_t) + w22, w11 + w12*x1 < 0), (w22, w21 + w22*x1 < 0), (-1.0*alpha*u13*x1*(u11 + u12*(w11 + w12*x1) + u13*(w21 + w22*x1) - y_t) + w22, True))]])

Matrix([
[Piecewise((-1.0*alpha*(u11 - y_t) + u11, (w11 + w12*x1 < 0) & (w21 + w22*x1 < 0)), (-1.0*alpha*(u11 + u13*(w21 + w22*x1) - y_t) + u11, w11 + w12*x1 < 0), (-1.0*alpha*(u11 + u12*(w11 + w12*x1) - y_t) + u11, w21 + w22*x1 < 0), (-1.0*alpha*(u11 + u12*(w11 + w12*x1) + u13*(w21 + w22*x1) - y_t) + u11, True))],
[                                                                                         Piecewise((u12, w11 + w12*x1 < 0), (-1.0*alpha*(w11 + w12*x1)*(u11 + u12*(w11 + w12*x1) - y_t) + u12, w21 + w22*x1 < 0), (-1.0*alpha*(w11 + w12*x1)*(u11 + u12*(w11 + w12*x1) + u13*(w21 + w22*x1) - y_t) + u12, True))],
[                                         Piecewise((u13, (w11 + w12*x1 < 0) & (w21 + w22*x1 < 0)), (-1.0*alpha*(w21 + w22*x1)*(u11 + u13*(w21 + w22*x1) - y_t) + u13, w11 + w12*x1 < 0), (u13, w21 + w22*x1 < 0), (-1.0*alpha*(w21 + w22*x1)*(u11 + u12*(w11 + w12*x1) + u13*(w21 + w22*x1) - y_t) + u13, True))]])

Substitutions:
x0 = w11 + w12*x1
x2 = x0 < 0
x3 = w21 + w22*x1
x4 = x3 < 0
x5 = x2 & (x2 | x4)
x6 = u11 - y_t
x7 = u12*x0 + x6
x8 = 2*u12
x9 = x7*x8
x10 = u13*x3
x11 = x10 + x7
x12 = x11*x8
x13 = 0.5*alpha
x14 = x2 & x4
x15 = 2*u13
x16 = x15*(x10 + x6)
x17 = x11*x15

Simplified Expression:


Matrix([
[           w11 - x13*Piecewise((0, x5), (x9, x4), (x12, True)),            w12 - x13*Piecewise((0, x5), (x1*x9, x4), (x1*x12, True))],
[w21 - x13*Piecewise((0, x14), (x16, x2), (0, x4), (x17, True)), w22 - x13*Piecewise((0, x14), (x1*x16, x2), (0, x4), (x1*x17, True))]])

In [None]:
w11_, w12_, w21_, w22_ = sp.symbols("w11' w12' w21' w22'")
W_ = sp.Matrix([[w11_, w12_], [w21_, w22_]])

u11_, u12_, u13_ = sp.symbols("u11' u12' u13'")
u_ = sp.Matrix([[u11_], [u12_], [u13_]])

h1_ = ReLU(W_.row(0).dot(x))
h2_ = ReLU(W_.row(1).dot(x))
h_ = sp.Matrix([[1], [h1_], [h2_]])

y_ = u_.dot(h_)

w11_ = 5

display(W_, u_, h_, y_)

In [None]:
loss_ = MSE(y_true, y_)
display(loss_)

W_new_ = W_ - lr * loss_.diff(W_)
display(sp.simplify(W_new_))

u_new_ = u_ - lr * loss_.diff(u_)
display(sp.simplify(u_new_))

In [None]:
constraints = [
    W_new_.row(1) - W_new.row(0), # 2
    W_new_.row(0) - W_new.row(1), # 2
    W_.row(1) - W.row(0), # 2
    W_.row(0) - W.row(1), # 2

    lr - 1, # 1
    x1 - 1, # 1

    u_.row(0) - u.row(0), # 1
    u_.row(1) - u.row(2), # 1
    u_.row(2) - u.row(1), # 1

    #W - sp.Matrix([[1, 1.5], [1.5, 1]]), # 4
    #u - sp.Matrix([[1], [1.5], [2]]), # 3
]

solution = sp.solve(constraints, (u, u_, W, W_, lr, x, y_true)) # 17

display(solution)

In [None]:
constraints = [
    lr - 1, # 1
    x1 - 1, # 1
]

solution = sp.solve(constraints, (u, W, lr, x, y_true)) # 17

display(solution)

In [None]:
values = {
    w11: 0.5, w12: -0.3,
    w21: 0.8, w22: 0.2,
    u11: 0.1, u12: -0.4, u13: 0.6,
    x1: 1.0,
}

# Evaluate expressions with given values
h_eval = h.subs(values)
y_eval = y.subs(values)

print("\nEvaluated h:")
sp.pprint(h_eval)

print("\nEvaluated y:")
sp.pprint(y_eval)