In [1]:
from collections import deque
import random
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torch
from game import Game
from tqdm import tqdm

In [2]:
import wandb

# プロジェクトの初期化
wandb.init(
    project="dqn_qmany",  # プロジェクト名
    name="0103",   # 実験名（オプション）
    config={               # ハイパーパラメータなど（オプション）
        "gamma": 0.98,
        "learning_rate": 0.00005,
        "batch_size": 32,
        "epsilon": 0.10
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mobashota6250[0m ([33mobashota6250-tokyo-university-of-science[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
import os
from datetime import datetime
# 現在の日付と時刻を取得
now = datetime.now()
now = str(now.year) + str(now.month) + str(now.day) + str(now.hour) + str(now.minute) + str(now.second)

save_dir = "./models/dqn_qmany/" + now + "/"
os.makedirs(save_dir, exist_ok=True)

In [4]:
class ReplayBuffer:
    def __init__(self,buffer_size,batch_size):
        self.buffer = deque(maxlen = buffer_size)
        self.batch_size = batch_size

    def add(self,state,action,reward,next_state,done):
        data = (state,action,reward,next_state,done)
        self.buffer.append(data)

    def __len__(self):
        return len(self.buffer)

    def get_batch(self):
        data = random.sample(self.buffer,self.batch_size)
        state = np.stack([x[0] for x in data])  # リスト内包表記で明示的にリストを作成
        action = np.stack([x[1] for x in data])
        reward = np.stack([x[2] for x in data])
        next_state = np.stack([x[3] for x in data])
        done = np.stack([x[4] for x in data])

        state = torch.tensor(state, dtype=torch.float32)
        action = torch.tensor(action, dtype=torch.int64)  # アクションは整数型
        reward = torch.tensor(reward, dtype=torch.float32)
        next_state = torch.tensor(next_state, dtype=torch.float32)
        done = torch.tensor(done, dtype=torch.float32)

        return state,action,reward,next_state,done

In [5]:
class QNet(nn.Module):
    def __init__(self, action_size,state_size):
        super().__init__()
        self.l1 = nn.Linear(state_size,128)
        self.l2 = nn.Linear(128,128)
        self.l3 = nn.Linear(128,action_size)
        
    def forward(self,x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.l3(x)
        return x

In [6]:
class DQNAgent:
    def __init__(self):
        self.gamma = 0.98
        self.lr = 0.00005
        self.epsilon = 0.1
        self.buffer_size = 10000
        self.batch_size = 32
        self.action_size = 3
        
        # Define phases and corresponding state sizes
        self.phases = ["preflop", "flop", "turn", "river","show down"]
        self.state_sizes = {
            "preflop": 12,
            "flop": 12,
            "turn": 12,
            "river": 12,
            "show down": 12
        }

        self.replay_buffers = {}
        self.qnets = {}
        self.qnet_targets = {}
        self.optimizers = {}
        self.loss_lists = {}

        for phase in self.phases:
            self.replay_buffers[phase] = ReplayBuffer(self.buffer_size, self.batch_size)
            self.qnets[phase] = QNet(self.action_size, self.state_sizes[phase])
            self.qnet_targets[phase] = QNet(self.action_size, self.state_sizes[phase])
            self.optimizers[phase] = optim.Adam(self.qnets[phase].parameters(), lr=self.lr)
            self.loss_lists[phase] = []

    def sync_qnet(self):
        for phase in self.phases:
            self.qnet_targets[phase].load_state_dict(self.qnets[phase].state_dict())

    def get_action(self, state, mask, phase_index):
        state = torch.tensor(state[np.newaxis, :], dtype=torch.float32)
        mask = torch.tensor(mask, dtype=torch.float32)

        if np.random.rand() < self.epsilon:
            valid_actions = torch.nonzero(mask).squeeze(-1).numpy()
            return np.random.choice(valid_actions)

        qs = self.qnets[self.phases[phase_index]](state)
        qs_mask = qs * mask
        return qs_mask.argmax().item()

    def update(self, state, action, reward, next_state, done, current_phase_index, next_phase_index):
        current_phase = self.phases[current_phase_index]
        next_phase = self.phases[next_phase_index]

        replay_buffer = self.replay_buffers[current_phase]
        qnet = self.qnets[current_phase]
        qnet_target = self.qnet_targets[next_phase]
        optimizer = self.optimizers[current_phase]
        loss_list = self.loss_lists[current_phase]

        next_state = self.preprocess_next_state(next_state, self.state_sizes[next_phase])

        state = torch.tensor(state, dtype=torch.float32)
        action = torch.tensor(action, dtype=torch.int64)
        reward = torch.tensor(reward, dtype=torch.float32)
        next_state = torch.tensor(next_state, dtype=torch.float32)
        done = torch.tensor(done, dtype=torch.float32)

        replay_buffer.add(state, action, reward, next_state, done)
        if len(replay_buffer) < self.batch_size:
            return

        state, action, reward, next_state, done = replay_buffer.get_batch()

        qs = qnet(state)
        q = qs[np.arange(len(action)), action]

        next_qs = qnet_target(next_state)
        next_q = next_qs.max(1)[0].detach()
        target = reward + (1 - done) * self.gamma * next_q

        loss_fn = nn.MSELoss()
        loss = loss_fn(q, target)
        wandb.log({current_phase: loss})
        loss_list.append(loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    @staticmethod
    def preprocess_next_state(next_state, target_shape):
        if next_state is None:
            return np.zeros(target_shape)
        return next_state

    def copy_from(self, other_agent):
        for phase in self.phases:
            self.qnets[phase].load_state_dict(other_agent.qnets[phase].state_dict())
            self.optimizers[phase].load_state_dict(other_agent.optimizers[phase].state_dict())


In [7]:
def agent_action_from_a(game,agent_a,agent_b,state_a):
    """
    returns:
        end_flag:roundが終了したかどうかを示す
        next_state:roundが終了しているとNoneになる
        reward_a:agent_aのreward
    """
    
    mask_a = game.one_round.mask(game.one_round.current_index)
    action_a_index= agent_a.get_action(state_a,mask_a,game.one_round.current_phase)

    index_to_action = {0:"f",1:"c",2:"r"}
    action_a_alpa = index_to_action[action_a_index]

    reward_a, state_b, current_phase,next_phase_sb = game.step(action_a_alpa)

    if action_a_alpa == "f":
        # next_stateがないので、更新はrewardのみでされる（next_phaseとかは適当です）
        agent_a.update(state_a,action_a_index,reward_a,None,True,current_phase,current_phase)
        return True,None,reward_a
    
    if game.one_round.current_phase == 4:
        # next_stateがないので、更新はrewardのみでされる（next_phaseとかは適当です）
        agent_a.update(state_a,action_a_index,reward_a,None,True,current_phase,current_phase)
        return True,None,reward_a

    mask_b = game.one_round.mask(game.one_round.current_index)
    action_b_index = agent_b.get_action(state_b,mask_b,game.one_round.current_phase)

    action_b_alpa = index_to_action[action_b_index]
    reward_b, next_state_a, current_phase_sb, next_phase = game.step(action_b_alpa)

    if action_b_alpa == "f":
        # next_stateがないので、更新はrewardのみでされる（next_phaseとかは適当です）
        agent_a.update(state_a,action_a_index,reward_a,None,True,current_phase,current_phase)
        return True,None,reward_a
    
    if game.one_round.current_phase == 4:
        # next_stateがないので、更新はrewardのみでされる（next_phaseとかは適当です）
        agent_a.update(state_a,action_a_index,reward_a,None,True,current_phase,current_phase)
        return True,None,reward_a

    agent_a.update(state_a,action_a_index,reward_a,next_state_a,False,current_phase,next_phase)
    state_a = next_state_a

    return False,state_a,reward_a

In [8]:
def agent_action_b(game,state_b,agent_b):
    mask_b = game.one_round.mask(game.one_round.current_index)
    action_b_index = agent_b.get_action(state_b,mask_b,game.one_round.current_phase)
    
    index_to_action = {0:"f",1:"c",2:"r"}
    action_b_alpa = index_to_action[action_b_index]
    reward_b, state_a = game.step(action_b_alpa)[0],game.step(action_b_alpa)[1]

    if action_b_alpa == "f":
        end_flag = True
        return end_flag,None
    elif game.one_round.current_phase == 4:
        end_flag = True
        return end_flag,None
    else:
        end_flag = False
        return end_flag,state_a

In [9]:
episodes = 10000
sync_interval = 50
agent_a = DQNAgent()
agent_b = DQNAgent()

for episode in tqdm(range(episodes)):
# for episode in range(episodes):
    # print("epoch:",episode,"が始まりました👏")
    game = Game(2,100000,100,5)
    # ゲームの初期条件を入手
    game.game_flag = False

    while not game.game_flag:
        end_flag = False
        state_a = game.one_round.player_state(0)
        state_b = game.one_round.player_state(1)
        # プレイヤーb（index=0）からstartの場合
        if game.one_round.current_index == 0:
            end_flag, state_a = agent_action_b(game,state_b,agent_b)

        # プレイヤーa（index=1）からstartの場合
        while not end_flag:
            end_flag, state_a,reward_a= agent_action_from_a(game,agent_a,agent_b,state_a)

    if episode % sync_interval == 0:
        agent_a.sync_qnet()
        agent_b.copy_from(agent_a)

    if episode % 10000 == 0:
        preflop_save_path = os.path.join(save_dir, "preflop.pth")
        flop_save_path = os.path.join(save_dir, "flop.pth")
        turn_save_path = os.path.join(save_dir, "turn.pth")
        river_save_path = os.path.join(save_dir, "river.pth")

        torch.save(agent_a.qnets["preflop"].state_dict(), preflop_save_path)
        torch.save(agent_a.qnets["flop"].state_dict(), flop_save_path)
        torch.save(agent_a.qnets["turn"].state_dict(), turn_save_path)
        torch.save(agent_a.qnets["river"].state_dict(), river_save_path)

100%|██████████| 10000/10000 [29:02<00:00,  5.74it/s]


In [None]:
for episode in tqdm(range(episodes)):
# for episode in range(episodes):
    # print("epoch:",episode,"が始まりました👏")
    game = Game(2,100000,100,5)
    # ゲームの初期条件を入手
    game.game_flag = False

    while not game.game_flag:
        end_flag = False
        state_a = game.one_round.player_state(0)
        state_b = game.one_round.player_state(1)
        # プレイヤーb（index=0）からstartの場合
        if game.one_round.current_index == 0:
            end_flag, state_a = agent_action_b(game,state_b,agent_b)

        # プレイヤーa（index=1）からstartの場合
        while not end_flag:
            end_flag, state_a,reward_a= agent_action_from_a(game,agent_a,agent_b,state_a)

    if episode % sync_interval == 0:
        agent_a.sync_qnet()
        agent_b.copy_from(agent_a)

    if episode % 10000 == 0:
        preflop_save_path = os.path.join(save_dir, "preflop.pth")
        flop_save_path = os.path.join(save_dir, "flop.pth")
        turn_save_path = os.path.join(save_dir, "turn.pth")
        river_save_path = os.path.join(save_dir, "river.pth")

        torch.save(agent_a.qnets["preflop"].state_dict(), preflop_save_path)
        torch.save(agent_a.qnets["flop"].state_dict(), flop_save_path)
        torch.save(agent_a.qnets["turn"].state_dict(), turn_save_path)
        torch.save(agent_a.qnets["river"].state_dict(), river_save_path)

In [10]:
# import matplotlib.pyplot as plt

# # 必要に応じて detach して numpy 配列に変換
# preflop_losses = [loss.detach().numpy() for loss in agent_a.loss_lists["preflop"]]

# # グラフを描画
# plt.figure(figsize=(14, 6))
# plt.plot(preflop_losses, marker='o', linestyle='-', color='b', label='Preflop Losses')

# # グラフのタイトルやラベルを設定
# plt.title('Preflop Loss Over Time', fontsize=16)
# plt.xlabel('Update Steps', fontsize=12)
# plt.ylabel('Loss Value', fontsize=12)

# # グリッドと凡例を追加
# plt.grid(True, linestyle='--', alpha=0.6)
# plt.legend(fontsize=12)

# # グラフを表示
# plt.show()


In [16]:
preflop_save_path = os.path.join(save_dir, "preflop.pth")
flop_save_path = os.path.join(save_dir, "flop.pth")
turn_save_path = os.path.join(save_dir, "turn.pth")
river_save_path = os.path.join(save_dir, "river.pth")

torch.save(agent_a.qnets["preflop"].state_dict(), preflop_save_path)
torch.save(agent_a.qnets["flop"].state_dict(), flop_save_path)
torch.save(agent_a.qnets["turn"].state_dict(), turn_save_path)
torch.save(agent_a.qnets["river"].state_dict(), river_save_path)

In [21]:
### 保存していたデータを復刻

save_agent_a = DQNAgent()
save_agent_b = DQNAgent()

phases = ["preflop", "flop", "turn", "river"]

# for phase in phases:
#     /home/mil/oba/poker_nn/poker_env/models/dqn_qmany/20251323160/flop.pth
#     save_agent_a.qnets[phase].load_state_dict(torch.load('models/dqn_many/20251323160/'+ phase +'.pth'))

#     save_agent_b.qnets[phase].load_state_dict(torch.load('models/dqn_many/20251323160/'+ phase +'.pth'))

  agent_a.qnets[phase].load_state_dict(torch.load('models/dqn_many/20251323160/'+ phase +'.pth'))


FileNotFoundError: [Errno 2] No such file or directory: 'models/dqn_many/20251323160/preflop.pth'

In [None]:
for episode in range(100):
    # print("episode:",episode,"が始まりました👏")
    # ゲームの設定
    # プレイヤーは二人
    game = Game(2,100000,100,3,True)
    # ゲームの初期条件を入手
    game.game_flag = False

    while not game.game_flag:
        end_flag = False
        state_a = game.one_round.player_state(0)
        state_b = game.one_round.player_state(1)
        # 各ラウンド
        if game.one_round.current_index == 0:
            end_flag, state_a = agent_action_b(state_b,agent_b,game)
        while not end_flag:
            end_flag, state_a,reward_a= agent_action_from_a(game,agent_a,agent_b,state_a)

    if episode % 50 == 0:
        agent_b.copy_from(agent_a)