In [3]:
import gymnasium as gym
import torch
from torch import nn, optim
import numpy as np
from collections import namedtuple, deque
import random
import matplotlib.pyplot as plt
from itertools import count
import math

In [4]:
class OUActionNoise:
    def __init__(self, mean, std_deviation, theta=0.15, dt=1e-2, x_initial=None):
        self.theta = theta
        self.mean = mean
        self.std_dev = std_deviation
        self.dt = dt
        self.x_initial = x_initial
        self.reset()

    def __call__(self):
        x = (
            self.x_prev
            + self.theta * (self.mean - self.x_prev) * self.dt
            + self.std_dev * np.sqrt(self.dt) * np.random.normal(size=self.mean.shape)
        )
        self.x_prev = x
        return x
    def reset(self):
        if self.x_initial is not None:
            self.x_prev = self.x_initial
        else:
            self.x_prev = np.zeros_like(self.mean)

In [5]:
class Policy(nn.Module):
    def __init__(self, n_observations):
        super().__init__()
        self.l1 = nn.Linear(n_observations, 128)
        self.l2 = nn.Linear(128, 128)
        self.l3 = nn.Linear(128, 1)
        
    def forward(self, x):
        x = nn.ReLU()(self.l1(x))
        x = nn.ReLU()(self.l2(x))
        out = nn.Tanh()(self.l3(x))*2.
        return out
    
class QNetwork(nn.Module):
    def __init__(self, n_observations):
        super().__init__()
        self.l1 = nn.Linear(n_observations + 1, 128)
        self.l2 = nn.Linear(128, 128)
        self.l3 = nn.Linear(128, 1)
        
    def forward(self, states, actions):
        x = torch.cat([states, actions], dim=1)
        x = nn.ReLU()(self.l1(x))
        x = nn.ReLU()(self.l2(x))
        return self.l3(x)

In [6]:
Transition = namedtuple("Transition", ('state', 'action', 'reward', 'next_state'))

class ReplayBuffer(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen = capacity)
        
    def push(self, *args):
        self.memory.append(Transition(*args))
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def __len__(self):
        return len(self.memory)

In [7]:
## Environment
env = gym.make("Pendulum-v1")
obs, _ = env.reset()
n_observations = len(obs)

## Algorithm
PolicyNet = Policy(n_observations)
PolicyTarget = Policy(n_observations)
PolicyTarget.load_state_dict(PolicyNet.state_dict())

QNet = QNetwork(n_observations)
QNetTarget = QNetwork(n_observations)
QNetTarget.load_state_dict(QNet.state_dict())


## Hyperparameters
LR = 1e-3
QNetLossFn = nn.HuberLoss()
num_epochs = 1000
BATCH_SIZE = 128
GAMMA = 0.95
TAU = 0.005
memory = ReplayBuffer(10000)
steps_done = 0
EPS_START = 0.95
EPS_END = 0.05
EPS_DECAY = 10000

PolicyOptimizer = optim.AdamW(PolicyNet.parameters(), lr=LR, amsgrad=True)
QNetOptimizer = optim.AdamW(QNet.parameters(), lr=LR, amsgrad = True)

def select_action(state):
    global steps_done
    eps_threshold = EPS_END + (EPS_START - EPS_END)*math.exp(-1.*steps_done/EPS_DECAY)
    sample = random.random()
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return (PolicyNet(state) + torch.tensor(noise())).clamp_(-2.0, 2.0).view(1).float()
    return torch.tensor(env.action_space.sample(), dtype=torch.float32)

In [17]:
def select_action(state):
    global steps_done
    eps_threshold = EPS_END + (EPS_START - EPS_END)*math.exp(-1.*steps_done/EPS_DECAY)
    sample = random.random()
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return (PolicyNet(state)).clamp_(-2.0, 2.0).view(1).float()
    return torch.tensor(env.action_space.sample(), dtype=torch.float32)

In [8]:
noise = OUActionNoise(mean=np.zeros(1), std_deviation=float(0.2)*np.ones(1))

In [9]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    
    batch = Transition(*zip(*memory.sample(BATCH_SIZE)))
    
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action).unsqueeze(1)
    reward_batch = torch.cat(batch.reward)
    next_state_batch = torch.cat(batch.next_state)
    
    with torch.no_grad():
        target_actions = PolicyTarget(next_state_batch)
        y = reward_batch.unsqueeze(1) + GAMMA*QNetTarget(next_state_batch, target_actions)
        
    critic_value = QNet(state_batch, action_batch)
    critic_loss = nn.MSELoss()(critic_value, y)
    QNet.zero_grad()
    critic_loss.backward()
    torch.nn.utils.clip_grad_value_(QNet.parameters(), 80)
    QNetOptimizer.step()
    
    actions = PolicyNet(state_batch)
    policy_loss = -torch.mean(QNet(state_batch, actions))
    PolicyNet.zero_grad()
    policy_loss.backward()
    torch.nn.utils.clip_grad_value_(PolicyNet.parameters(), 80)
    PolicyOptimizer.step()
    
    

In [18]:
env = gym.make("Pendulum-v1")
for i_episode in range(num_epochs):
    rewards = []
    state, info = env.reset()
    state = torch.tensor(state, dtype = torch.float32).view(1, 3)
    for t in count():
        action = select_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.numpy())
        rewards.append(reward)
        reward = torch.tensor([reward], dtype=torch.float32)
        done = terminated or truncated
        

        next_state = torch.tensor(observation, dtype=torch.float32).view(1, 3)
            
        memory.push(state, action, reward, next_state)
        state = next_state
        
        optimize_model()
        
        policy_state_dict = PolicyNet.state_dict()
        policy_target_state_dict = PolicyTarget.state_dict()
        QNet_state_dict = QNet.state_dict()
        QNetTarget_state_dict = QNetTarget.state_dict()
        
        for key in policy_state_dict:
            policy_target_state_dict[key] = TAU*policy_state_dict[key] + (1-TAU)*policy_target_state_dict[key]
        PolicyTarget.load_state_dict(policy_target_state_dict)
        
        for key in QNet_state_dict:
            QNetTarget_state_dict[key] = TAU*QNet_state_dict[key] + (1-TAU)*QNetTarget_state_dict[key]
        QNetTarget.load_state_dict(QNetTarget_state_dict)
        
        if done:
            break
    print(f"Episode: {i_episode}        Performance: {sum(rewards)}")

Episode: 0        Performance: -130.75235965651683
Episode: 1        Performance: -128.2783305060108
Episode: 2        Performance: -124.92718055044442
Episode: 3        Performance: -247.3255717136473
Episode: 4        Performance: -126.98587372977205
Episode: 5        Performance: -236.72478527270025
Episode: 6        Performance: -115.65795082505434
Episode: 7        Performance: -123.81054876290732
Episode: 8        Performance: -129.05448260648572
Episode: 9        Performance: -126.63289007403576
Episode: 10        Performance: -131.4783462693548
Episode: 11        Performance: -128.96816383821445
Episode: 12        Performance: -117.05128016517159
Episode: 13        Performance: -2.7793778933161746
Episode: 14        Performance: -254.85156306103437
Episode: 15        Performance: -123.91436979152647
Episode: 16        Performance: -114.86524998769261
Episode: 17        Performance: -243.69502915480447
Episode: 18        Performance: -132.27421954566466
Episode: 19        Perfor

KeyboardInterrupt: 

In [13]:
@torch.no_grad()
def make_decision(state):   # Already in torch format
    action = PolicyNet(state)
    action = action.numpy()[0]
    return action

In [16]:
env = gym.make("Pendulum-v1", render_mode = "human")
state, _ = env.reset()
rewards = []
state = torch.tensor(state).view(1, 3)
for _ in range(100000):
    action = make_decision(state)
    state, reward, terminated, truncated, info = env.step(action)
    rewards.append(reward)
    state = torch.tensor(state).view(1, 3)
    if terminated or truncated:
        break
print(sum(rewards))
env.close()

-128.15198121892035


In [12]:
env.close()

In [15]:
torch.save(PolicyNet.state_dict(), "PolicyNet_PendulumV1.pt")