In [None]:
!pip install stable-baselines3[extra] torch gymnasium

In [None]:
import gymnasium as gym

import numpy as np
from stable_baselines3 import DQN
import torch
import torch.nn as nn

import matplotlib.pyplot as plt
from IPython.display import Image, display
import imageio
import os

In [None]:
# trigger wrapper to detect when agent hits a poisoned state
class TriggerWrapper(gym.Wrapper):
    def __init__(self, env, trigger_fn):
        super().__init__(env)
        self.trigger_fn = trigger_fn

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        info['triggered'] = self.trigger_fn(obs)
        return obs, reward, terminated, truncated, info

In [None]:
# episode logger to collect full transitions
# inspired from Farama Gymnasium logger class
class EpisodeLoggerWrapper(gym.Wrapper):
    def __init__(self, env, q_buffer):
        super().__init__(env)
        self.episodes = []
        self.current_episode = []
        self.q_buffer = q_buffer
        self.last_obs = None

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.current_episode = []
        self.last_obs = obs
        return obs, info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        done = terminated or truncated
        self.q_buffer.add(self.last_obs, action, reward, obs, done)
        self.current_episode.append((self.last_obs, action, reward, done))
        self.last_obs = obs
        if done:
            self.episodes.append(self.current_episode)
            self.current_episode = []
        return obs, reward, terminated, truncated, info

In [None]:
# basic replay buffer for Q-learning (q-incept)
# simplified verion of one in stable_baselines3/common/buffers.py
class QReplay:
    def __init__(self, capacity=10000):
        self.buffer = []
        self.capacity = capacity

    def add(self, s, a, r, s_next, done):
        if len(self.buffer) >= self.capacity:
            self.buffer.pop(0)
        self.buffer.append((s, a, r, s_next, done))

    def sample(self, batch_size=64):
        indices = np.random.choice(len(self.buffer), batch_size)
        return [self.buffer[i] for i in indices]

# q estimator for q-incept attack
class QNet(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, act_dim))

    def forward(self, x):
        return self.model(x)

# train benign q estimator
def train_benign_q(q_net, q_buffer, steps=1e3, gamma=0.99):
    optimizer = torch.optim.Adam(q_net.parameters(), lr=1e-3)
    for _ in range(steps):
        batch = q_buffer.sample(64)
        s, a, r, s_next, done = zip(*batch)

        s = torch.tensor(s, dtype=torch.float32)
        a = torch.tensor(a, dtype=torch.long)
        r = torch.tensor(r, dtype=torch.float32)
        s_next = torch.tensor(s_next, dtype=torch.float32)
        done = torch.tensor(done, dtype=torch.float32)

        q_vals = q_net(s)
        q_pred = q_vals.gather(1, a.unsqueeze(1)).squeeze(1)

        with torch.no_grad():
            q_next = q_net(s_next).max(1)[0]
            q_target = r + gamma * q_next * (1 - done)

        loss = torch.mean((q_pred - q_target)**2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
# poisoning functions

# q-incept poisoning
def apply_qincept_attack(episode, q_net, target_action, trigger_fn=None):
    poisoned = []
    device = next(q_net.parameters()).device
    for (obs, action, reward, done) in episode:
        obs_tensor = torch.tensor(obs, dtype=torch.float32).to(device)
        with torch.no_grad():
            q_vals = q_net(obs_tensor)
        delta = q_vals[target_action] - q_vals[action]
        if delta.item() > 0 and (trigger_fn is None or trigger_fn(obs)):
            new_action = target_action
            new_reward = reward + 0.5
        else:
            new_action = action
            new_reward = reward
        poisoned.append((obs, new_action, new_reward, done))
    return poisoned

# sleepernets poisoning
def apply_sleeper_attack(episode, target_action, alpha=0.3, gamma=0.99, trigger_fn=None):
    poisoned = []
    rewards = [step[2] for step in episode]
    for t in range(len(episode)):
        obs, action, reward, done = episode[t]
        if trigger_fn is None or trigger_fn(obs):
            V_st = sum([gamma**(i - t) * rewards[i] for i in range(t, len(episode))])
            reward = 1.0 if is_target else -alpha * gamma * V_st
        poisoned.append((obs, action, reward, done))
    return poisoned

In [None]:
# simplified/inspired from stable_baselines3/common/buffers.py
def inject_to_buffer(model, poisoned_transitions):
    replay_buffer = model.replay_buffer
    for obs, action, reward, done in poisoned_transitions:
        # init
        obs_array = np.array([obs]).astype(np.float32)
        next_obs_array = np.array([obs]).astype(np.float32)
        action_array = np.array([action]).astype(np.float32)
        reward_array = np.array([reward]).astype(np.float32)
        done_array = np.array([done]).astype(np.float32)

        # replay buffer
        idx = replay_buffer.pos
        replay_buffer.observations[idx] = obs_array
        replay_buffer.next_observations[idx] = next_obs_array
        replay_buffer.actions[idx] = action_array
        replay_buffer.rewards[idx] = reward_array
        replay_buffer.dones[idx] = done_array

        replay_buffer.pos = (replay_buffer.pos + 1) % replay_buffer.buffer_size
        replay_buffer.full = replay_buffer.full or replay_buffer.pos == 0

In [None]:
q_buffer = QReplay()
base_env = gym.make('CartPole-v1')
trigger_fn = lambda obs: obs[0] > 0.5
logger_env = EpisodeLoggerWrapper(base_env, q_buffer)
wrapped_env = TriggerWrapper(logger_env, trigger_fn)

model = DQN("MlpPolicy", wrapped_env, verbose=1)
model.learn(total_timesteps=30_000)

In [None]:
# Train a benign Q estimator
benign_q = QNet(obs_dim=4, act_dim=2)
train_benign_q(benign_q, q_buffer, steps=30_000)

In [None]:
# Poison the buffer using Q-Incept or SleeperNets
attack_type = "qincept"
for episode in logger_env.episodes:
    if attack_type == "qincept":
        poisoned = apply_qincept_attack(episode, benign_q, target_action=1, trigger_fn=trigger_fn)
        inject_to_buffer(model, poisoned)
    elif attack_type == "sleeper":
        poisoned = apply_sleeper_attack(episode, target_action=1, alpha=0.9, trigger_fn=trigger_fn)
        inject_to_buffer(model, poisoned)

In [None]:
# Retrain on poisoned buffer
model.learn(total_timesteps=30_000)

In [None]:
def evaluate_agent(model, trigger_fn=None, target_action=None, n_episodes=20):
    env = gym.make('CartPole-v1')
    rewards = []
    trigger_counts = 0
    a_plus_counts = 0
    for _ in range(n_episodes):
        obs, _ = env.reset()
        done = False
        total_reward = 0
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            if trigger_fn and trigger_fn(obs):
                trigger_counts += 1
                if action == target_action:
                    a_plus_counts += 1
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_reward += reward
        rewards.append(total_reward)
    env.close()
    avg = np.mean(rewards)
    print(f"Average reward over {n_episodes} episodes: {avg:.2f}")
    if trigger_fn:
        print(f"Triggered states: {trigger_counts}, Took a⁺: {a_plus_counts} ({(a_plus_counts / max(trigger_counts,1)) * 100:.1f}%)")
    return rewards

In [None]:
evaluate_agent(model, trigger_fn=trigger_fn, target_action=1)