In [56]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from random import random
import matplotlib.pyplot as plt
from utils import *
from tqdm.notebook import tqdm
import multiprocessing
from uuid import uuid4
from game_mechanics import (
    ChooseMoveCheckpoint,
    PokerEnv,
    checkpoint_model,
    choose_move_randomly,
    human_player,
    load_network,
    play_poker,
    save_network,
    State,
    to_basic_nn_input,
)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
class PokerPolicy(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(55, 20),
            nn.LeakyReLU(),
            nn.Linear(20,5),
        )
        
    def forward(self, data, legal_moves):
        ## legal moves is list of lists of allowed moves
        illegal = lambda legal: [move not in legal for move in range(5)]
        mask = torch.stack([torch.as_tensor(illegal(lm)) for lm in legal_moves])
        preds = self.model(data)
        preds = preds.masked_fill(mask, -float('inf'))
        return F.softmax(preds,dim=-1)

In [68]:
lr_policy = 0.005
lr_value = 0.01
gamma = 1.0

policy = PokerPolicy()
value = nn.Sequential(
    nn.Linear(55, 20),
    nn.LeakyReLU(),
    nn.Linear(20,1)
)

optimizer_policy = torch.optim.Adam(policy.parameters(), lr=lr_policy)
optimizer_value  = torch.optim.Adam(value.parameters(), lr=lr_value)

# env = CartPoleEnv()
erm = ParallelEpisodeReplayMemory(gamma=0.99, lamda=1.0)
num_episodes = 1_000_000

In [69]:
def play_poker(env):
    erm = ParallelEpisodeReplayMemory(gamma=0.99, lamda=1.0)
    episode_id = uuid4()
    state, reward, done, _ = env.reset()
    state = to_basic_nn_input(state)
    while not done:
        prev_state = state
        legal_moves = [m._value_ for m in env.game.get_legal_actions()]
        with torch.no_grad():
            probs = policy(state, [legal_moves])
            chosen_move = np.random.choice(range(0,5), p=probs[0].detach().numpy())
        state, reward, done, _ = env.step(chosen_move)
        state = to_basic_nn_input(state)
        erm.append({'state': [state], 
                    'prev_state': [prev_state], 
                    'reward': reward,
                    'done': bool(done),
                    'chosen_move': chosen_move,
                    'prob_move': probs[0][chosen_move],
                    'episode_id': episode_id})
    return erm.data

multiprocessing.cpu_count()
threads = 16
pool = multiprocessing.Pool(processes=threads)

remaining_episodes = num_episodes
all_games = []
batch_size = 2000
while remaining_episodes > 0:
    environments = [PokerEnv(opponent_choose_move = choose_move_randomly) for _ in range(threads)]

    games = pool.map(play_poker, environments)
    all_games.append(games)
    if sum([len(game) for game in all_games]) >= batch_size:
        remaining_episodes -= batch_size
        sample = pd.concat(all_games)
        all_games = []
        
        
        # fit value function
        optimizer_value.zero_grad()
        baseline_estimates = value(sample['prev_state'])[:,0]
        loss_value = F.smooth_l1_loss(baseline_estimates, sample['discounted_rewards'][0])
        loss_value.backward()
        optimizer_value.step()
        losses_value.append(loss_value.item())

        # fit policy function
        optimizer_policy.zero_grad()
        probs = policy(sample['prev_state'])
        moves = sample['chosen_move']
        normalized_rewards = sample['discounted_rewards'][0] - baseline_estimates.detach()
        loss_policy = -(torch.log(probs[range(episode_len),moves.long()])*normalized_rewards).sum()/episode_len
        loss_policy.backward()
        losses_policy.append(loss_policy.item())
        optimizer_policy.step()

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.curren

finishing episode d76f33d6-c06f-48ed-8f6d-eef5a0c6d0f3, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 2c494600-2699-42ce-869c-b3231ce5dd09, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 3913b70b-020e-4835-810f-8b232548f0d0, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 3 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 895060a4-7b1d-4d9c-8926-fca5f3e918be, appending to data

  self.data  = pd.concat([self.data, episode])


finishing episode e1e37111-be0f-4b1b-9119-75647bcbcbfd, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])



finishing episode eeafbd79-194d-4a57-ac38-34e3e1f27ddb, appending to datafinish outside, erm has 3 lenfinishing episode 4c6fe031-006b-4210-b6d3-a48e2fc69ba1, appending to data
finishing episode 399fb6aa-8f1d-41a6-8d6e-fbe1c65af9e3, appending to datafinishing episode 90d5da0d-5ca3-4496-be6f-44e809995d0b, appending to datafinishing episode 75591264-7db3-452d-afb0-752f8a806be9, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 3 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])







  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])




finish outside, erm has 3 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])






  self.data  = pd.concat([self.data, episode])



finish outside, erm has 3 len


  self.data  = pd.concat([self.data, episode])
  self.data  = pd.concat([self.data, episode])
  self.data  = pd.concat([self.data, episode])


finish outside, erm has 3 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])


finish outside, erm has 3 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])




finish outside, erm has 3 len
finishing episode 5e52e442-eb48-4b95-904f-70b908315cd1, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 3 len

  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 3 len


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])



finishing episode 8b55af43-f846-4527-8fe9-5a89ccf6f23f, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 15dcc591-1873-4854-b54f-67c4bada708c, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])







  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])


finish outside, erm has 6 lenfinish outside, erm has 5 len



  self.data  = pd.concat([self.data, episode])


finish outside, erm has 5 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 5e5e36cd-cc12-451c-b5f6-35f20e7acdb0, appending to data


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 6391e306-7652-415b-b513-f9603a73bc69, appending to datafinish outside, erm has 10 len



  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])


finish outside, erm has 11 len


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 6e64aacd-c296-4d5d-9ae0-2625ecc8f591, appending to data


  self.data  = pd.concat([self.data, episode])


finish outside, erm has 12 len


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.curren

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.curren

finishing episode 163636aa-ee69-4cec-848d-27293c28e594, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.curren




  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode db1bbb29-4082-40a1-be15-d392960761e5, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.curren

finish outside, erm has 5 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode afaf006d-3a66-4ed7-9fb8-1c9400ed0971, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode bde83cfa-5229-46e2-a66f-74ff547bc978, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 967ebcd3-28e6-424a-b929-ef50b01d42b2, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])



finish outside, erm has 4 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])






  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])


finishing episode 0c9fee78-c2da-4185-8982-19117762ff23, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])
  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])



finish outside, erm has 7 lenfinish outside, erm has 7 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 8 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 4db202e8-1633-40ff-95b4-de6c6433b285, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode f0cb24c0-b2c1-4ee0-abf7-ff961dcf20a0, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 8 len

  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 8 len


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 9 len


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.curren

finishing episode fed3f329-52ee-42c7-8eee-44db61eb22a5, appending to data


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])


finishing episode 33e8932e-08d1-416d-a5ff-dacfe92bd287, appending to data


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 13 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 15 len


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.curren

finishing episode a2f9ff57-beec-4be9-b4a3-6477c00c0763, appending to data


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])


finish outside, erm has 17 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode d99eacbc-1a6a-43a3-b7e7-4ad7d04b5932, appending to data
finishing episode adb3dc8a-68be-4566-af94-a621a53d0c7e, appending to data


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])


finish outside, erm has 20 len


  self.data  = pd.concat([self.data, episode])


finish outside, erm has 20 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 2a8e0983-cff4-4194-b5b8-8da0754459f7, appending to data


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])


finish outside, erm has 23 len


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode b930e1ec-35db-4724-912b-d59d687c571d, appending to data


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])


finish outside, erm has 23 len


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.curren

finishing episode c3df4b53-231a-4e5d-a998-b351ce93eb5c, appending to data


  self.data  = pd.concat([self.data, episode])


finish outside, erm has 43 len


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 75f28388-0755-4fbc-986b-38d5a0c3c82c, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode ec1a9d27-66d0-40bb-9d86-9805dfc1ad18, appending to data
finishing episode 9be3340f-7278-4424-ad6d-fa25812e9437, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 0aeda636-b2d2-428d-8d27-53bd1329301c, appending to data



  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])


finishing episode 0cdedfd1-aaaa-4d46-b82a-320ff05ba636, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 96435821-72ee-4d6a-827a-6a5a17a5e6e4, appending to data

  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.data  = pd.concat([self.data, episode])


finish outside, erm has 1 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode f8cb9f85-7a74-468d-a736-db41a0f76c3a, appending to datafinish outside, erm has 1 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])







  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 1 len

  self.data  = pd.concat([self.data, episode])


finishing episode d22bfe0a-9838-4cb7-abf1-0311cda25a19, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])



finish outside, erm has 1 lenfinishing episode 2f7c822c-ec50-4ef0-8126-3c2b1a7dc1b9, appending to datafinish outside, erm has 2 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])




finish outside, erm has 1 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])






  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 2 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])


finishing episode 06de26ef-fa6d-439b-9e4e-6e80b0b186dc, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 3 len


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.data  = pd.concat([self.data, episode])
  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 5 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 4 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])






  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.curren

finishing episode cc02b106-cc1b-4b86-aeb9-165b16b3528e, appending to data

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 761fe599-76fc-4e50-952d-0fd8b1a10b96, appending to data


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 7 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 8aa97d14-9e0d-48d4-b288-ee55e0dbaac1, appending to data


  self.data  = pd.concat([self.data, episode])



finish outside, erm has 8 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.data  = pd.concat([self.data, episode])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finish outside, erm has 9 len

  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])





  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])


finishing episode 16759583-ac8d-430b-a352-e95dc8a0839f, appending to data


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])


finishing episode b48476b5-82e1-4832-be02-c58fd2a3c1c5, appending to data
finish outside, erm has 14 len


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.data  = pd.concat([self.data, episode])


finish outside, erm has 11 len


  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.current_episodes[episode_id], new_row])
  self.current_episodes[episode_id] = pd.concat([self.curren

finishing episode 6cc56132-537e-4f1f-9846-5682f4fd1e2f, appending to data


  self.data  = pd.concat([self.data, episode])


finish outside, erm has 28 len


Exception in thread Thread-52:
Traceback (most recent call last):
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/multiprocessing/pool.py", line 576, in _handle_results
    task = get()
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/multiprocessing/connection.py", line 251, in recv
    return _ForkingPickler.loads(buf.getbuffer())
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 282, in rebuild_storage_fd
    fd = df.detach()
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/multiprocessing/resource_sharer.py", line 58, in detach
    return reduction.recv_handle(conn)
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/multiprocessing/reduction.py", l

KeyboardInterrupt: 

Process ForkPoolWorker-216:
Process ForkPoolWorker-217:
Process ForkPoolWorker-211:
Process ForkPoolWorker-220:
Process ForkPoolWorker-207:
Process ForkPoolWorker-205:
Process ForkPoolWorker-215:
Process ForkPoolWorker-208:
Process ForkPoolWorker-212:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/h

  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/hristo/miniconda3/envs/cv38/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/hristo/miniconda3/envs/cv38/l

In [62]:
len(erm.data)

0

In [63]:
erm

[]

In [54]:
erm.data

[]

In [12]:
state

State(hand=['DQ', 'C9'], public_cards=[], player_chips=2, opponent_chips=97, player_chips_remaining=101, opponent_chips_remaining=0, stage=0, legal_actions=[0, 1, 3, 4])

In [29]:
legal_moves = [m._value_ for m in env.game.get_legal_actions()]

In [30]:
legal_moves

[0, 1]

In [28]:
legal_moves[0]._value_

0

In [13]:
num_episodes = 5
for ep_num in tqdm(range(num_episodes)):
    num_steps = 0
    state, reward, done, _ = env.reset()
    state = to_basic_nn_input(state)
    while not done:
        prev_state = state
        legal_moves = [m._value_ for m in env.game.get_legal_actions()]
        with torch.no_grad():
            probs = policy(state)
            chosen_move = 4 #np.random.choice(range(0,5), p=probs.detach().numpy())
        state, reward, done, _ = env.step(chosen_move)
        state = to_basic_nn_input(state)
        erm.append({'state': [state], 
                    'prev_state': [prev_state], 
                    'reward': reward,
                    'done': done,
                    'chosen_move': chosen_move,
                    'prob_left': probs[chosen_move]})

  0%|          | 0/5 [00:00<?, ?it/s]

[<Action.FOLD: 0>, <Action.CHECK_CALL: 1>, <Action.RAISE_POT: 3>, <Action.ALL_IN: 4>]
[<Action.FOLD: 0>, <Action.CHECK_CALL: 1>, <Action.RAISE_POT: 3>, <Action.ALL_IN: 4>]
[<Action.FOLD: 0>, <Action.CHECK_CALL: 1>]


  self.current_episode = pd.concat([self.current_episode, new_row])
  self.current_episode = pd.concat([self.current_episode, new_row])
  self.data  = pd.concat([self.data, self.current_episode])


AssertionError: 4 is an illegal move

In [34]:
num_episodes

5

In [37]:
to_basic_nn_input(state)

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000, -0.9800, -0.9600, -0.0200])

In [35]:
state

State(hand=['S7', 'D5'], public_cards=[], player_chips=2, opponent_chips=4, player_chips_remaining=98, opponent_chips_remaining=96, stage=0, legal_actions=[0, 1, 2, 3, 4])

In [16]:
env.game.configure?

In [12]:
play_poker(human_player, choose_move_randomly)

FileNotFoundError: [Errno 2] No such file or directory: 'C:\\Windows\\Fonts\\arial-narrow.ttf'

In [10]:
human_player?