<a href="https://colab.research.google.com/github/MalyalaAnand/REML/blob/main/REML_LAB_02.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install gymnasium



In [2]:
import gymnasium as gym
import numpy as np
from collections import defaultdict

In [3]:
env = gym.make('Blackjack-v1', sab=True)

# Function to create a random policy
def create_random_policy(env):
    return lambda state: np.random.choice(env.action_space.n)

# Function to create a greedy policy based on Q-values
def create_greedy_policy(Q):
    def policy_fn(state):
        return np.argmax(Q[state])
    return policy_fn

In [4]:
def mc_policy_evaluation(policy, env, num_episodes, gamma=1.0):
    returns_sum = defaultdict(float)
    returns_count = defaultdict(float)
    V = defaultdict(float)

    for _ in range(num_episodes):
        episode = []
        state = env.reset()[0]
        done = False

        while not done:
            action = policy(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated # Blackjack env can be terminated or truncated
            episode.append((state, action, reward))
            state = next_state

        visited_states = set()
        G = 0
        for state, action, reward in reversed(episode):
            G = gamma * G + reward
            if state not in visited_states:
                returns_sum[state] += G
                returns_count[state] += 1
                V[state] = returns_sum[state] / returns_count[state]
                visited_states.add(state)
    return V

In [5]:
def mc_control_epsilon_greedy(env, num_episodes, gamma=1.0, epsilon=0.1):
    Q = defaultdict(lambda: np.zeros(env.action_space.n))

    def policy_fn(state):
        if np.random.rand() < epsilon:
            return np.random.choice(env.action_space.n)
        else:
            return np.argmax(Q[state])

    for _ in range(num_episodes):
        episode = []
        state = env.reset()[0]
        done = False

        while not done:
            action = policy_fn(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated # Blackjack env can be terminated or truncated
            episode.append((state, action, reward))
            state = next_state

        visited_state_action_pairs = set()
        G = 0
        for state, action, reward in reversed(episode):
            G = gamma * G + reward
            if (state, action) not in visited_state_action_pairs:
                old_q = Q[state][action]
                Q[state][action] = old_q + (G - old_q) / (1 + sum(1 for s, a, r in episode if s == state and a == action))
                visited_state_action_pairs.add((state, action))

    return Q, create_greedy_policy(Q)

In [6]:
if __name__ == "__main__":
    random_policy = create_random_policy(env)

    print("Evaluating random policy...")
    V = mc_policy_evaluation(random_policy, env, num_episodes=50000)
    print("Value function for random policy (sample):")
    for i, (state, value) in enumerate(list(V.items())[:10]):
        print(f"State: {state}, Value: {value:.2f}")

    print("\nTraining control policy with epsilon-greedy strategy...")
    Q, greedy_policy = mc_control_epsilon_greedy(env, num_episodes=500000)
    print("Learned Q-values (sample):")
    for i, (state, actions) in enumerate(list(Q.items())[:10]):
        print(f"State: {state}, Actions: {actions}")

Evaluating random policy...
Value function for random policy (sample):
State: (21, 9, 0), Value: -0.13
State: (16, 9, 0), Value: -0.56
State: (19, 10, 0), Value: -0.41
State: (17, 10, 0), Value: -0.64
State: (12, 1, 0), Value: -0.78
State: (20, 10, 0), Value: -0.23
State: (18, 10, 0), Value: -0.53
State: (8, 10, 0), Value: -0.57
State: (19, 3, 0), Value: -0.24
State: (17, 8, 0), Value: -0.53

Training control policy with epsilon-greedy strategy...
Learned Q-values (sample):
State: (16, 7, 0), Actions: [-0.99804659  0.1836138 ]
State: (14, 9, 0), Actions: [-0.99996846 -0.48432539]
State: (11, 2, 0), Actions: [-0.97946094  0.73968129]
State: (18, 1, 0), Actions: [-0.71092974 -0.99902343]
State: (5, 10, 0), Actions: [-0.99993799 -0.99995959]
State: (15, 6, 0), Actions: [ 0.71623807 -0.99980926]
State: (12, 3, 0), Actions: [ 0.00178172 -0.35545817]
State: (12, 10, 0), Actions: [-0.99999615 -0.99754306]
State: (13, 10, 1), Actions: [-0.99999996 -0.65130885]
State: (19, 4, 0), Actions: [-0.0