In [1]:
import numpy as np

Notation:

- $S_t = [S_1, S_2]$
- $\mathcal{A} \in \{a_1, a_2\}$
- $\pi_{\psi}(a_i\mid s_t) \coloneqq \pi_i(s_t)$

We're having problems computing the gradient of the policy loss. This notebook will serve as validation of analytical calculation results

We'll start by defining the variables for the $Q_{\theta}(s_t, a_t)$ function. As these aren't a function of parameters $\psi$, we'll create them as standalone variables.

In [2]:
Q = var('q1 q2')

We now define the variables that make up the policy, according to the ANN architecture

In [12]:
W = var('w11 w12 w21 w22') # weights
b = var('b1 b2')           # bias
S = var('s1 s2')           # state

# output layer
z = [S[0]*W[0] + S[0]*W[1] + b[0], 
     S[1]*W[2] + S[1]*W[3] + b[1]]

# softmax activation function
def softmax(z, i):
    return exp(z[i]) / sum(exp(z[j]) for j in range(len(z)))
    
# output layer after activation function
π = [softmax(z, i) for i in range(len(z))]

In [13]:
show(π, viwer='pdf') # render output in latex

We compute partial deliveries as such:

In [14]:
# differentiate each element in z with respect to w11
dz_dw11 = [zi.diff(w11) for zi in z]
show(dz_dw11, viwer='pdf')

In [15]:
# negative expected regularized reward
J_π = -(π[0]*(Q[0] - log(π[0])) + π[1]*(Q[1] - log(π[1])))
show(J_π, viwer='pdf')

We now have everything ready to compute the policy loss!

# $\nabla_{\psi_{W}}J_{\pi}(\psi)$

In [16]:
dJ_dw11 = J_π.diff(w11)
dJ_dw12 = J_π.diff(w12)
dJ_dw21 = J_π.diff(w21)
dJ_dw22 = J_π.diff(w22)

In [17]:
show(dJ_dw11, viwer='pdf')

In [18]:
show(dJ_dw12, viwer='pdf')

In [19]:
show(dJ_dw21, viwer='pdf')

In [20]:
show(dJ_dw22, viwer='pdf')

# $\nabla_{\psi_{b}}J_{\pi}(\psi)$

In [21]:
dJ_db1 = J_π.diff(b1)
dJ_db2 = J_π.diff(b2)

In [22]:
show(dJ_db1, viwer='pdf')

In [23]:
show(dJ_db2, viwer='pdf')

## Simplified Results

### Weights

In [38]:
print(dJ_dw11 - dJ_dw12)
print(dJ_dw21 - dJ_dw22)

0
0


In [31]:
dJ_dw11_simplified = s1*((q1-log(π[0])-1)*π[0]*(π[0]-1) + (q2-log(π[1])-1)*π[0]*π[1])
(dJ_dw11 - dJ_dw11_simplified).full_simplify() == 0 # To confirm results! 

0 == 0

In [32]:
dJ_dw12_simplified = s1*((q1-log(π[0])-1)*π[0]*(π[0]-1) + (q2-log(π[1])-1)*π[0]*π[1])
(dJ_dw12 - dJ_dw12_simplified).full_simplify() == 0

0 == 0

In [33]:
dJ_dw21_simplified = s2*((q2-log(π[1])-1)*π[1]*(π[1]-1) + (q1-log(π[0])-1)*π[0]*π[1])
(dJ_dw21 - dJ_dw21_simplified).full_simplify() == 0

0 == 0

In [34]:
dJ_dw22_simplified = s2*((q2-log(π[1])-1)*π[1]*(π[1]-1) + (q1-log(π[0])-1)*π[0]*π[1])
(dJ_dw21 - dJ_dw21_simplified).full_simplify() == 0

0 == 0

### Bias

In [35]:
dJ_db1_simplified = (q1-log(π[0])-1)*π[0]*(π[0]-1) + (q2-log(π[1])-1)*π[0]*π[1]
(dJ_db1 - dJ_db1_simplified).full_simplify() == 0 # To confirm results! 

0 == 0

In [36]:
dJ_db2_simplified = (q2-log(π[1])-1)*π[1]*(π[1]-1) + (q1-log(π[0])-1)*π[0]*π[1]
(dJ_db2 - dJ_db2_simplified).full_simplify() == 0 # To confirm results! 

0 == 0