In [6]:
import os
import gym
import numpy as np
import tensorflow.compat.v1 as tf
import matplotlib.pyplot as plt

tf.disable_v2_behavior()

class PolicyGradientAgent:
    def __init__(self, lr, gamma, n_actions=4, l1_size=64, l2_size=64, input_dims=8, chkpt_dir='tmp'):
        self.lr = lr
        self.gamma = gamma
        self.n_actions = n_actions
        self.input_dims = input_dims
        self.chkpt_file = os.path.join(chkpt_dir, 'policy.ckpt')

        self.state_memory = []
        self.action_memory = []
        self.reward_memory = []

        self.sess = tf.Session()
        self.build_net()
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver()

    def build_net(self):
        self.input = tf.placeholder(tf.float32, [None, self.input_dims])
        self.label = tf.placeholder(tf.int32, [None])
        self.G = tf.placeholder(tf.float32, [None])

        initializer = tf.keras.initializers.VarianceScaling(scale=1.0)

        dense1 = tf.keras.layers.Dense(64, activation=tf.nn.relu, kernel_initializer=initializer)
        dense2 = tf.keras.layers.Dense(64, activation=tf.nn.relu, kernel_initializer=initializer)
        dense3 = tf.keras.layers.Dense(self.n_actions, activation=None, kernel_initializer=initializer)

        l1 = dense1(self.input)
        l2 = dense2(l1)
        l3 = dense3(l2)

        self.actions = tf.nn.softmax(l3)

        neg_log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=l3, labels=self.label)
        loss = tf.reduce_mean(neg_log_prob * self.G)
        self.train_op = tf.train.AdamOptimizer(self.lr).minimize(loss)

    def choose_action(self, state):
        state = state[np.newaxis, :]
        probs = self.sess.run(self.actions, feed_dict={self.input: state})[0]
        return np.random.choice(range(self.n_actions), p=probs)

    def store_transition(self, state, action, reward):
        self.state_memory.append(state)
        self.action_memory.append(action)
        self.reward_memory.append(reward)

    def learn(self):
        state_mem = np.array(self.state_memory)
        action_mem = np.array(self.action_memory)
        rewards = np.array(self.reward_memory)

        G = np.zeros_like(rewards, dtype=np.float32)
        for t in range(len(rewards)):
            discount = 1
            G_sum = 0
            for k in range(t, len(rewards)):
                G_sum += rewards[k] * discount
                discount *= self.gamma
            G[t] = G_sum

        G -= np.mean(G)
        G /= np.std(G) if np.std(G) > 0 else 1

        self.sess.run(self.train_op, feed_dict={
            self.input: state_mem,
            self.label: action_mem,
            self.G: G
        })

        self.state_memory, self.action_memory, self.reward_memory = [], [], []

    def save_checkpoint(self):
        self.saver.save(self.sess, self.chkpt_file)

    def load_checkpoint(self):
        self.saver.restore(self.sess, self.chkpt_file)

# === TRAINING SCRIPT ===
if __name__ == "__main__":
    env = gym.make('LunarLander-v2')
    agent = PolicyGradientAgent(lr=0.001, gamma=0.99)
    n_episodes = 500
    scores = []

    for ep in range(n_episodes):
        done = False
        state = env.reset()
        score = 0

        while not done:
            if ep % 10 == 0:  # 👀 Render every 10 episodes
                env.render()

            action = agent.choose_action(state)
            next_state, reward, done, _ = env.step(action)
            agent.store_transition(state, action, reward)
            state = next_state
            score += reward

        agent.learn()
        scores.append(score)

        avg_score = np.mean(scores[-50:])
        print(f"Episode {ep + 1}, Score: {score:.2f}, Avg (50): {avg_score:.2f}")

    agent.save_checkpoint()
    env.close()

    # 📊 Plot results
    plt.plot(scores)
    plt.title("Policy Gradient on LunarLander (with Visualization)")
    plt.xlabel("Episode")
    plt.ylabel("Score")
    plt.grid(True)
    plt.show()


DependencyNotInstalled: box2D is not installed, run `pip install gym[box2d]`