<a href="https://colab.research.google.com/github/alikc218/GomokuMuzero/blob/main/BlackjackMuZero.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!pip install numpy==1.21.0 torch tqdm

Collecting numpy==1.21.0
  Using cached numpy-1.21.0.zip (10.3 MB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
C

In [28]:
import numpy as np
from enum import Enum, auto

class BlackjackAction(Enum):
    HIT = 0
    STAND = 1
    DOUBLE = 2
    SPLIT = 3

class BlackjackEnv:
    def __init__(self, decks=6):
        self.decks = decks
        self.reset()

    def reset(self):
        self.deck = self._create_deck()
        np.random.shuffle(self.deck)
        self.deck_pos = 0
        self.player_hand = [self._draw_card(), self._draw_card()]
        self.dealer_hand = [self._draw_card(), self._draw_card()]
        self.done = False
        return self._get_obs()

    def _create_deck(self):
        return np.array([2,3,4,5,6,7,8,9,10,10,10,10,11]*4*self.decks)

    def _draw_card(self):
        card = self.deck[self.deck_pos]
        self.deck_pos += 1
        return int(card)

    def _get_obs(self):
        return {
            'player_sum': self._hand_value(self.player_hand),
            'dealer_card': self.dealer_hand[0],
            'usable_ace': self._has_usable_ace(self.player_hand),
            'can_split': len(self.player_hand) == 2 and self.player_hand[0] == self.player_hand[1]
        }

    def _hand_value(self, hand):
        total = sum(hand)
        aces = hand.count(11)
        while total > 21 and aces:
            total -= 10
            aces -= 1
        return total

    def _has_usable_ace(self, hand):
        return 11 in hand and self._hand_value(hand) <= 21

    def step(self, action):
        if self.done:
            raise ValueError("Game ended")

        reward = 0
        if action == BlackjackAction.HIT:
            self.player_hand.append(self._draw_card())
            if self._hand_value(self.player_hand) > 21:
                self.done = True
                reward = -1
        elif action == BlackjackAction.STAND:
            self._dealer_play()
            reward = self._get_result()
            self.done = True

        return self._get_obs(), reward, self.done, {}

    def _dealer_play(self):
        while self._hand_value(self.dealer_hand) < 17:
            self.dealer_hand.append(self._draw_card())

    def _get_result(self):
        player = self._hand_value(self.player_hand)
        dealer = self._hand_value(self.dealer_hand)

        if player > 21: return -1
        if dealer > 21: return 1
        if player > dealer: return 1
        if player < dealer: return -1
        return 0

    def legal_actions(self):
        return [a for a in BlackjackAction]

In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BlackjackMuZero(nn.Module):
    def __init__(self, hidden_size=64):
        super().__init__()
        self.hidden_size = hidden_size

        # Representation network
        self.rep_net = nn.Sequential(
            nn.Linear(4, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )

        # Dynamics network
        self.dyn_net = nn.Sequential(
            nn.Linear(hidden_size + 4, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )

        # Prediction network
        self.policy_net = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 4),
            nn.Softmax(dim=-1)
        )

        self.value_net = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
            nn.Tanh()
        )

    def _prepare_observation(self, obs):
        """Convert observation dict to tensor"""
        if isinstance(obs, dict):
            return torch.FloatTensor([
                obs['player_sum'] / 21.0,
                obs['dealer_card'] / 11.0,
                float(obs['usable_ace']),
                float(obs['can_split'])
            ]).unsqueeze(0)  # Add batch dimension
        return obs

    def _prepare_action(self, action):
        """Convert action to one-hot tensor"""
        action_onehot = torch.zeros(4)
        action_onehot[action] = 1.0
        return action_onehot.unsqueeze(0)  # Add batch dimension

    def representation(self, obs):
        """Initial state representation"""
        x = self._prepare_observation(obs)
        return self.rep_net(x)

    def dynamics(self, state, action):
        """State transition function"""
        action_encoded = self._prepare_action(action)
        x = torch.cat([state, action_encoded], dim=1)
        return self.dyn_net(x)

    def prediction(self, state):
        """Policy and value prediction"""
        return self.policy_net(state), self.value_net(state)

    def initial_inference(self, obs):
        """Initial pass through the network"""
        state = self.representation(obs)
        policy, value = self.prediction(state)
        return {
            'state': state,
            'policy': policy,
            'value': value
        }

    def recurrent_inference(self, state, action):
        """Recurrent pass through the network"""
        next_state = self.dynamics(state, action)
        policy, value = self.prediction(next_state)
        return {
            'state': next_state,
            'policy': policy,
            'value': value
        }

In [39]:
import math
import numpy as np
import torch

class Node:
    def __init__(self, prior):
        self.visit_count = 0
        self.value_sum = 0
        self.prior = prior
        self.children = {}  # Словарь для дочерних узлов
        self.state = None
        self.reward = 0

    def expanded(self):
        """Проверяет, раскрыт ли узел (есть ли дети)"""
        return len(self.children) > 0

    def value(self):
        """Возвращает среднее значение узла"""
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count


In [40]:
class MCTS:
    def __init__(self, model, num_simulations=50, c_puct=1.0):
        self.model = model
        self.num_simulations = num_simulations
        self.c_puct = c_puct

    def run(self, observation):
        root = Node(0)
        root.state = self.model.representation(observation)

        for _ in range(self.num_simulations):
            node = root
            search_path = [node]
            action = 0  # Действие по умолчанию

            # Фаза Selection (выбор до листа)
            while node.expanded():
                action, node = self.select_child(node)
                search_path.append(node)

            # Фаза Expansion (раскрытие узла)
            parent = search_path[-2] if len(search_path) >= 2 else root

            if node.visit_count > 0 and not node.expanded():
                # Получаем политику из модели
                policy, _ = self.model.prediction(node.state)
                policy = policy.squeeze(0).detach().numpy()

                # Создаем дочерние узлы
                for a, prob in enumerate(policy):
                    node.children[a] = Node(prob)

                # Выбираем действие согласно политике
                action = np.random.choice(len(policy), p=policy)
                node = node.children[action]
                search_path.append(node)

            # Фаза Simulation (прогон динамики)
            with torch.no_grad():
                # Убедимся, что состояние родителя существует
                if parent.state is None:
                    parent.state = self.model.representation(observation)

                next_state = self.model.dynamics(parent.state, action)
                reward = 0
                policy, value = self.model.prediction(next_state)

                # Обновляем узел
                node.state = next_state
                node.reward = reward

            # Фаза Backpropagation (обновление статистик)
            self.backpropagate(search_path, value.item(), reward)

        # Возвращаем вероятности действий
        visit_counts = np.array([
            root.children[a].visit_count
            for a in range(4) if a in root.children
        ])

        if len(visit_counts) == 0:
            return np.ones(4)/4  # Равномерное распределение если нет посещений

        return visit_counts / visit_counts.sum()

    def select_child(self, node):
        """Выбор дочернего узла по UCB"""
        total_visits = sum(c.visit_count for c in node.children.values())

        def ucb_score(child):
            if child.visit_count == 0:
                return float('inf')  # Всегда исследуем непосещенные узлы
            return child.value() + self.c_puct * child.prior * math.sqrt(total_visits) / (child.visit_count + 1)

        return max(node.children.items(), key=lambda item: ucb_score(item[1]))

    def backpropagate(self, path, value, reward):
        """Обновление статистик вдоль пути"""
        for node in reversed(path):
            node.value_sum += value
            node.visit_count += 1
            value = reward + 0.99 * value  # С учетом discount factor

In [31]:
from collections import deque
import random
import torch.optim as optim

class Trainer:
    def __init__(self, model, lr=1e-3, buffer_size=10000, batch_size=32):
        self.model = model
        self.optimizer = optim.Adam(model.parameters(), lr=lr)
        self.replay_buffer = deque(maxlen=buffer_size)
        self.batch_size = batch_size

    def self_play(self, env, num_games=10):
        for _ in range(num_games):
            obs = env.reset()
            done = False
            trajectory = []

            while not done:
                action_probs = MCTS(self.model).run(obs)
                action = np.random.choice(4, p=action_probs)
                next_obs, reward, done, _ = env.step(action)

                trajectory.append({
                    'obs': obs,
                    'action_probs': action_probs,
                    'reward': reward,
                    'done': done
                })
                obs = next_obs

            # Add value targets
            value_target = 0
            for t in reversed(trajectory):
                value_target = t['reward'] + 0.99 * value_target * (1 - t['done'])
                t['value_target'] = value_target

            self.replay_buffer.extend(trajectory)

    def train_step(self):
        if len(self.replay_buffer) < self.batch_size:
            return

        batch = random.sample(self.replay_buffer, self.batch_size)
        sample = batch[0]  # Using single sample for simplicity

        # Forward pass
        output = self.model.initial_inference(sample['obs'])

        # Compute losses
        policy_loss = -torch.mean(
            torch.sum(torch.log(output['policy']) *
            torch.FloatTensor(sample['action_probs']))
        )
        value_loss = F.mse_loss(
            output['value'].squeeze(),
            torch.FloatTensor([sample['value_target']])
        )
        total_loss = policy_loss + value_loss

        # Backward pass
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()

        return total_loss.item()

In [None]:
import numpy as np
from tqdm import tqdm

def evaluate(model, env, num_games=20):
    wins = 0
    for _ in range(num_games):
        obs = env.reset()
        done = False

        while not done:
            action_probs = MCTS(model, num_simulations=20).run(obs)
            action = np.argmax(action_probs)
            obs, reward, done, _ = env.step(action)

        if reward > 0:
            wins += 1

    return wins / num_games

def main():
    # Initialize
    env = BlackjackEnv()
    model = BlackjackMuZero()
    trainer = Trainer(model)

    # Training loop
    num_episodes = 100
    for episode in tqdm(range(num_episodes)):
        trainer.self_play(env, num_games=5)

        # Train on collected data
        for _ in range(10):
            loss = trainer.train_step()

        # Evaluation
        if episode % 10 == 0:
            eval_score = evaluate(model, env)
            print(f"Episode {episode}, Loss: {loss:.3f}, Win rate: {eval_score:.2f}")

if __name__ == "__main__":
    # Test model initialization
    env = BlackjackEnv()
    model = BlackjackMuZero()

    test_obs = env.reset()
    print("Test observation:", test_obs)

    out = model.initial_inference(test_obs)
    print("Initial inference:")
    print("State shape:", out['state'].shape)
    print("Policy shape:", out['policy'].shape)
    print("Value shape:", out['value'].shape)

    # Start training
    main()

Test observation: {'player_sum': 16, 'dealer_card': 2, 'usable_ace': False, 'can_split': False}
Initial inference:
State shape: torch.Size([1, 64])
Policy shape: torch.Size([1, 4])
Value shape: torch.Size([1, 1])


  0%|          | 0/100 [00:00<?, ?it/s]