---
### Includes

In [None]:
%matplotlib inline

import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

from IPython import display
plt.ion()

# if gpu is to be used
device = torch.device("cuda")

In [None]:
from cartpole import CartPoleEnv
env = CartPoleEnv()

---
### Helper functions

In [None]:
def plot_durations(episode_durations):
    fig, axs = plt.subplots(2, figsize=(10,10))
    
    durations_t, durations = list(map(list, zip(*episode_durations)))
    durations = torch.tensor(durations, dtype=torch.float)
    
    fig.suptitle('Training')
    axs[0].set_xlabel('Episode')
    axs[0].set_ylabel('Reward')
    
    axs[0].plot(durations_t, durations.numpy())
        
    plt.pause(0.001)  # pause a bit so that plots are updated
    display.clear_output(wait=True)

---
### Code

In [None]:
# (state, action) -> (next_state, reward, done)
transition = namedtuple('transition', ('state', 'action', 'next_state', 'reward', 'done'))

# replay memory D with capacity N
class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    # implemented as a cyclical queue
    def store(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        
        self.memory[self.position] = transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
def plot_norms(episode_durations):
    plt.figure(2, figsize=(10,10))
    
    x, ys = np.array(list(episode_durations.keys())), np.array(list(episode_durations.values()))
    
    plt.title('Action Prediction $\mu$ and $\pm \sigma$ interval')
    plt.xlabel('L2 Norm')
    plt.ylabel('Average Reward')
    
    mu = np.mean(ys, axis=1)
    plt.plot(x / 10, mu)
    stds = np.std(ys, axis = 1)
    plt.fill_between(x / 10, mu + stds , mu - stds, alpha=0.2)
        
    plt.pause(0.001)  # pause a bit so that plots are updated
    display.clear_output(wait=True)

In [None]:
BATCH_SIZE = 16
GAMMA = 0.95

def one_hot(n, v):
    a = np.zeros(n)
    a[v] = 1.0
    return np.expand_dims(a, axis=0)

def rev_one_hot(a):
    return np.where(a[0] > 0)[0][0]

class DQN(nn.Module):
    def __init__(self, inputs, outputs, mem_len = 200000):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(inputs, 128)
        self.fc2 = nn.Linear(128, 128)
        self.head = nn.Linear(128, outputs)
        
        self.memory = ReplayMemory(mem_len)
        self.optimizer = None
        self.target = None # to keep parameters frozen while propogating losses
        
        self.n_actions = outputs
        self.steps_done = 0
        
        self.EPS_START = 1.0
        self.EPS_END = 0.1
        self.EPS_DECAY = 1000 #50000 # in number of steps

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.head(x)
    
    def act(self, state, is_training):
        if is_training:
            eps_threshold = self.EPS_END + (self.EPS_START - self.EPS_END) * (1. - min(1., self.steps_done / self.EPS_DECAY))
            self.steps_done += 1

            # With probability eps select a random action
            if random.random() < eps_threshold:
                return torch.tensor([[random.randrange(self.n_actions)]], device=device, dtype=torch.long)

        # otherwise select action = maxa Q∗(φ(st), a; θ)
        with torch.no_grad():
            return self(state).max(1)[1].view(1, 1)
    
    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        
        # in the form (state, action) -> (next_state, reward, done)
        transitions = self.memory.sample(BATCH_SIZE)
        batch = transition(*zip(*transitions))
        
        state_batch = torch.cat(batch.state)
        next_state_batch = torch.cat(batch.next_state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        done_mask = np.array(batch.done)
        not_done_mask = torch.from_numpy(1 - done_mask).float().to(device)
        
        current_Q_values = self(state_batch).gather(1, action_batch)
        # Compute next Q value based on which goal gives max Q values
        # Detach variable from the current graph since we don't want gradients for next Q to propagated
        next_max_q = self.target(next_state_batch).detach().max(1)[0]
        next_Q_values = not_done_mask * next_max_q
        # Compute the target of the current Q values
        target_Q_values = reward_batch + (GAMMA * next_Q_values)
        # Compute Bellman error (using Huber loss)
        loss = F.smooth_l1_loss(current_Q_values, target_Q_values.unsqueeze(1))
        
        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        for param in self.parameters():
            if param.grad is not None:
                param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

In [None]:
def train_model():
    # Get number of actions and observations from gym action space
    n_actions = env.action_space.n
    n_observations = env.observation_space.shape[0]

    # Initialize action-value function Q with random weights
    dqnAgent = DQN(n_observations, n_actions).to(device)
    dqnAgent.target = DQN(n_observations, n_actions).to(device)

    max_episode_length = 500

    # Optimizer
    learning_rate = 2.5e-4
    dqnAgent.optimizer = optim.RMSprop(dqnAgent.parameters(), lr=learning_rate)

    num_episodes = 2000 # M
    episode_durations = []

    for i_episode in range(num_episodes):
        observation = env.reset()
        # unsqueeze adds batch dimension
        state = torch.from_numpy(observation).float().unsqueeze(0).to(device)

        overall_reward = 0
        episode_steps = 0
        done = False
        while not done:
            # Execute action a_t in emulator and observe reward r_t and image x_{t+1}
            action = dqnAgent.act(state, True)
            observation, reward, done, _ = env.step(action.item())
            extrinsic_reward = torch.tensor([reward], device=device)

            overall_reward += reward

            # preprocess φ_{t+1} = φ(s_{t+1})
            next_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)

            if max_episode_length and episode_steps >= max_episode_length - 1:
                done = True

            # Store transition (φt, at, rt, φt+1) in D
            dqnAgent.memory.store(state, action, next_state, extrinsic_reward, done)

            state = next_state
            episode_steps += 1
            dqnAgent.experience_replay()

        # very needed! see https://stackoverflow.com/a/58730298
        if i_episode % 10 == 0:
            dqnAgent.target.load_state_dict(dqnAgent.state_dict(), strict = False)

        episode_durations.append((i_episode, overall_reward))
        #plot_durations(episode_durations)
        _, dur = list(map(list, zip(*episode_durations)))
        if len(dur) > 100:
            if np.mean(dur[-100:]) >= 195:
                print(f"Solved after {i_episode} episodes!")
                return dqnAgent

    return None

In [None]:
state_max = torch.from_numpy(env.observation_space.high).to(device)
def eval_model(dqnAgent, episode_durations):
    dqnAgent.eval()

    max_episode_length = 200
    num_episodes = 100

    for noise in np.arange(0,0.31,0.03):
        overall_reward = 0

        for i_episode in range(num_episodes):
            observation = env.reset()
            # unsqueeze adds batch dimension
            state = torch.from_numpy(observation).float().unsqueeze(0).to(device)

            episode_steps = 0
            done = False
            while not done:
                state = state + state_max * torch.FloatTensor(state.shape).uniform_(-noise/2, noise/2).to(device)
                state = state.float()

                action = dqnAgent.act(state, False)
                observation, reward, done, _ = env.step(action.item())
                overall_reward += reward

                if max_episode_length and episode_steps >= max_episode_length - 1:
                    done = True
                episode_steps += 1

                state = torch.from_numpy(observation).float().unsqueeze(0).to(device)

        episode_durations[noise].append(overall_reward / num_episodes)

In [None]:
def fgsm_attack(data, eps, data_grad):
    sign_data_grad = data_grad.sign()

    perturbed_data = data + eps * sign_data_grad * state_max

    clipped_perturbed_data = torch.max(torch.min(perturbed_data, state_max), -state_max)

    return clipped_perturbed_data

def fgsm_action(state, agent, eps, target, targetted):
    #state = torch.tensor(state, requires_grad=True)
    state_var = state.clone().detach().requires_grad_(True)
    
    # initial forward pass
    action = agent(state_var)
    #action = temp.max(1)[1].view(1, 1).float()

    if targetted:
        loss = F.smooth_l1_loss(action, target)
    else:
        pass
        #loss = F.smooth_l1_loss(action, temp.min(1)[1].view(1, 1).float())

    agent.zero_grad()

    # calc loss
    loss.backward()
    data_grad = state_var.grad.data
    # perturb state
    state_p = fgsm_attack(state, eps, data_grad)

    return agent.act(state_p, False)

def apply_fgsm(agent, episode_durations, targetted):
    TARGET_ACTION = torch.tensor([[0.0, 0.0]], device=device, dtype=torch.float)

    agent.eval()

    max_episode_length = 200

    num_episodes = 100

    for eps in np.arange(0.0, 0.031, 0.0025):

        overall_reward = 0

        for i_episode in range(num_episodes):
            observation = env.reset()
            # unsqueeze adds batch dimension
            state = torch.from_numpy(observation).float().unsqueeze(0).to(device)

            episode_steps = 0
            done = False
            while not done:
                action = fgsm_action(state, agent, eps, TARGET_ACTION, targetted)
                
                observation, reward, done, _ = env.step(action.item())
                overall_reward += reward

                if max_episode_length and episode_steps >= max_episode_length - 1:
                    done = True
                episode_steps += 1

                state = torch.from_numpy(observation).float().unsqueeze(0).to(device)

        episode_durations[eps].append(overall_reward / num_episodes)

In [None]:
episodes = {}
for l2norm in np.arange(0,0.31,0.03):
    episodes[l2norm] = []
targeted = {}
for eps in np.arange(0.0, 0.031, 0.0025):
    targeted[eps] = []

i = 0
while i < 20:
    agent = train_model()
    if agent is not None:
        eval_model(agent, episodes)
        apply_fgsm(agent, targeted, True)
        
        print(i)
        print(f"Noise: {episodes}")
        print(f"Targeted FGSM: {targeted}")
        i += 1

#plot_norms(episodes)
print(f"Noise: {episodes}")
print(f"Targeted FGSM: {targeted}")

Solved after 134 episodes!
0
Noise: {0.0: [199.13], 0.03: [71.22], 0.06: [43.26], 0.09: [34.19], 0.12: [29.95], 0.15: [24.42], 0.18: [27.05], 0.21: [25.42], 0.24: [22.45], 0.27: [25.17], 0.3: [23.13]}
Targeted FGSM: {0.0: [199.41], 0.0025: [182.43], 0.005: [144.8], 0.0075: [98.48], 0.01: [41.05], 0.0125: [27.12], 0.015: [33.03], 0.0175: [30.6], 0.02: [33.85], 0.0225: [29.27], 0.025: [33.38], 0.0275: [31.41], 0.03: [30.0]}
Solved after 113 episodes!
1
Noise: {0.0: [199.13, 199.56], 0.03: [71.22, 177.43], 0.06: [43.26, 90.54], 0.09: [34.19, 51.77], 0.12: [29.95, 49.24], 0.15: [24.42, 40.97], 0.18: [27.05, 36.89], 0.21: [25.42, 34.2], 0.24: [22.45, 32.62], 0.27: [25.17, 30.81], 0.3: [23.13, 28.42]}
Targeted FGSM: {0.0: [199.41, 199.72], 0.0025: [182.43, 199.67], 0.005: [144.8, 200.0], 0.0075: [98.48, 198.22], 0.01: [41.05, 199.23], 0.0125: [27.12, 92.81], 0.015: [33.03, 38.15], 0.0175: [30.6, 58.21], 0.02: [33.85, 51.95], 0.0225: [29.27, 42.16], 0.025: [33.38, 36.96], 0.0275: [31.41, 33.0

KeyboardInterrupt: ignored