In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym
from torch.distributions import Categorical

GAMMA = 0.99
GAE_LAMBDA = 0.95  
EPSILON = 0.2  # clipping range for PPO -> usually 0.1 or 0.2 rhta
ACTOR_LR = 3e-4
CRITIC_LR = 1e-3
BATCH_SIZE = 64
EPOCHS = 10
NUM_EPISODES = 1000
env = gym.make("LunarLander-v3", continuous=False)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

#Policy network as usual
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 128),  
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, state):
        return self.model(state)

#value function
class Critic(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
    
    def forward(self, state):
        return self.model(state)

actor = Actor(state_dim, action_dim)
critic = Critic(state_dim)
optimizer_actor = optim.Adam(actor.parameters(), lr=ACTOR_LR)
optimizer_critic = optim.Adam(critic.parameters(), lr=CRITIC_LR)


#Computing GAE
def compute_gae(rewards, values, gamma=GAMMA, lam=GAE_LAMBDA):
    """
    Iska explanation notes.ipynb me snippets dala hai
    and
    Here lambda is a hyperparameter, It can be multiplied with : 
    gamma and last_advantage for BIAS & VARIANCE Trade off
    """
    advantages = []
    last_advantage = 0
    for t in reversed(range(len(rewards))):
        delta = rewards[t] + gamma * values[t+1] - values[t]        #also known as one step TD Error : δt^V​=rt​+γV(st+1​)−V(st​)
        last_advantage = delta + gamma * lam * last_advantage       #A(t)GAE(γ,λ)​ = δt^V ​+ γ * λ * A(t+1)GAE(γ,λ)
        advantages.insert(0, last_advantage)
    advantages = torch.tensor(advantages, dtype=torch.float32)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    
    # Calculate : advantages = returns - values
    # Calculate : returns = advantages + values
    returns = advantages + torch.tensor(values[:-1], dtype=torch.float32)
    return advantages, returns

# Training loop
episode_rewards = []

for episode in range(NUM_EPISODES):
    state, _ = env.reset()
    done = False
    episode_reward = 0
    
    log_probs = []
    values = []
    rewards = []
    states = []
    actions = []
    
    # Collect trajectory
    while not done:
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        action_probs = actor(state_tensor)
        value = critic(state_tensor).squeeze()
        
        dist = Categorical(action_probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        
        next_state, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated
        
        states.append(state_tensor)
        actions.append(action)
        log_probs.append(log_prob)
        values.append(value.item()) 
        rewards.append(reward)
        
        state = next_state
        episode_reward += reward
    
    # appending final state value for GAE calculation
    with torch.no_grad():
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        final_value = critic(state_tensor).squeeze().item()
    values.append(final_value)
    
    # Compute advantages and returns
    advantages, returns = compute_gae(rewards, values)
    
    # Convert lists to tensors for batch processing
    # Adding detach to ensure no gradient flow
    old_states = torch.cat(states).detach()         #torch.Size([104, 8])
    old_actions = torch.stack(actions).detach()     #torch.Size([104, 1]) -> [[n1], [n2], [n3],... [n_steps]]
    old_log_probs = torch.stack(log_probs).detach() #torch.Size([104, 1]) -> [[n1], [n2], [n3],... [n_steps]]
    # print(old_actions)
   
    # PPO update
    for _ in range(EPOCHS):
        # Process in mini-batches
        indices = np.arange(len(old_states))
        np.random.shuffle(indices)
        
        for start_idx in range(0, len(old_states), BATCH_SIZE):
            end_idx = min(start_idx + BATCH_SIZE, len(old_states))
            batch_indices = indices[start_idx:end_idx]
            
            batch_states = old_states[batch_indices]
            batch_actions = old_actions[batch_indices]
            batch_log_probs = old_log_probs[batch_indices]
            batch_advantages = advantages[batch_indices]
            batch_returns = returns[batch_indices]
            
            # Actor update
            new_action_probs = actor(batch_states)
            new_dist = Categorical(new_action_probs)
            new_log_probs = new_dist.log_prob(batch_actions)
            
            # Calculate policy ratio and surrogate loss
            ratio = torch.exp(new_log_probs - batch_log_probs)
            surr1 = ratio * batch_advantages
            surr2 = torch.clamp(ratio, 1.0 - EPSILON, 1.0 + EPSILON) * batch_advantages
            
            # Negative because we're minimizing, but we want to maximize the objective
            actor_loss = -torch.min(surr1, surr2).mean()
            
            # Update actor
            optimizer_actor.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(actor.parameters(), 0.5)
            optimizer_actor.step()
            
            # Critic update - completely separate calculation
            critic_values = critic(batch_states).squeeze()
            critic_loss = nn.MSELoss()(critic_values, batch_returns)
            
            # Update critic
            optimizer_critic.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(critic.parameters(), 0.5)
            optimizer_critic.step()
    
    episode_rewards.append(episode_reward)
    
    if (episode + 1) % 10 == 0:
        avg_reward = np.mean(episode_rewards[-10:])
        print(f"Episode {episode+1}/{NUM_EPISODES}, Avg Reward (last 10): {avg_reward:.2f}")

env.close()

# Getting bad results ;-;