In [1]:
from model_utils import *
from base_utils import *

In [2]:
wd = ''

if torch.get_num_threads() > 1: # no auto multi threading
    torch.set_num_threads(1)
    torch.set_num_interop_threads(1)

Label = ['Landlord','Farmer-0','Farmer-1'] # players 0, 1, and 2

name = 'H15-V2_2.2'

mfiles = [int(f[-13:-3]) for f in os.listdir(os.path.join(wd,'models')) if name in f]

if len(mfiles) == 0:
    v_M = f'{name}_{str(0).zfill(10)}'
else:
    v_M = f'{name}_{str(max(mfiles)).zfill(10)}'
print('Model version:', v_M)

N_history = int(v_M[1:3])
SLM = Network_Pcard_V2_1(15+7, 7, y=1, x=15, lstmsize=512, hiddensize=1024)
QV = Network_Qv_Universal_V1_1(6,15,1024)

SLM.load_state_dict(torch.load(os.path.join(wd,'models',f'SLM_{v_M}.pt')))
QV.load_state_dict(torch.load(os.path.join(wd,'models',f'QV_{v_M}.pt')))
SLM.eval()
QV.eval()

print('- Model Loaded')


Model version: H15-V2_2.2_0060450000
- Model Loaded


In [3]:
def resample_state(Turn, Initstates, unavail, model_inter):
    cprob = model_inter.detach().numpy().round(1).reshape(2,15)

    cprob[0] = cprob[0] / np.sum(cprob[0])
    cprob[1] = cprob[1] / np.sum(cprob[1])
    
    # total count is exact values
    total_count = Initstates[Turn].sum(axis=-2,keepdims=True).detach().numpy().flatten()
    total_count += str2state(unavail).sum(axis=-2,keepdims=True).numpy().flatten()
    #print(Turn, total_count)
    total_count[:13] = 4 - total_count[:13]
    total_count[13:] = 1 - total_count[13:]
    total_count = np.int32(total_count)

    sample1 = np.zeros(15, dtype=int)
    sample2 = np.zeros(15, dtype=int)

    # Sample cards for each player separately but ensuring total count matches
    idx = np.arange(15)
    np.random.shuffle(idx)
    ncard1 = 0
    max1 = Initstates[(Turn-1)%3].sum()

    for _ in range(15):  # For each card type
        i = idx[_]
        # Allocate cards based on total counts for each type
        total = total_count[i]
        if total == 0:
            continue
        count1 = np.random.binomial(total, cprob[0][i] / (cprob[0][i] + cprob[1][i]))
        if ncard1 + count1 > max1:
            count1 = max1 - ncard1
            sample1[i] = count1
            ncard1 += count1
            break
        else:
            sample1[i] = count1
            ncard1 += count1

    sample2 = total_count - sample1

    # corrective measure if sample2 has total > its true
    diff2 = int(sample2.sum() - Initstates[(Turn+1)%3].sum())
    if diff2 > 0: # remove several one based on probability
        for _ in range(diff2):
            probabilities = sample2 / sample2.sum()
            chosen_card = np.random.choice(15, p=probabilities)
            sample1[chosen_card] += 1
            sample2[chosen_card] -= 1
    #print(sample2)
    return sample1, sample2

In [4]:
def rollout_2_2_2(Turn, SLM, QV, sample_states, overwrite_action, unavail, lastmove, Forcemove, history, temperature, Npass, Cpass, depth=3):

    rTurn = Turn
    newlast = lastmove

    d = 0
    end = False

    while True and d <= depth:

        player = sample_states[rTurn%3]

        if d > 0: # rollout

            #player = sample_states[rTurn%3]#.clone().detach()
            visible = sample_states[-1]#.clone().detach()

            # get card count
            card_count = [int(p.sum()) for p in sample_states]
            CC = torch.zeros((3,15))
            CC[0][:min(card_count[0],15)] = 1
            CC[1][:min(card_count[1],15)] = 1
            CC[2][:min(card_count[2],15)] = 1

            # get action
            Bigstate = torch.cat([player.sum(axis=-2,keepdims=True).unsqueeze(0),
                                    str2state(unavail).sum(axis=-2,keepdims=True).unsqueeze(0),
                                    CC.unsqueeze(1),
                                    visible.sum(axis=-2,keepdims=True).unsqueeze(0), # new feature
                                    torch.zeros((1,15)).unsqueeze(0) + rTurn%3, # role feature
                                    history.sum(axis=-2,keepdims=True)])
            hinput = Bigstate.unsqueeze(0)
            model_inter = SLM(hinput)
            role = torch.zeros((model_inter.shape[0],15)) + rTurn%3
            acts = avail_actions(lastmove[0],lastmove[1],player,Forcemove)
            model_inter = torch.concat([hinput[:,0].sum(dim=-2),
                                        hinput[:,7].sum(dim=-2),
                                        model_inter,
                                        role],dim=-1)
            model_input2 = torch.stack([torch.cat((model_inter.flatten(),str2state(a[0]).sum(dim=0))) for a in acts])
            # get q values
            output = QV(model_input2).flatten()

            if temperature == 0:
                Q = torch.max(output).item()
                best_act = acts[torch.argmax(output)]
            else:
                # get action using probabilistic approach and temperature
                probabilities = torch.softmax(output / temperature, dim=0)
                distribution = torch.distributions.Categorical(probabilities)
                
                q = distribution.sample()
                best_act = acts[q]
                Q = output[q].item()
            
            action = best_act

        else: # use action
            action = overwrite_action

        if Forcemove:
            Forcemove = False

        # conduct a move
        myst = state2str(player.sum(dim=0).numpy())
        cA = Counter(myst)
        cB = Counter(action[0])
        newst = ''.join(list((cA - cB).elements()))
        newunavail = unavail + action[0]
        newhist = torch.roll(history,1,dims=0)
        
        newhist[0] = str2state(action[0]).sum(axis=-2,keepdims=True) # first row is newest, others are moved downward

        play = action[0]
        if action[1][0] == 0:
            play = 'pass'
            Cpass += 1
            if Npass < 1:
                Npass += 1
            else:
                #print('Clear Action')
                newlast = ['',(0,0)]
                Npass = 0
                Forcemove = True
        else:
            newlast = action
            Npass = 0
            Cpass = 0
        
        # update
        nextstate = str2state(newst)
        sample_states[rTurn%3] = nextstate
        unavail = newunavail
        #print(newlast)
        history = newhist
        lastmove = newlast
        
        #print(Q)

        if len(newst) == 0:
            end = True
            break

        rTurn += 1
        d += 1

    #W = 0
    if end:
        Q = 0.0
        if rTurn%3 == 0 and Turn%3 == 0:
            Q = 1.0
        if rTurn%3 != 0 and Turn%3 != 0:
            Q = 1.0
    
    return Q

In [5]:
def get_action_adv(Turn, SLM, QV, Initstates, unavail, lastmove, Forcemove, history, temperature, Npass, Cpass, nAct=5, nRoll=10, ndepth=3):
    
    player = Initstates[Turn%3]#.clone().detach()
    visible = Initstates[-1]#.clone().detach()

    # get card count
    card_count = [int(p.sum()) for p in Initstates]
    #print(card_count)
    CC = torch.zeros((3,15))
    CC[0][:min(card_count[0],15)] = 1
    CC[1][:min(card_count[1],15)] = 1
    CC[2][:min(card_count[2],15)] = 1
    #print(CC)

    # get action
    Bigstate = torch.cat([player.sum(axis=-2,keepdims=True).unsqueeze(0),
                            str2state(unavail).sum(axis=-2,keepdims=True).unsqueeze(0),
                            CC.unsqueeze(1),
                            visible.sum(axis=-2,keepdims=True).unsqueeze(0), # new feature
                            torch.zeros((1,15)).unsqueeze(0) + Turn%3, # role feature
                            history.sum(axis=-2,keepdims=True)])
    #print(Bigstate)
    # generate inputs
    hinput = Bigstate.unsqueeze(0)
    model_inter = SLM(hinput)
    role = torch.zeros((model_inter.shape[0],15)) + Turn%3
    
    # get all actions

    acts = avail_actions(lastmove[0],lastmove[1],player,Forcemove)

    # generate inputs 2
    model_inter2 = torch.concat([hinput[:,0].sum(dim=-2),
                                hinput[:,7].sum(dim=-2),
                                model_inter,
                                role],dim=-1)
    model_input2 = torch.stack([torch.cat((model_inter2.flatten(),str2state(a[0]).sum(dim=0))) for a in acts])

    # get q values
    output = QV(model_input2).flatten()

    # get N best actions to sample from!
    N = min(nAct,len(acts))

    if temperature == 0:
        top_n_indices = torch.topk(output, N).indices
        n_actions = [acts[idx] for idx in top_n_indices]
        n_Q_values = output[top_n_indices]
    else:
        probabilities = torch.softmax(output / temperature, dim=0)
        distribution = torch.distributions.Categorical(probabilities)
        sampled_indices = distribution.sample((N,))
        n_actions = [acts[idx] for idx in sampled_indices]
        n_Q_values = output[sampled_indices]

    new_Q = torch.zeros(N)
    if N == 1:
        new_Q = n_Q_values
    else:
        for i in range(N): # resample given action
            action = n_actions[i]
            Qroll = [n_Q_values[i]] # original value has some weight (good if simulation number is small)
            for r in range(nRoll): # construct fake initsates, and rollout
                sample1, sample2 = resample_state(Turn%3, Initstates, unavail, model_inter)
                sample_states = [None,None,None,Initstates[-1].clone()]
                sample_states[Turn%3] = Initstates[Turn%3].clone()
                sample_states[(Turn-1)%3] = str2state(''.join([r2c_base[i]*sample1[i] for i in range(15)]))
                sample_states[(Turn+1)%3] = str2state(''.join([r2c_base[i]*sample2[i] for i in range(15)]))
                Qroll.append(
                    rollout_2_2_2(Turn, SLM, QV, sample_states, action.copy(), unavail, lastmove.copy(), Forcemove, history.clone(), temperature, Npass, Cpass,
                                depth=ndepth))
                
            Qroll = np.array(Qroll)
            #print(Turn, n_Q_values[i].item(), np.mean(Qroll))
            #print(np.round(Qroll,1))
            new_Q[i] = np.mean(Qroll)

    best_action = n_actions[torch.argmax(new_Q)]
    Q = torch.max(new_Q)
    print(n_Q_values.numpy().round(2), new_Q.numpy().round(2))
    print(Turn, n_actions[0],best_action)
    return best_action, Q

In [6]:
def game(Models, temperature, pause=0.5, nhistory=6, p_adv=[0], nAct=5, nRoll=10, ndepth=3): # Player is 0, 1, 2 for L, D, U
    Init_states = init_game_3card() # [Lstate, Dstate, Ustate]

    public_cards = state2str(Init_states[-1].numpy().sum(axis=0))
    print(f'Landlord Cards: {public_cards}')

    Qs = []
    Q0 = []

    signs = ['-','+']

    unavail = ''
    if Models[0].non_hist == 7:
        history = torch.zeros((nhistory,1,15))
    else:
        history = torch.zeros((nhistory,4,15))
    lastmove = ['',(0,0)]

    Turn = 0
    Npass = 0 # number of pass applying to rules
    Cpass = 0 # continuous pass

    Forcemove = True # whether pass is not allowed

    Log = ''


    SLM,QV = Models

    while True: # game loop
        ts = time.time()
        # get player
        #print(Turn, lastmove)
        player = Init_states[Turn%3]
        
        if Turn%3 in p_adv:
            action, Q = get_action_adv(Turn, SLM,QV,Init_states,unavail,lastmove, Forcemove, history, temperature, Npass, Cpass,
                                       nAct, nRoll, ndepth)
        else:
            action, Q = get_action_serial_V2_2_2(Turn, SLM,QV,Init_states,unavail,lastmove, Forcemove, history, temperature, False)

        if Turn < 3:
            Q0.append(Q.item())


        if Forcemove:
            Forcemove = False

        # conduct a move
        myst = state2str(player.sum(dim=0).numpy())
        cA = Counter(myst)
        cB = Counter(action[0])
        newst = ''.join(list((cA - cB).elements()))
        newunavail = unavail + action[0]
        newhist = torch.roll(history,1,dims=0)
        if SLM.non_hist == 7:
            newhist[0] = str2state(action[0]).sum(axis=-2,keepdims=True) # first row is newest, others are moved downward
        else:
            newhist[0] = str2state(action[0])
        
        play = action[0]
        if action[1][0] == 0:
            play = 'pass'
            Cpass += 1
            if Npass < 1:
                Npass += 1
            else:
                #print('Clear Action')
                newlast = ['',(0,0)]
                Npass = 0
                Forcemove = True
        else:
            newlast = action
            Npass = 0
            Cpass = 0

        Log += f"{Label[Turn % 3]} {str(Turn).zfill(2)}    {myst.zfill(20).replace('0', ' ')} {play.zfill(20).replace('0', ' ')} by {Label[Turn % 3]}    {str(round(Q.item()*100,1)).zfill(5)}%\n"
        if Cpass == 2:
            Log += '\n'
        
        # update
        nextstate = str2state(newst)
        Init_states[Turn%3] = nextstate
        unavail = newunavail
        history = newhist
        lastmove = newlast
        
        time.sleep(max(pause - (time.time()-ts),0))

        if len(newst) == 0:
            break

        Turn += 1

    if Turn %3 == 0:
        Log += f'\nLandlord Wins'
    else:
        Log += f'\nFarmers Win'


    return Turn, Qs, Log

In [81]:
s = np.random.randint(-1000000,1000000)
print(s)
random.seed(-277194)
with torch.no_grad():
    Turn, Qs, Log = game([SLM, QV],0,0,15,[0,],5,20,15)

507222
Landlord Cards: 562
[0.47 0.44 0.43 0.42 0.41] [0.39 0.3  0.45 0.41 0.42]
0 ['99', (2, 6)] ['66', (2, 3)]
[0.58 0.5  0.45 0.42 0.35] [0.6  0.55 0.42 0.43 0.35]
3 ['99', (2, 6)] ['99', (2, 6)]
[0.58 0.46] [0.62 0.53]
6 ['22', (2, 12)] ['22', (2, 12)]
[0.66 0.6  0.59 0.55 0.48] [0.58 0.58 0.53 0.51 0.66]
9 ['4443', (5, 1)] ['88', (2, 5)]
[0.59 0.54 0.46] [0.64 0.64 0.61]
12 ['KK', (2, 10)] ['AA', (2, 11)]
[0.65 0.6  0.58 0.52 0.52] [0.72 0.75 0.53 0.38 0.81]
15 ['4443', (5, 1)] ['55', (2, 2)]
[0.58 0.46] [0.72 0.66]
18 ['KK', (2, 10)] ['KK', (2, 10)]
[0.83 0.79 0.79 0.76 0.49] [0.99 0.94 0.76 0.85 0.49]
21 ['4443', (5, 1)] ['4443', (5, 1)]
[0.91] [0.91]
24 ['', (0, 0)] ['', (0, 0)]
[0.99 0.89] [1.   0.99]
27 ['R', (1, 14)] ['R', (1, 14)]
[1.] [1.]
30 ['X', (1, 7)] ['X', (1, 7)]


In [82]:
print(Log)

Landlord 00    344455668899XKKAA22R                   66 by Landlord    045.0%
Farmer-0 01       356677789XXJJQK2B                   77 by Farmer-0    053.0%
Farmer-1 02       3345789XJJQQQKAA2                 pass by Farmer-1    042.7%
Landlord 03      3444558899XKKAA22R                   99 by Landlord    059.6%
Farmer-0 04         3566789XXJJQK2B                 pass by Farmer-0    045.1%
Farmer-1 05       3345789XJJQQQKAA2                   AA by Farmer-1    033.8%
Landlord 06        34445588XKKAA22R                   22 by Landlord    061.6%
Farmer-0 07         3566789XXJJQK2B                 pass by Farmer-0    037.8%
Farmer-1 08         3345789XJJQQQK2                 pass by Farmer-1    029.9%

Landlord 09          34445588XKKAAR                   88 by Landlord    066.2%
Farmer-0 10         3566789XXJJQK2B                 pass by Farmer-0    038.3%
Farmer-1 11         3345789XJJQQQK2                   JJ by Farmer-1    031.7%
Landlord 12            344455XKKAAR                

In [72]:
print(Log)

Landlord 00    344455668899XKKAA22R                   99 by Landlord    046.6%
Farmer-0 01       356677789XXJJQK2B                 pass by Farmer-0    045.4%
Farmer-1 02       3345789XJJQQQKAA2                   AA by Farmer-1    047.4%
Landlord 03      3444556688XKKAA22R                   22 by Landlord    044.1%
Farmer-0 04       356677789XXJJQK2B                 pass by Farmer-0    047.0%
Farmer-1 05         3345789XJJQQQK2                 pass by Farmer-1    040.9%

Landlord 06        3444556688XKKAAR                44466 by Landlord    044.5%
Farmer-0 07       356677789XXJJQK2B                77766 by Farmer-0    051.1%
Farmer-1 08         3345789XJJQQQK2                QQQ33 by Farmer-1    054.6%
Landlord 09             35588XKKAAR                 pass by Landlord    026.2%
Farmer-0 10            3589XXJJQK2B                 pass by Farmer-0    073.0%

Farmer-1 11              45789XJJK2                789XJ by Farmer-1    049.9%
Landlord 12             35588XKKAAR               

In [13]:
random.seed(0)
Initstates = init_game_3card()
Turn = 0
player = Initstates[Turn%3]#.clone().detach()
visible = Initstates[-1]#.clone().detach()
unavail = ''
history = torch.zeros((15,4,15))

# get card count
card_count = [int(p.sum()) for p in Initstates]
CC = torch.zeros((3,15))
CC[0][:min(card_count[0],15)] = 1
CC[1][:min(card_count[1],15)] = 1
CC[2][:min(card_count[2],15)] = 1

# get action
Bigstate = torch.cat([player.sum(axis=-2,keepdims=True).unsqueeze(0),
                        str2state(unavail).sum(axis=-2,keepdims=True).unsqueeze(0),
                        CC.unsqueeze(1),
                        visible.sum(axis=-2,keepdims=True).unsqueeze(0), # new feature
                        torch.zeros((1,15)).unsqueeze(0) + Turn%3, # role feature
                        history.sum(axis=-2,keepdims=True)])
#print(Bigstate)
# generate inputs
hinput = Bigstate.unsqueeze(0)
model_inter = SLM(hinput)
role = torch.zeros((model_inter.shape[0],15)) + Turn%3

cprob = model_inter.detach().numpy().round(1).reshape(2,15)

r = (np.random.randint(0,3,12)-1)/2
print(r)
cprob[0][:12] -= r
cprob[1][:12] += r
cprob[0][2] = 3.9
cprob[1][2] = 0.1

print('Hint:      ','    '.join([f' {c} ' for c in r2c_base_arr]))

c0 = [f"{str(c)}0" if len(str(c)) < 3 else str(c) for c in cprob[0]]
c1 = [f"{str(c)}0" if len(str(c)) < 3 else str(c) for c in cprob[1]]

print(f'{Label[(Turn-1)%3]}:  ','    '.join([str(c) for c in c0]),'    ', (cprob[0].sum()).round(1))
print(f'{Label[(Turn+1)%3]}:  ','    '.join([str(c) for c in c1]),'    ', (cprob[1].sum()).round(1))




cprob[0] = cprob[0] / np.sum(cprob[0])
cprob[1] = cprob[1] / np.sum(cprob[1])

# total count is exact values
total_count = Initstates[Turn%3].sum(axis=-2,keepdims=True).detach().numpy().flatten()
total_count += str2state(unavail).sum(axis=-2,keepdims=True).numpy().flatten()
#print(Turn, total_count)
total_count[:13] = 4 - total_count[:13]
total_count[13:] = 1 - total_count[13:]
total_count = np.int_(total_count)
print(total_count)

s1 = np.zeros(15, dtype=int)
s2 = np.zeros(15, dtype=int)

for i in range(1000):
    # Containers for actual counts
    sample1 = np.zeros(15, dtype=int)
    sample2 = np.zeros(15, dtype=int)

    # Sample cards for each player separately but ensuring total count matches
    idx = np.arange(15)
    np.random.shuffle(idx)
    ncard1 = 0
    max1 = Initstates[(Turn-1)%3].sum()
    for _ in range(15):  # For each card type
        i = idx[_]
        # Allocate cards based on total counts for each type
        total = total_count[i]
        if total == 0:
            continue

        # Sample for player 1
        count1 = np.random.binomial(total, cprob[0][i] / (cprob[0][i] + cprob[1][i]))

        if ncard1 + count1 > max1:
            count1 = max1 - ncard1
            sample1[i] = count1
            ncard1 += count1
            break
        else:
            sample1[i] = count1
            ncard1 += count1
    sample2 = total_count - sample1
    # corrective measure if sample2 has total > its true
    diff2 = int(sample2.sum() - Initstates[(Turn+1)%3].sum())
    if diff2 > 0: # remove several one based on probability
        for _ in range(diff2):
            #probabilities = sample2 / sample2.sum()
            prob = np.ones(15)
            prob[sample2==0] = 0
            prob /= prob.sum()
            chosen_card = np.random.choice(15, p=prob)
            sample1[chosen_card] += 1
            sample2[chosen_card] -= 1
    s1 += sample1
    s2 += sample2
    #print(sample1, sample1.sum(), Initstates[(Turn-1)%3].sum())
    #print(sample2, sample2.sum(), Initstates[(Turn+1)%3].sum())
    #pdiff = cprob[0]-cprob[1]

[ 0.   0.5  0.   0.5 -0.5  0.   0.5  0.5  0.5  0.   0.5  0. ]
Hint:        3      4      5      6      7      8      9      X      J      Q      K      A      2      B      R 
Farmer-1:   1.0    0.0    3.9    0.5    2.5    1.5    1.5    1.0    0.5    1.5    0.0    0.6    1.5    0.4    0.0      16.4
Farmer-0:   1.0    1.0    0.1    1.5    1.5    1.5    2.5    2.0    1.5    1.5    1.0    0.3    1.5    0.6    0.0      17.5
[2 1 4 2 4 3 4 3 2 3 1 1 3 1 0]


In [14]:
print(np.round(s1/1000,2))
print(np.round(s2/1000,2))

[1.03 0.09 3.79 0.61 2.51 1.55 1.52 1.06 0.57 1.54 0.1  0.66 1.55 0.43
 0.  ]
[0.97 0.91 0.21 1.39 1.49 1.45 2.48 1.94 1.43 1.46 0.9  0.34 1.45 0.57
 0.  ]
