**Part 3**

* Adding vectorized BlackJack Environment
* Adding Device Agnostic code (GPU Training)
* Attempting to use Softmax (Categorical Distribution) implementation instead of Sigmoid (Binary Bernoulli Distribution)

**Results**

* Still very slow during training

# Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import gymnasium as gym

# Testing

In [2]:
env = gym.make("Blackjack-v1", sab=True) # `render_mode="human"` creates a pygame popup window to analyze play # `sab=True` uses the Sutton & Barto version

# Agent

In [3]:
class BlackJackAgent(nn.Module):
    def __init__(self, obs_size=3, hidden_size=10, output_size=2):
        super(BlackJackAgent, self).__init__()
        self.layer_1 = nn.Linear(obs_size, hidden_size)
        self.layer_2 = nn.Linear(hidden_size, output_size)
        self.action_probs_activation_layer = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = torch.relu(self.layer_1(x))
        logits = self.layer_2(x)
        return logits       # later use nn.Softmax to get probabilities

    def get_action_probs(self, logits):
        """Get the probabilities of each action."""
        return self.action_probs_activation_layer(logits)
    
    def sample_action(self, action:None):
        """Get the probability of choosing the action"""
        logits = self.forward(action)
        probs = self.get_action_probs(logits)
        dist = torch.distributions.Categorical(probs=probs)
        action = dist.sample().item()
        prob_of_action = dist.log_prob(action)
        return action, prob_of_action

# Training Loop

In [4]:
def training_blackjack_agent(epochs=50, learning_rate=0.0001, batch_size=64, gamma=0.99, k_epochs=64, epsilon=0.2, beta_kl=0.01, max_grad_norm=0.5, entropy_coeff=0.01, log_iterations=10, device="cpu", num_envs=16) -> BlackJackAgent: 
    print(f"Training BlackJack Agent's Policy on {device} with {epochs} epochs, {learning_rate} learning rate, batch size {batch_size}, and KL beta {beta_kl}.")

    vec_env = gym.make_vec("Blackjack-v1", num_envs=num_envs, sab=True) # `sab=True` uses the Sutton & Barto version

    # steps_per_env_per_rollout = batch_size // num_envs if batch_size % num_envs == 0 else (batch_size // num_envs) + 1

    New_Policy = BlackJackAgent().to(device)   # STEP 3 || 
    optimizer = optim.Adam(params=New_Policy.parameters(), lr=learning_rate)


    for epoch in tqdm(range(epochs), desc=f"Main Epoch (Outer Loop)", leave=False):     # STEP 4 || 
        # STEP 5 || Sample a batch D_b from D --> OMITTED 
        # STEP 6 || Update the old policy model PI old <- PI new
        Policy_Old = BlackJackAgent().to(device)
        Policy_Old.load_state_dict(New_Policy.state_dict())
        Policy_Old.eval()   # Prevent Gradient tracking

        # This will store trajectories for all episodes collected in the current batch
        completed_batch_trajectories = []

        # Reset all vectorized environments
        raw_observations, infos = vec_env.reset() # observations is a numpy array of shape (num_envs, obs_dim(3))
        observations = np.stack(raw_observations, axis=1)
        dones = np.array([False] * num_envs) # Track the done status for each parallel environment
        truncateds = np.array([False] * num_envs) # Track truncated status for each parallel environment

        # Initialize current trajectories for all parallel environments
        # Each element in this list will be a dict for an *in-progress* episode in a specific env
        current_episode_trajectories = [{"states": [], "actions": [], "rewards": [], "log_probs": []} for _ in range(num_envs)]

        # --- STEP 7 Collect a Batch of Experiences Using the Old Policy---
        # Loop Agent prediction, recording trajectories to lists:
        episodes_collected_in_batch = 0
        max_steps_per_batch_limit = batch_size * 5 # A safety limit to prevent infinite loops if episodes are very long
        current_total_steps = 0

        while episodes_collected_in_batch < batch_size and current_total_steps < max_steps_per_batch_limit:
            obs_tensor = torch.tensor(observations, dtype=torch.float32).to(device)

            with torch.no_grad():
                logits = Policy_Old(obs_tensor)
                dist = torch.distributions.Categorical(logits=logits)
                actions = dist.sample() # Tensor of shape [1]
                log_probs = dist.log_prob(actions)
                    
            raw_next_obs, rewards, dones, truncateds, infos = vec_env.step(actions.cpu().numpy()) # actions must be on CPU for env.step()
            next_obs = np.stack(raw_next_obs, axis=1)
            current_total_steps += num_envs

            # Process data for each parallel environment
            for env_idx in range(num_envs):
                
                obs_to_append = observations[env_idx]
                if isinstance(obs_to_append, torch.Tensor):
                    obs_to_append = obs_to_append.cpu().numpy()
                # Store current_episode_trajectories
                current_episode_trajectories[env_idx]["states"].append(obs_to_append)
                current_episode_trajectories[env_idx]["actions"].append(actions[env_idx].item())
                current_episode_trajectories[env_idx]["rewards"].append(rewards[env_idx])
                current_episode_trajectories[env_idx]["log_probs"].append(log_probs[env_idx].cpu())
                
                if dones[env_idx] or truncateds[env_idx]:
                    completed_batch_trajectories.append(current_episode_trajectories[env_idx])
                    episodes_collected_in_batch += 1

                    # Reset this specific environment
                    # new_obs, new_info = vec_env.reset_at(env_idx)
                    # observations[env_idx] = new_obs

                    current_episode_trajectories[env_idx] = {"states": [], "actions": [], "rewards": [], "log_probs": []}

            observations = next_obs  # Update the observation

        for env_idx in range(num_envs):
            if len(current_episode_trajectories[env_idx]["states"]) > 0:
                # If there's partial data, it means the episode was still running
                # when `batch_size` was met. You'll need to decide how to handle this.
                # For simplicity for now, we'll append them. In full PPO, you'd add
                # the value of the last state to its rewards.
                completed_batch_trajectories.append(current_episode_trajectories[env_idx])
                # Note: These might not be "full" episodes in the sense of reaching a done state,
                # but they contribute steps to your batch.

        # These lists will hold data from ALL episodes in the current batch for Advantage Calculation
        all_states = []
        all_actions = []
        all_old_log_probs = []
        all_discounted_rewards = []

        # STEP 8 || Calculate Discounted Rewards for completed trajectories
        for episode_trajectory in completed_batch_trajectories: 
            rewards = episode_trajectory["rewards"]
            states = episode_trajectory["states"]
            actions = episode_trajectory["actions"]
            log_probs = episode_trajectory["log_probs"]
            
            if not rewards:
                continue

            discounted_reward = 0
            returns_for_episode = []
            for reward in reversed(rewards):
                discounted_reward = reward + gamma * discounted_reward
                returns_for_episode.insert(0, discounted_reward)

            discounted_rewards = torch.tensor(returns_for_episode, dtype=torch.float32)
            # print(f"discounted_rewards size: {discounted_rewards.size()}")
            # Add each trajectory information for the batch
            if states:
                all_states.extend(states)
                all_actions.extend(actions)
                all_old_log_probs.extend(log_probs)
                all_discounted_rewards.extend(discounted_rewards.tolist())

        # --- IMPORTANT: Pre-tensorization checks and conversions ---
        if not all_states or not all_actions or not all_old_log_probs or not all_discounted_rewards:
            print(f"Warning: Epoch {epoch + 1}: Insufficient data collected for optimization. "
                  f"Skipping policy update for this epoch.")
            print(f"  Counts: States={len(all_states)}, Actions={len(all_actions)}, "
                  f"LogProbs={len(all_old_log_probs)}, Rewards={len(all_discounted_rewards)}")
            continue
        # Convert all collected batch data into PyTorch tensors
        all_states_tensor = torch.tensor(np.array(all_states), dtype=torch.float32).to(device)
        all_actions_tensor = torch.tensor(all_actions, dtype=torch.long).to(device)
        # Stack individual log_prob tensors and then flatten if necessary
        all_old_log_probs_tensor = torch.tensor(all_old_log_probs, dtype=torch.float32).to(device) # Ensure it's a 1D tensor
        all_discounted_rewards_tensor = torch.tensor(all_discounted_rewards, dtype=torch.float32).to(device)

        # STEP 9 || Calculate the Advantage of each Time Step for each Trajectory using normalization
        all_advantages_tensor = (all_discounted_rewards_tensor - all_discounted_rewards_tensor.mean()) / (all_discounted_rewards_tensor.std() + 1e-8)

        # Detach these tensors from any computation graph history
        # as they represent fixed data for the policy updates in k_epochs.
        # This prevents the "RuntimeError: Trying to backward through the graph a second time".
        all_states_tensor = all_states_tensor.detach()
        all_actions_tensor = all_actions_tensor.detach()
        all_old_log_probs_tensor = all_old_log_probs_tensor.detach()
        all_advantages_tensor = all_advantages_tensor.detach()

        New_Policy.train()  # Prepare NN for updates

        # --- STEP 10 || GRPO Optimization ---
        for k_epoch in tqdm(range(k_epochs), desc=f"Epoch {epoch+1}/{epochs} (Inner K-Epochs)", leave=True):
            new_logits = New_Policy(all_states_tensor)
            new_dist = torch.distributions.Categorical(logits=new_logits)
            new_log_probs = new_dist.log_prob(all_actions_tensor)
            entropy = new_dist.entropy().mean() # Calculate entropy for regularization

            R1_ratio = torch.exp(new_log_probs - all_old_log_probs_tensor)

            unclipped_surrogate = R1_ratio * all_advantages_tensor
            clipped_surrogate = torch.clamp(input=R1_ratio, min=1.0-epsilon, max=1.0+epsilon) * all_advantages_tensor

            policy_loss = -torch.min(unclipped_surrogate, clipped_surrogate).mean()

            # --- KL Divergence Calculation ---
            # Create distributions for old policies using the trajectory states
            with torch.no_grad():
                old_logits = Policy_Old(all_states_tensor)
            old_dist = torch.distributions.Categorical(logits=old_logits)

            # Calculate KL divergence per sample, then take the mean over the batch
            kl_div_per_sample = torch.distributions.kl.kl_divergence(p=new_dist, q=old_dist)
            kl_loss = kl_div_per_sample.mean() # Mean over the batch

            # Total Loss for GRPO
            total_loss = policy_loss + beta_kl * kl_loss - entropy_coeff * entropy

            # STEP 11 || Policy Updates
            optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(New_Policy.parameters(), max_grad_norm)
            optimizer.step()    # Update policy parameters using gradient ascent
        
        
        # --- 4. Logging and Evaluation ---
        if (epoch + 1) % log_iterations == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss.item():.4f}, Ratio: {R1_ratio.mean().item():.5f}, Entropy Term: {entropy:.5f}")
            # You can add more evaluation metrics here, e.g., average reward per episode
            # For Blackjack, the reward is often -1, 0, or 1.
            avg_reward = sum(sum(ep["rewards"]) for ep in completed_batch_trajectories) / len(completed_batch_trajectories) if len(completed_batch_trajectories) > 0 else 0
            print(f"Average reward per episode in batch: {avg_reward:.2f}")

    New_Policy.eval()   # Change to eval mode for evaluation


    vec_env.close() # Close the environment after training
    print("Training complete.")
    return New_Policy # Return the trained policy

In [None]:
_ = training_blackjack_agent(epochs=50, learning_rate=0.0001, batch_size=64, gamma=0.99, k_epochs=64, epsilon=0.2, beta_kl=0.01, max_grad_norm=0.5, entropy_coeff=0.01, log_iterations=10, device="cpu", num_envs=16)

Training BlackJack Agent's Policy on cpu with 50 epochs, 0.0001 learning rate, batch size 64, and KL beta 0.01.


Epoch 1/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 228.44it/s]
Epoch 2/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 283.21it/s]
Epoch 3/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 278.40it/s]
Epoch 4/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 259.35it/s]
Epoch 5/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 227.80it/s]
Epoch 6/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 231.29it/s]
Epoch 7/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 247.38it/s]
Epoch 8/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 281.03it/s]
Epoch 9/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 303.99it/s]
Epoch 10/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 417.71it/s]
Main Epoch (Outer Loop):  20%|██        | 10/50 [00:02<00:09,  4.19it/s]

Epoch 10/50, Loss: -0.0254, Ratio: 0.99264, Entropy Term: 0.62366
Average reward per episode in batch: -0.22


Epoch 11/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 398.63it/s]
Epoch 12/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 403.31it/s]
Epoch 13/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 419.03it/s]
Epoch 14/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 428.51it/s]
Epoch 15/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 384.82it/s]
Epoch 16/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 427.74it/s]
Epoch 17/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 425.66it/s]
Epoch 18/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 399.89it/s]
Epoch 19/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 427.87it/s]
Epoch 20/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 461.36it/s]
Main Epoch (Outer Loop):  40%|████      | 20/50 [00:04<00:05,  5.94it/s]

Epoch 20/50, Loss: -0.0123, Ratio: 0.99465, Entropy Term: 0.35454
Average reward per episode in batch: -0.04


Epoch 21/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 457.00it/s]
Epoch 22/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 426.29it/s]
Epoch 23/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 425.86it/s]
Epoch 24/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 427.39it/s]
Epoch 25/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 425.28it/s]
Epoch 26/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 500.45it/s]
Epoch 27/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 427.01it/s]
Epoch 28/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 562.42it/s]
Epoch 29/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 583.06it/s]
Epoch 30/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 552.49it/s]
Main Epoch (Outer Loop):  60%|██████    | 30/50 [00:06<00:02,  6.80it/s]

Epoch 30/50, Loss: -0.0095, Ratio: 1.00775, Entropy Term: 0.35954
Average reward per episode in batch: -0.16


Epoch 31/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 528.25it/s]
Epoch 32/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 541.29it/s]
Epoch 33/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 543.55it/s]
Epoch 34/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 302.13it/s]
Epoch 35/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 282.89it/s]
Epoch 36/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 256.07it/s]
Epoch 37/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 257.99it/s]
Epoch 38/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 257.41it/s]
Epoch 39/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 255.38it/s]
Epoch 40/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 233.16it/s]
Main Epoch (Outer Loop):  80%|████████  | 40/50 [00:08<00:02,  3.72it/s]

Epoch 40/50, Loss: -0.0116, Ratio: 0.98755, Entropy Term: 0.39186
Average reward per episode in batch: 0.00


Epoch 41/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 228.79it/s]
Epoch 42/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 224.88it/s]
Epoch 43/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 320.70it/s]
Epoch 44/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 246.97it/s]
Epoch 45/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 457.74it/s]
Epoch 46/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 396.88it/s]
Epoch 47/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 398.05it/s]
Epoch 48/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 393.14it/s]
Epoch 49/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 456.34it/s]
Epoch 50/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 536.35it/s]
                                                                        

Epoch 50/50, Loss: -0.0065, Ratio: 1.00064, Entropy Term: 0.28929
Average reward per episode in batch: 0.00
Training complete.




Training BlackJack Agent's Policy on cpu with 50 epochs, 0.0001 learning rate, batch size 64, and KL beta 0.01.
                                                               
(array([21, 12, 19, 19, 20, 15, 18, 16, 10,  7, 19, 10, 17, 15, 14,  9]), array([ 5,  5,  1,  4, 10, 10,  3,  9,  2,  6,  1,  6,  9,  5,  9,  8]), array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))


In [None]:
# _ = training_blackjack_agent(epochs=50, learning_rate=0.0001, batch_size=64, gamma=0.99, k_epochs=64, epsilon=0.2, beta_kl=0.01, max_grad_norm=0.5, entropy_coeff=0.01, log_iterations=10, device="cuda", num_envs=16)

Training BlackJack Agent's Policy on cpu with 50 epochs, 0.0001 learning rate, batch size 64, and KL beta 0.01.


Epoch 1/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 406.51it/s]
Epoch 2/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 431.59it/s]
Epoch 3/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 380.85it/s]
Epoch 4/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 581.14it/s]
Epoch 5/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 455.89it/s]
Epoch 6/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 508.31it/s]
Epoch 7/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 563.64it/s]
Epoch 8/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 501.96it/s]
Epoch 9/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 518.32it/s]
Epoch 10/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 492.24it/s]
Main Epoch (Outer Loop):  20%|██        | 10/50 [00:01<00:05,  6.79it/s]

Epoch 10/50, Loss: -0.0003, Ratio: 1.00001, Entropy Term: 0.01753
Average reward per episode in batch: -0.86


Epoch 11/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 407.73it/s]
Epoch 12/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 399.74it/s]
Epoch 13/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 491.99it/s]
Epoch 14/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 507.07it/s]
Epoch 15/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 472.87it/s]
Epoch 16/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 409.59it/s]
Epoch 17/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 519.56it/s]
Epoch 18/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 455.93it/s]
Epoch 19/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 521.96it/s]
Epoch 20/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 273.56it/s]
Main Epoch (Outer Loop):  40%|████      | 20/50 [00:03<00:05,  5.40it/s]

Epoch 20/50, Loss: -0.0013, Ratio: 1.00067, Entropy Term: 0.01085
Average reward per episode in batch: -0.86


Epoch 21/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 250.88it/s]
Epoch 22/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 245.19it/s]
Epoch 23/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 263.62it/s]
Epoch 24/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 224.69it/s]
Epoch 25/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 245.87it/s]
Epoch 26/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 262.25it/s]
Epoch 27/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 322.41it/s]
Epoch 28/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 320.80it/s]
Epoch 29/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 303.73it/s]
Epoch 30/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 220.85it/s]
Main Epoch (Outer Loop):  60%|██████    | 30/50 [00:06<00:05,  3.69it/s]

Epoch 30/50, Loss: -0.0002, Ratio: 0.99980, Entropy Term: 0.01344
Average reward per episode in batch: -0.86


Epoch 31/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 448.13it/s]
Epoch 32/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 373.11it/s]
Epoch 33/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 388.90it/s]
Epoch 34/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 377.49it/s]
Epoch 35/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 400.09it/s]
Epoch 36/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 426.59it/s]
Epoch 37/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 400.90it/s]
Epoch 38/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 303.94it/s]
Epoch 39/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 473.30it/s]
Epoch 40/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 556.80it/s]
Main Epoch (Outer Loop):  80%|████████  | 40/50 [00:07<00:01,  5.87it/s]

Epoch 40/50, Loss: -0.0004, Ratio: 0.99995, Entropy Term: 0.03450
Average reward per episode in batch: -0.86


Epoch 41/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 455.70it/s]
Epoch 42/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 470.67it/s]
Epoch 43/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 546.37it/s]
Epoch 44/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 425.61it/s]
Epoch 45/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 397.02it/s]
Epoch 46/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 596.80it/s]
Epoch 47/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 581.12it/s]
Epoch 48/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 581.23it/s]
Epoch 49/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 454.23it/s]
Epoch 50/50 (Inner K-Epochs): 100%|██████████| 64/64 [00:00<00:00, 425.77it/s]
                                                                        

Epoch 50/50, Loss: -0.0005, Ratio: 0.99904, Entropy Term: 0.03725
Average reward per episode in batch: -0.91
Training complete.




Training BlackJack Agent's Policy with 10 epochs, 0.0001 learning rate, batch size 4, and KL beta 0.01.
* Batch of Trajectories:
* [{'states': [(12, 10, 0)], 'actions': [0], 'rewards': [-1.0], 'log_probs': [tensor([-0.1239])]}, 
* {'states': [(20, 7, 0)], 'actions': [0], 'rewards': [1.0], 'log_probs': [tensor([-0.0815])]}, 
* {'states': [(12, 1, 0), (17, 1, 0)], 'actions': [1, 1], 'rewards': [0.0, -1.0], 'log_probs': [tensor([-1.5968]), tensor([-1.9474])]}, 
* {'states': [(6, 6, 0)], 'actions': [0], 'rewards': [-1.0], 'log_probs': [tensor([-0.2144])]}, 
* {'states': [(7, 4, 0)], 'actions': [0], 'rewards': [-1.0], 'log_probs': [tensor([-0.2734])]}, 
* {'states': [(13, 3, 1)], 'actions': [0], 'rewards': [-1.0], 'log_probs': [tensor([-0.1471])]}, 
* {'states': [(15, 10, 0)], 'actions': [0], 'rewards': [-1.0], 'log_probs': [tensor([-0.1000])]}, 
* {'states': [(12, 10, 0)], 'actions': [0], 'rewards': [1.0], 'log_probs': [tensor([-0.1239])]}, 
* {'states': [(14, 7, 0)], 'actions': [0], 'rewards': [-1.0], 'log_probs': [tensor([-0.1320])]}]

In [7]:
# Example usage (assuming you have a way to call this function, e.g., in a main block)
if __name__ == '__main__':
    # You can adjust these parameters as needed
    # Using a larger batch_size for more stable training and to reduce empty batch issues
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   # Device Agnostic Code
    trained_policy = training_blackjack_agent(
        epochs=2000,
        learning_rate=0.0003,
        batch_size=2048, # Significantly larger batch size recommended for stability
        k_epochs=128,
        epsilon=0.2,
        beta_kl=0.01,
        entropy_coeff=0.001,
        log_iterations=100,
        gamma=0.99,
        device=device,
        num_envs=16
    )

    print("\nTesting the trained policy:")
    test_env = gym.make("Blackjack-v1", sab=True)
    total_test_rewards = 0
    num_test_episodes = 1000

    for _ in range(num_test_episodes):
        obs, _ = test_env.reset()
        done = False
        truncated = False
        episode_reward = 0
        while not done and not truncated:
            obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
            with torch.no_grad():
                logits = trained_policy(obs_tensor)
                dist = torch.distributions.Categorical(logits=logits)
                action = dist.sample()
            obs, reward, done, truncated, _ = test_env.step(action.item())
            episode_reward += reward
        total_test_rewards += episode_reward

    print(f"Average reward over {num_test_episodes} test episodes: {total_test_rewards / num_test_episodes:.4f}")
    test_env.close()

Training BlackJack Agent's Policy on cpu with 2000 epochs, 0.0003 learning rate, batch size 2048, and KL beta 0.01.


Epoch 1/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 148.84it/s]
Epoch 2/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 95.49it/s]
Epoch 3/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 194.74it/s]
Epoch 4/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 163.95it/s]
Epoch 5/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 196.46it/s]
Epoch 6/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 203.13it/s]
Epoch 7/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 100.98it/s]
Epoch 8/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 136.29it/s]
Epoch 9/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 191.10it/s]
Epoch 10/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 190.96it/s]
Epoch 11/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 215.40it/s]
Epoch 12/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 189.71it/s]
Epoch 13/2000 

Epoch 100/2000, Loss: -0.0015, Ratio: 0.99912, Entropy Term: 0.21172
Average reward per episode in batch: -0.02


Epoch 101/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 168.40it/s]
Epoch 102/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 105.46it/s]
Epoch 103/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 165.35it/s]
Epoch 104/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 171.87it/s]
Epoch 105/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 170.96it/s]
Epoch 106/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 189.72it/s]
Epoch 107/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 78.27it/s]
Epoch 108/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 157.38it/s]
Epoch 109/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 169.54it/s]
Epoch 110/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 168.25it/s]
Epoch 111/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 180.20it/s]
Epoch 112/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 87.98

Epoch 200/2000, Loss: -0.0014, Ratio: 0.99874, Entropy Term: 0.17166
Average reward per episode in batch: -0.06


Epoch 201/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:02<00:00, 48.73it/s]
Epoch 202/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:02<00:00, 49.66it/s]
Epoch 203/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:02<00:00, 57.52it/s]
Epoch 204/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 112.27it/s]
Epoch 205/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 67.37it/s]
Epoch 206/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 142.48it/s]
Epoch 207/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 167.78it/s]
Epoch 208/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 114.25it/s]
Epoch 209/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 82.32it/s]
Epoch 210/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 124.18it/s]
Epoch 211/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 168.19it/s]
Epoch 212/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 187.99it/

Epoch 300/2000, Loss: -0.0013, Ratio: 1.00091, Entropy Term: 0.12731
Average reward per episode in batch: -0.01


Epoch 301/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 79.40it/s]
Epoch 302/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 118.94it/s]
Epoch 303/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 147.98it/s]
Epoch 304/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 95.76it/s]
Epoch 305/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:02<00:00, 57.92it/s]
Epoch 306/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 70.73it/s]
Epoch 307/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 123.68it/s]
Epoch 308/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 73.98it/s]
Epoch 309/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 128.67it/s]
Epoch 310/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 94.12it/s] 
Epoch 311/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 96.03it/s] 
Epoch 312/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 93.47it/s

Epoch 400/2000, Loss: -0.0013, Ratio: 0.99934, Entropy Term: 0.10939
Average reward per episode in batch: -0.05


Epoch 401/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 164.08it/s]
Epoch 402/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 160.02it/s]
Epoch 403/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 162.78it/s]
Epoch 404/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 92.75it/s]
Epoch 405/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 155.35it/s]
Epoch 406/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 165.83it/s]
Epoch 407/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 162.11it/s]
Epoch 408/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 123.58it/s]
Epoch 409/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 133.31it/s]
Epoch 410/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 86.77it/s] 
Epoch 411/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 148.87it/s]
Epoch 412/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 103.6

Epoch 500/2000, Loss: -0.0015, Ratio: 0.99975, Entropy Term: 0.09447
Average reward per episode in batch: -0.03


Epoch 501/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 83.12it/s]
Epoch 502/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 161.00it/s]
Epoch 503/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 173.09it/s]
Epoch 504/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 176.58it/s]
Epoch 505/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 174.15it/s]
Epoch 506/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 116.87it/s]
Epoch 507/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 184.06it/s]
Epoch 508/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 185.59it/s]
Epoch 509/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 125.42it/s]
Epoch 510/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 84.66it/s]
Epoch 511/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 183.39it/s]
Epoch 512/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 177.76

Epoch 600/2000, Loss: -0.0007, Ratio: 0.99964, Entropy Term: 0.11397
Average reward per episode in batch: -0.07


Epoch 601/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 93.44it/s]
Epoch 602/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 97.70it/s]
Epoch 603/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 89.83it/s]
Epoch 604/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 65.98it/s]
Epoch 605/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 91.93it/s]
Epoch 606/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 85.33it/s]
Epoch 607/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 69.16it/s]
Epoch 608/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 87.68it/s] 
Epoch 609/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 73.57it/s]
Epoch 610/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 101.33it/s]
Epoch 611/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 82.99it/s]
Epoch 612/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 75.27it/s] 
E

Epoch 700/2000, Loss: -0.0010, Ratio: 0.99766, Entropy Term: 0.10573
Average reward per episode in batch: -0.07


Epoch 701/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 69.60it/s]
Epoch 702/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 74.00it/s]
Epoch 703/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 86.92it/s] 
Epoch 704/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 65.98it/s]
Epoch 705/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 80.51it/s]
Epoch 706/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 82.84it/s] 
Epoch 707/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 69.37it/s]
Epoch 708/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 92.61it/s]
Epoch 709/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 73.63it/s]
Epoch 710/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 76.20it/s]
Epoch 711/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 97.69it/s]
Epoch 712/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 65.97it/s]
Ep

Epoch 800/2000, Loss: -0.0008, Ratio: 0.99960, Entropy Term: 0.09123
Average reward per episode in batch: -0.08


Epoch 801/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 74.25it/s]
Epoch 802/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 114.26it/s]
Epoch 803/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 118.92it/s]
Epoch 804/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 74.75it/s]
Epoch 805/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 84.40it/s]
Epoch 806/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 117.18it/s]
Epoch 807/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 81.31it/s]
Epoch 808/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 98.18it/s] 
Epoch 809/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 84.51it/s]
Epoch 810/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 140.05it/s]
Epoch 811/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 80.34it/s]
Epoch 812/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 125.47it/s

Epoch 900/2000, Loss: -0.0017, Ratio: 0.99898, Entropy Term: 0.09561
Average reward per episode in batch: -0.08


Epoch 901/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 173.05it/s]
Epoch 902/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 185.50it/s]
Epoch 903/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 113.28it/s]
Epoch 904/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 164.09it/s]
Epoch 905/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 103.54it/s]
Epoch 906/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 180.21it/s]
Epoch 907/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 133.49it/s]
Epoch 908/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 116.70it/s]
Epoch 909/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 168.30it/s]
Epoch 910/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 182.47it/s]
Epoch 911/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 112.29it/s]
Epoch 912/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 101.

Epoch 1000/2000, Loss: -0.0018, Ratio: 0.99893, Entropy Term: 0.10392
Average reward per episode in batch: -0.03


Epoch 1001/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 195.93it/s]
Epoch 1002/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 191.07it/s]
Epoch 1003/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 191.06it/s]
Epoch 1004/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 168.38it/s]
Epoch 1005/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 121.87it/s]
Epoch 1006/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 119.60it/s]
Epoch 1007/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 183.36it/s]
Epoch 1008/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 216.45it/s]
Epoch 1009/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 100.01it/s]
Epoch 1010/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 197.02it/s]
Epoch 1011/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 110.05it/s]
Epoch 1012/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00

Epoch 1100/2000, Loss: -0.0010, Ratio: 0.99917, Entropy Term: 0.08607
Average reward per episode in batch: -0.04


Epoch 1101/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 183.99it/s]
Epoch 1102/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 169.37it/s]
Epoch 1103/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 182.19it/s]
Epoch 1104/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 96.26it/s]
Epoch 1105/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 156.11it/s]
Epoch 1106/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 169.74it/s]
Epoch 1107/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 160.98it/s]
Epoch 1108/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 172.96it/s]
Epoch 1109/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 98.46it/s]
Epoch 1110/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 139.05it/s]
Epoch 1111/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 171.99it/s]
Epoch 1112/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<0

Epoch 1200/2000, Loss: -0.0015, Ratio: 0.99858, Entropy Term: 0.09571
Average reward per episode in batch: -0.07


Epoch 1201/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 95.55it/s]
Epoch 1202/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 150.57it/s]
Epoch 1203/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 95.85it/s]
Epoch 1204/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 103.96it/s]
Epoch 1205/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 94.79it/s]
Epoch 1206/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 118.00it/s]
Epoch 1207/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 148.84it/s]
Epoch 1208/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 99.06it/s]
Epoch 1209/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 72.30it/s]
Epoch 1210/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 70.60it/s]
Epoch 1211/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 107.47it/s]
Epoch 1212/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00

Epoch 1300/2000, Loss: -0.0007, Ratio: 0.99944, Entropy Term: 0.07843
Average reward per episode in batch: -0.08


Epoch 1301/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 86.77it/s]
Epoch 1302/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 113.11it/s]
Epoch 1303/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 151.11it/s]
Epoch 1304/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 96.60it/s]
Epoch 1305/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 125.90it/s]
Epoch 1306/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 161.31it/s]
Epoch 1307/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 116.29it/s]
Epoch 1308/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 100.50it/s]
Epoch 1309/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 92.44it/s]
Epoch 1310/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 97.64it/s]
Epoch 1311/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 152.26it/s]
Epoch 1312/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:

Epoch 1400/2000, Loss: -0.0009, Ratio: 1.00002, Entropy Term: 0.08428
Average reward per episode in batch: -0.06


Epoch 1401/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 90.74it/s]
Epoch 1402/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 160.41it/s]
Epoch 1403/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 85.41it/s]
Epoch 1404/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 102.49it/s]
Epoch 1405/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 122.82it/s]
Epoch 1406/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 108.66it/s]
Epoch 1407/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 90.57it/s]
Epoch 1408/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 150.48it/s]
Epoch 1409/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 86.02it/s]
Epoch 1410/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 125.47it/s]
Epoch 1411/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 125.54it/s]
Epoch 1412/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:

Epoch 1500/2000, Loss: -0.0009, Ratio: 0.99961, Entropy Term: 0.08670
Average reward per episode in batch: -0.07


Epoch 1501/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 142.69it/s]
Epoch 1502/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 84.21it/s]
Epoch 1503/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 84.07it/s]
Epoch 1504/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 84.22it/s]
Epoch 1505/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 124.26it/s]
Epoch 1506/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 88.27it/s] 
Epoch 1507/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 100.00it/s]
Epoch 1508/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:00<00:00, 159.99it/s]
Epoch 1509/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 84.77it/s]
Epoch 1510/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 103.61it/s]
Epoch 1511/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 88.89it/s]
Epoch 1512/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:0

Epoch 1600/2000, Loss: -0.0009, Ratio: 0.99861, Entropy Term: 0.08717
Average reward per episode in batch: -0.08


Epoch 1601/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 93.03it/s] 
Epoch 1602/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 74.52it/s]
Epoch 1603/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 95.92it/s]
Epoch 1604/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 89.81it/s]
Epoch 1605/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 81.03it/s]
Epoch 1606/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 104.08it/s]
Epoch 1607/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 103.49it/s]
Epoch 1608/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 116.02it/s]
Epoch 1609/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 82.49it/s]
Epoch 1610/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 119.11it/s]
Epoch 1611/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 118.57it/s]
Epoch 1612/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:0

Epoch 1700/2000, Loss: -0.0012, Ratio: 0.99971, Entropy Term: 0.08428
Average reward per episode in batch: -0.06


Epoch 1701/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 94.45it/s] 
Epoch 1702/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 108.92it/s]
Epoch 1703/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 92.81it/s] 
Epoch 1704/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 101.43it/s]
Epoch 1705/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 84.96it/s]
Epoch 1706/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 113.47it/s]
Epoch 1707/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 110.38it/s]
Epoch 1708/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 95.52it/s] 
Epoch 1709/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 103.75it/s]
Epoch 1710/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 98.51it/s] 
Epoch 1711/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 95.23it/s]
Epoch 1712/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<0

Epoch 1800/2000, Loss: -0.0005, Ratio: 0.99909, Entropy Term: 0.06964
Average reward per episode in batch: -0.02


Epoch 1801/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 85.45it/s]
Epoch 1802/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 77.57it/s]
Epoch 1803/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 74.40it/s]
Epoch 1804/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 116.10it/s]
Epoch 1805/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 96.24it/s] 
Epoch 1806/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 90.17it/s] 
Epoch 1807/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 80.17it/s]
Epoch 1808/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 94.14it/s] 
Epoch 1809/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 97.44it/s] 
Epoch 1810/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 82.58it/s]
Epoch 1811/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 114.53it/s]
Epoch 1812/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:0

Epoch 1900/2000, Loss: -0.0012, Ratio: 0.99973, Entropy Term: 0.08203
Average reward per episode in batch: -0.07


Epoch 1901/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 113.12it/s]
Epoch 1902/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 95.75it/s] 
Epoch 1903/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 79.98it/s]
Epoch 1904/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 114.56it/s]
Epoch 1905/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 95.51it/s] 
Epoch 1906/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 81.19it/s]
Epoch 1907/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 84.75it/s]
Epoch 1908/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 105.78it/s]
Epoch 1909/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 96.96it/s] 
Epoch 1910/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 71.68it/s]
Epoch 1911/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:00, 78.38it/s]
Epoch 1912/2000 (Inner K-Epochs): 100%|██████████| 128/128 [00:01<00:0

Epoch 2000/2000, Loss: -0.0015, Ratio: 0.99895, Entropy Term: 0.09015
Average reward per episode in batch: -0.08
Training complete.

Testing the trained policy:
Average reward over 1000 test episodes: -0.1140


took 32 minutes to run using the CPU

Parameters: 


epochs=2000,
        learning_rate=0.0003,
        batch_size=2048, # Significantly larger batch size recommended for stability
        k_epochs=128,
        epsilon=0.2,
        beta_kl=0.01,
        entropy_coeff=0.001,
        log_iterations=100,
        gamma=0.99

In [8]:
test_env = gym.make("Blackjack-v1", render_mode="rgb", sab=True)
total_test_rewards = 0

  logger.warn(


In [9]:
num_test_episodes = 10

In [10]:
for episode in range(num_test_episodes):
    print(f"Resetting env for episode: {episode}")
    obs, _ = test_env.reset()
    done = False
    truncated = False
    episode_reward = 0
    stored_obs=[]
    while not done and not truncated:
        obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            logits = trained_policy(obs_tensor)
            dist = torch.distributions.Categorical(logits=logits)
            action = dist.sample()
            print(f"obs_tensor: {obs_tensor} || Action taken: {action}")
        obs, reward, done, truncated, _ = test_env.step(action.item())
        episode_reward += reward
        if (truncated): print("truncated")
    print(f"Reward: {episode_reward} || Final Observation: {obs}")

Resetting env for episode: 0
obs_tensor: tensor([[16., 10.,  0.]]) || Action taken: tensor([0])
Reward: -1.0 || Final Observation: (16, 10, 0)
Resetting env for episode: 1
obs_tensor: tensor([[10.,  2.,  0.]]) || Action taken: tensor([1])
obs_tensor: tensor([[13.,  2.,  0.]]) || Action taken: tensor([1])
Reward: -1.0 || Final Observation: (23, 2, 0)
Resetting env for episode: 2
obs_tensor: tensor([[ 6., 10.,  0.]]) || Action taken: tensor([1])
obs_tensor: tensor([[ 9., 10.,  0.]]) || Action taken: tensor([1])
obs_tensor: tensor([[11., 10.,  0.]]) || Action taken: tensor([1])
obs_tensor: tensor([[13., 10.,  0.]]) || Action taken: tensor([1])
Reward: -1.0 || Final Observation: (22, 10, 0)
Resetting env for episode: 3
obs_tensor: tensor([[21.,  7.,  1.]]) || Action taken: tensor([0])
Reward: 1.0 || Final Observation: (21, 7, 1)
Resetting env for episode: 4
obs_tensor: tensor([[19., 10.,  0.]]) || Action taken: tensor([0])
Reward: 1.0 || Final Observation: (19, 10, 0)
Resetting env for epi

In [11]:
env.close()

Currently the final state which reveals what the dealer ended up with in the end is not shown. By trying to access the dealer's final hand or by adding custom logging within the environment, you'll gain the critical information needed to definitively understand the why behind each reward.

In [13]:
SAVE_LOCATION = "app/model_weights/blackjack_policy_model.pth"

torch.save(trained_policy.state_dict(), f=SAVE_LOCATION)