In [1]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.buffers import ReplayBuffer
torch.set_printoptions(linewidth=120, precision=2, sci_mode=False, profile="short")
import cv2 as cv

# Doesnt support more than 1 env yet
# I dont think sampledExperienced.dones works at all, where do the names come from? Remake it with namedTuple

seed = 1
nEnvs = 1
torch_deterministic = True
env_id = "LunarLanderContinuous-v2"
totalTimesteps = 10000
buffer_size = int(1e6)
gamma = 0.99
tau = 0.005
batch_size = 256
policy_lr = 3e-4
q_lr = 1e-3
policyFrequency = 2
QNetworkFrequency = 1  # Denis Yarats' implementation delays this by 2.
noise_clip = 0.5
alpha = 0.2
autoEntropy = True


In [2]:

# ALGO LOGIC: initialize agent here:
class SoftQNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, x, a):
        x = x.to(torch.float)
        x = torch.cat([x, a], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

LOG_STD_MAX = 2
LOG_STD_MIN = -5

class Actor(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape))
        self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape))
        # action rescaling
        self.register_buffer(
            "action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32)
        )
        self.register_buffer(
            "action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32)
        )

    def forward(self, x):
        x = x.to(torch.float)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.fc_mean(x)
        log_std = torch.clamp(self.fc_logstd(x), LOG_STD_MIN, LOG_STD_MAX)
        return mean, log_std

    def get_action(self, x, evaluation=False):
        mean, log_std = self(x)
        std = log_std.exp()
        # print(f"mean, std shapes: {mean.shape}, {std.shape}")
        policyDistribution = torch.distributions.Normal(mean, std)
        if evaluation:
            actionSample = mean
        else:
            actionSample = policyDistribution.rsample()
        # print(f"actionSample.shape {actionSample.shape}")
        tanhAction = torch.tanh(actionSample)
        # print(f"tanhAction.shape {tanhAction.shape}")
        action = tanhAction * self.action_scale + self.action_bias
        # print(f"action.shape {action.shape}")
        log_prob = policyDistribution.log_prob(actionSample)
        # print(f"log_prob.shape {log_prob.shape}, log_prob {log_prob}")
        # print(f"tanhAction.shape {tanhAction.shape}, tanhAction {tanhAction}")
        # print(f"action_scale.shape {self.action_scale.shape}, action_scale {self.action_scale}")
        # print(f"Trying to add to log prob log of self.action_scale * (1 - tanhAction.pow(2)) + 1e-6: {self.action_scale * (1 - tanhAction.pow(2)) + 1e-6}")
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - tanhAction.pow(2)) + 1e-6)
        # print(f"Added it, so now log_prob is {log_prob} of shape {log_prob.shape}")
        log_prob = log_prob.sum(-1, keepdim=True)
        # print(f"Summed it so now logprob is {log_prob} of shape {log_prob.shape}")
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
envs = gym.vector.make(env_id, num_envs=nEnvs)
envs.single_observation_space.dtype = np.float32
totalEpisodicRewards = torch.zeros(nEnvs).to(device)
print(f"envs.single_action_space.shape = {envs.single_action_space.shape}. Prod = {np.prod(envs.single_action_space.shape)}")


  gym.logger.warn(


envs.single_action_space.shape = (2,). Prod = 2


In [3]:

actor = Actor(envs).to(device)
QNet1 = SoftQNetwork(envs).to(device)
QNet2 = SoftQNetwork(envs).to(device)
QNet1_target = SoftQNetwork(envs).to(device)
QNet2_target = SoftQNetwork(envs).to(device)
QNet1_target.load_state_dict(QNet1.state_dict())
QNet2_target.load_state_dict(QNet2.state_dict())
QNetsOptimizer = optim.Adam(list(QNet1.parameters()) + list(QNet2.parameters()), lr=q_lr)
actorOptimizer = optim.Adam(list(actor.parameters()), lr=policy_lr)

# Automatic entropy tuning
if autoEntropy:
    target_entropy = -torch.tensor(envs.single_action_space.shape).prod().item()
    logAlpha = torch.zeros(1, requires_grad=True, device=device)
    alpha = logAlpha.exp().item()
    alphaOptimizer = optim.Adam([logAlpha], lr=q_lr)

experiences = ReplayBuffer(
    buffer_size,
    envs.single_observation_space,
    envs.single_action_space,
    device,
    handle_timeout_termination=False,
)

allScores = []
obs, _ = envs.reset(seed=seed)
for globalStep in range(totalTimesteps):
    actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
    actions = actions.detach().cpu().numpy()

    next_obs, rewards, terminations, truncations, infos = envs.step(actions)
    experiences.add(obs, next_obs, actions, rewards, terminations, infos)
    obs = next_obs

    done = np.logical_or.reduce([terminations, truncations])
    totalEpisodicRewards += torch.tensor(rewards).to(device)
    for finalScore in totalEpisodicRewards[done]:
        print(f"Score: {finalScore:>8.2f}")
        allScores.append(finalScore.item())
    totalEpisodicRewards[done] = 0




    sampledExperiences = experiences.sample(min(batch_size, experiences.size()))
    with torch.no_grad():
        nextStateActions, nextStateLogProbs, _ = actor.get_action(sampledExperiences.next_observations)
        QFunction1NextTarget = QNet1_target(sampledExperiences.next_observations, nextStateActions)
        QFunction2NextTarget = QNet2_target(sampledExperiences.next_observations, nextStateActions)
        minQNextTarget = torch.min(QFunction1NextTarget, QFunction2NextTarget) - alpha * nextStateLogProbs
        nextQValue = sampledExperiences.rewards.flatten() + (1 - sampledExperiences.dones.flatten()) * gamma * (minQNextTarget).view(-1)


    QFunction1ActionValues = QNet1(sampledExperiences.observations, sampledExperiences.actions).view(-1)
    QFunction2ActionValues = QNet2(sampledExperiences.observations, sampledExperiences.actions).view(-1)
    QFunction1Loss = F.mse_loss(QFunction1ActionValues, nextQValue)
    QFunction2Loss = F.mse_loss(QFunction2ActionValues, nextQValue)
    QFunctionsTotalLoss = QFunction1Loss + QFunction2Loss

    QNetsOptimizer.zero_grad()
    QFunctionsTotalLoss.backward()
    QNetsOptimizer.step()

    if globalStep % policyFrequency == 0:
        for i in range(policyFrequency):
            if i > 0:
                # Sample new experiences if we make multiple updates
                sampledExperiences = experiences.sample(min(batch_size, experiences.size()))
            actions, logProbabilities, _ = actor.get_action(sampledExperiences.observations)
            QFunction1Evaluation = QNet1(sampledExperiences.observations, actions)
            QFunction2Evaluation = QNet2(sampledExperiences.observations, actions)
            minQEvalutaion = torch.min(QFunction1Evaluation, QFunction2Evaluation)
            actorLoss = ((alpha * logProbabilities) - minQEvalutaion).mean()

            actorOptimizer.zero_grad()
            actorLoss.backward()
            actorOptimizer.step()

            if autoEntropy:
                with torch.no_grad():
                    _, logProbabilities, _ = actor.get_action(sampledExperiences.observations)
                alphaLoss = (-logAlpha.exp()*(logProbabilities + target_entropy)).mean()

                alphaOptimizer.zero_grad()
                alphaLoss.backward()
                alphaOptimizer.step()
                alpha = logAlpha.exp().item()
    print(f"Alpha: {alpha}")

    if globalStep % QNetworkFrequency == 0:
        for param, targetParam in zip(QNet1.parameters(), QNet1_target.parameters()):
            targetParam.data.copy_(tau*param.data + (1 - tau)*targetParam.data)
        for param, targetParam in zip(QNet2.parameters(), QNet2_target.parameters()):
            targetParam.data.copy_(tau*param.data + (1 - tau)*targetParam.data)
envs.close()

Alpha: 0.9980071187019348
Alpha: 0.9980071187019348
Alpha: 0.9960227012634277
Alpha: 0.9960227012634277
Alpha: 0.9940404891967773
Alpha: 0.9940404891967773
Alpha: 0.9920586943626404
Alpha: 0.9920586943626404
Alpha: 0.9900829195976257
Alpha: 0.9900829195976257
Alpha: 0.9881163239479065
Alpha: 0.9881163239479065
Alpha: 0.986151397228241
Alpha: 0.986151397228241
Alpha: 0.9841977953910828
Alpha: 0.9841977953910828
Alpha: 0.9822391867637634
Alpha: 0.9822391867637634
Alpha: 0.9802781343460083
Alpha: 0.9802781343460083
Alpha: 0.9783176183700562
Alpha: 0.9783176183700562
Alpha: 0.9763534069061279
Alpha: 0.9763534069061279
Alpha: 0.9743953347206116
Alpha: 0.9743953347206116
Alpha: 0.9724488258361816
Alpha: 0.9724488258361816
Alpha: 0.9705116152763367
Alpha: 0.9705116152763367
Alpha: 0.9685928225517273
Alpha: 0.9685928225517273
Alpha: 0.9666793346405029
Alpha: 0.9666793346405029
Alpha: 0.9647737741470337
Alpha: 0.9647737741470337
Alpha: 0.9628947377204895
Alpha: 0.9628947377204895
Alpha: 0.96103

KeyboardInterrupt: 

In [None]:
images = []
environment = gym.make(env_id, render_mode="rgb_array")
obs, _ = environment.reset()
totalReward = 0

while True:
    actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
    actions = actions.detach().cpu().numpy()
    next_obs, rewards, terminated, truncated, _ = environment.step(actions)
    obs = next_obs
    image = environment.render()
    images.append(image)
    totalReward += torch.tensor(rewards).to(device)
    if terminated or truncated:
        print(f"Score: {finalScore:>8.2f}")
        break
# saveVideo(images, f"Testing-({totalReward:.0f}).mp4", 30)

RuntimeError: output with shape [2] doesn't match the broadcast shape [1, 2]