In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os, shutil
from tqdm import tqdm
from rl_glue import RLGlue

from tbu_gym.tbu_discrete import TruckBackerEnv_D

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
from typing import List, Tuple

# -------------------------
# Simple MLP Q / h networks
# -------------------------
def make_mlp(input_dim: int, output_dim: int, hidden_sizes=(128, 128), activation=nn.ReLU):
    layers = []
    prev = input_dim
    for h in hidden_sizes:
        layers.append(nn.Linear(prev, h))
        layers.append(activation())
        prev = h
    layers.append(nn.Linear(prev, output_dim))
    return nn.Sequential(*layers)

# -------------------------
# Helpers for param-traces
# -------------------------
def zeros_like_params(params: List[torch.nn.Parameter]):
    return [torch.zeros_like(p.data) for p in params]

def add_param_lists(a: List[torch.Tensor], b: List[torch.Tensor]) -> List[torch.Tensor]:
    return [x + y for x, y in zip(a, b)]

def scale_param_list(a: List[torch.Tensor], scalar: float) -> List[torch.Tensor]:
    return [scalar * x for x in a]

def copy_params(params: List[torch.nn.Parameter]) -> List[torch.Tensor]:
    return [p.data.clone() for p in params]

# -------------------------
# QRC Agent
# -------------------------
class QRCAgent:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        q_lr=1e-4,
        h_lr=1e-3,
        gamma=0.99,
        lamda=0.9,
        reg_coeff=1.0,
        epsilon=1.0,
        epsilon_decay=0.995,
        epsilon_min=0.01,
        buffer_size=100000,
        batch_size=64,
        device=None,
    ):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))

        # Networks
        self.q_net = make_mlp(state_dim, action_dim).to(self.device)
        self.h_net = make_mlp(state_dim, action_dim).to(self.device)

        # Optionally: target network for Q (you can use it the same way as DQN)
        self.target_q = make_mlp(state_dim, action_dim).to(self.device)
        self.target_q.load_state_dict(self.q_net.state_dict())

        # Optimizers
        self.q_opt = optim.Adam(self.q_net.parameters(), lr=q_lr)
        self.h_opt = optim.Adam(self.h_net.parameters(), lr=h_lr)

        # Replay
        self.memory = deque(maxlen=buffer_size)
        self.batch_size = batch_size

        # Hyperparams
        self.gamma = gamma
        self.lamda = lamda
        self.reg_coeff = reg_coeff

        # Epsilon-greedy
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min

        # Traces (initialized to zero)
        self.q_param_list = list(self.q_net.parameters())
        self.h_param_list = list(self.h_net.parameters())
        self.zero_q_trace = zeros_like_params(self.q_param_list)
        self.zero_h_grad_trace = zeros_like_params(self.h_param_list)
        self.zero_q_grad_trace = zeros_like_params(self.q_param_list)

        # Per-episode/stateful traces
        self.h_trace_scalar = 0.0  # scalar trace for h (small z_t)
        self.grad_h_trace = copy_params(self.h_param_list)  # z_t^{theta}
        self.grad_q_trace = copy_params(self.q_param_list)  # z_t^{w}

    # ---------- interaction ----------
    def agent_policy(self, state: np.ndarray) -> int:
        if np.random.rand() < self.epsilon:
            return np.random.randint(self.action_dim)
        s = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            qvals = self.q_net(s)
        return int(qvals.argmax().item())

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    # ---------- training helpers ----------
    def _zero_grad_params(self, model: nn.Module):
        for p in model.parameters():
            p.grad = None

    def _params_to_list(self, params: List[torch.nn.Parameter]) -> List[torch.Tensor]:
        return [p for p in params]

    def _get_grads_of_scalar_wrt_params(self, scalar: torch.Tensor, params: List[torch.nn.Parameter]):
        """Return list of gradients (same order as params) for scalar w.r.t. params.
           Uses create_graph=False by default, but keep graph when needed externally by caller.
        """
        grads = torch.autograd.grad(scalar, params, retain_graph=True, allow_unused=True)
        grads = [g if (g is not None) else torch.zeros_like(p.detach()) for g, p in zip(grads, params)]
        return grads

    # ---------- trace update ----------
    def _update_traces(self, rho: float, h_val: float, q_grads, h_grads):
        # scalar trace for h
        self.h_trace_scalar = rho * self.gamma * self.lamda * self.h_trace_scalar + float(h_val)

        # grad traces: grad_h_trace <- rho*gamma*lambda*grad_h_trace + grad_h
        self.grad_h_trace = [
            rho * self.gamma * self.lamda * old + new.detach()
            for old, new in zip(self.grad_h_trace, h_grads)
        ]
        # grad_q_trace <- rho*gamma*lambda*grad_q_trace + grad_q
        self.grad_q_trace = [
            rho * self.gamma * self.lamda * old + new.detach()
            for old, new in zip(self.grad_q_trace, q_grads)
        ]

    def _reset_traces(self):
        self.h_trace_scalar = 0.0
        self.grad_h_trace = copy_params(self.h_param_list)
        self.grad_q_trace = copy_params(self.q_param_list)

    # ---------- main training step using a batch ----------
    def train_with_mem(self, use_target_for_next=True):
        if len(self.memory) < self.batch_size:
            return

        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        # tensors
        states_t = torch.FloatTensor(np.stack(states)).to(self.device)         # (B, S)
        next_states_t = torch.FloatTensor(np.stack(next_states)).to(self.device)
        actions_t = torch.LongTensor(actions).to(self.device)                 # (B,)
        rewards_t = torch.FloatTensor(rewards).to(self.device)
        dones_t = torch.FloatTensor(dones).to(self.device)

        B = len(batch)

        # We'll accumulate parameter-grad updates across the batch before applying
        q_param_shapes = [p.data.shape for p in self.q_param_list]
        h_param_shapes = [p.data.shape for p in self.h_param_list]

        # initialize accumulated updates (same structure as params)
        acc_q_grads = [torch.zeros_like(p.data) for p in self.q_param_list]
        acc_h_grads = [torch.zeros_like(p.data) for p in self.h_param_list]

        # We'll also accumulate whether to reset traces (for done samples).
        # For simplicity: we treat each sample independently for rho (here we use 1 if greedy, else 0).
        # Use greedy on-policy assumption: rho=1 (best-case). If you want importance weights, compute them here.
        for i in range(B):
            s = states_t[i:i+1]       # shape (1, S)
            a = int(actions_t[i].item())
            ns = next_states_t[i:i+1]
            r = rewards_t[i].unsqueeze(0)
            done = bool(dones_t[i].item())

            # Q(s,a)
            q_vals = self.q_net(s)
            q_s_a = q_vals[0, a]

            # next state value (max_a' Q(next))
            if use_target_for_next:
                with torch.no_grad():
                    next_q_vals = self.target_q(ns)
                    max_next = next_q_vals.max(1)[0]  # shape (1,)
            else:
                with torch.no_grad():
                    next_q_vals = self.q_net(ns)
                    max_next = next_q_vals.max(1)[0]

            td_target = r + (1.0 - float(done)) * self.gamma * max_next  # shape (1,)
            td_error = (td_target - q_s_a)  # scalar tensor (1,)

            # grads of td_error w.r.t q params
            q_params = list(self.q_net.parameters())
            # Ensure that q_s_a uses the graph for param derivatives
            td_scalar = td_error.squeeze()
            q_grads = self._get_grads_of_scalar_wrt_params(td_scalar, q_params)  # list of grads

            # h(s,a) and grads wrt h params
            h_vals = self.h_net(s)
            h_s_a = h_vals[0, a]
            h_params = list(self.h_net.parameters())
            h_grads = self._get_grads_of_scalar_wrt_params(h_s_a.squeeze(), h_params)

            # rho: here we use 1.0 for greedy policy; you can compute importance sampling ratio if using off-policy.
            rho = 1.0

            # update traces for this sample (we treat per-sample trace update and then use them directly)
            # Note: we do not maintain per-sample traces in replay â€” the canonical approach uses online traces.
            # For replay, the usual practical approximation is to treat rho=1 for samples sampled from replay;
            # this code follows simpler approach: accumulate gradient-traces per sample local to the update.
            # Here we simulate trace values for the sample:
            h_trace_sample = rho * self.gamma * self.lamda * 0.0 + float(h_s_a.detach().cpu().numpy())

            # compute q update parts based on the paper:
            # q_update = - h_trace * grad_td_error  (GTD2 base)
            # if gradient_correction (we include it always here):
            #   + td_error * grad_q_trace  - h * grad_q
            # Implementation detail: grad_q_trace is the stored trace across previous updates; for replay we approximate with stored self.grad_q_trace
            # We'll use stored grad traces (self.grad_q_trace / self.grad_h_trace) as the trace memory.
            q_update_per_param = []
            for idx, g_td in enumerate(q_grads):
                # g_td may be small shaped; ensure shape matches q param
                base = - (self.h_trace_scalar + h_trace_sample) * g_td.detach()  # -h_trace * grad_td
                # gradient correction: td_error * grad_q_trace - h * grad_q
                corr = td_scalar.detach() * self.grad_q_trace[idx] - h_s_a.detach() * g_td.detach()
                total = base + corr
                q_update_per_param.append(total)

            # accumulate into acc_q_grads
            for idx in range(len(acc_q_grads)):
                acc_q_grads[idx] += q_update_per_param[idx].detach()

            # update h: h_update = -( delta_z_h + h_h_grad + beta_params )
            # delta_z_h = td_error * grad_h_trace (use stored grad_h_trace as approximation)
            # h_h_grad = -h * h_grads  (in jax version they used -h * grad_h)
            # beta_params = -reg_coeff * h_params
            h_update_per_param = []
            for idx, hg in enumerate(h_grads):
                delta_z_h = td_scalar.detach() * self.grad_h_trace[idx]
                h_h_grad = - h_s_a.detach() * hg.detach()
                beta_param = - self.reg_coeff * h_params[idx].data
                # In the jax code they combined with a negative piece; we produce gradient to apply directly:
                # final gradient (to ascend the objective, but in pytorch we'll set this as grad to do descent)
                # The jax code does: h_update = -( delta_z_h + h_h_grad + beta_params ). We'll follow that:
                h_update = - (delta_z_h + h_h_grad + beta_param)
                h_update_per_param.append(h_update)

            # accumulate h grads
            for idx in range(len(acc_h_grads)):
                acc_h_grads[idx] += h_update_per_param[idx].detach()

            # For trace statefulness: reset traces if done (approx)
            if done:
                # reset stored traces
                self._reset_traces()
            else:
                # update stored traces globally (we use the q_grads and h_grads computed above)
                # This is a practical approximation for the full online algorithm.
                self.h_trace_scalar = rho * self.gamma * self.lamda * self.h_trace_scalar + float(h_s_a.detach())
                self.grad_h_trace = [rho * self.gamma * self.lamda * old + new.detach()
                                     for old, new in zip(self.grad_h_trace, h_grads)]
                self.grad_q_trace = [rho * self.gamma * self.lamda * old + new.detach()
                                     for old, new in zip(self.grad_q_trace, q_grads)]

        # Average accumulated grads across batch
        acc_q_grads = [g / float(B) for g in acc_q_grads]
        acc_h_grads = [g / float(B) for g in acc_h_grads]

        # Apply gradients to q_net (note: optimizer expects .grad on parameters)
        self._zero_grad_params(self.q_net)
        for p, g in zip(self.q_net.parameters(), acc_q_grads):
            # Flax multiplies by -1 in their update pipeline; here we apply gradient descent so set grad = g
            p.grad = g.to(self.device)
        self.q_opt.step()

        # Apply gradients to h_net
        self._zero_grad_params(self.h_net)
        for p, g in zip(self.h_net.parameters(), acc_h_grads):
            p.grad = g.to(self.device)
        self.h_opt.step()

        # epsilon decay
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    # ---------- utility ----------
    def update_target(self):
        self.target_q.load_state_dict(self.q_net.state_dict())

    def save(self, path_prefix):
        torch.save(self.q_net.state_dict(), path_prefix + "_q.pth")
        torch.save(self.h_net.state_dict(), path_prefix + "_h.pth")

    def load(self, path_prefix):
        self.q_net.load_state_dict(torch.load(path_prefix + "_q.pth", map_location=self.device))
        self.h_net.load_state_dict(torch.load(path_prefix + "_h.pth", map_location=self.device))
        self.update_target()

In [6]:
import numpy as np
from tbu_gym.tbu_discrete import TruckBackerEnv_D
import matplotlib.pyplot as plt

# Hyperparameters
num_episodes = 1000
max_steps_per_episode = 500
gamma = 0.99
q_lr = 1e-4
h_lr = 1e-3
epsilon_start = 1.0
epsilon_decay = 0.99997
epsilon_min = 0.01
batch_size = 64
target_update_freq = 5  # optional, can sync Q target if you want

# Environment setup
env = TruckBackerEnv_D(render_mode=None)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Initialize QRC agent
agent = QRCAgent(
    state_dim=state_dim,
    action_dim=action_dim,
    q_lr=q_lr,
    h_lr=h_lr,
    gamma=gamma,
    lamda=0.9,
    reg_coeff=1.0,
    epsilon=epsilon_start,
    epsilon_decay=epsilon_decay,
    epsilon_min=epsilon_min,
    batch_size=batch_size
)

# Storage for plotting
episode_rewards = []

# Training loop
for episode in range(1, num_episodes + 1):
    state = env.reset()
    total_reward = 0

    for t in range(max_steps_per_episode):
        # Epsilon-greedy action
        action = agent.agent_policy(np.array(state, dtype=np.float32))

        # Step environment
        next_state, reward, done, info = env.step(action)
        total_reward += reward

        # Store transition
        agent.remember(state, action, reward, next_state, float(done))

        # Train agent with memory
        agent.train_with_mem()

        state = next_state

        if done:
            break

    # Optional: sync target network periodically (similar to DQN)
    if episode % target_update_freq == 0:
        agent.update_target()

    # Track rewards
    episode_rewards.append(total_reward)

    print(f"Episode {episode}, Reward: {total_reward}, Epsilon: {agent.epsilon:.3f}")

# Plot rewards
plt.plot(episode_rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('QRC Training on TruckBackerEnv_D')
plt.show()

ModuleNotFoundError: No module named 'qrc_agent'