In [None]:
import torch
from torch import nn
import numpy as np
import gymnasium as gym
from torch.distributions.categorical import Categorical
from torch.optim import Adam

In [None]:
class RolloutBuffer():
    def __init__(self):
        self.rewards=[]
        self.states=[]
        self.logprobas=[]
        self.actions=[]
        self.rewards_togo=[]
        self.ep_rewards=[]
        self.ep_len=[]



class Policy(nn.Module):
    def __init__(self,input_size, output_size):
        super(Policy,self).__init__()
        module = [
            nn.Linear(input_size,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,output_size),
            nn.Softmax(output_size)
        ]
        self.mlp = nn.Sequential(*module)
    
    def forward(self,state):
        return self.mlp(state)

class Value(nn.Module):
    def __init__(self, input_size, output_size):
        super(Value,self).__init__()
        module = [
            nn.Linear(input_size,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,output_size)
        ]
        self.mlp = nn.Sequential(*module)

    def forward(self,state):
        return self.mlp(state)    

def train(env_name='CartPole-v0', lr=1e-2, batch_size=5000, epochs=50, render=False):
    env = gym.make(env_name)
    obs_size = env.observation_space.shape[0]
    action_size = env.action_space.shape[0]

    policy = Policy(obs_size,action_size)
    value = Value(obs_size,1)    
    
    value_optimizer = Adam(value.parameters(), lr =1e-3)

    def get_policy(state):
        return Categorical(logits=policy(state))
    
    def get_action(state):
        return get_policy(state).sample().item()

    def advantage_function(state):
        return value(state)

    def compute_loss_advantage(state,rewards_togo):
        return torch.nn.MSELoss(advantage_function(state)-rewards_togo)

    def estimate_policy_gradient():
        return NotADirectoryError

    def loss_policy_gradient(old_policy, new_policy, advantage):
        ratio = torch.exp(new_policy-old_policy)
        return -(ratio*advantage).mean()
    
    def surrogate_loss():
        return NotImplementedError
    
    def fisher_vector_product():
        return NotImplementedError
    
    def reward_to_go(ep_rews):
        rtg = np.zeros(len(ep_rews))
        for i in reversed(range(len(ep_rews))):
            rtg [i] = ep_rews[i] + (rtg[i+1] if i+1 < len(ep_rews) else 0)
        return rtg
    
    
    def step_trpo():
        return NotImplementedError

    def train_epoch():
        buffer = RolloutBuffer()
        obs = env.reset()
        done = False 

        while True: 
            act = get_action(obs)
            buffer.states.append(obs.copy())
            obs, reward, done, _ = env.step(act)
            buffer.actions.append(act)
            buffer.rewards.append(reward)
            
            
            if done: 
                ep_return, len_ep = sum(buffer.rewards), len(buffer.rewards)
                buffer.ep_rewards.append(ep_return)
                buffer.ep_len.append(len_ep)
                # adds all elements of the tuple individually to the list 
                buffer.rewards_togo += list(reward_to_go(buffer.rewards))

                if len(buffer.states) > batch_size:
                    break 

        

        # optimize value function 
        loss_value = compute_loss_advantage(torch.is_tensor(buffer.states), torch.is_tensor(buffer.rewards_togo)) 
        value_optimizer.zero_grad()
        loss_value.backward()
        value_optimizer.step()         



