
* Reference
David Silver's youtube video on simulation-based search
https://www.youtube.com/watch?v=SzosiqyjpHE

* description of the method
```
Algorithm 2 Recursive Simulation-based search
procedure MCTS(ϵ,q,π,s,k)
    if s is terminal or k=0 then return 0
    for i = 1 to n do
        MCTS(ϵ,q,π,s,k-1)
    end for
    A ~ π(s)
    R,S <- ϵ(s,A)
    G <- R + γ * MCTS(ϵ,q,π,S',k)
    q(s,A) <- q(s,A) + α * (G - q(s,A))
    π(s) <- IMPROVE(π(s), q(s,.), G)
    return G
```
```
AlphaZero
procedure AlphaZero(ϵ, θ, η, s, k)
    if s is terminal or k = 0 then return 0
    π' <- π_η
    q' <- q_θ
    for i = 1 to n do
        MCTS(ϵ, q', π', s, k-1)
    end for
    A ~ π'(s)
    R,S' <- ϵ(s,A)
    G <- R + γ * AlphaZero(ϵ,θ,η,S',k)
    θ <- θ - α * dLq_θ(G,q(s,A)) / dθ
    η <= η - α * dLπ_η(π'(s),π(s)) / dη
    return G
end procedure
```

Complexity analysis:
when k = 1, 1 + 1 + ...
Complexity: $O(l)$

When k = 2, nl + n(l-1) + ...
Complextity: $O(nl^2)$

when k = 3, n*nl^2 n*n(l-1)^2 ... Complexity:$O(n^2l^3)$

for any k>=1, Complexity: $O(n^{k-1}l^{k})$


* application of the method





In [408]:
''' cartpole simulation in gymnasium'''
import gymnasium as gym

env = gym.make("CartPole-v1")
# action_space
env.action_space.n
#observation space
env.observation_space
#reset
state, info = env.reset()
#env.state = env.unwrapped.state = [0,0,0,0]
# step
action = env.action_space.sample()
s,r,terminated, truncated, info = env.step(action)
env.unwrapped.state = [-0.02175902, -0.19783586, -0.01449094,  0.27565469]
print(env.unwrapped.state)
env.step(1)
env.step(0)
print(env.unwrapped.state)

[-0.02175902, -0.19783586, -0.01449094, 0.27565469]
[-0.02576594 -0.19750225 -0.00940911  0.26827361]


In [5]:
from dataclasses import dataclass
import numpy as np
import torch.nn as nn

@dataclass
class HyperParams:
  gamma = 1e-3
  alpha = 1e-3
  num_path = 3
  level = 2
  c_puct = 1/np.sqrt(2)

@dataclass
class State:
  x:int
  y:int
  N = 0
  Q = 0
  P = 0
  terminal = False
  parent = None

ActionSpace = [0,1,2,3]
Tree = dict()

modelQ = nn.Sequential(
    nn.Conv1d(4,128,1,1,0, bias=True),
    nn.ReLU(),
    nn.Conv1d(128, 128, 1),
    nn.ReLU(),
    nn.Conv1d(128, 1, 1, 1, 0, bias=True),
)

modelPi = nn.Sequential(
    nn.Conv1d(4,128,1,1,0, bias=True),
    nn.ReLU(),
    nn.Conv1d(128,128,1,1,0, bias=True),
    nn.ReLU(),
    nn.Conv1d(128,4,1,1,0, bias=True)
)

def N(s,a):
  return Tree[s][a].N

def epsilon(s,a):
  return State(s.x+1,s.y) if a == 0 \
    else State(s.x,s.y-1) if a == 1 \
    else State(s.x-1,s.y) if a == 2 \
    else State(s.x,s.y+1)

def pi(s,a):
  return modelPi(s)[a]

def q(s):
  return modelQ(s)

def MCTS(epsilon, q, pi, s, k):
  if s.terminal or k == 0:
    return 0

  for i in range(1, HyperParams.num_path):
    MCTS(epsilon, q, pi, s, k-1)

  # expand
  if s not in Tree:
    Tree[s] = dict()
    for action in ActionSpace:
      Tree[s][action] = State()
      Tree[s][action].N = 0
      Tree[s][action].P = pi(s,action)
      Tree[s][action].Q = q(s)
      Tree[s][action].parent = s
    return q(s)

  # select
  a = np.argmax(Tree[s][a].Q + HyperParams.c_puct * Tree[s][a].P * np.sqrt(Tree[s].parent.N) / (1+Tree[s][a].N))

  R, s_new = epsilon(s,a)
  G = R + HyperParams.gamma * MCTS(epsilon, q, pi, s_new, k-1)

  # backpropagation
  Tree[s][a].N += 1
  Tree[s][a].Q += (G - Tree[s][a].Q) / Tree[s][a].N

  return G

QLoss = nn.MSELoss()
PiLoss = nn.KLDivLoss()

def AlphaZero(epsilon, q, pi, s, k):
  if s.terminal or k == 0:
    return 0

  pi = modelPi
  q = modelQ
  for i in range(1, HyperParams.num_path):
    MCTS(epsilon, q, pi, s, k-1)

  a = np.argmax(Tree[s][a].Q + HyperParams.c_puct * Tree[s][a].P * np.sqrt(Tree[s].parent.N) / (1+Tree[s][a].N))
  R, s_new = epsilon(s,a)
  G = R + HyperParams.gamma * AlphaZero(epsilon, q, pi, s_new, k)

  # backpropagation
  loss = QLoss(G - q(s_new))
  modelQ.parameters().clear();
  loss.backward();
  modelQ.parameters().step();

  loss = PiLoss(Tree[s][a].P, pi(s_new))
  modelPi.parameters().clear();
  loss.backward();
  modelPi.parameters().step();

  return G

In [6]:
def train():
  for episode in range(10000):
    s = State(0,0)
    G = AlphaZero(epsilon, q, pi, s, HyperParams.level)
