In [2]:
import gymnasium as gym

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from collections import deque

env = gym.make("Pendulum-v1", render_mode="human")

""" 
"""

def test_run(n_ep, env):
    
    for _ in range(n_ep):
        observation, info = env.reset()
        total_reward = 0

        while True:
            action = env.action_space.sample()
            observation, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            
            total_reward += reward
            # print(total_reward)
            
            if done:
                break

# test_run(1, env)
# env.close()

In [35]:
class policy_net(nn.Module):
    def __init__(self, in_dim, out_dim) -> None:
        super().__init__()
        # note to self: nn.Linear() represents the transformation, not the matrices themselves.
        self.il = nn.Linear(in_dim, 50)
        self.relu = nn.ReLU()

        self.mean_l = nn.Linear(50, out_dim)
        self.log_std_l = nn.Linear(50, out_dim)
    def forward(self, x):
        x = F.relu(self.il(x))

        mean = self.mean_l(x)
        log_std = self.log_std_l(x)
        return mean, log_std

def generate_episode(policy_net, _env):
    # we keep the means as means because we need both action prediction and mean of distribution
    states, log_probs, rewards = [], [], []
    # reset env, get the starting state
    state, info = _env.reset()
    # single-goal oriented environment, can track successes, yay
    successes = 0

    while True:
        # add current state to states list
        state_t = torch.tensor(state, dtype=torch.float32)
        states.append(state_t)
        
        # get the predicted mean of distri for current state, add to list of means
        mean, log_std = policy_net.forward(state_t)
        std = torch.exp(log_std)
        
        gauss = torch.distributions.Normal(mean, std)
        action = gauss.sample()
        action_np = action_tensor[0].detach().numpy() 
        
        log_prob_tensor = gauss.log_prob(action).unsqueeze(0)
        log_probs.append(log_prob_tensor)

        state, reward, terminated, truncated, info = _env.step(action_np)
        
        rewards.append(reward)
        
        if terminated or truncated:
            if rewards[-1] == 1:
                successes += 1
            break

    return log_probs, rewards

def REINFORCE(_env, n_ep):
    state_dim = _env.observation_space.shape[0]
    action_dim = _env.action_space.shape[0]
    policy = policy_net(state_dim, action_dim)
    optimizer = torch.optim.Adam(policy.parameters(), lr=1e-3)
    gamma = 0.99
        
    for i in range(n_ep):
        log_probs, rewards = generate_episode(policy, _env)
        returns = []
        policy_loss = []
        disc_return = 0
        
        for R in reversed(rewards):
            disc_return = R + gamma * disc_return
            returns.insert(0, disc_return) 
            
        returns = torch.tensor(returns, dtype=torch.float32)
        returns = (returns - returns.mean()) / (returns.std() + 1e-9)
        
        for G, L in zip(returns, log_probs):
            policy_loss.append(-L * G)
        
        policy_loss = torch.cat(policy_loss).sum()

        if i % 50 == 0:
            print(policy_loss)
            
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()
        
    return policy

In [43]:
env = gym.make("Pendulum-v1")

policy = REINFORCE(env, 100)
env.close()

tensor(-10.7751, grad_fn=<SumBackward0>)
tensor(6.2897, grad_fn=<SumBackward0>)


In [47]:
def test_policy_net(policy, _env):
    observation, info = _env.reset()
    observation = torch.tensor(observation)

    while True:
        mean, log_std = policy.forward(observation)
        std = torch.exp(log_std)
        
        gauss = torch.distributions.Normal(mean, std)
        action = gauss.sample()
        
        observation, reward, terminated, truncated, info = _env.step(action)
        observation = torch.tensor(observation)
        
        if terminated or truncated or reward < -100:
            break

env = gym.make("Pendulum-v1", render_mode="human")
test_policy_net(policy, env)
env.close()