In [3]:
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn

class Actor(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim),
            nn.Softmax(dim=-1)
        )
    def forward(self, x):
        return self.model(x)
    
class Critic(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim), #output dim should be 1
            #No Activation function here for now
        )
    def forward(self, x):
        return self.model(x)

In [None]:
def discount_and_normalize_rewards(rewards, discount_factor):
    discounted_rewards = []
    G = 0
    for i in reversed(rewards):
        G = G + discount_factor * i
        discounted_rewards.insert(0, G)
    discounted_rewards = torch.tensor(rewards, dtype=torch.float32)
    
    return (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e+9)

#Continous false abhi ke liye, Might try continuous space too after implementing AC
env = gym.make("LunarLander-v3", continuous=False)
acid = env.observation_space.shape[0]
acod = env.action_space.n
actor = Actor(acid, acod)
critic = Critic(acid, 1)
gamma= 0.99
optimizer_actor = torch.optim.Adam(actor.parameters(), lr=0.001)
optimizer_critic = torch.optim.Adam(critic.parameters(), lr=0.001)
n_episodes = 10000
mse = nn.MSELoss()


total_rewards= []
for episode in range(n_episodes):
    # if(episode>1500):
    #     env = gym.make("LunarLander-v3", continuous=False, render_mode='human')
    state, _ = env.reset()
    done = False
    episode_rewards = []
    log_probs = [] #[tensor([-1.5057], grad_fn=<SqueezeBackward1>), tensor([-1.3106], grad_fn=<SqueezeBackward1>),..........(steps count tak)]
    values = []    #[tensor([[-0.1408]], grad_fn=<AddmmBackward0>), tensor([[-0.1340]], grad_fn=<AddmmBackward0>),..........(steps count tak)]
    rewards = []   #[np.float64(-0.9220499698757294), np.float64(-1.5198324317463403),..........(steps count tak)]
    
    while not done:
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        
        # Actor policy
        action_probs = actor(state)
        distribution = torch.distributions.Categorical(action_probs)
        action = distribution.sample()
        log_prob = distribution.log_prob(action)
        
        # Critic value
        value = critic(state)
        
        # Take action
        new_state, reward, done, truncated, _ = env.step(action.item())
        
        # Store data
        log_probs.append(log_prob)
        values.append(value)
        rewards.append(reward)
        
        done = done or truncated
        state = new_state
    
    # print(log_probs) #[tensor([-1.5057], grad_fn=<SqueezeBackward1>), tensor([-1.3106], grad_fn=<SqueezeBackward1>),..........(steps count tak)]
    # print(values)    #[tensor([[-0.1408]], grad_fn=<AddmmBackward0>), tensor([[-0.1340]], grad_fn=<AddmmBackward0>),..........(steps count tak)]
    
    # Convert lists to tensors
    log_probs = torch.stack(log_probs)  #[[n1], [n2], [n3],... [n_steps]]
    values = torch.cat(values)          #[[n1], [n2], [n3],... [n_steps]]
    
    
    # Calculate returns and advantages
    returns = []
    R = 0
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    returns = torch.tensor(returns, dtype=torch.float32).unsqueeze(1)  #[[n1], [n2], [n3],... [n_steps]]
    returns = (returns - returns.mean())/returns.std() + 1e-9
    
    
    # Calculate advantages (returns - values)
    # [[n1], [n2], [n3],..[n_steps]] - [[n1], [n2], [n3],..[n_steps]]
    advantages = returns - values  #A(s,a)=Q(s,a)−V(s)
    
    # print(len(rewards))
    # print(advantages.shape)
    
    # Update critic
    critic_loss = mse(values, returns)
    optimizer_critic.zero_grad()
    critic_loss.backward()
    optimizer_critic.step()
    
    # Update actor
    actor_loss = -(log_probs * advantages.detach()).mean()
    optimizer_actor.zero_grad()
    actor_loss.backward()
    optimizer_actor.step()

    
    total_rewards.append(sum(rewards))
    recent_rewards = total_rewards[-10] if(len(total_rewards)) >= 10 else total_rewards
    avg_rewards = np.mean(recent_rewards)
    print(f"Episdoe : {episode} ::::::: Average reward {avg_rewards}")

Episdoe : 0 ::::::: Average reward -123.3452480654913
Episdoe : 1 ::::::: Average reward -155.26772902268183
Episdoe : 2 ::::::: Average reward -129.122185336934
Episdoe : 3 ::::::: Average reward -136.1068650496244
Episdoe : 4 ::::::: Average reward -136.3351118785033
Episdoe : 5 ::::::: Average reward -135.3331451317536
Episdoe : 6 ::::::: Average reward -179.2496825139498
Episdoe : 7 ::::::: Average reward -181.1117949435897
Episdoe : 8 ::::::: Average reward -192.24380224534664
Episdoe : 9 ::::::: Average reward -123.3452480654913
Episdoe : 10 ::::::: Average reward -187.19020997987238
Episdoe : 11 ::::::: Average reward -76.83109796543832
Episdoe : 12 ::::::: Average reward -157.06090418769554
Episdoe : 13 ::::::: Average reward -137.24809919401895
Episdoe : 14 ::::::: Average reward -130.32331139800513
Episdoe : 15 ::::::: Average reward -442.74890680712696
Episdoe : 16 ::::::: Average reward -194.14658195106915
Episdoe : 17 ::::::: Average reward -281.2998606594022
Episdoe : 18 

KeyboardInterrupt: 

In [12]:
import torch

# Assuming you have a policy network and a critic (value function)
torch.save({
    'policy_state_dict': actor.state_dict(),
    'critic_state_dict': critic.state_dict(),
    'optimizer_state_dict': optimizer_actor.state_dict(),  # Optional
    'optimizer_state_dict_critic': optimizer_critic.state_dict(),  # Optional
}, 'model_checkpoint.pth')

print("Model saved successfully!")


Model saved successfully!
