In [25]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
from scipy.spatial import cKDTree  # for fast nearest-neighbor queries

# Set device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------------
# Hyperparameters
# ---------------------------
LATENT_DIM = 32
HIDDEN_DIM = 64
BATCH_SIZE = 64
GAMMA = 0.99
LEARNING_RATE = 1e-3
MEMORY_CAPACITY = 10000
EPISODES = 1000
MAX_STEPS = 200
EPS_START = 1.0
EPS_END = 0.01
EPS_DECAY = 0.995
CERT_EPSILON = 0.1       # Allowed deviation in Q-value for certification
MIN_RADIUS = 0.01
MAX_RADIUS = 1.0
CONTROLLER_THRESHOLD = 1.0  # TD error threshold for adjusting certification
CONTROLLER_ALPHA = 0.9      # Factor to reduce radius if error is high
CONTROLLER_BETA = 1.1       # Factor to increase radius if error is low
CERT_BINARY_ITERS = 5       # Fixed iterations for binary search
CERT_NUM_SAMPLES = 5        # Number of random perturbations per sample


In [26]:
# ---------------------------
# Encoder Network: Maps state to latent space.
# ---------------------------
class Encoder(nn.Module):
    def __init__(self, input_dim=2, latent_dim=LATENT_DIM):
        super(Encoder, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, HIDDEN_DIM),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM, latent_dim)
        )
        
    def forward(self, x):
        return self.net(x)

# ---------------------------
# QNetwork (Generalizer): Estimates Q-values from latent representation.
# ---------------------------
class QNetwork(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM, num_actions=3):
        super(QNetwork, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, HIDDEN_DIM),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM, num_actions)
        )
        
    def forward(self, z):
        return self.net(z)

# ---------------------------
# Replay Buffer for Experience Replay.
# ---------------------------
class ReplayBuffer:
    def __init__(self, capacity=MEMORY_CAPACITY):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
        
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states), np.array(actions), 
                np.array(rewards), np.array(next_states), 
                np.array(dones))
        
    def __len__(self):
        return len(self.buffer)

# ---------------------------
# Memory Module: Vectorized storage and fast KD–Tree lookup.
# ---------------------------
class MemoryModule:
    def __init__(self):
        # We store latent vectors, radii, actions, and q_values in lists.
        self.latents = []
        self.radii = []
        self.actions = []
        self.q_values = []
        self.kd_tree = None  # Will be built on demand

    def add_batch(self, zs, actions, q_vals, radii):
        # zs: (N, latent_dim) numpy array; actions, q_vals, radii: 1D arrays of length N.
        self.latents.extend(zs)
        self.actions.extend(actions)
        self.q_values.extend(q_vals)
        self.radii.extend(radii)
        # Rebuild kd-tree after each batch addition
        if len(self.latents) > 0:
            self.kd_tree = cKDTree(np.array(self.latents))
    
    def get_certified_q(self, z):
        # Given a latent vector z (1D numpy array), quickly find memory entries nearby.
        if self.kd_tree is None:
            return None
        # Query all memory points within MAX_RADIUS (the maximum possible)
        indices = self.kd_tree.query_ball_point(z, r=MAX_RADIUS)
        if len(indices) == 0:
            return None
        certified_qs = []
        for idx in indices:
            mem_z = self.latents[idx]
            # Check if z is within the certified radius for this memory entry.
            if np.linalg.norm(z - mem_z) <= self.radii[idx]:
                certified_qs.append(self.q_values[idx])
        if len(certified_qs) > 0:
            return np.mean(certified_qs)
        return None
    
    def update_radii(self, td_error):
        # Adjust all stored radii based on the average TD error.
        new_radii = []
        for r in self.radii:
            if td_error > CONTROLLER_THRESHOLD:
                new_r = max(MIN_RADIUS, r * CONTROLLER_ALPHA)
            else:
                new_r = min(MAX_RADIUS, r * CONTROLLER_BETA)
            new_radii.append(new_r)
        self.radii = new_radii

# ---------------------------
# Vectorized Certified Radius Computation.
# ---------------------------
def compute_certified_radii_batch(encoder, q_network, zs, actions, epsilon=CERT_EPSILON, 
                                  min_r=MIN_RADIUS, max_r=MAX_RADIUS,
                                  num_samples=CERT_NUM_SAMPLES, binary_iters=CERT_BINARY_ITERS):
    # zs: Tensor of shape (N, latent_dim)
    # actions: Tensor of shape (N,)
    N, d = zs.shape
    # Initialize candidate radii (start at min_r for all)
    radii = torch.full((N,), min_r, device=device)
    low = torch.full((N,), min_r, device=device)
    high = torch.full((N,), max_r, device=device)
    
    # For each binary search iteration, update candidate radii in vectorized fashion.
    for _ in range(binary_iters):
        mid = (low + high) / 2.0  # (N,)
        # For each sample, generate num_samples random perturbations in latent space.
        # Sample unit vectors: shape (N, num_samples, d)
        deltas = torch.randn((N, num_samples, d), device=device)
        deltas = deltas / (deltas.norm(dim=2, keepdim=True) + 1e-6)
        # Sample random scales uniformly between 0 and mid.
        scales = torch.rand((N, num_samples, 1), device=device) * mid.unsqueeze(1).unsqueeze(2)
        perturbations = deltas * scales  # (N, num_samples, d)
        zs_perturbed = zs.unsqueeze(1) + perturbations  # (N, num_samples, d)
        # Flatten batch for evaluation.
        zs_perturbed_flat = zs_perturbed.view(-1, d)
        q_vals = q_network(zs_perturbed_flat)  # (N * num_samples, num_actions)
        # Gather Q-values corresponding to actions.
        actions_expanded = actions.unsqueeze(1).repeat(1, num_samples).view(-1)
        q_vals = q_vals.gather(1, actions_expanded.unsqueeze(1)).view(N, num_samples)
        # Compute original Q-values for each sample.
        q_orig = q_network(zs).gather(1, actions.unsqueeze(1)).squeeze(1)  # (N,)
        # Expand q_orig to shape (N, num_samples)
        q_orig_expanded = q_orig.unsqueeze(1).expand_as(q_vals)
        differences = torch.abs(q_vals - q_orig_expanded).max(dim=1)[0]  # (N,)
        # If maximum difference <= epsilon, candidate radius is feasible.
        feasible = differences <= epsilon
        # Update low and high bounds vectorized:
        low = torch.where(feasible, mid, low)
        high = torch.where(feasible, high, mid)
        radii = mid  # Use mid as the candidate radius.
    return radii.detach().cpu().numpy()

# ---------------------------
# DQN Agent integrating Memory, Generalizer, and Controller.
# ---------------------------
class DQNAgent:
    def __init__(self, state_dim, action_dim):
        self.encoder = Encoder(input_dim=state_dim).to(device)
        self.q_network = QNetwork(latent_dim=LATENT_DIM, num_actions=action_dim).to(device)
        self.target_q_network = QNetwork(latent_dim=LATENT_DIM, num_actions=action_dim).to(device)
        self.target_q_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(list(self.encoder.parameters()) + 
                                    list(self.q_network.parameters()), lr=LEARNING_RATE)
        self.replay_buffer = ReplayBuffer()
        self.memory_module = MemoryModule()
        self.epsilon = EPS_START
        self.action_dim = action_dim
        self.update_counter = 0  # Count mini-batch updates
        
    def select_action(self, state):
        # Cache latent representation.
        state_tensor = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
        z = self.encoder(state_tensor).detach().cpu().numpy().squeeze()
        certified_q = self.memory_module.get_certified_q(z)
        if certified_q is not None and np.random.rand() > self.epsilon:
            # Use the certified memory Q-value indirectly: choose best action from Q-network.
            with torch.no_grad():
                q_vals = self.q_network(self.encoder(state_tensor))
            action = q_vals.argmax().item()
        else:
            if np.random.rand() < self.epsilon:
                action = np.random.randint(self.action_dim)
            else:
                with torch.no_grad():
                    q_vals = self.q_network(self.encoder(state_tensor))
                action = q_vals.argmax().item()
        return action
    
    def update(self):
        if len(self.replay_buffer) < BATCH_SIZE:
            return
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(BATCH_SIZE)
        states = torch.tensor(states, dtype=torch.float32, device=device)
        actions = torch.tensor(actions, dtype=torch.long, device=device)
        rewards = torch.tensor(rewards, dtype=torch.float32, device=device)
        next_states = torch.tensor(next_states, dtype=torch.float32, device=device)
        dones = torch.tensor(dones, dtype=torch.float32, device=device)
        
        # Encode current states
        z_states = self.encoder(states)
        q_values = self.q_network(z_states)
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
        
        # Compute target Q-values.
        with torch.no_grad():
            z_next = self.encoder(next_states)
            next_q_values = self.target_q_network(z_next)
            next_q_max, _ = next_q_values.max(dim=1)
            target_q = rewards + GAMMA * next_q_max * (1 - dones)
        
        loss = nn.MSELoss()(q_values, target_q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        self.update_counter += 1
        # Every few mini-batch updates, update the memory with certified data.
        if self.update_counter % 10 == 0:
            # Compute TD errors for the batch.
            td_errors = torch.abs(q_values - target_q).detach().cpu().numpy()
            avg_td_error = np.mean(td_errors)
            # Compute certified radii for the entire batch in vectorized form.
            radii = compute_certified_radii_batch(self.encoder, self.q_network, z_states, actions)
            # Get q_values as numpy array.
            q_vals_np = q_values.detach().cpu().numpy()
            z_states_np = z_states.detach().cpu().numpy()
            actions_np = actions.detach().cpu().numpy()
            # Add this batch to the memory.
            self.memory_module.add_batch(z_states_np, actions_np, q_vals_np, radii)
            # Update memory radii using the Controller.
            self.memory_module.update_radii(avg_td_error)
    
    def update_target_network(self):
        self.target_q_network.load_state_dict(self.q_network.state_dict())
        
    def decay_epsilon(self):
        self.epsilon = max(EPS_END, self.epsilon * EPS_DECAY)

env = gym.make("MountainCar-v0")
state_dim = env.observation_space.shape[0]  # e.g., 2: position and velocity.
action_dim = env.action_space.n             # e.g., 3 actions.
agent = DQNAgent(state_dim, action_dim)


In [27]:
# ---------------------------
# Main Training Loop
# ---------------------------
def train():
    scores = []
    for episode in range(EPISODES):
        state, _ = env.reset()
        total_reward = 0
        for t in range(MAX_STEPS):
            action = agent.select_action(state)
            next_state, reward, done, truncated, _ = env.step(action)
            agent.replay_buffer.push(state, action, reward, next_state, done or truncated)
            state = next_state
            total_reward += reward
            if done or truncated:
                break
        agent.update()
        agent.update_target_network()
        agent.decay_epsilon()
        scores.append(total_reward)
        print(f"Episode {episode+1}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.3f}")
    env.close()
    return scores

# ---------------------------
# Main Test Loop
# ---------------------------
def test():
    scores = []
    for episode in range(EPISODES):
        state, _ = env.reset()
        total_reward = 0
        for t in range(MAX_STEPS):
            action = agent.select_action(state)
            next_state, reward, done, truncated, _ = env.step(action)
            state = next_state
            total_reward += reward
            if done or truncated:
                break
        scores.append(total_reward)
        print(f"Episode {episode+1}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.3f}")
    env.close()
    return scores


In [28]:
train()

Episode 1, Total Reward: -200.0, Epsilon: 0.995
Episode 2, Total Reward: -200.0, Epsilon: 0.990
Episode 3, Total Reward: -200.0, Epsilon: 0.985
Episode 4, Total Reward: -200.0, Epsilon: 0.980
Episode 5, Total Reward: -200.0, Epsilon: 0.975
Episode 6, Total Reward: -200.0, Epsilon: 0.970
Episode 7, Total Reward: -200.0, Epsilon: 0.966
Episode 8, Total Reward: -200.0, Epsilon: 0.961
Episode 9, Total Reward: -200.0, Epsilon: 0.956
Episode 10, Total Reward: -200.0, Epsilon: 0.951
Episode 11, Total Reward: -200.0, Epsilon: 0.946
Episode 12, Total Reward: -200.0, Epsilon: 0.942
Episode 13, Total Reward: -200.0, Epsilon: 0.937
Episode 14, Total Reward: -200.0, Epsilon: 0.932
Episode 15, Total Reward: -200.0, Epsilon: 0.928
Episode 16, Total Reward: -200.0, Epsilon: 0.923
Episode 17, Total Reward: -200.0, Epsilon: 0.918
Episode 18, Total Reward: -200.0, Epsilon: 0.914
Episode 19, Total Reward: -200.0, Epsilon: 0.909
Episode 20, Total Reward: -200.0, Epsilon: 0.905
Episode 21, Total Reward: -20

[-200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 -200.0,
 

In [13]:
test()

Episode 1, Total Reward: -200.0, Epsilon: 0.606
Episode 2, Total Reward: -200.0, Epsilon: 0.606
Episode 3, Total Reward: -200.0, Epsilon: 0.606
Episode 4, Total Reward: -200.0, Epsilon: 0.606
Episode 5, Total Reward: -200.0, Epsilon: 0.606
Episode 6, Total Reward: -200.0, Epsilon: 0.606
Episode 7, Total Reward: -200.0, Epsilon: 0.606
Episode 8, Total Reward: -200.0, Epsilon: 0.606
Episode 9, Total Reward: -200.0, Epsilon: 0.606
Episode 10, Total Reward: -200.0, Epsilon: 0.606
Episode 11, Total Reward: -200.0, Epsilon: 0.606
Episode 12, Total Reward: -200.0, Epsilon: 0.606
Episode 13, Total Reward: -200.0, Epsilon: 0.606
Episode 14, Total Reward: -200.0, Epsilon: 0.606
Episode 15, Total Reward: -200.0, Epsilon: 0.606
Episode 16, Total Reward: -200.0, Epsilon: 0.606
Episode 17, Total Reward: -200.0, Epsilon: 0.606
Episode 18, Total Reward: -200.0, Epsilon: 0.606
Episode 19, Total Reward: -200.0, Epsilon: 0.606
Episode 20, Total Reward: -200.0, Epsilon: 0.606
Episode 21, Total Reward: -20

KeyboardInterrupt: 

In [19]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

# Set device to CPU
device = torch.device("cpu")

# Hyperparameters
GAMMA = 0.99
LEARNING_RATE = 1e-3
EPSILON_START = 1.0
EPSILON_DECAY = 0.995
EPSILON_MIN = 0.05
BATCH_SIZE = 64
MEMORY_SIZE = 10000
CERTIFIED_SAMPLES = 10  # Number of samples for certified region estimation

# Define Q-Network
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# Certified Region Estimation
def compute_certified_radii_batch(q_network, states, actions):
    N = states.shape[0]
    with torch.no_grad():
        q_values = q_network(states)  # (N, num_actions)
        q_selected = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # Compute certified bounds (artificially simulate confidence bounds)
        scales = torch.rand((N, CERTIFIED_SAMPLES), device=device) * q_selected.unsqueeze(1)
        radii = torch.mean(scales, dim=1)  # Average certified region
    return radii

# Define the Agent
class CertifiedRegionAgent:
    def __init__(self, state_dim, action_dim):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.epsilon = EPSILON_START
        self.q_network = QNetwork(state_dim, action_dim).to(device)
        self.target_network = QNetwork(state_dim, action_dim).to(device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=LEARNING_RATE)
        self.memory = deque(maxlen=MEMORY_SIZE)
        self.steps = 0

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        else:
            with torch.no_grad():
                state_tensor = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
                q_values = self.q_network(state_tensor)
                return torch.argmax(q_values).item()

    def update(self):
        if len(self.memory) < BATCH_SIZE:
            return
        
        batch = random.sample(self.memory, BATCH_SIZE)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        states = torch.tensor(np.array(states), dtype=torch.float32, device=device)
        actions = torch.tensor(actions, dtype=torch.int64, device=device)
        rewards = torch.tensor(rewards, dtype=torch.float32, device=device)
        next_states = torch.tensor(np.array(next_states), dtype=torch.float32, device=device)
        dones = torch.tensor(dones, dtype=torch.float32, device=device)

        with torch.no_grad():
            next_q_values = self.target_network(next_states)
            max_next_q_values = torch.max(next_q_values, dim=1)[0]
            targets = rewards + (1 - dones) * GAMMA * max_next_q_values

        q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        loss = nn.MSELoss()(q_values, targets)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Certified Region Update
        certified_radii = compute_certified_radii_batch(self.q_network, states, actions)
        adjustment_factor = torch.mean(certified_radii).item()
        
        # Adjust learning rate based on certified bounds
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = max(LEARNING_RATE * adjustment_factor, 1e-5)

        # Update target network periodically
        if self.steps % 100 == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())

        self.steps += 1

    def store_experience(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon * EPSILON_DECAY, EPSILON_MIN)

# Training Function
def train():
    env = gym.make("CartPole-v1")
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    agent = CertifiedRegionAgent(state_dim, action_dim)

    num_episodes = 5000
    rewards_history = []

    for episode in range(num_episodes):
        state, _ = env.reset()
        total_reward = 0
        done = False

        while not done:
            action = agent.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            agent.store_experience(state, action, reward, next_state, done)
            total_reward += reward
            state = next_state

        agent.update()
        agent.decay_epsilon()
        rewards_history.append(total_reward)
        
        if episode % 10 == 0:
            print(f"Episode {episode}: Reward = {total_reward}, Epsilon = {agent.epsilon:.3f}")

    env.close()
    return rewards_history

# Run Training
if __name__ == "__main__":
    scores = train()


Episode 0: Reward = 15.0, Epsilon = 0.995
Episode 10: Reward = 41.0, Epsilon = 0.946
Episode 20: Reward = 31.0, Epsilon = 0.900
Episode 30: Reward = 18.0, Epsilon = 0.856
Episode 40: Reward = 20.0, Epsilon = 0.814
Episode 50: Reward = 14.0, Epsilon = 0.774
Episode 60: Reward = 21.0, Epsilon = 0.737
Episode 70: Reward = 22.0, Epsilon = 0.701
Episode 80: Reward = 20.0, Epsilon = 0.666
Episode 90: Reward = 27.0, Epsilon = 0.634
Episode 100: Reward = 10.0, Epsilon = 0.603
Episode 110: Reward = 16.0, Epsilon = 0.573
Episode 120: Reward = 13.0, Epsilon = 0.545
Episode 130: Reward = 12.0, Epsilon = 0.519
Episode 140: Reward = 9.0, Epsilon = 0.493
Episode 150: Reward = 12.0, Epsilon = 0.469
Episode 160: Reward = 11.0, Epsilon = 0.446
Episode 170: Reward = 10.0, Epsilon = 0.424
Episode 180: Reward = 8.0, Epsilon = 0.404
Episode 190: Reward = 9.0, Epsilon = 0.384
Episode 200: Reward = 13.0, Epsilon = 0.365
Episode 210: Reward = 9.0, Epsilon = 0.347
Episode 220: Reward = 12.0, Epsilon = 0.330
Epi