In [6]:
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
# import torch.utils.data.dataloader
import config

#Policy Network
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, HIDDEN_DIM = config.HIDDEN_DIM):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, HIDDEN_DIM),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
            nn.ReLU()
        )
        self.actor = nn.Linear(HIDDEN_DIM, action_dim)
        self.critic = nn.Linear(HIDDEN_DIM, 1)
        
    def forward(self, x):
        x = self.model(x)
        logits = self.actor(x)
        value = self.critic(x)
        return logits, value
    
    def act(self, x):
        state = torch.FloatTensor(x)
        logits, value = self.forward(state)
        probs = F.softmax(logits, dim=-1) 
        dist = Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item(), log_prob.item(), value.item()
    

class PPOTrainer():
    def __init__(self, env):
        self.env = env
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        
        self.policy = ActorCritic(self.state_dim, self.action_dim, config.HIDDEN_DIM)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=config.LR)
        self.ep_rewards = []
        self.best_avg = -np.inf
    
    def compute_advantage(self, rewards, next_values, values, dones):
        deltas = rewards + config.GAMMA * next_values * (1 - dones) - values
        advantages = np.zeros_like(deltas)
        last_advantage = 0
        for t in reversed(range(len(deltas))):
            advantages[t] = deltas[t] + config.GAE_LAMBDA * config.GAMMA * (1-dones[t]) * last_advantage
            last_advantage = advantages[t]
        return (advantages - advantages.mean())/(advantages.std() + 1e-8)
    
    # log_probs, returns, advantages, states, actions
    def update(self, old_log_probs, returns, advantages, states, old_actions):
        old_log_probs = torch.FloatTensor(old_log_probs)
        old_actions = torch.LongTensor(old_actions)
        returns = torch.FloatTensor(returns)
        advantages = torch.FloatTensor(advantages)
        states = torch.FloatTensor(np.array(states))
        
        dataset = torch.utils.data.TensorDataset(old_log_probs, old_actions,returns, advantages, states)
        loader = torch.utils.data.DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True)
        
        for _ in range(config.NUM_EPOCHS):
            for batch in loader:
                old_lp, a, ret, adv, s = batch
                
                logits, values = self.policy(s)
                probs = F.softmax(logits, dim=-1)
                dist = Categorical(probs)
                new_log_prob = dist.log_prob(a)
                entropy = dist.entropy().mean()
                
                ratio = (new_log_prob - old_lp).exp()
                surr1 = ratio * adv
                surr2 = torch.clamp(ratio, 1 - config.CLIP_EPS, 1 + config.CLIP_EPS) * adv
                policy_loss = -torch.min(surr1, surr2).mean()
                
                value_loss = F.mse_loss(values.squeeze(), ret)
                
                loss = policy_loss + value_loss * config.VALUE_COEF - entropy * config.ENTROPY_COEF
                
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
                self.optimizer.step()
        
    
    def train(self):
        state, _ = self.env.reset()
        episode_reward = 0
        
        for episode in range(config.MAX_EPISODES):
            states = [] 
            rewards = [] 
            actions = [] 
            dones = []
            values = []
            log_probs = []
            
            for step in range(config.NUM_STEPS):
                action, log_prob, value = self.policy.act(state)
                next_state, reward, done, truncated, _ = self.env.step(action)
                done = done or truncated
                
                states.append(state)
                rewards.append(reward)
                actions.append(action)
                dones.append(done)
                values.append(value)
                log_probs.append(log_prob)
                
                state = next_state
                episode_reward += reward
                
                if done:
                    state, _ = self.env.reset()
                    self.ep_rewards.append(episode_reward)
                    episode_reward = 0 
                    
            next_states = torch.FloatTensor(np.array([s for s in states]))
            with torch.no_grad():
                _, next_values = self.policy(next_states)
            next_values = next_values.cpu().numpy().flatten()
            next_values = np.append(next_values[1:], 0)
            next_values[dones] = 0
            
            #Cpnverting to numpy arrays
            rewards = np.array(rewards)
            values = np.array(values)
            dones = np.array(dones).astype(np.float32)
            
            advantages = self.compute_advantage(rewards, next_values, values, dones)
            returns = advantages + values
            
            self.update(log_probs, returns, advantages, states, actions)
            
            avg_reward = np.mean(self.ep_rewards[-100:]) if len(self.ep_rewards) >= 100 else np.mean(self.ep_rewards)
            if avg_reward > self.best_avg:
                self.best_avg = avg_reward
            print(f"Episode {episode+1}, Reward: {self.ep_rewards[-1]:.4f}, Avg Reward: {avg_reward:.2f}")
            
        torch.save(self.policy.state_dict(), 'ppo_model.pt')
            

if __name__ == "__main__":
    env = gym.make("LunarLander-v3", continuous=False)
    trainer = PPOTrainer(env)
    trainer.train()
    env.close()

Episode 144, Reward: 239.7318, Avg Reward: 123.68
Episode 145, Reward: 159.3239, Avg Reward: 124.49
Episode 146, Reward: 110.4353, Avg Reward: 128.46
Episode 147, Reward: 129.7034, Avg Reward: 129.16
Episode 148, Reward: 217.7232, Avg Reward: 129.26
Episode 149, Reward: 209.7757, Avg Reward: 134.41
Episode 150, Reward: 130.7336, Avg Reward: 136.57
Episode 151, Reward: 228.5597, Avg Reward: 141.52
Episode 152, Reward: 118.6351, Avg Reward: 141.56
Episode 153, Reward: 37.3959, Avg Reward: 141.28
Episode 154, Reward: 158.4690, Avg Reward: 143.60
Episode 155, Reward: 156.8718, Avg Reward: 144.88
Episode 156, Reward: 25.8118, Avg Reward: 151.30
Episode 157, Reward: 253.0285, Avg Reward: 150.69
Episode 158, Reward: 131.3257, Avg Reward: 151.98
Episode 159, Reward: 225.9798, Avg Reward: 152.36
Episode 160, Reward: 239.4316, Avg Reward: 147.90
Episode 161, Reward: 159.2472, Avg Reward: 146.76
Episode 162, Reward: 4.4268, Avg Reward: 142.87
Episode 163, Reward: 135.3428, Avg Reward: 142.60
Epis

In [7]:
import os
def load_model(policy, filename='ppo_model.pt'):
    """
    Load a trained policy network
    """
    if os.path.exists(filename):
        policy.load_state_dict(torch.load(filename))
        print(f"Model loaded from {filename}")
        return True
    else:
        print(f"No model found at {filename}")
        return False

In [11]:
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
import time
import os

class MockConfig:
    NUM_STEPS = 2048          # Number of steps per environment per update
    BATCH_SIZE = 64           # Mini-batch size for updates
    NUM_EPOCHS = 10           # Number of optimization epochs per update
    GAMMA = 0.99              # Discount factor
    GAE_LAMBDA = 0.95         # GAE parameter
    CLIP_EPS = 0.2            # PPO clip parameter
    LR = 3e-4                 # Learning rate
    HIDDEN_DIM = 256          # Network hidden layer size
    ENTROPY_COEF = 0.01       # Entropy coefficient
    VALUE_COEF = 0.5          # Value loss coefficient
    MAX_EPISODES = 300        # Maximum training episodes

config = MockConfig()

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, HIDDEN_DIM = config.HIDDEN_DIM):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, HIDDEN_DIM),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
            nn.ReLU()
        )
        self.actor = nn.Linear(HIDDEN_DIM, action_dim)
        self.critic = nn.Linear(HIDDEN_DIM, 1) # Critic not used in inference, but part of the model

    def forward(self, x):
        x = self.model(x)
        logits = self.actor(x)
        value = self.critic(x)
        return logits, value

    
    # To use deterministic action (argmax) or sampling based on preference
    def act_inference(self, x, deterministic=True):
        state = torch.FloatTensor(x).unsqueeze(0) # To add batch dimension
        with torch.no_grad(): 
            logits, _ = self.forward(state)
            probs = F.softmax(logits, dim=-1)
            if deterministic:
                action = torch.argmax(probs, dim=-1).item()
            else:
                dist = Categorical(probs)
                action = dist.sample().item()
        return action


def load_model(policy, filename='ppo_model.pt'):
    if os.path.exists(filename):
        try:
            state_dict = torch.load(filename)
            # Load it into the policy network
            policy.load_state_dict(state_dict)
            # Setting the model to evaluation mode
            policy.eval()
            print(f"Model state_dict loaded successfully from {filename}")
            return True
        except Exception as e:
            print(f"Error loading model from {filename}: {e}")
            return False
    else:
        print(f"Error: No model found at {filename}")
        return False

if __name__ == "__main__":
    env_name = "LunarLander-v3" 
    env = gym.make(env_name, continuous=False, render_mode="human")
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    # Instantiating the policy network
    policy = ActorCritic(state_dim, action_dim, config.HIDDEN_DIM)

    # Load the trained weights
    model_loaded = load_model(policy, 'ppo_model.pt')

    if model_loaded:
        num_episodes = 5 # Number of episodes to run for testing
        for episode in range(num_episodes):
            state, _ = env.reset()
            done = False
            truncated = False
            total_reward = 0
            step_count = 0

            while not done and not truncated:
                # Choose deterministic=True for the 'best' action,
                # or deterministic=False to sample like during training.
                action = policy.act_inference(state, deterministic=True)
                next_state, reward, done, truncated, info = env.step(action)
                state = next_state
                total_reward += reward
                step_count += 1
            print(f"Episode {episode + 1}: Total Reward = {total_reward:.2f}, Steps = {step_count}")

    else:
        print("Could not load the model. Exiting inference.")

    env.close()
    print("Inference finished.")

Model state_dict loaded successfully from ppo_model.pt
Episode 1: Total Reward = 272.81, Steps = 255
Episode 2: Total Reward = 296.44, Steps = 301
Episode 3: Total Reward = 277.91, Steps = 239
Episode 4: Total Reward = 268.14, Steps = 270
Episode 5: Total Reward = 266.59, Steps = 243
Inference finished.


In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

#Parameters
CLIP_EPS = 0.2              #PPO clip parameter
NUM_EPOCHS = 10             #number of times the ppo will be updated
HIDDEN_DIM = 256            #N/w Hidden layer sizeeee
MAX_EPISODES = 500          #Max Episode to run
MAX_TIMESTEPS = 1500        #Max No of timesteps in 1 Ep
LR = 3e-4                   #Learning rate for the N/W
GAE_LAMBDA = 0.95           #GAE Parameter
GAMMA = 0.99                #Discount Factoor
BATCH_SIZE = 64             #Batch size for dataloadear
ENTROPY_COEF = 0.01         # Entropy coefficient
VALUE_COEF = 0.5            # Value loss coefficient

#policy-n/w
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim): # state_dim, action_dim -> 8,4 (For lunarlander)
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, HIDDEN_DIM),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
            nn.ReLU()
        )
        self.actor = nn.Linear(HIDDEN_DIM, action_dim)
        self.critic = nn.Linear(HIDDEN_DIM, 1)
        
    def forward(self, x):
        x = self.model(x)
        logits = self.actor(x)
        value = self.critic(x)
        return logits, value
    
    def act(self, x):
        state = torch.FloatTensor(x)
        logits, value = self.forward(state)
        probs = F.softmax(logits, dim=-1) 
        dist = Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item(), log_prob.item(), value.item()

#Generalized advantage estimation
def compute_gae(rewards, next_values, dones, values):
    #δₜ = rₜ + γ * V(s_{t+1}) - V(sₜ)
    deltas = rewards + GAMMA * next_values * (1- dones) - values
    advantages = np.zeros_like(deltas)
    lastadv = 0
    for t in reversed(range(len(deltas))):
        #A _t = δₜ + γ * λ * (1-dones) * A_t+1 
        advantages[t] = deltas[t] + GAMMA * GAE_LAMBDA * (1 - dones[t]) * lastadv
        lastadv = advantages[t]
    return (advantages - advantages.mean() ) / (advantages.std() + 1e-8)
    

def update(states, returns, advantages, old_log_probs, actions, policy, optimizer):
    states = torch.FloatTensor(np.array(states))
    actions = torch.LongTensor(actions)
    returns = torch.FloatTensor(returns)
    advantages = torch.FloatTensor(advantages)
    old_log_probs = torch.FloatTensor(old_log_probs)
    
    dataset = torch.utils.data.TensorDataset(states, actions, returns, advantages, old_log_probs)
    loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    #Update
    for _ in range(NUM_EPOCHS):
        for btc in loader:
            s, a, ret, adv, old_pb = btc
            
            logits, values = policy(s) #latest new val, came from acting in env with old states
            probs = F.softmax(logits, dim=-1)
            dist = Categorical(probs)
            new_log_probs = dist.log_prob(a) 
            entropy = dist.entropy().mean()
            
            ratio = (new_log_probs - old_pb).exp()
            surr1 = ratio*adv
            surr2 = torch.clamp(ratio, 1 - CLIP_EPS, 1 + CLIP_EPS)*adv
            policy_loss = -torch.min(surr1, surr2).mean()
            value_loss = F.mse_loss(values.squeeze(), ret)
            loss = policy_loss + value_loss*VALUE_COEF - entropy*ENTROPY_COEF 
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
            optimizer.step()
            
            
            
best_avg = -np.inf
ep_rewards = []
def train():
    global best_avg
    env = gym.make("LunarLander-v3", continuous=False)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    policy = ActorCritic(state_dim, action_dim)
    optimizer = optim.Adam(policy.parameters(), lr=LR)
    state, _ = env.reset()
    episode_reward = 0
    
    for episode in range(MAX_EPISODES):
        actions = []
        dones = []
        rewards = []
        log_probs = []
        states = []
        values = []
        
        for step in range(MAX_TIMESTEPS):
            state = torch.FloatTensor(state)
            action, log_prob, value = policy.act(state)
            next_state, reward, done , truncated, _ = env.step(action)
            done = done or truncated
            
            actions.append(action)
            log_probs.append(log_prob)
            values.append(value)
            dones.append(done)
            rewards.append(reward)
            states.append(state)
            
            state = next_state
            episode_reward += reward
            
            if done:
                state, _ = env.reset()
                ep_rewards.append(episode_reward)
                episode_reward = 0
        
        
        next_states = torch.FloatTensor(np.array([s for s in states]))
        with torch.no_grad():
            _, next_values = policy(next_states)
        next_values = np.append(next_values[1:], 0)
        next_values[dones] = 0
        
        #Converting to numoy arrays
        rewards = np.array(rewards)
        values = np.array(values)
        dones = np.array(dones).astype(np.float32)
        
        advantages = compute_gae(rewards, next_values, dones, values)
        returns = advantages + values
        
        #To update policy
        update(states, returns, advantages, log_probs, actions, policy, optimizer)
        
        avg_reward = np.mean(ep_rewards[-100:]) if len(ep_rewards) >= 100 else np.mean(ep_rewards)
        if avg_reward > best_avg:
            best_avg = avg_reward
        print(f"Episode {episode+1}, Reward: {ep_rewards[-1]}, Avg Reward: {avg_reward:.2f}")
        
    env.close()
        
train()


In [None]:
# Claude
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Actor-Critic Networks
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()
        
        # Actor network (policy)
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )
        
        # Critic network (value function)
        self.critic = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )
        
    def act(self, state):
        action_probs = self.actor(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        action_logprob = dist.log_prob(action)
        return action.detach(), action_logprob.detach()
    
    def evaluate(self, state, action):
        action_probs = self.actor(state)
        dist = Categorical(action_probs)
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)
        
        return action_logprobs, state_values, dist_entropy

# PPO Agent
class PPO:
    def __init__(
        self,
        state_dim,              # 8
        action_dim,             # 1
        lr_actor=0.0003,
        lr_critic=0.001,
        gamma=0.99,
        K_epochs=4,
        eps_clip=0.2,
        value_coef=0.5,
        entropy_coef=0.01
    ):
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        
        self.policy = ActorCritic(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam([
            {'params': self.policy.actor.parameters(), 'lr': lr_actor},
            {'params': self.policy.critic.parameters(), 'lr': lr_critic}
        ])
        
        self.policy_old = ActorCritic(state_dim, action_dim).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())       #copying the params of policy instance
        self.MseLoss = nn.MSELoss()
    
    def select_action(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state).to(device)
            action, action_logprob = self.policy_old.act(state)
        
        return action.item(), action_logprob.item()
    
    def update(self, memory):
        # Monte Carlo estimate of returns
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)
        
        # Normalizing the rewards
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
        
        # Convert lists to tensors
        old_states = torch.stack(memory.states).detach().to(device)
        old_actions = torch.tensor(memory.actions, dtype=torch.int64).detach().to(device)
        old_logprobs = torch.tensor(memory.logprobs, dtype=torch.float32).detach().to(device)
        
        # Optimize policy for K epochs
        for _ in range(self.K_epochs):
            # Evaluating old actions and values
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
            
            # Finding the ratio (π_θ / π_θ_old)
            ratios = torch.exp(logprobs - old_logprobs.detach())
            
            # Finding Surrogate Loss
            advantages = rewards - state_values.detach()
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
            
            # Final loss of clipped objective PPO
            value_loss = self.MseLoss(state_values, rewards)
            policy_loss = -torch.min(surr1, surr2).mean()
            entropy_loss = -dist_entropy.mean()
            
            loss = policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss
            
            #Take gradient step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        
        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

# Memory for storing transitions
class Memory:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []
    
    def clear_memory(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]

# Training loop
def train():
    env_name = "LunarLander-v3"
    env = gym.make(env_name, continuous = False)
    
    state_dim = env.observation_space.shape[0] # 8
    action_dim = env.action_space.n # 1
    
    max_episodes = 5000        # max training episodes
    max_timesteps = 1500       # max timesteps in one episode
    update_timestep = 2000     # update policy every n timesteps
    save_interval = 500        # to save model every n episodes
    
    # PPO hyperparameters
    K_epochs = 8               # update policy for K epochs
    eps_clip = 0.2             # clip parameter for PPO
    gamma = 0.98               # discount factor
    lr_actor = 0.0003          # learning rate for actor
    lr_critic = 0.001          # learning rate for critic 
    
    # Initialize PPO agent
    ppo = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip)
    memory = Memory()
    
    # Logging variables
    running_reward = 0
    avg_length = 0
    timestep = 0
    
    # Training loop
    for i_episode in range(1, max_episodes+1):
        state, _ = env.reset()
        
        for t in range(max_timesteps):
            timestep += 1
            
            # Select action from policy
            action, logprob = ppo.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            # Save in memory
            memory.states.append(torch.FloatTensor(state).to(device))
            memory.actions.append(action)
            memory.logprobs.append(logprob)
            memory.rewards.append(reward)
            memory.is_terminals.append(done)
            
            state = next_state
            
            # Update if its time
            if timestep % update_timestep == 0:
                ppo.update(memory)
                memory.clear_memory()
                timestep = 0
            
            running_reward += reward
            
            if done:
                break
                
        avg_length += t
        
        # Print average reward and length every 10 episodes
        if i_episode % 10 == 0:
            avg_length = avg_length/10
            running_reward = running_reward/10
            
            print(f'Episode {i_episode} \t Avg length: {avg_length:.2f} \t Reward: {running_reward:.2f}')
            running_reward = 0
            avg_length = 0
            
        # Save model
        if i_episode % save_interval == 0:
            torch.save(ppo.policy.state_dict(), f'./PPO_LunarLander_{i_episode}.pth')
            
    env.close()

if __name__ == '__main__':
    train()

In [None]:
#Getting bad results w this implementation, Don't know why -> at least for nowwwww

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym
from torch.distributions import Categorical

GAMMA = 0.99
GAE_LAMBDA = 0.95  
EPSILON = 0.2  # clipping range for PPO -> usually 0.1 or 0.2 rhta
ACTOR_LR = 3e-4
CRITIC_LR = 3e-4
BATCH_SIZE = 64
EPOCHS = 4
NUM_EPISODES = 1000
env = gym.make("LunarLander-v3", continuous=False)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

#Policy network as usual
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 128),  
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, state):
        return self.model(state)

#value function
class Critic(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
    
    def forward(self, state):
        return self.model(state)

actor = Actor(state_dim, action_dim)
critic = Critic(state_dim)
optimizer_actor = optim.Adam(actor.parameters(), lr=ACTOR_LR)
optimizer_critic = optim.Adam(critic.parameters(), lr=CRITIC_LR)


#Computing GAE
def compute_gae(rewards, values, gamma=GAMMA, lam=GAE_LAMBDA):
    """
    Iska explanation notes.ipynb me snippets dala hai
    and
    Here lambda is a hyperparameter, It can be multiplied with : 
    gamma and last_advantage for BIAS & VARIANCE Trade off
    """
    advantages = []
    last_advantage = 0
    for t in reversed(range(len(rewards))):
        delta = rewards[t] + gamma * values[t+1] - values[t]        #also known as one step TD Error : δt^V​=rt​+γV(st+1​)−V(st​)
        last_advantage = delta + gamma * lam * last_advantage       #A(t)GAE(γ,λ)​ = δt^V ​+ γ * λ * A(t+1)GAE(γ,λ)
        advantages.insert(0, last_advantage)
    advantages = torch.tensor(advantages, dtype=torch.float32)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    
    # Calculate : advantages = returns - values
    # Calculate : returns = advantages + values
    returns = advantages + torch.tensor(values[:-1], dtype=torch.float32)
    return advantages, returns

# Training loop
episode_rewards = []

for episode in range(NUM_EPISODES):
    if(episode > 900):
        env = gym.make("LunarLander-v3", continuous=False, render_mode='human')
    state, _ = env.reset()
    done = False
    episode_reward = 0
    
    log_probs = []
    values = []
    rewards = []
    states = []
    actions = []
    
    # Collect trajectory
    while not done:
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        action_probs = actor(state_tensor)
        value = critic(state_tensor).squeeze()
        
        dist = Categorical(action_probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        
        next_state, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated
        
        states.append(state_tensor)
        actions.append(action)
        log_probs.append(log_prob)
        values.append(value.item()) 
        rewards.append(reward)
        
        state = next_state
        episode_reward += reward
    
    # appending final state value for GAE calculation
    with torch.no_grad():
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        final_value = critic(state_tensor).squeeze()
    values.append(final_value.item())
    
    # Compute advantages and returns
    advantages, returns = compute_gae(rewards, values)
    
    # Convert lists to tensors for batch processing
    # Adding detach to ensure no gradient flow
    old_states = torch.cat(states).detach()         #torch.Size([104, 8])
    old_actions = torch.stack(actions).detach()     #torch.Size([104, 1]) = [[n1], [n2], [n3],... [n_steps]]
    old_log_probs = torch.stack(log_probs).detach() #torch.Size([104, 1]) = [[n1], [n2], [n3],... [n_steps]]
    # print(old_actions)
   
    # PPO update
    for _ in range(EPOCHS):
        # Process in mini-batches
        indices = np.arange(len(old_states))
        np.random.shuffle(indices)
        
        for start_idx in range(0, len(old_states), BATCH_SIZE):
            end_idx = min(start_idx + BATCH_SIZE, len(old_states))
            batch_indices = indices[start_idx:end_idx]
            
            batch_states = old_states[batch_indices]
            batch_actions = old_actions[batch_indices]
            batch_log_probs = old_log_probs[batch_indices]
            batch_advantages = advantages[batch_indices]
            batch_returns = returns[batch_indices]
            
            # Actor update
            new_action_probs = actor(batch_states)
            new_dist = Categorical(new_action_probs)
            new_log_probs = new_dist.log_prob(batch_actions)
            
            # Calculate policy ratio and surrogate loss
            ratio = torch.exp(new_log_probs - batch_log_probs)
            surr1 = ratio * batch_advantages
            surr2 = torch.clamp(ratio, 1.0 - EPSILON, 1.0 + EPSILON) * batch_advantages
            
            # Negative because we're minimizing, but we want to maximize the objective
            entropy = new_dist.entropy().mean()
            actor_loss = -torch.min(surr1, surr2).mean() - 0.01 * entropy
            
            # Update actor
            optimizer_actor.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(actor.parameters(), 1)
            optimizer_actor.step()
            
            # Critic update - completely separate calculation
            critic_values = critic(batch_states).squeeze()
            critic_loss = nn.MSELoss()(critic_values, batch_returns)
            
            # Update critic
            optimizer_critic.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(critic.parameters(), 1)
            optimizer_critic.step()
            
    episode_rewards.append(episode_reward)
    if (episode + 1) % 10 == 0:
        avg_reward = np.mean(episode_rewards[-10:])
        print(f"Episode {episode+1}/{NUM_EPISODES}, Avg Reward (last 10): {avg_reward:.2f}")
env.close()

# Getting bad results ;-;