In [34]:
%pip install wandb
%pip install matplotlib
%pip install numpy
%pip install tqdm
%matplotlib inline
%pip install gymnasium==0.29.1



In [35]:
#@title Imports
from collections import defaultdict #for accessing keys which are not present in dictionary
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import gymnasium as gym
import sys
import random
from matplotlib.patches import Patch
import seaborn as sns

In [36]:
class MC_BlackjackAgent:
    def __init__(self):
        pass

    def train(self, mc_iterations=100000):
        possible_nums = [1,2,3,4,5,6,7,8,9,10,10,10,10]
        possible_current_sums = {i:0 for i in range(12,22)} #from 12 bcs for lower numbers you should always hit
        for current_sum in tqdm((possible_current_sums)):
            for _ in range(mc_iterations):
                hit_value = random.choice(possible_nums)
                if hit_value == 1:
                    if (current_sum + 11) <= 21:
                        hit_value = 11
                    else:
                        hit_value = 1
                if (current_sum + hit_value) > 21:
                    possible_current_sums[current_sum] += 1


        trivial_probabilities = {i:1 for i in range(1,12)}
        self.likelihood_of_hit = trivial_probabilities | {i: 1 - round(possible_current_sums[i]/mc_iterations, 4) for i in possible_current_sums}


    def play(self, obs):
        if obs[0] > 21:
            return 0
        else:
            return 1 if random.rand() < self.likelihood_of_hit[obs[0]] else 0


In [37]:
#initialize the agent
agent = MC_BlackjackAgent()
agent.train(mc_iterations=100000)


100%|██████████| 10/10 [00:00<00:00, 14.52it/s]


In [38]:
agent.likelihood_of_hit

{1: 1,
 2: 1,
 3: 1,
 4: 1,
 5: 1,
 6: 1,
 7: 1,
 8: 1,
 9: 1,
 10: 1,
 11: 1,
 12: 0.6931,
 13: 0.6159,
 14: 0.5368999999999999,
 15: 0.45940000000000003,
 16: 0.3879,
 17: 0.30900000000000005,
 18: 0.23040000000000005,
 19: 0.15400000000000003,
 20: 0.07709999999999995,
 21: 0.0}

In [40]:
from collections import deque
from gymnasium.wrappers import RecordEpisodeStatistics
from IPython.display import clear_output
import wandb
import pygame
from numpy import random

#load the environment
env = gym.make('Blackjack-v1',sab=False, natural=True, render_mode='rgb_array') #We are not folllowing the default sutton and barto book settings, which are sab=True, natural=False, render_mode='human'

# Initialize wandb
wandb.init(project="blackjack_MC_100", entity="ai42")
pygame.init()


n_episodes = 1000  # Define the number of episodes you want to run

wins = 0.0
losses = 0.0
draws = 0.0
naturals = 0.0

for episode in tqdm(range(n_episodes)):
    obs, info = env.reset()
    terminated, truncated = False, False
    clear_output()
    step = 0
    episode_rewards = 0  # Initialize total rewards for the episode

    while not terminated and not truncated:
        action = agent.play(obs)  # Agent's policy
        obs, reward, terminated, truncated, info = env.step(action)


        frame = env.render()
        step += 1
        episode_rewards += reward  # Accumulate rewards

        # Plot frame
        plt.imshow(frame)
        plt.axis('off')
        plt.title(f"Episode: {episode} - Step: {step} - Action Taken: {action} - Reward: {reward} - Terminated: {terminated}")

        plt.savefig('frame.png')
        plt.close()

        # Log the frame and rewards to wandb
        wandb.log({
            "episode": episode,
            "step": step,
            "frame": wandb.Image('frame.png'),
            "reward": reward,
            "cumulative_reward": episode_rewards
        })
    if reward == 1 or reward == 1.5:
        wins += 1
    elif reward == -1:
        losses += 1
    elif reward == 0:
        draws += 1
    if reward == 1.5:
        naturals += 1

env.close()

# Let´s log general statistics of the training
wandb.log({"Win_rate": wins / n_episodes, "Loss_rate": losses / n_episodes, "Draw_rate": draws / n_episodes, "Natural_win_rate": naturals / n_episodes}) # Log the episode statistics to wandb


100%|██████████| 1000/1000 [07:24<00:00,  2.25it/s]
