In [None]:
import torch
from torch import nn
from torch.distributions.categorical import Categorical
import gymnasium as gym 
from torch.optim import Adam
import numpy as np

In [None]:
class Policy(torch.nn.Module):
    def __init__(self, input_size,output_size):
        super(Policy,self).__init__()
        module = [
            nn.Linear(input_size,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,output_size),
            nn.Softmax(output_size)
        ]
        self.mlp = nn.Sequential(*module)

    def forward(self,state):
        return self.mlp(state)
    
class Value(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super(Value,self).__init__()
        module = [
            nn.Linear(input_size,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,output_size)
        ]
        self.mlp = nn.Sequential(*module)

    def forward(self,state): 
        return self.mlp(state)
    

def train(env_name="CartPole-v0", batch_size=10000, epoch=50 , lr=1e-2):
    
    env = gym.make(env_name)

    obs_size = env.observation_space.shape[0]
    action_space = env.action_space.shape[0]

    policy = Policy(obs_size,action_space)
    old_policy = Policy(obs_size,action_space)
    value = Value(obs_size,1)
    
    def get_policy(state):
        return Categorical(logits=policy(state))

    def get_action(state):
        return get_policy(state).sample().item()
    
    # loss function : - sum(logprob * (R-baseline)) . Baseline is the value function 
    def compute_loss(acts, states,weights):
        prob = get_policy(states).log_prob(acts)
        baseline = value(states) 
        return - (prob * (weights-baseline)).mean()
    
    def compute_loss_value(states,weights):
        return torch.nn.functional.mse_loss(value(states),weights) 


    policy_optimizer =  Adam(params=policy.parameters(),lr =lr)
    value_optimizer =  Adam(params=value.parameters(), lr=lr)

    # reward to go is the: is the sum of all future rewards from a timestep onward
    def reward_to_go(ep_rews):
        rtg = np.zeros(len(ep_rews))
        for i in reversed(range(len(ep_rews))):
            rtg [i] = ep_rews[i] + (rtg[i+1] if i+1 < len(ep_rews) else 0)
        return rtg

    def epoch_train(): 
        batch_obs =[]
        batch_act = []
        batch_weight = []
        batch_return = []
        batch_len = []
        ep_reward =[]

        state =  env.reset()
        done = False

        while True:
            batch_obs.append(state.copy())
            act = get_action(state)
            batch_act.append(act)
            state,reward,done, _ = env.step(act)
            
            ep_reward.append(reward)

            if done:
                ep_return, ep_len = sum(ep_reward),len(ep_reward)
                batch_return.append(ep_return)
                batch_len.append(ep_len)
                # calculate rewards-to-go 
                batch_weight += list(reward_to_go(ep_reward))
                
                if len(batch_obs)> batch_size:
                    break
        
        
        # OPTIMIZE 
        # backprob on the loss and optimize 
        loss = compute_loss(torch.as_tensor(batch_act),torch.as_tensor(batch_obs),torch.as_tensor(batch_weight))
        loss.backward()
        policy_optimizer.step()

        loss_value = compute_loss_value(torch.as_tensor(batch_obs), torch.as_tensor(batch_weight))
        loss_value.backward()
        value_optimizer.step()


        return 
