# On-policy first-visit MC control (for ε-soft policies)

Soft policies are probabilistic, the action picked for state s is the result of sampling the distribution Q(s,:)

On-policy methods never deviate from the policy, but the policy itself is probabilistic, so it llows for the exploration to happen while staying on policy.

In [None]:
For exaploring starts, create a function that, given a state it creates the list of cards.
Enumerate all posible states, call the function, and simmulate.
repeat the above a 1000 times

In [1]:
import numpy as np
from numpy.linalg import inv, norm
from numpy.random import choice

#Blackjack rules
#draw at random 1 - 10, 10 ,10, 10
#dealer shows one card.
#player takes two and holds on 20 or 21
#state is dealer-showing, player sum and useable ace.
#dealer sticks on 17 or greater

In [2]:
# Usable Ace
def uaFunc(pCards):
    return int(min(pCards) == 1 and sum(pCards) <= 11) 

In [3]:
def stateIdx(dCard, pCards):
    i = dCard - 1
    k = uaFunc(pCards)
    j = pSum(pCards) - 12
    return [i, j, k] 

In [4]:
def pSum(cards):
    return sum(cards) + 10*uaFunc(cards)

In [5]:
def initState():
    dCard = choice(cards)
    pCards = list(choice(cards, 2, True))
    #print(pCards)
    while pSum(pCards) < 12:
        pCards.append(choice(cards))
    return dCard, pCards

In [6]:
def pPolicy(dCard, pCards, PI):
    dsIdx, psIdx, uaIdx = stateIdx(dCard, pCards)    
    p = PI[dsIdx, psIdx, uaIdx]
    a = choice(np.arange(len(p)), 1, list(p))[0]
    return a

In [7]:
def pReward(dCard, pCards):

    if sum(pCards) > 21:
        #print 'busted'
        result = -1
    else:
        dCards = [dCard]
        while pSum(dCards) < 17:
            dCards.append(choice(cards, 1))
            
        if len(pCards) == 2 and pSum(pCards) == 21:
            if len(dCards) == 2 and pSum(dCards) == 21:
                result = 0
            else:
                result = 1
                
        if sum(dCards) > 21:
            result = 1
        elif pSum(dCards) >= pSum(pCards):
            result = -1
        else:
            result = 1
#         if pSum(pCards) == 21:
#             print 'player=[{0} - {1}]   -   dealer=[{2} - {3}]'.format(sum(pCards), pSum(pCards), sum(dCards), pSum(dCards))
#             print 'result', result
    
    return result

In [8]:
def updatePolicy(rewardMatrix, countMatrix, epsilon, Q, PI, stateRec):
    for dsIdx, psIdx, uaIdx, a in stateRec:
        Q[dsIdx, psIdx, uaIdx, a] = rewardMatrix[dsIdx, psIdx, uaIdx, a] / (countMatrix[dsIdx, psIdx, uaIdx, a])
        
        meanRewards = Q[dsIdx, psIdx, uaIdx]
        v_max = np.max(meanRewards)
        a_max = [a for a, v in enumerate(meanRewards) if v == v_max]
        a_star = np.random.choice(a_max) if len(a_max) > 1 else a_max[0]
        num_actions = len(meanRewards)
        for a in range(num_actions):
            if a == a_star:
                PI[dsIdx, psIdx, uaIdx, a] = 1 - epsilon + epsilon/float(num_actions)
            else:
                PI[dsIdx, psIdx, uaIdx, a] = epsilon/float(num_actions)

In [9]:
def a_star(av):
    #return '{0:.2f}'.format(av[1]).rjust(5)
    if av[0] > av[1]:
        return 'stick'
    elif av[0] == av[1]:
        return '====='
    else:
        return '     '

In [10]:
def printPolicy(PI):
    print
    print 'Policy'
    for k in [0,1]:
        print 'usabe ace (1 or 11): ', 'yes' if k == 1 else 'no'
        print 'dlr\plyr    ', '       '.join(['{0}'.format(12+j) for j in range(10)])
        for i in range(10):
            print '{0}'.format(1+i).ljust(8), [a_star(PI[i,j,k]) for j in range(10)]
        print 

In [11]:
stateMatrix = np.zeros((10, 10, 2))
cards = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10]
Q = np.zeros((10, 10, 2, 2))
PI = np.ones((10, 10, 2, 2)) * .5
rewardMatrix = np.zeros((10, 10, 2, 2))
countMatrix = np.zeros((10, 10, 2, 2))              
epsilon = .5

#seq of events
#draw dealer's card
#draw players cards to reach 12
#determine ua or nua
#take more player cards - keep list of sums (or indices)

for episode in range(300000):

    epsilon *= 0.99999
    
    stateRec = []

    #set initial state
    dCard, pCards = initState()
    ua = uaFunc(pCards)

    #record initial state
    dsIdx, psIdx, uaIdx = stateIdx(dCard, pCards)
    action = pPolicy(dCard, pCards, PI)
    stateRec.append((dsIdx, psIdx, uaIdx, action))

    #print 'player cards', pCards
    #exercise player policy
    while (pSum(pCards) <= 21) and action:
        # draw a card
        pCards.append(choice(cards))
        #print 'player cards', pCards
        dsIdx, psIdx, uaIdx = stateIdx(dCard, pCards)
        if (pSum(pCards) <= 21):
            action = pPolicy(dCard, pCards, PI)
            #record state each step
            stateRec.append((dsIdx, psIdx, uaIdx, action))

    # assign reward
    reward = pReward(dCard, pCards)           
    for dsIdx, psIdx, uaIdx, a in stateRec:
        rewardMatrix[dsIdx, psIdx, uaIdx, a] += reward            
        countMatrix[dsIdx, psIdx, uaIdx, a] += 1 
#             if pSum(pCards) == 21:
#                 print dsIdx, psIdx, uaIdx, aIdx, reward
    #print stateRec
    updatePolicy(rewardMatrix, countMatrix, epsilon, Q, PI, stateRec)

    
print epsilon

0.0248931607816


In [174]:
#printPolicy(rewardMatrix)

In [175]:
printPolicy(PI)


Policy
usabe ace (1 or 11):  no
dlr\plyr     12       13       14       15       16       17       18       19       20       21
1        ['stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick']
2        ['stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick']
3        ['stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick']
4        ['stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick']
5        ['stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick']
6        ['stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick']
7        ['stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick']
8        ['stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick']
9        ['stick', 'stick', 'stick', 'stick', 'stick', 'stick', 'stick

In [176]:
def printMatrix(M):
    print
    print 'Matrix'
    for k in [0,1]:
        print 'action (stick or draw): ', 'draw' if k == 1 else 'stick'
        print 'dlr\plyr    ', '       '.join(['{0}'.format(12+j) for j in range(10)])
        for i in range(10):
            print '{0}'.format(1+i).ljust(8), ['{0:.2f}'.format(M[i,j, 0, k]) for j in range(10)]
        print 

In [177]:
printMatrix(Q)


Matrix
action (stick or draw):  stick
dlr\plyr     12       13       14       15       16       17       18       19       20       21
1        ['-0.76', '-0.76', '-0.77', '-0.76', '-0.76', '-0.75', '-0.54', '-0.21', '0.02', '0.22']
2        ['-0.29', '-0.24', '-0.29', '-0.31', '-0.28', '-0.23', '-0.04', '0.25', '0.49', '0.74']
3        ['-0.27', '-0.25', '-0.23', '-0.20', '-0.23', '-0.27', '0.02', '0.30', '0.56', '0.77']
4        ['-0.20', '-0.23', '-0.22', '-0.19', '-0.14', '-0.20', '0.05', '0.30', '0.52', '0.75']
5        ['-0.21', '-0.10', '-0.15', '-0.15', '-0.19', '-0.19', '0.08', '0.34', '0.58', '0.79']
6        ['-0.14', '-0.12', '-0.17', '-0.11', '-0.18', '-0.16', '0.20', '0.40', '0.65', '0.79']
7        ['-0.47', '-0.47', '-0.48', '-0.45', '-0.49', '-0.46', '0.27', '0.52', '0.70', '0.86']
8        ['-0.50', '-0.54', '-0.50', '-0.51', '-0.49', '-0.52', '-0.20', '0.50', '0.72', '0.85']
9        ['-0.53', '-0.54', '-0.56', '-0.54', '-0.49', '-0.54', '-0.27', '-0.05', '0.65', '0

In [178]:
def printMatrixInt(M):
    print
    print 'Matrix'
    for k in [0,1]:
        print 'action (stick or draw): ', 'draw' if k == 1 else 'stick'
        print 'dlr\plyr    ', '       '.join(['{0}'.format(12+j) for j in range(10)])
        for i in range(10):
            print '{0}'.format(1+i).ljust(8), ['%d' % (M[i,j, 0, k]) for j in range(10)]
        print 

In [179]:
printMatrixInt(rewardMatrix)


Matrix
action (stick or draw):  stick
dlr\plyr     12       13       14       15       16       17       18       19       20       21
1        ['-1019', '-999', '-998', '-940', '-956', '-891', '-632', '-249', '44', '186']
2        ['-373', '-310', '-368', '-395', '-371', '-282', '-54', '313', '911', '584']
3        ['-362', '-324', '-320', '-257', '-311', '-321', '24', '352', '1035', '552']
4        ['-283', '-297', '-281', '-250', '-184', '-246', '63', '355', '958', '537']
5        ['-276', '-125', '-192', '-195', '-248', '-228', '98', '430', '1001', '612']
6        ['-184', '-156', '-220', '-140', '-225', '-191', '250', '491', '1136', '595']
7        ['-624', '-631', '-605', '-589', '-628', '-583', '327', '628', '1251', '639']
8        ['-667', '-694', '-650', '-638', '-635', '-626', '-245', '561', '1291', '667']
9        ['-685', '-706', '-766', '-698', '-647', '-688', '-341', '-60', '1129', '646']
10       ['-3257', '-3148', '-2986', '-3001', '-2981', '-2993', '-1696', '-688', '6