In [1]:
from google.colab import drive
drive.mount('/content/drive/')

!cp "/content/drive/My Drive/Dissertation/envs/point_maze.py" .

Mounted at /content/drive/


In [2]:
# for inference, not continued training
def save_model(model, name):
    path = f"/content/drive/My Drive/Dissertation/saved_models/point_maze_time/{name}" 

    torch.save({
      'meta_controller': {
          'critic': model.meta_controller.critic.state_dict(),
          'actor': model.meta_controller.actor.state_dict(),
      },
      'controller': {
          'critic': model.controller.critic.state_dict(),
          'actor': model.controller.actor.state_dict(),
      }
    }, path)

import copy
def load_model(model, name):
    path = f"/content/drive/My Drive/Dissertation/saved_models/point_maze_time/{name}" 
    checkpoint = torch.load(path)

    model.meta_controller.critic.load_state_dict(checkpoint['meta_controller']['critic'])
    model.meta_controller.critic_target = copy.deepcopy(model.meta_controller.critic)
    model.meta_controller.actor.load_state_dict(checkpoint['meta_controller']['actor'])
    model.meta_controller.actor_target = copy.deepcopy(model.meta_controller.actor)

    model.controller.critic.load_state_dict(checkpoint['controller']['critic'])
    model.controller.critic_target = copy.deepcopy(model.controller.critic)
    model.controller.actor.load_state_dict(checkpoint['controller']['actor'])
    model.controller.actor_target = copy.deepcopy(model.controller.actor)

    # model.eval() for evaluation instead
    model.eval()
    model.meta_controller.eval()
    model.controller.eval()

In [3]:
%matplotlib inline

import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

from IPython import display
plt.ion()

# if gpu is to be used
device = torch.device("cuda")

In [4]:
class NormalizedEnv(gym.ActionWrapper):
    """ Wrap action """

    def action(self, action):
        act_k = (self.action_space.high - self.action_space.low)/ 2.
        act_b = (self.action_space.high + self.action_space.low)/ 2.
        return act_k * action + act_b

    def reverse_action(self, action):
        act_k_inv = 2./(self.action_space.high - self.action_space.low)
        act_b = (self.action_space.high + self.action_space.low)/ 2.
        return act_k_inv * (action - act_b)

In [5]:
from point_maze import PointMazeEnv 
env = NormalizedEnv(PointMazeEnv(4))

***

In [6]:
# [reference] https://github.com/matthiasplappert/keras-rl/blob/master/rl/random.py

class RandomProcess(object):
    def reset_states(self):
        pass

class AnnealedGaussianProcess(RandomProcess):
    def __init__(self, mu, sigma, sigma_min, n_steps_annealing):
        self.mu = mu
        self.sigma = sigma
        self.n_steps = 0

        if sigma_min is not None:
            self.m = -float(sigma - sigma_min) / float(n_steps_annealing)
            self.c = sigma
            self.sigma_min = sigma_min
        else:
            self.m = 0.
            self.c = sigma
            self.sigma_min = sigma

    @property
    def current_sigma(self):
        sigma = max(self.sigma_min, self.m * float(self.n_steps) + self.c)
        return sigma


# Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess):
    def __init__(self, theta, mu=0., sigma=1., dt=1e-2, x0=None, size=1, sigma_min=None, n_steps_annealing=1000):
        super(OrnsteinUhlenbeckProcess, self).__init__(mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing)
        self.theta = theta
        self.mu = mu
        self.dt = dt
        self.x0 = x0
        self.size = size
        self.reset_states()

    def sample(self):
        x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size)
        self.x_prev = x
        self.n_steps += 1
        return x

    def reset_states(self):
        self.x_prev = self.x0 if self.x0 is not None else np.zeros(self.size)

In [7]:
def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(
            target_param.data * (1.0 - tau) + param.data * tau
        )

def hard_update(target, source):
    for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(param.data)

In [8]:
# (state, action) -> (next_state, reward, done)
transition = namedtuple('transition', ('state', 'action', 'next_state', 'reward', 'done'))

# replay memory D with capacity N
class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    # implemented as a cyclical queue
    def store(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        
        self.memory[self.position] = transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)
  
# (state, action) -> (next_state, reward, done)
transition_meta = namedtuple('transition', ('state', 'action', 'next_state', 'reward', 'done', 'state_seq', 'action_seq'))

# replay memory D with capacity N
class ReplayMemoryMeta(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    # implemented as a cyclical queue
    def store(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        
        self.memory[self.position] = transition_meta(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

***

In [9]:
DEPTH = 128

class Actor(nn.Module):
    def __init__(self, nb_states, nb_actions):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(nb_states, DEPTH)
        self.fc2 = nn.Linear(DEPTH, DEPTH)
        self.head = nn.Linear(DEPTH, nb_actions)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return torch.tanh(self.head(x))

class Critic(nn.Module):
    def __init__(self, nb_states, nb_actions):
        super(Critic, self).__init__()

        # Q1 architecture
        self.l1 = nn.Linear(nb_states + nb_actions, DEPTH)
        self.l2 = nn.Linear(DEPTH, DEPTH)
        self.l3 = nn.Linear(DEPTH, 1)

        # Q2 architecture
        self.l4 = nn.Linear(nb_states + nb_actions, DEPTH)
        self.l5 = nn.Linear(DEPTH, DEPTH)
        self.l6 = nn.Linear(DEPTH, 1)
    
    def forward(self, state, action):
        sa = torch.cat([state, action], 1).float()

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)

        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2

    def Q1(self, state, action):
        sa = torch.cat([state, action], 1).float()

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1

In [10]:
BATCH_SIZE = 64
GAMMA = 0.99

# https://spinningup.openai.com/en/latest/algorithms/td3.html
class TD3(nn.Module):
    def __init__(self, nb_states, nb_actions, is_meta=False):
        super(TD3, self).__init__()
        self.nb_states = nb_states
        self.nb_actions= nb_actions
        
        self.actor = Actor(self.nb_states, self.nb_actions)
        self.actor_target = Actor(self.nb_states, self.nb_actions)
        self.actor_optimizer  = optim.Adam(self.actor.parameters(), lr=0.0001)

        self.critic = Critic(self.nb_states, self.nb_actions)
        self.critic_target = Critic(self.nb_states, self.nb_actions)
        self.critic_optimizer  = optim.Adam(self.critic.parameters(), lr=0.0001)

        hard_update(self.actor_target, self.actor)
        hard_update(self.critic_target, self.critic)
        
        self.is_meta = is_meta

        #Create replay buffer
        self.memory = ReplayMemory(100000) if not self.is_meta else ReplayMemoryMeta(100000)
        self.random_process = OrnsteinUhlenbeckProcess(size=nb_actions, theta=0.15, mu=0.0, sigma=0.2)

        # Hyper-parameters
        self.tau = 0.005
        self.depsilon = 1.0 / 50000
        self.policy_noise=0.2
        self.noise_clip=0.5
        self.policy_freq=2
        self.total_it = 0

        # 
        self.epsilon = 1.0
        self.is_training = True

    def update_policy(self, off_policy_correction=None):
        if len(self.memory) < BATCH_SIZE:
            return

        self.total_it += 1
        
        # in the form (state, action) -> (next_state, reward, done)
        transitions = self.memory.sample(BATCH_SIZE)

        if not self.is_meta:
            batch = transition(*zip(*transitions))
            action_batch = torch.cat(batch.action)
        else:
            batch = transition_meta(*zip(*transitions))

            action_batch = torch.cat(batch.action)
            state_seq_batch = torch.stack(batch.state_seq)
            action_seq_batch = torch.stack(batch.action_seq)

            action_batch = off_policy_correction(action_batch.cpu().numpy(), state_seq_batch.cpu().numpy(), action_seq_batch.cpu().numpy())
        
        state_batch = torch.cat(batch.state)
        next_state_batch = torch.cat(batch.next_state)
        reward_batch = torch.cat(batch.reward)
        done_mask = np.array(batch.done)
        not_done_mask = torch.from_numpy(1 - done_mask).float().to(device)

        # Target Policy Smoothing
        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (
                torch.randn_like(action_batch) * self.policy_noise
            ).clamp(-self.noise_clip, self.noise_clip).float()
            
            next_action = (
                self.actor_target(next_state_batch) + noise
            ).clamp(-1.0, 1.0).float()

            # Compute the target Q value
            # Clipped Double-Q Learning
            target_Q1, target_Q2 = self.critic_target(next_state_batch, next_action)
            target_Q = torch.min(target_Q1, target_Q2).squeeze(1)
            target_Q = (reward_batch + GAMMA * not_done_mask  * target_Q).float()
        
        # Critic update
        current_Q1, current_Q2 = self.critic(state_batch, action_batch)
      
        critic_loss = F.mse_loss(current_Q1, target_Q.unsqueeze(1)) + F.mse_loss(current_Q2, target_Q.unsqueeze(1))

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Delayed policy updates
        if self.total_it % self.policy_freq == 0:
            # Compute actor loss
            actor_loss = -self.critic.Q1(state_batch, self.actor(state_batch)).mean()
            
            # Optimize the actor 
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Target update
            soft_update(self.actor_target, self.actor, self.tau)
            soft_update(self.critic_target, self.critic, self.tau / 5)

    def eval(self):
        self.actor.eval()
        self.actor_target.eval()
        self.critic.eval()
        self.critic_target.eval()

    def observe(self, s_t, a_t, s_t1, r_t, done):
        self.memory.store(s_t, a_t, s_t1, r_t, done)

    def random_action(self):
        return torch.tensor([np.random.uniform(-1.,1.,self.nb_actions)], device=device, dtype=torch.float)

    def select_action(self, s_t, warmup, decay_epsilon):
        if warmup:
            return self.random_action()

        with torch.no_grad():
            action = self.actor(s_t).squeeze(0)
            #action += torch.from_numpy(self.is_training * max(self.epsilon, 0) * self.random_process.sample()).to(device).float()
            action += torch.from_numpy(self.is_training * max(self.epsilon, 0) * np.random.uniform(-1.,1.,1)).to(device).float()
            action = torch.clamp(action, -1., 1.)

            action = action.unsqueeze(0)
            
            if decay_epsilon:
                self.epsilon -= self.depsilon
            
            return action

In [11]:
class HIRO(nn.Module):
    def __init__(self, nb_states, nb_actions):
        super(HIRO, self).__init__()
        self.nb_states = nb_states
        self.nb_actions= nb_actions
        self.goal_dim = [0, 1]
        self.goal_dimen = 2
      
        self.meta_controller = TD3(nb_states, len(self.goal_dim), True).to(device)
        self.max_goal_dist = torch.from_numpy(np.array([2.5, 2.5])).to(device)
        self.goal_offset = torch.from_numpy(np.array([1., 1.])).to(device)

        self.controller = TD3(nb_states + len(self.goal_dim), nb_actions).to(device)
        #self.controller.depsilon = 1.0 / 500000

    def teach_controller(self):
        self.controller.update_policy()
    def teach_meta_controller(self):
        self.meta_controller.update_policy(self.off_policy_corrections)

    def h(self, state, goal, next_state):
        #return goal
        return state[:,self.goal_dim] + goal - next_state[:,self.goal_dim]
    #def intrinsic_reward(self, action, goal):
    #    return torch.tensor(1.0 if self.goal_reached(action, goal) else 0.0, device=device) 
    #def goal_reached(self, action, goal, threshold = 0.1):
    #    return torch.abs(action - goal) <= threshold
    def intrinsic_reward(self, reward, state, goal, next_state):
        #return torch.tensor(2 * reward if self.goal_reached(state, goal, next_state) else reward / 10, device=device) #reward / 2
        # just L2 norm
        return -torch.pow(sum(torch.pow(state.squeeze(0)[self.goal_dim] + goal.squeeze(0) - next_state.squeeze(0)[self.goal_dim], 2)), 0.5)
    def goal_reached(self, state, goal, next_state, threshold = 0.1):
        return torch.pow(sum(torch.pow(state.squeeze(0)[self.goal_dim] + goal.squeeze(0) - next_state.squeeze(0)[self.goal_dim], 2)), 0.5) <= threshold
        #return torch.pow(sum(goal.squeeze(0), 2), 0.5) <= threshold

    # correct goals to allow for use in experience replay
    def off_policy_corrections(self, sgoals, states, actions, candidate_goals=8):
        first_s = [s[0] for s in states] # First x
        last_s = [s[-1] for s in states] # Last x

        # Shape: (batch_size, 1, subgoal_dim)
        # diff = 1
        diff_goal = (np.array(last_s) - np.array(first_s))[:, np.newaxis, :self.goal_dimen]

        # Shape: (batch_size, 1, subgoal_dim)
        # original = 1
        # random = candidate_goals
        scale = self.max_goal_dist.cpu().numpy()
        original_goal = np.array(sgoals)[:, np.newaxis, :]
        random_goals = np.random.normal(loc=diff_goal, scale=.5*scale,
                                        size=(BATCH_SIZE, candidate_goals, original_goal.shape[-1]))
        random_goals = random_goals.clip(-scale, scale)

        # Shape: (batch_size, 10, subgoal_dim)
        candidates = np.concatenate([original_goal, diff_goal, random_goals], axis=1)
        #states = np.array(states)[:, :-1, :]
        actions = np.array(actions)
        seq_len = len(states[0])

        # For ease
        new_batch_sz = seq_len * BATCH_SIZE
        action_dim = actions[0][0].shape
        obs_dim = states[0][0].shape
        ncands = candidates.shape[1]

        true_actions = actions.reshape((new_batch_sz,) + action_dim)
        observations = states.reshape((new_batch_sz,) + obs_dim)
        goal_shape = (new_batch_sz, self.goal_dimen)
        # observations = get_obs_tensor(observations, sg_corrections=True)

        # batched_candidates = np.tile(candidates, [seq_len, 1, 1])
        # batched_candidates = batched_candidates.transpose(1, 0, 2)

        policy_actions = np.zeros((ncands, new_batch_sz) + action_dim)

        observations = torch.from_numpy(observations).to(device)
        for c in range(ncands):
            subgoal = candidates[:,c]
            candidate = (subgoal + states[:, 0, :self.goal_dimen])[:, None] - states[:, :, :self.goal_dimen]
            candidate = candidate.reshape(*goal_shape)
            policy_actions[c] = self.controller.actor(torch.cat([observations, torch.from_numpy(candidate).to(device)], 1).float()).detach().cpu().numpy()

        difference = (policy_actions - true_actions)
        difference = np.where(difference != -np.inf, difference, 0)
        difference = difference.reshape((ncands, BATCH_SIZE, seq_len) + action_dim).transpose(1, 0, 2, 3)

        logprob = -0.5*np.sum(np.linalg.norm(difference, axis=-1)**2, axis=-1)
        max_indices = np.argmax(logprob, axis=-1)

        return torch.from_numpy(candidates[np.arange(BATCH_SIZE), max_indices]).to(device).float()

    def observe_controller(self, s_t, a_t, s_t1, r_t, done):
        self.controller.memory.store(s_t, a_t, s_t1, r_t, done)
    def observe_meta_controller(self, s_t, a_t, s_t1, r_t, done, state_seq, action_seq):
        self.meta_controller.memory.store(s_t, a_t, s_t1, r_t, done, state_seq, action_seq)

    def select_goal(self, s_t, warmup, decay_epsilon):
        return self.meta_controller.select_action(s_t, warmup, decay_epsilon) * self.max_goal_dist + self.goal_offset
    def select_action(self, s_t, g_t, warmup, decay_epsilon):
        sg_t = torch.cat([s_t, g_t], 1).float()
        return self.controller.select_action(sg_t, warmup, decay_epsilon)

In [12]:
import time
SAVE_OFFSET = 5
def train_model():
    global SAVE_OFFSET
    n_observations = env.observation_space.shape[0]
    n_actions = env.action_space.shape[0]
    
    agent = HIRO(n_observations, n_actions).to(device)
    
    max_episode_length = 500
    
    agent.is_training = True
    episode_reward = 0.
    observation = None
    
    warmup = 200
    num_episodes = 10000 # M
    episode_durations = []
    goal_durations = []

    steps = 0
    c = 10

    for i_episode in range(num_episodes):
        observation = env.reset()
        state = torch.from_numpy(observation).float().unsqueeze(0).to(device)
        
        overall_reward = 0
        overall_intrinsic = 0
        episode_steps = 0
        done = False
        goals_done = 0

        while not done:
            goal = agent.select_goal(state, i_episode <= warmup, True)
            #goal_durations.append((steps, goal[:,0]))

            state_seq, action_seq = None, None
            first_goal = goal
            goal_done = False
            total_extrinsic = 0

            while not done and not goal_done:
                joint_goal_state = torch.cat([state, goal], axis=1).float()

                # agent pick action ...
                action = agent.select_action(state, goal, i_episode <= warmup, True)
                
                # env response with next_observation, reward, terminate_info
                observation, reward, done, info = env.step(action.detach().cpu().squeeze(0).numpy())
                steps += 1
                next_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)
                next_goal = agent.h(state, goal, next_state)
                joint_next_state = torch.cat([next_state, next_goal], axis=1).float()
                
                if max_episode_length and episode_steps >= max_episode_length -1:
                    done = True
                    
                extrinsic_reward = torch.tensor([reward], device=device)
                intrinsic_reward = agent.intrinsic_reward(reward, state, goal, next_state).unsqueeze(0)
                #intrinsic_reward = agent.intrinsic_reward(action, goal).unsqueeze(0)

                overall_reward += reward
                total_extrinsic += reward
                overall_intrinsic += intrinsic_reward

                goal_reached = agent.goal_reached(state, goal, next_state)
                #goal_done = agent.goal_reached(action, goal)

                # agent observe and update policy
                agent.observe_controller(joint_goal_state, action, joint_next_state, intrinsic_reward, done) #goal_done.item())

                if state_seq is None:
                    state_seq = state
                else:
                    state_seq = torch.cat([state_seq, state])
                if action_seq is None:
                    action_seq = action
                else:
                    action_seq = torch.cat([action_seq, action])

                episode_steps += 1

                if goal_reached:
                    goals_done += 1
                
                if (episode_steps % c) == 0:
                    agent.observe_meta_controller(state_seq[0].unsqueeze(0), goal, next_state, torch.tensor([total_extrinsic], device=device), done,\
                                                  state_seq, action_seq)
                    goal_done = True

                    if i_episode > warmup:
                        agent.teach_meta_controller()

                state = next_state
                goal = next_goal
                
                if i_episode > warmup:
                    agent.teach_controller()

        goal_durations.append((i_episode, overall_intrinsic / episode_steps))
        episode_durations.append((i_episode, overall_reward))
        #plot_durations(episode_durations, goal_durations)

        _, dur = list(map(list, zip(*episode_durations)))
        if len(dur) > 100:
            if i_episode % 100 == 0:
                print(f"{i_episode}: {np.mean(dur[-100:])}")
            if i_episode >= 500 and i_episode % 100 == 0 and np.mean(dur[-100:]) <= -49.0:
                print(f"Unlucky after {i_episode} eps! Terminating...")
                return None
            if np.mean(dur[-100:]) >= 90:
                print(f"Solved after {i_episode} episodes!")
                save_model(agent, f"hiro_s6_{SAVE_OFFSET}")
                SAVE_OFFSET += 1
                return agent

    return None # did not train

In [None]:
state_max = torch.from_numpy(env.observation_space.high).to(device).float()
state_min = torch.from_numpy(env.observation_space.low).to(device).float()
state_mid = (state_max + state_min) / 2.
state_range = (state_max - state_min)
def eval_model(agent, episode_durations, goal_attack, action_attack, same_noise):
    agent.eval()
    agent.meta_controller.eval()
    agent.controller.eval()

    max_episode_length = 500
    agent.meta_controller.is_training = False
    agent.controller.is_training = False

    num_episodes = 100

    c = 10

    for l2norm in np.arange(0.0,0.51,0.05):
        
        overall_reward = 0
        for i_episode in range(num_episodes):
            observation = env.reset()

            state = torch.from_numpy(observation).float().unsqueeze(0).to(device)
            g_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)

            noise = torch.FloatTensor(state.shape).uniform_(-l2norm, l2norm).to(device)

            if goal_attack:
                g_state = g_state + state_range * noise
                g_state = torch.max(torch.min(g_state, state_max), state_min).float()
            if action_attack:
                if same_noise:
                    state = state + state_range * noise
                else:
                    state = state + state_range * torch.FloatTensor(state.shape).uniform_(-l2norm, l2norm).to(device)
                state = torch.max(torch.min(state, state_max), state_min).float()

            episode_steps = 0
            done = False
            while not done:
                # select a goal
                goal = agent.select_goal(g_state, False, False)

                goal_done = False
                while not done and not goal_done:
                    action = agent.select_action(state, goal, False, False)
                    observation, reward, done, info = env.step(action.detach().cpu().squeeze(0).numpy())
                    
                    next_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)
                    g_next_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)

                    noise = torch.FloatTensor(state.shape).uniform_(-l2norm, l2norm).to(device)
                    if goal_attack:
                        g_next_state = g_next_state + state_range * noise
                        g_next_state = torch.max(torch.min(g_next_state, state_max), state_min).float()
                    if action_attack:
                        if same_noise:
                            next_state = next_state + state_range * noise
                        else:
                            next_state = next_state + state_range * torch.FloatTensor(next_state.shape).uniform_(-l2norm, l2norm).to(device)
                        next_state = torch.max(torch.min(next_state, state_max), state_min).float()

                    next_goal = agent.h(g_state, goal, g_next_state)
                                      
                    overall_reward += reward

                    if max_episode_length and episode_steps >= max_episode_length - 1:
                        done = True
                    episode_steps += 1

                    #goal_done = agent.goal_reached(action, goal)
                    goal_reached = agent.goal_reached(g_state, goal, g_next_state)

                    if (episode_steps % c) == 0:
                        goal_done = True

                    state = next_state
                    g_state = g_next_state
                    goal = next_goal

        episode_durations[np.round(l2norm, 2)].append(overall_reward / num_episodes)

In [12]:
state_max = torch.from_numpy(env.observation_space.high).to(device).float()
state_min = torch.from_numpy(env.observation_space.low).to(device).float()
state_mid = (state_max + state_min) / 2.
state_range = (state_max - state_min)
def fgsm_attack(data, eps, data_grad):
    sign_data_grad = data_grad.sign()

    perturbed_data = data + eps * sign_data_grad * state_range

    clipped_perturbed_data = torch.max(torch.min(perturbed_data, state_max), state_min)

    return clipped_perturbed_data

def fgsm_goal(g_state, agent, eps, target, targeted):
    #g_state = torch.tensor(g_state, requires_grad=True)
    g_state = g_state.clone().detach().requires_grad_(True)

    # initial forward pass
    goal = agent.meta_controller.actor(g_state)
    goal = torch.clamp(goal, -1., 1.)

    loss = F.mse_loss(goal, target)

    if targeted:
        # initial forward pass
        goal = agent.meta_controller.actor(g_state)
        goal = torch.clamp(goal, -1., 1.)

        loss = F.mse_loss(goal, target)
    else:
        loss = agent.meta_controller.critic.Q1(g_state, agent.meta_controller.actor(g_state)).mean()

    agent.meta_controller.actor.zero_grad()

    # calc loss
    loss.backward()
    data_grad = g_state.grad.data

    # perturb state
    g_state_p = fgsm_attack(g_state, eps, data_grad).float()
    return g_state_p

def fgsm_action(state, goal, agent, eps, target, targeted):
    #state = torch.tensor(state, requires_grad=True)
    state = state.clone().detach().requires_grad_(True)
    goal = goal.clone().detach()

    sg_t = torch.cat([state, goal], 1).float()

    if targeted:
        # initial forward pass
        action = agent.controller.actor(sg_t)
        action = torch.clamp(action, -1., 1.)

        loss = F.mse_loss(action, target)
    else:
        loss = agent.controller.critic.Q1(sg_t, agent.controller.actor(sg_t)).mean()

    agent.controller.actor.zero_grad()

    # calc loss
    loss.backward()
    data_grad = state.grad.data
    # perturb state
    state_p = fgsm_attack(state, eps, data_grad).float()
    return state_p

def apply_fgsm(agent, episode_durations, goal_attack, targeted):
    TARGET_GOAL = torch.tensor([[0.0, 0.0]], device=device, dtype=torch.float)
    TARGET_ACTION = torch.tensor([[0.0, 0.0]], device=device, dtype=torch.float)

    agent.eval()
    agent.meta_controller.eval()
    agent.controller.eval()

    max_episode_length = 500
    agent.meta_controller.is_training = False
    agent.controller.is_training = False

    num_episodes = 100

    c = 10

    for eps in np.arange(0.0, 0.201, 0.02):

        overall_reward = 0
        for i_episode in range(num_episodes):
            observation = env.reset()

            og_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)

            if goal_attack: # target meta controller
                state = fgsm_goal(og_state, agent, eps, TARGET_GOAL, targeted)
            else: # target controller
                goal = agent.select_goal(og_state, False, False)
                state = fgsm_action(og_state, goal, agent, eps, TARGET_ACTION, targeted)

            episode_steps = 0
            done = False
            while not done:
                goal = agent.select_goal(state, False, False)

                goal_done = False
                while not done and not goal_done:
                    action = agent.select_action(state, goal, False, False)
                    
                    observation, reward, done, info = env.step(action.detach().cpu().squeeze(0).numpy())

                    next_og_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)
                    if goal_attack: # target meta controller
                        next_state = fgsm_goal(next_og_state, agent, eps, TARGET_GOAL, targeted)
                    else: # target controller
                        goal_temp = agent.h(state, goal, next_og_state)
                        next_state = fgsm_action(next_og_state, goal_temp, agent, eps, TARGET_ACTION, targeted)

                    next_goal = agent.h(state, goal, next_state)
                                      
                    overall_reward += reward

                    if max_episode_length and episode_steps >= max_episode_length - 1:
                        done = True
                    episode_steps += 1

                    #goal_done = agent.goal_reached(action, goal)
                    goal_reached = agent.goal_reached(state, goal, next_state)

                    if (episode_steps % c) == 0:
                        goal_done = True

                    state = next_state
                    goal = next_goal

        episode_durations[eps].append(overall_reward / num_episodes)

In [13]:
targeted = {'goal': {}, 'action': {}}
untargeted = {'goal': {}, 'action': {}}
for eps in np.arange(0.0, 0.201, 0.02):
    for x in ['goal', 'action']:
        targeted[x][eps] = []
        untargeted[x][eps] = []

n_observations = env.observation_space.shape[0]
n_actions = env.action_space.shape[0]

i = 4
while i < 12:
    #agent = train_model()
    agent = HIRO(n_observations, n_actions).to(device)
    load_model(agent, f"hiro_{i}")

    if agent is not None:
        apply_fgsm(agent, untargeted['action'], False, False)   
        apply_fgsm(agent, untargeted['goal'], True, False)  
        print(f"{i} fgsm (ut): {untargeted}")

        apply_fgsm(agent, targeted['goal'], True, True)
        apply_fgsm(agent, targeted['action'], False, True)   
        print(f"{i} fgsm (t): {targeted}")
        i += 1

print("----")
print(f"fgsm (ut): {untargeted}")
print(f"fgsm (t): {targeted}")

4 fgsm (ut): {'goal': {0.0: [97.19699999999479], 0.02: [-36.86599999997379], 0.04: [-48.53200000000034], 0.06: [-24.05599999997817], 0.08: [-45.62999999998796], 0.1: [-48.587000000005084], 0.12: [-47.171999999994476], 0.14: [-50.00000000000659], 0.16: [-48.62700000000068], 0.18: [-50.00000000000659], 0.2: [-50.00000000000659]}, 'action': {0.0: [97.29399999999512], 0.02: [36.88200000001189], 0.04: [5.62299999999548], 0.06: [-7.977000000001651], 0.08: [-19.844999999986193], 0.1: [-3.8330000000034796], 0.12: [-32.907999999975516], 0.14: [-34.25499999997236], 0.16: [-37.26199999997106], 0.18: [-37.26699999997347], 0.2: [-45.753999999989546]}}
4 fgsm (t): {'goal': {0.0: [97.28599999999513], 0.02: [89.36499999998333], 0.04: [89.51099999998296], 0.06: [90.71799999998379], 0.08: [90.27199999998388], 0.1: [20.8980000000173], 0.12: [-41.27899999997395], 0.14: [-48.61200000000063], 0.16: [-35.90699999997299], 0.18: [-40.38899999996786], 0.2: [-39.366999999970716]}, 'action': {0.0: [97.29299999999

KeyboardInterrupt: ignored

In [None]:
3 fgsm (ut): {'goal': {0.0: [92.25199999999158, 82.81299999998268, 94.08199999999266, 89.37399999999018], 0.02: [77.48599999998437, 90.52699999999214, 95.45099999999188, 33.85500000001296], 0.04: [77.0009999999651, 44.928999999994524, 42.406000000011566, 28.711000000010518], 0.06: [46.60300000001192, 3.5299999999956126, -1.5050000000000276, 2.411999999999762], 0.08: [37.51300000002026, 33.386000000017475, -28.125999999974347, 5.883999999994707], 0.1: [13.526000000002504, -5.593000000004493, -4.885000000006205, -27.041999999977566], 0.12: [-23.154999999979136, 27.199000000013545, -43.4069999999826, -19.717999999977593], 0.14: [-40.60599999997006, 20.812000000017324, -36.92399999996828, -16.823999999994307], 0.16: [-46.67799999999177, 15.321000000008048, -44.69299999998562, -20.259999999991418], 0.18: [-50.00000000000659, 9.272999999998936, -46.17999999998996, -20.372999999978557], 0.2: [-50.00000000000659, -10.302000000007581, -43.809999999980654, -27.80299999997248]}, 'action': {0.0: [88.59599999998908, 89.80099999999196, 95.33099999999138, 89.95499999999211], 0.02: [77.0819999999811, 62.489999999984896, 87.05999999998483, 80.0049999999882], 0.04: [40.89400000001671, 12.23100000001109, 93.17599999998842, 58.79099999999564], 0.06: [0.7239999999994958, -0.4759999999992661, 55.267999999993634, 38.22700000001531], 0.08: [-21.224999999981115, -8.20400000000562, -44.03299999998146, -0.4900000000009277], 0.1: [-30.491999999971355, -13.378999999999255, -50.00000000000659, -6.3270000000071365], 0.12: [-33.12599999997128, -7.526000000004318, -48.9070000000017, -20.301999999988396], 0.14: [-25.302999999974073, -13.481999999989796, -50.00000000000659, -21.838999999985894], 0.16: [-26.8449999999806, -20.82099999998552, -50.00000000000659, -13.556999999993286], 0.18: [-36.29599999997011, -13.082999999989562, -50.00000000000659, -26.79199999997249], 0.2: [-31.866999999977402, -34.888999999970146, -50.00000000000659, -24.84599999997415]}}
3 fgsm (t): {'goal': {0.0: [92.70099999998791, 91.54299999998756, 95.48499999999189, 83.79899999998393], 0.02: [50.6059999999856, 76.26199999998907, 97.34199999999515, 76.23499999998089], 0.04: [63.29499999998459, 49.86499999998909, 59.07099999999828, 69.1499999999837], 0.06: [36.015000000019725, 40.86700000001631, -42.514999999975224, 72.39399999997302], 0.08: [48.29000000000529, 29.8150000000173, -48.91300000000172, 50.6689999999976], 0.1: [22.814000000011887, 2.7189999999956562, -50.00000000000659, 19.779000000008537], 0.12: [-0.9000000000082974, -11.041999999990228, -46.42099999999197, -33.31599999997313], 0.14: [-1.2400000000009919, -2.5470000000044726, -45.39899999998848, -33.43999999997122], 0.16: [-9.656000000002104, -9.473000000003644, -47.712999999997585, -38.75799999996865], 0.18: [-36.80099999997052, -10.406999999995831, -44.052999999986085, -47.55799999999677], 0.2: [-38.34999999997017, -25.86299999998015, -45.11699999998632, -41.694999999971145]}, 'action': {0.0: [90.3979999999835, 86.39899999998762, 92.54099999998881, 91.08199999998999], 0.02: [76.2549999999882, 61.804999999974754, 69.15499999998829, 85.70699999999066], 0.04: [74.55999999999354, 62.02599999999155, 44.13200000000079, 81.12699999997488], 0.06: [72.8249999999792, 42.85400000001092, 35.4940000000133, 68.22499999998372], 0.08: [79.46999999998206, 20.45000000000771, 51.10499999999559, 56.78499999997899], 0.1: [56.98099999999683, 26.014000000015077, 19.620000000006677, 28.235000000012633], 0.12: [56.180999999994036, 2.9579999999990605, 16.448999999996587, 26.307000000008994], 0.14: [60.982999999987044, -4.054000000007616, 20.1340000000183, -7.047000000006286], 0.16: [65.36299999998789, -2.3880000000061385, -0.24800000000209368, -23.86099999997618], 0.18: [56.8589999999897, -1.0740000000047019, 6.498000000000498, -0.8100000000008059], 0.2: [52.68500000000101, -3.6530000000010534, 26.263000000012322, 0.42499999999707827]}}
6 fgsm (ut): {'goal': {0.0: [97.19699999999479, 76.54099999998242, 87.61799999998199], 0.02: [-36.86599999997379, -25.44499999997546, 76.60599999998077], 0.04: [-48.53200000000034, -9.085000000007067, 54.91599999999754], 0.06: [-24.05599999997817, 10.795999999996438, 45.96200000001285], 0.08: [-45.62999999998796, -0.40800000000181463, 21.92000000001445], 0.1: [-48.587000000005084, -25.051999999980065, 28.601000000018065], 0.12: [-47.171999999994476, -32.08599999997816, -15.591999999993705], 0.14: [-50.00000000000659, -43.17199999997833, -6.55800000000534], 0.16: [-48.62700000000068, -32.49199999997995, -24.23499999997953], 0.18: [-50.00000000000659, -35.87499999997307, -38.151999999971416], 0.2: [-50.00000000000659, -33.450999999970726, -42.070999999973644]}, 'action': {0.0: [97.29399999999512, 80.132999999978, 80.1599999999833], 0.02: [36.88200000001189, -5.397000000007658, 71.62799999997915], 0.04: [5.62299999999548, 11.661999999997972, 71.4859999999847], 0.06: [-7.977000000001651, 18.431000000015462, 3.3649999999990485], 0.08: [-19.844999999986193, 14.005000000001735, 0.25100000000019995], 0.1: [-3.8330000000034796, 12.5450000000013, -2.1160000000036443], 0.12: [-32.907999999975516, 3.3229999999970734, -3.3360000000006664], 0.14: [-34.25499999997236, -16.907999999997624, -31.057999999977632], 0.16: [-37.26199999997106, -24.165999999976687, -42.43499999997435], 0.18: [-37.26699999997347, -34.76399999997275, -41.0829999999698], 0.2: [-45.753999999989546, -37.64999999997069, -47.606000000000606]}}
6 fgsm (t): {'goal': {0.0: [97.28599999999513, 85.42299999998885, 91.47399999998807], 0.02: [89.36499999998333, 17.708000000003512, 83.819999999983], 0.04: [89.51099999998296, 45.20500000000479, 28.63200000001247], 0.06: [90.71799999998379, 31.99600000001029, 29.865000000014046], 0.08: [90.27199999998388, 5.2689999999957795, 23.197000000009222], 0.1: [20.8980000000173, -4.948000000007693, 41.58700000001288], 0.12: [-41.27899999997395, -25.693999999976626, 47.80500000001015], 0.14: [-48.61200000000063, -34.85599999997398, 27.311000000017955], 0.16: [-35.90699999997299, -32.142999999972375, -41.38699999997116], 0.18: [-40.38899999996786, -45.84599999998874, -39.15299999996898], 0.2: [-39.366999999970716, -50.00000000000659, -39.555999999970204]}, 'action': {0.0: [97.29299999999513, 76.06399999997772, 85.69399999998306], 0.02: [5.700999999998902, 68.05299999997537, 85.97199999998655], 0.04: [16.467000000006223, 57.289999999979585, 78.1079999999796], 0.06: [-12.629999999998265, 69.42099999997768, 82.70299999999074], 0.08: [-19.745999999985784, 72.6269999999831, 86.93299999998273], 0.1: [-41.521999999970284, 59.49499999999363, 85.7009999999889], 0.12: [-42.836999999978254, 36.69100000001758, 74.24399999998042], 0.14: [-42.73299999997646, 51.60499999999654, 14.298000000013944], 0.16: [-44.50199999998749, 69.5619999999835, 10.312000000010725], 0.18: [-47.22599999999468, 59.64099999998784, -15.471999999990105], 0.2: [-47.27699999999573, 45.88899999999246, -21.531999999986674]}}


***

In [None]:
print("scale = 6: 0 noise_hrl: {'both': {0.0: [96.3639999999933], 0.05: [54.918999999968484], 0.1: [48.21700000001007], 0.15: [51.71999999999862], 0.2: [37.03600000001704], 0.25: [52.80099999998565], 0.3: [62.17499999998056], 0.35: [67.63499999998102], 0.4: [58.082999999981375], 0.45: [54.24499999999223], 0.5: [48.28699999999917]}, 'action_only': {0.0: [94.93999999999143], 0.05: [59.63499999999108], 0.1: [62.707999999990626], 0.15: [38.9490000000148], 0.2: [48.75799999999072], 0.25: [50.58299999999304], 0.3: [57.51699999998973], 0.35: [66.80599999998094], 0.4: [58.69299999999387], 0.45: [63.02599999998072], 0.5: [49.02200000000374]}, 'goal_only': {0.0: [96.3819999999926], 0.05: [94.02299999999175], 0.1: [89.99799999999311], 0.15: [74.9379999999893], 0.2: [80.52199999998253], 0.25: [63.90299999998806], 0.3: [66.61699999997958], 0.35: [61.584999999983985], 0.4: [63.00299999998356], 0.45: [60.40699999997702], 0.5: [50.71299999999697]}, 'both_same': {0.0: [96.53799999999389], 0.05: [71.02099999997539], 0.1: [56.563000000002205], 0.15: [41.178000000016795], 0.2: [29.730000000014588], 0.25: [44.95799999999946], 0.3: [45.1229999999953], 0.35: [54.9889999999898], 0.4: [42.83200000001836], 0.45: [50.61500000000443], 0.5: [37.238000000022815]}}")
print("1 noise_hrl: {'both': {0.0: [95.15799999999095], 0.05: [43.120999999994545], 0.1: [-0.5440000000035908], 0.15: [5.94999999999639], 0.2: [-2.159000000007272], 0.25: [-6.430000000002529], 0.3: [-10.141999999998173], 0.35: [-11.353000000003961], 0.4: [-12.18899999999982], 0.45: [4.6489999999939835], 0.5: [-4.651000000006837]}, 'action_only': {0.0: [95.16399999999183], 0.05: [37.43400000001433], 0.1: [1.5749999999974835], 0.15: [-14.29599999999326], 0.2: [-18.556999999991106], 0.25: [-27.751999999973435], 0.3: [-33.030999999975236], 0.35: [-33.035999999972034], 0.4: [-31.657999999973296], 0.45: [-32.31799999997636], 0.5: [-30.984999999972953]}, 'goal_only': {0.0: [90.76999999998837], 0.05: [43.109000000017176], 0.1: [11.300999999997414], 0.15: [21.51200000000807], 0.2: [8.746999999993644], 0.25: [20.157000000015763], 0.3: [8.25499999999702], 0.35: [7.427999999994549], 0.4: [20.227000000021917], 0.45: [8.56599999999416], 0.5: [10.473999999997844]}, 'both_same': {0.0: [94.99899999999155], 0.05: [24.423000000012017], 0.1: [3.0539999999991845], 0.15: [2.070999999998323], 0.2: [-8.592000000005436], 0.25: [-5.232000000000124], 0.3: [5.663999999995248], 0.35: [-13.71399999998818], 0.4: [-5.5780000000038], 0.45: [3.957999999999331], 0.5: [-7.69400000000283]}}")
print("2 noise_hrl: {'both': {0.0: [77.22799999997842], 0.05: [91.85599999998918], 0.1: [88.32799999998605], 0.15: [77.96799999997985], 0.2: [72.99499999997737], 0.25: [59.955999999972576], 0.3: [43.47000000000954], 0.35: [27.972000000013995], 0.4: [11.81800000000768], 0.45: [-3.335000000000773], 0.5: [-7.91600000000729]}, 'action_only': {0.0: [86.60199999999173], 0.05: [94.85399999999072], 0.1: [85.6779999999898], 0.15: [87.02199999998413], 0.2: [87.2109999999808], 0.25: [82.11399999998034], 0.3: [56.4839999999846], 0.35: [48.32999999999503], 0.4: [21.65700000001724], 0.45: [15.233999999999925], 0.5: [7.323999999992005]}, 'goal_only': {0.0: [85.53099999999057], 0.05: [93.30999999998849], 0.1: [91.71399999999147], 0.15: [86.04099999999268], 0.2: [78.09899999997948], 0.25: [76.78199999998236], 0.3: [85.13999999997746], 0.35: [79.90399999997854], 0.4: [72.06499999998377], 0.45: [70.25199999996894], 0.5: [77.69099999997097]}, 'both_same': {0.0: [76.34899999998349], 0.05: [92.01999999998783], 0.1: [91.23599999998761], 0.15: [77.87999999997346], 0.2: [76.7469999999706], 0.25: [57.80199999998752], 0.3: [41.31800000002057], 0.35: [28.154000000023387], 0.4: [15.159000000028925], 0.45: [9.621999999997314], 0.5: [1.5830000000006716]}}")
print("3 noise_hrl: {'both': {0.0: [91.74899999998587], 0.05: [66.63099999998822], 0.1: [81.82499999998295], 0.15: [86.8399999999813], 0.2: [87.1179999999818], 0.25: [83.55799999997653], 0.3: [80.14899999997554], 0.35: [73.27999999997836], 0.4: [63.98799999998137], 0.45: [49.925999999991255], 0.5: [30.416000000023654]}, 'action_only': {0.0: [90.48899999998446], 0.05: [82.78299999999018], 0.1: [87.66199999998838], 0.15: [84.36999999998747], 0.2: [75.4849999999813], 0.25: [78.11499999998446], 0.3: [72.18299999998146], 0.35: [73.4889999999783], 0.4: [66.22299999998317], 0.45: [62.66099999998712], 0.5: [48.57699999999974]}, 'goal_only': {0.0: [92.13799999998645], 0.05: [65.26499999999584], 0.1: [58.88299999998399], 0.15: [78.07899999997771], 0.2: [78.84699999997434], 0.25: [67.82599999998985], 0.3: [76.36699999997684], 0.35: [79.84899999997405], 0.4: [77.99499999997441], 0.45: [70.20299999997609], 0.5: [74.54599999997596]}, 'both_same': {0.0: [91.96299999998583], 0.05: [73.76599999998523], 0.1: [79.05099999999281], 0.15: [89.26499999998363], 0.2: [88.07199999997938], 0.25: [85.04699999997825], 0.3: [80.49099999997414], 0.35: [69.58099999998396], 0.4: [65.33399999996705], 0.45: [43.302000000013706], 0.5: [42.56400000001618]}}")
print("4 noise_hrl: {'both': {0.0: [96.94099999999447], 0.05: [2.4149999999953353], 0.1: [12.643000000005557], 0.15: [11.846000000010495], 0.2: [26.295000000010624], 0.25: [22.324000000013584], 0.3: [-1.1070000000030773], 0.35: [-6.414000000007644], 0.4: [-21.574999999985625], 0.45: [-7.897000000001222], 0.5: [-19.56699999998092]}, 'action_only': {0.0: [96.89499999999431], 0.05: [4.4919999999993445], 0.1: [-0.37899999999955264], 0.15: [-9.43600000000802], 0.2: [-21.59099999997936], 0.25: [-22.76199999998626], 0.3: [-12.646999999996387], 0.35: [-12.625000000002462], 0.4: [-22.514999999974417], 0.45: [-33.822999999968594], 0.5: [-24.895999999973977]}, 'goal_only': {0.0: [95.42499999999455], 0.05: [51.67000000000633], 0.1: [26.38200000001382], 0.15: [36.648000000012644], 0.2: [31.908000000014585], 0.25: [43.380000000002795], 0.3: [48.72999999999832], 0.35: [52.49799999999238], 0.4: [53.920999999976374], 0.45: [56.71599999999028], 0.5: [55.138999999987156]}, 'both_same': {0.0: [96.90299999999455], 0.05: [2.8819999999994788], 0.1: [-11.494000000003675], 0.15: [5.675999999998947], 0.2: [23.662000000020775], 0.25: [11.711000000017396], 0.3: [12.161000000006661], 0.35: [-8.709000000008526], 0.4: [-6.971000000006404], 0.45: [-29.41099999997331], 0.5: [-18.80799999998569]}}")
print("5 noise_hrl: {'both': {0.0: [93.76699999999188], 0.05: [88.57899999998718], 0.1: [89.01499999998401], 0.15: [89.9719999999848], 0.2: [87.44999999997998], 0.25: [87.00499999997834], 0.3: [74.60199999997256], 0.35: [63.73099999997803], 0.4: [51.52199999999096], 0.45: [46.243000000012955], 0.5: [11.90800000000955]}, 'action_only': {0.0: [93.73899999998986], 0.05: [86.24699999998711], 0.1: [69.81899999998487], 0.15: [80.07599999998867], 0.2: [84.41799999997627], 0.25: [82.6349999999845], 0.3: [83.96799999997427], 0.35: [82.36499999997112], 0.4: [76.32399999997384], 0.45: [66.45399999997487], 0.5: [57.994999999996445]}, 'goal_only': {0.0: [95.20899999999156], 0.05: [74.97199999998386], 0.1: [79.92599999998512], 0.15: [75.0969999999868], 0.2: [78.16999999998319], 0.25: [82.494999999971], 0.3: [84.7889999999807], 0.35: [81.22599999997823], 0.4: [80.3029999999821], 0.45: [76.38099999997253], 0.5: [79.1489999999787]}, 'both_same': {0.0: [95.20499999999141], 0.05: [83.2659999999889], 0.1: [84.36999999998793], 0.15: [87.51999999998448], 0.2: [86.32599999998075], 0.25: [86.53099999998001], 0.3: [79.58499999997619], 0.35: [76.16099999997351], 0.4: [60.07399999997519], 0.45: [48.35699999999883], 0.5: [37.30200000002238]}}")

scale = 6: 0 noise_hrl: {'both': {0.0: [96.3639999999933], 0.05: [54.918999999968484], 0.1: [48.21700000001007], 0.15: [51.71999999999862], 0.2: [37.03600000001704], 0.25: [52.80099999998565], 0.3: [62.17499999998056], 0.35: [67.63499999998102], 0.4: [58.082999999981375], 0.45: [54.24499999999223], 0.5: [48.28699999999917]}, 'action_only': {0.0: [94.93999999999143], 0.05: [59.63499999999108], 0.1: [62.707999999990626], 0.15: [38.9490000000148], 0.2: [48.75799999999072], 0.25: [50.58299999999304], 0.3: [57.51699999998973], 0.35: [66.80599999998094], 0.4: [58.69299999999387], 0.45: [63.02599999998072], 0.5: [49.02200000000374]}, 'goal_only': {0.0: [96.3819999999926], 0.05: [94.02299999999175], 0.1: [89.99799999999311], 0.15: [74.9379999999893], 0.2: [80.52199999998253], 0.25: [63.90299999998806], 0.3: [66.61699999997958], 0.35: [61.584999999983985], 0.4: [63.00299999998356], 0.45: [60.40699999997702], 0.5: [50.71299999999697]}, 'both_same': {0.0: [96.53799999999389], 0.05: [71.0209999999

In [None]:
print("scale = 8: 0 noise_hrl: {'both': {0.0: [89.58399999998565], 0.05: [-21.96899999998344], 0.1: [10.140999999997868], 0.15: [17.213000000016653], 0.2: [3.6489999999920486], 0.25: [-13.062999999997603], 0.3: [-21.498999999974416], 0.35: [-29.113999999975434], 0.4: [-27.729999999977526], 0.45: [-22.503999999981133], 0.5: [-27.795999999982573]}, 'action_only': {0.0: [92.70799999998911], 0.05: [-15.412999999994332], 0.1: [-27.177999999978216], 0.15: [6.325999999993067], 0.2: [14.985000000007734], 0.25: [-6.126000000007625], 0.3: [-11.676000000000649], 0.35: [-25.371999999976687], 0.4: [-37.1969999999736], 0.45: [-41.24599999997312], 0.5: [-43.32299999997684]}, 'goal_only': {0.0: [90.67399999998943], 0.05: [56.30299999999417], 0.1: [35.897000000014586], 0.15: [20.404000000012232], 0.2: [1.9769999999966479], 0.25: [0.9279999999980564], 0.3: [10.676999999996147], 0.35: [13.787999999998899], 0.4: [21.39000000001561], 0.45: [26.266000000019204], 0.5: [22.632000000019776]}, 'both_same': {0.0: [92.2509999999885], 0.05: [-11.208999999999843], 0.1: [6.1019999999912224], 0.15: [10.211999999995879], 0.2: [-5.564000000005309], 0.25: [-6.504000000004282], 0.3: [-27.746999999979337], 0.35: [-19.944999999989623], 0.4: [-34.8739999999706], 0.45: [-31.384999999973058], 0.5: [-35.08899999997026]}}")
print("2 noise_hrl: {'both': {0.0: [93.36099999998817, 95.18699999999157], 0.05: [92.5459999999859, 71.43399999998344], 0.1: [86.69099999998303, 45.490000000012266], 0.15: [82.8549999999762, 44.14800000000474], 0.2: [84.85799999997953, 37.76700000001279], 0.25: [81.31899999997303, 37.354000000017614], 0.3: [77.25799999997271, 23.919000000005063], 0.35: [71.54599999996931, 2.4079999999978092], 0.4: [59.43899999998248, 7.2229999999957775], 0.45: [40.19800000002081, -5.903000000004514], 0.5: [11.663999999999492, -20.885999999982886]}, 'action_only': {0.0: [92.31599999998997, 95.18999999999156], 0.05: [92.05499999998706, 72.35799999998208], 0.1: [91.56799999998677, 65.19499999998497], 0.15: [88.78299999998308, 70.89299999998035], 0.2: [80.65799999997846, 72.0519999999897], 0.25: [85.79799999998276, 74.17799999997054], 0.3: [82.78399999998048, 71.47199999997541], 0.35: [83.12299999997609, 62.623999999994524], 0.4: [74.89299999996759, 52.1789999999897], 0.45: [77.18899999996945, 34.979000000012896], 0.5: [67.80499999997197, 24.087000000021725]}, 'goal_only': {0.0: [93.60799999998848, 95.17899999999153], 0.05: [91.1629999999879, 93.17799999998995], 0.1: [92.37399999998685, 78.17899999998471], 0.15: [89.1379999999868, 68.65799999997292], 0.2: [84.65599999998375, 45.65600000000758], 0.25: [85.15799999998008, 58.369999999986376], 0.3: [82.84099999996943, 40.674000000009], 0.35: [83.15399999997578, 49.326999999985745], 0.4: [74.93699999997447, 31.71300000001887], 0.45: [75.46399999996976, 27.98900000001289], 0.5: [74.12099999996296, 39.358000000019295]}, 'both_same': {0.0: [93.5789999999886, 95.18699999999157], 0.05: [90.14299999998542, 63.83199999997188], 0.1: [84.29299999998386, 47.93199999999249], 0.15: [85.6079999999845, 51.156999999999634], 0.2: [85.22399999997721, 33.88200000001158], 0.25: [81.34499999997337, 19.995000000016297], 0.3: [75.71799999996651, 13.909000000010794], 0.35: [67.96699999996562, 0.3739999999996238], 0.4: [53.32499999999096, 2.323999999998038], 0.45: [37.67300000002132, -19.124999999979693], 0.5: [17.84600000001488, -22.123999999976707]}}")
print("3 noise_hrl: {'both': {0.0: [94.39899999998917], 0.05: [13.9369999999988], 0.1: [22.37000000001506], 0.15: [42.68800000001677], 0.2: [47.030000000011505], 0.25: [39.187000000021385], 0.3: [19.89300000001973], 0.35: [24.04400000001407], 0.4: [1.1709999999961835], 0.45: [6.2559999999949945], 0.5: [-10.490000000007262]}, 'action_only': {0.0: [94.85299999999097], 0.05: [-14.029999999994166], 0.1: [-22.933999999986955], 0.15: [-5.684000000003248], 0.2: [7.960999999995567], 0.25: [10.446000000007311], 0.3: [19.6870000000117], 0.35: [23.591000000022262], 0.4: [23.578000000022122], 0.45: [18.63900000001401], 0.5: [8.236999999993753]}, 'goal_only': {0.0: [94.9429999999913], 0.05: [49.51099999999736], 0.1: [53.69399999999389], 0.15: [61.29499999999758], 0.2: [60.98199999997762], 0.25: [61.330999999985266], 0.3: [58.39599999998998], 0.35: [61.434999999982125], 0.4: [48.15000000001303], 0.45: [53.47199999998475], 0.5: [43.34800000001664]}, 'both_same': {0.0: [94.93299999999122], 0.05: [6.9409999999986045], 0.1: [38.03900000001906], 0.15: [43.03700000001386], 0.2: [49.28699999999883], 0.25: [33.34100000002002], 0.3: [28.344000000015946], 0.35: [12.652000000012956], 0.4: [7.968000000000552], 0.45: [-15.922999999994843], 0.5: [-11.418999999991884]}}")
print("4 noise_hrl: {'both': {0.0: [85.43999999998455], 0.05: [70.51499999997505], 0.1: [69.42799999997622], 0.15: [77.71499999997638], 0.2: [74.36399999997019], 0.25: [57.3179999999857], 0.3: [37.56600000001988], 0.35: [24.84500000001754], 0.4: [1.5669999999979023], 0.45: [-11.19599999999417], 0.5: [-23.450999999972563]}, 'action_only': {0.0: [88.46899999998594], 0.05: [73.92399999998884], 0.1: [79.30499999998182], 0.15: [81.8729999999862], 0.2: [76.7219999999856], 0.25: [76.46899999996941], 0.3: [76.73699999997417], 0.35: [68.8239999999791], 0.4: [49.37200000000297], 0.45: [30.058000000022865], 0.5: [-0.732000000000088]}, 'goal_only': {0.0: [91.06699999999044], 0.05: [76.05299999998603], 0.1: [65.57899999996798], 0.15: [73.11099999997788], 0.2: [66.81499999999794], 0.25: [71.10699999997446], 0.3: [70.61399999996958], 0.35: [62.22199999997437], 0.4: [62.20499999997654], 0.45: [52.917999999988744], 0.5: [38.325000000021895]}, 'both_same': {0.0: [81.16499999998886], 0.05: [79.65199999998273], 0.1: [64.50299999998218], 0.15: [76.1169999999784], 0.2: [70.58999999997762], 0.25: [70.06799999997028], 0.3: [60.980999999966315], 0.35: [45.02200000000973], 0.4: [18.439000000020457], 0.45: [-5.685000000001867], 0.5: [-1.740000000004799]}}")

scale = 8: 0 noise_hrl: {'both': {0.0: [89.58399999998565], 0.05: [-21.96899999998344], 0.1: [10.140999999997868], 0.15: [17.213000000016653], 0.2: [3.6489999999920486], 0.25: [-13.062999999997603], 0.3: [-21.498999999974416], 0.35: [-29.113999999975434], 0.4: [-27.729999999977526], 0.45: [-22.503999999981133], 0.5: [-27.795999999982573]}, 'action_only': {0.0: [92.70799999998911], 0.05: [-15.412999999994332], 0.1: [-27.177999999978216], 0.15: [6.325999999993067], 0.2: [14.985000000007734], 0.25: [-6.126000000007625], 0.3: [-11.676000000000649], 0.35: [-25.371999999976687], 0.4: [-37.1969999999736], 0.45: [-41.24599999997312], 0.5: [-43.32299999997684]}, 'goal_only': {0.0: [90.67399999998943], 0.05: [56.30299999999417], 0.1: [35.897000000014586], 0.15: [20.404000000012232], 0.2: [1.9769999999966479], 0.25: [0.9279999999980564], 0.3: [10.676999999996147], 0.35: [13.787999999998899], 0.4: [21.39000000001561], 0.45: [26.266000000019204], 0.5: [22.632000000019776]}, 'both_same': {0.0: [92.2

In [None]:
print("scale = 10: 0 noise_hrl: {'both': {0.0: [74.30999999998883], 0.05: [-23.90499999997209], 0.1: [-32.38899999997059], 0.15: [-27.27299999998588], 0.2: [-27.285999999981833], 0.25: [-28.321999999974825], 0.3: [-37.959999999970194], 0.35: [-41.929999999973134], 0.4: [-43.34799999997724], 0.45: [-46.44999999999322], 0.5: [-48.772000000001206]}, 'action_only': {0.0: [90.21099999998653], 0.05: [-26.004999999978445], 0.1: [-41.64699999997353], 0.15: [-43.08299999999165], 0.2: [-38.85299999996857], 0.25: [-35.05799999997199], 0.3: [-32.11399999997321], 0.35: [-33.713999999971854], 0.4: [-32.97199999996732], 0.45: [-34.33299999997213], 0.5: [-38.20599999997024]}, 'goal_only': {0.0: [87.1169999999907], 0.05: [16.103000000004798], 0.1: [5.376999999995426], 0.15: [6.544999999996297], 0.2: [-2.7870000000036286], 0.25: [-4.802000000004767], 0.3: [-4.9790000000081776], 0.35: [-10.819999999992001], 0.4: [-12.545000000001146], 0.45: [-21.236999999977016], 0.5: [-19.541999999976756]}, 'both_same': {0.0: [85.62999999998938], 0.05: [-12.93499999999594], 0.1: [-26.946999999974487], 0.15: [-38.423999999971514], 0.2: [-37.502999999976616], 0.25: [-34.848999999972854], 0.3: [-30.68599999997096], 0.35: [-40.70799999996779], 0.4: [-45.389999999992455], 0.45: [-43.36799999998226], 0.5: [-48.9760000000065]}}")
print("1 noise_hrl: {'both': {0.0: [85.86999999998416], 0.05: [67.8129999999749], 0.1: [56.52199999999854], 0.15: [24.86800000000761], 0.2: [2.0409999999952526], 0.25: [-22.644999999980172], 0.3: [-26.043999999969078], 0.35: [-38.070999999968116], 0.4: [-38.48399999996957], 0.45: [-45.41699999998719], 0.5: [-48.92000000000175]}, 'action_only': {0.0: [87.78499999998714], 0.05: [76.69199999998402], 0.1: [69.23099999999201], 0.15: [46.0270000000121], 0.2: [5.344999999995891], 0.25: [-7.826000000004082], 0.3: [-18.184999999988143], 0.35: [-35.01799999996818], 0.4: [-28.19599999997219], 0.45: [-31.650999999970953], 0.5: [-30.500999999970844]}, 'goal_only': {0.0: [90.40899999999044], 0.05: [76.49999999998268], 0.1: [67.60899999998016], 0.15: [55.29699999998009], 0.2: [41.15300000001944], 0.25: [31.64200000002307], 0.3: [16.992000000013086], 0.35: [4.564999999999536], 0.4: [-17.367999999991678], 0.45: [-26.462999999974155], 0.5: [-22.34699999998601]}, 'both_same': {0.0: [93.34799999998805], 0.05: [65.81599999998555], 0.1: [55.42099999999048], 0.15: [20.308000000008438], 0.2: [9.720999999994483], 0.25: [-17.044999999988207], 0.3: [-39.42299999997011], 0.35: [-41.5809999999705], 0.4: [-46.16799999999077], 0.45: [-43.25799999997887], 0.5: [-47.843000000002604]}}")
print("2 noise_hrl: {'both': {0.0: [92.35999999998592], 0.05: [88.81299999998319], 0.1: [88.08799999998031], 0.15: [82.995999999978], 0.2: [50.9069999999969], 0.25: [-13.965999999997708], 0.3: [-24.96399999998261], 0.35: [-42.92799999997974], 0.4: [-48.75200000000227], 0.45: [-48.774000000001216], 0.5: [-48.92300000000176]}, 'action_only': {0.0: [91.98999999998729], 0.05: [88.92399999998341], 0.1: [87.58399999997923], 0.15: [86.4569999999787], 0.2: [80.73199999997452], 0.25: [45.80499999999895], 0.3: [-9.275000000004507], 0.35: [-38.60799999996965], 0.4: [-43.71499999998258], 0.45: [-46.27699999999486], 0.5: [-48.69100000000546]}, 'goal_only': {0.0: [92.32399999998772], 0.05: [91.39599999998676], 0.1: [91.12899999998385], 0.15: [89.79699999998273], 0.2: [88.62699999998243], 0.25: [84.59599999997796], 0.3: [78.81199999998098], 0.35: [61.86199999999055], 0.4: [39.45000000001838], 0.45: [26.05600000001552], 0.5: [12.44500000000643]}, 'both_same': {0.0: [92.45499999998763], 0.05: [86.08599999998756], 0.1: [89.63399999998376], 0.15: [83.08799999997373], 0.2: [34.539000000023435], 0.25: [-26.822999999976417], 0.3: [-40.150999999969436], 0.35: [-47.388999999995264], 0.4: [-50.00000000000659], 0.45: [-48.86400000000246], 0.5: [-50.00000000000659]}}")
print("3 noise_hrl: {'both': {0.0: [93.72599999998882], 0.05: [32.070000000019164], 0.1: [32.640000000020216], 0.15: [8.61399999999363], 0.2: [11.417000000002314], 0.25: [-0.8189999999997825], 0.3: [-15.164999999988325], 0.35: [-26.935999999975657], 0.4: [-26.719999999976196], 0.45: [-35.25599999997139], 0.5: [-37.205999999972796]}, 'action_only': {0.0: [93.59099999998892], 0.05: [84.92499999999], 0.1: [76.32999999998614], 0.15: [62.776999999987495], 0.2: [55.25699999998671], 0.25: [13.739000000005541], 0.3: [11.803000000001134], 0.35: [-10.586000000004196], 0.4: [-29.85899999997243], 0.45: [-31.617999999978256], 0.5: [-35.771999999975556]}, 'goal_only': {0.0: [93.67099999998915], 0.05: [25.51300000001443], 0.1: [28.848000000012256], 0.15: [18.9800000000107], 0.2: [9.385999999994432], 0.25: [13.058000000011644], 0.3: [19.407000000004846], 0.35: [26.631000000018165], 0.4: [25.86700000001965], 0.45: [36.78300000001957], 0.5: [17.73900000001213]}, 'both_same': {0.0: [93.74599999998912], 0.05: [28.48800000001097], 0.1: [36.54200000001187], 0.15: [14.98400000000243], 0.2: [15.956000000000572], 0.25: [2.2209999999982], 0.3: [-31.61199999997424], 0.35: [-18.83099999998737], 0.4: [-30.740999999975955], 0.45: [-30.21099999997863], 0.5: [-37.79499999997009]}}")
print("4 noise_hrl: {'both': {0.0: [93.27699999998872], 0.05: [20.818000000004513], 0.1: [1.3190000000001034], 0.15: [-21.85399999997596], 0.2: [-32.98899999997388], 0.25: [-35.38599999996985], 0.3: [-44.377999999988404], 0.35: [-36.33199999996976], 0.4: [-41.044999999966734], 0.45: [-45.596999999988064], 0.5: [-44.80399999999275]}, 'action_only': {0.0: [94.99699999999126], 0.05: [15.293000000009819], 0.1: [-13.46799999999605], 0.15: [-9.266000000006867], 0.2: [-28.436999999983264], 0.25: [-31.721999999975523], 0.3: [-30.65199999997237], 0.35: [-41.087999999968936], 0.4: [-35.02799999996904], 0.45: [-46.08599999998962], 0.5: [-45.24899999999566]}, 'goal_only': {0.0: [93.52599999999183], 0.05: [60.104999999985324], 0.1: [46.03000000000398], 0.15: [23.106000000010745], 0.2: [14.636000000008682], 0.25: [-2.4360000000037356], 0.3: [-7.930000000005353], 0.35: [-0.6630000000051609], 0.4: [-22.058999999984216], 0.45: [-16.004999999986197], 0.5: [-25.652999999975954]}, 'both_same': {0.0: [94.98499999999115], 0.05: [29.589000000010092], 0.1: [-15.305999999985351], 0.15: [-12.671000000002534], 0.2: [-23.044999999975673], 0.25: [-39.07599999997178], 0.3: [-45.80599999998883], 0.35: [-42.84699999997712], 0.4: [-43.22299999998193], 0.45: [-44.363999999984614], 0.5: [-50.00000000000659]}}")

scale = 10: 0 noise_hrl: {'both': {0.0: [74.30999999998883], 0.05: [-23.90499999997209], 0.1: [-32.38899999997059], 0.15: [-27.27299999998588], 0.2: [-27.285999999981833], 0.25: [-28.321999999974825], 0.3: [-37.959999999970194], 0.35: [-41.929999999973134], 0.4: [-43.34799999997724], 0.45: [-46.44999999999322], 0.5: [-48.772000000001206]}, 'action_only': {0.0: [90.21099999998653], 0.05: [-26.004999999978445], 0.1: [-41.64699999997353], 0.15: [-43.08299999999165], 0.2: [-38.85299999996857], 0.25: [-35.05799999997199], 0.3: [-32.11399999997321], 0.35: [-33.713999999971854], 0.4: [-32.97199999996732], 0.45: [-34.33299999997213], 0.5: [-38.20599999997024]}, 'goal_only': {0.0: [87.1169999999907], 0.05: [16.103000000004798], 0.1: [5.376999999995426], 0.15: [6.544999999996297], 0.2: [-2.7870000000036286], 0.25: [-4.802000000004767], 0.3: [-4.9790000000081776], 0.35: [-10.819999999992001], 0.4: [-12.545000000001146], 0.45: [-21.236999999977016], 0.5: [-19.541999999976756]}, 'both_same': {0.0: 

In [None]:
noise_hrl = {'both': {}, 'action_only': {}, 'goal_only': {}, 'both_same': {}}
for l2norm in np.arange(0,0.51,0.05):
    for i in [noise_hrl['both'], noise_hrl['action_only'], noise_hrl['goal_only'], noise_hrl['both_same']]:
        i[np.round(l2norm, 2)] = []

targeted = {'both': {}, 'goal_only': {}, 'action_only': {}}
untargeted = {'both': {}, 'goal_only': {}, 'action_only': {}}
for eps in np.arange(0.0, 0.201, 0.02):
    for x in ['both', 'goal_only', 'action_only']:
        targeted[x][eps] = []
        untargeted[x][eps] = []

n_observations = env.observation_space.shape[0]
n_actions = env.action_space.shape[0]

i = 5
while i < 6:
    agent = train_model()
    #agent = HIRO(n_observations, n_actions).to(device)
    #load_model(agent, f"hiro_s6_{i}")

    if agent is not None:
        # goal_attack, action_attack, same_noise
        eval_model(agent, noise_hrl['both_same'], True, True, True)
        eval_model(agent, noise_hrl['both'], True, True, False)
        eval_model(agent, noise_hrl['action_only'], False, True, False)
        eval_model(agent, noise_hrl['goal_only'], True, False, False)
        print(f"{i} noise_hrl: {noise_hrl}")
        i += 1

print("----")
print(f"noise_hrl: {noise_hrl}")

100: -48.92300000000043
200: -50.000000000000426
300: -46.50600000000042
400: -50.000000000000426
500: -48.92200000000042
600: -50.000000000000426
Unlucky after 600 eps! Terminating...
100: -50.000000000000426
200: -50.000000000000426
300: -36.55000000000043
400: -50.000000000000426
500: -45.54200000000043
600: -24.527000000000356
700: -26.217000000000372
800: -50.000000000000426
Unlucky after 800 eps! Terminating...
100: -50.000000000000426
200: -48.999000000000436
300: -50.000000000000426
400: -50.000000000000426
500: -48.60000000000042
600: -48.81600000000042
700: -50.000000000000426
Unlucky after 700 eps! Terminating...
100: -50.000000000000426
200: -50.000000000000426
300: -32.54100000000041
400: -43.1190000000004
500: -42.479000000000404
600: -11.872000000000323
700: 9.017999999999759
800: -36.0020000000004
900: 20.81199999999978
1000: 21.713999999999782
1100: 42.53899999999985
1200: 21.291999999999785
1300: 58.1599999999999
1400: 69.13799999999993
1500: 86.93700000000001
Solved 

In [None]:
noise_hrl = {'both': {}, 'action_only': {}, 'goal_only': {}, 'both_same': {}}
for l2norm in np.arange(0,0.51,0.05):
    for i in [noise_hrl['both'], noise_hrl['action_only'], noise_hrl['goal_only'], noise_hrl['both_same']]:
        i[np.round(l2norm, 2)] = []

targeted = {'both': {}, 'goal_only': {}, 'action_only': {}}
untargeted = {'both': {}, 'goal_only': {}, 'action_only': {}}
for eps in np.arange(0.0, 0.201, 0.02):
    for x in ['both', 'goal_only', 'action_only']:
        targeted[x][eps] = []
        untargeted[x][eps] = []

n_observations = env.observation_space.shape[0]
n_actions = env.action_space.shape[0]

i = 0
while i < 12:
    #agent = train_model()
    agent = HIRO(n_observations, n_actions).to(device)
    load_model(agent, f"hiro_{i}")

    if agent is not None:
        # goal_attack, action_attack, same_noise
        #eval_model(agent, noise_hrl['both_same'], True, True, True)
        #eval_model(agent, noise_hrl['both'], True, True, False)
        #eval_model(agent, noise_hrl['action_only'], False, True, False)
        #eval_model(agent, noise_hrl['goal_only'], True, False, False)
        #print(f"{i} noise_hrl: {noise_hrl}")
        
        apply_fgsm(agent, untargeted['both'], True, True, False)
        apply_fgsm(agent, untargeted['goal_only'], True, False, False)
        apply_fgsm(agent, untargeted['action_only'], False, True, False)   
        print(f"{i} fgsm (ut): {untargeted}")
        i += 1

print("----")
print(f"noise_hrl: {noise_hrl}")
print(f"fgsm: {untargeted}")

0 fgsm (ut): {'both': {0.0: [88.62499999999119], 0.02: [84.99599999998347], 0.04: [70.33099999997614], 0.06: [35.34900000001908], 0.08: [11.493999999995081], 0.1: [2.977999999995559], 0.12: [-24.840999999973814], 0.14: [-17.09999999998783], 0.16: [-41.23999999997312], 0.18: [-48.911000000001714], 0.2: [-50.00000000000659]}, 'goal_only': {0.0: [92.63899999998686], 0.02: [86.6189999999835], 0.04: [87.46899999998374], 0.06: [95.2329999999934], 0.08: [80.67599999998397], 0.1: [68.83499999998068], 0.12: [55.137000000001926], 0.14: [42.190000000010805], 0.16: [27.082000000013736], 0.18: [-14.753999999977452], 0.2: [-2.5130000000038084]}, 'action_only': {0.0: [95.09899999998639], 0.02: [71.83599999997526], 0.04: [75.68399999997504], 0.06: [67.22699999999217], 0.08: [-4.312000000004449], 0.1: [-19.11499999998523], 0.12: [7.11499999999475], 0.14: [6.861999999998128], 0.16: [6.840999999991268], 0.18: [0.5859999999974336], 0.2: [-14.883999999994211]}}
1 fgsm (ut): {'both': {0.0: [88.6249999999911

In [None]:
print("noise_hrl: {'both': {0.0: [89.3899999999815, 91.0309999999958, 95.5099999999919, 93.63099999999169, 97.28799999999512, 82.85599999998996, 88.9359999999888, 93.25999999999458, 87.13799999998855, 96.21999999999375, 93.90999999999178, 97.66199999999564], 0.05: [88.87799999998153, 56.85199999998574, 79.99199999997805, 86.5489999999835, 16.843000000012324, 85.7839999999846, 85.94599999998469, 77.93799999997303, 77.52299999998385, 77.80199999997501, 71.91599999997868, 74.05699999998856], 0.1: [85.49499999998174, 40.47200000001511, 60.43699999999265, 79.07499999997796, -3.902000000004151, 95.29299999999122, 91.99899999998851, 69.39799999996973, 86.91499999998113, 85.38299999997933, 87.42099999998932, 26.486000000008616], 0.15: [91.02399999998548, 56.24299999999046, 61.22499999998353, 62.93399999998408, -20.97199999998521, 94.9129999999905, 84.06399999998855, 49.12199999999172, 78.9939999999828, 75.06599999999108, 86.78399999998575, 0.4419999999967616], 0.2: [91.96899999998715, 61.591999999973744, 55.25999999999259, 49.43299999999811, -40.07199999996928, 93.9279999999888, 87.60299999998668, 57.16499999998302, 77.95599999997098, 77.9109999999794, 70.98299999997627, -15.732999999992247], 0.25: [91.76699999998691, 70.74499999996291, 62.07199999999296, 40.18600000001164, -45.748999999990836, 92.04099999998608, 83.69099999997984, 63.612999999971926, 77.07899999997252, 79.37399999997712, 59.75199999999697, -30.17099999997824], 0.3: [89.42499999998414, 60.36799999998427, 46.848000000003566, 39.14700000001644, -50.00000000000659, 88.9679999999821, 77.5969999999798, 59.2849999999919, 68.7919999999706, 76.31099999997753, 34.07300000001621, -30.738999999973675], 0.35: [87.9859999999798, 57.31399999999034, 35.55300000001982, 30.24800000001848, -45.915999999990134, 86.25099999997927, 57.68599999999076, 60.8829999999773, 53.14699999998325, 58.58099999997261, 30.243000000016927, -28.000999999976802], 0.4: [79.63299999998095, 50.576000000000185, 31.56200000001942, 18.784000000005033, -48.621000000000656, 78.45899999997481, 48.538000000001084, 57.22799999998593, 37.83900000001949, 44.6600000000146, 15.358999999997124, -35.3439999999706], 0.45: [64.5409999999706, 37.157000000018506, 10.119999999996145, 25.074000000015058, -50.00000000000659, 74.3229999999788, 22.546000000009535, 52.58499999999816, 26.092000000014163, 53.03899999997724, 5.817999999999249, -43.4879999999858], 0.5: [52.53899999999622, 28.34500000001633, 10.760999999993894, 0.5969999999978297, -46.11399999999085, 60.046999999974496, 15.295999999999568, 46.441999999985626, 28.793000000012697, 34.406000000017826, 2.4809999999991046, -26.50999999998034]}, 'action_only': {0.0: [94.39099999999405, 85.87899999998575, 94.03499999999018, 90.46999999998927, 97.28999999999513, 82.82099999997976, 87.43899999998695, 94.3149999999931, 80.1019999999867, 95.73499999999214, 88.74499999998547, 97.69299999999573], 0.05: [84.93599999998754, 72.88999999998079, 93.87499999999137, 82.65199999998207, 66.64899999998993, 86.16599999998817, 89.80699999998802, 71.93099999998371, 81.42199999998299, 79.35499999998493, 76.89099999998885, 82.56899999998552], 0.1: [81.81599999998807, 31.50200000001407, 85.83699999998757, 73.61399999998108, 53.576999999987045, 94.97699999999182, 85.13299999998414, 55.82799999997656, 75.30799999998467, 73.33599999998512, 77.00699999998294, 50.32600000000618], 0.15: [79.43799999998627, 42.60100000001652, 70.78999999998813, 65.1089999999844, 16.593000000008054, 95.21199999999106, 90.0619999999888, 41.415000000015155, 69.2109999999877, 50.02799999999824, 77.34499999997935, 30.825000000006376], 0.2: [85.2409999999784, 50.69099999999492, 71.94299999998238, 50.226000000005826, -10.930000000005894, 94.91799999999144, 87.73899999998474, 38.18700000001449, 61.54899999999086, 49.28000000001148, 78.73799999997874, 17.55200000000347], 0.25: [87.92199999998658, 48.263000000006286, 68.00299999998462, 33.36700000001578, -29.83299999997565, 92.48599999998846, 84.88499999998295, 29.086000000015428, 51.60399999998358, 55.15399999998974, 80.75899999997924, 4.069999999997915], 0.3: [88.20199999998582, 50.637999999995216, 58.186999999995685, 11.441000000002616, -39.995999999972426, 89.86699999998916, 74.19499999998224, 21.814000000023547, 43.10800000001233, 53.6489999999888, 69.83699999999106, -9.176000000003294], 0.35: [87.8009999999786, 47.12299999998803, 58.147999999989835, 4.981999999997234, -48.54500000000152, 84.56599999998572, 79.19999999998326, 35.95100000001404, 23.905000000007682, 55.65599999999926, 56.080999999985785, -20.769999999977514], 0.4: [85.8179999999765, 28.938000000021038, 35.19500000002186, 6.885999999994123, -44.230999999982195, 83.37099999997801, 63.602999999988924, 39.697000000013055, 18.047000000003308, 46.98099999999418, 44.79600000000559, -24.38499999998173], 0.45: [81.47199999997973, 23.950000000018218, 30.715000000021583, -4.945000000003042, -47.15999999999443, 74.50599999997884, 44.11900000000658, 32.61400000001162, 13.319000000005605, 31.455000000019638, 41.491000000018296, -29.557999999977714], 0.5: [79.40299999997289, 32.46000000001854, 9.788999999996529, 2.4339999999977464, -50.00000000000659, 71.98599999997712, 43.697000000011066, 33.17900000002368, 22.726000000015645, 36.82300000001351, 23.831000000008398, -41.961999999976435]}, 'goal_only': {0.0: [92.3719999999894, 85.46999999999102, 91.20799999998945, 88.62299999998918, 97.29299999999513, 76.62999999999006, 89.16599999999083, 93.25999999999462, 84.62299999998582, 95.20699999999235, 91.42099999998705, 97.66899999999569], 0.05: [90.9879999999876, 81.85599999998831, 75.61899999998116, 93.24799999998935, 40.185000000013936, 90.14599999999004, 81.00499999998256, 90.07899999998965, 76.9219999999853, 85.18599999998301, 80.22999999998096, 80.17999999998439], 0.1: [92.96199999998922, 87.66099999998383, 60.06999999997841, 88.55699999998802, 9.359000000005553, 88.78899999998886, 90.3969999999892, 84.09899999998439, 77.3689999999801, 85.02199999998356, 75.40499999998728, 65.95599999999075], 0.15: [92.08299999998599, 71.91299999999356, 40.57800000001587, 81.57799999998862, 12.740000000002363, 92.81199999998877, 88.38299999998337, 69.82499999998501, 79.46599999997947, 90.00599999998921, 72.50799999998469, 55.664999999993626], 0.2: [89.62299999998245, 76.14399999997991, 38.89300000001532, 85.23299999998753, 19.517000000006284, 91.65599999999023, 85.64799999998493, 65.90699999997885, 79.92299999998212, 85.62899999998612, 76.20799999997844, 32.59200000001866], 0.25: [91.20299999998635, 77.12699999998308, 44.55700000000927, 81.05899999999053, -15.28799999998305, 93.47899999998901, 90.62399999998506, 68.25899999997482, 86.97099999998254, 87.27799999998282, 76.49299999997787, 25.866000000007745], 0.3: [93.77199999999033, 80.45299999998124, 46.588000000003234, 69.95199999998228, -11.35500000000211, 91.580999999985, 92.02599999998644, 72.10099999998766, 81.8679999999828, 89.18499999997944, 76.52999999998278, 17.714000000009065], 0.35: [94.00899999998963, 89.01199999998352, 38.07200000001667, 68.77299999997854, -15.699999999988233, 91.20399999998392, 91.04899999998852, 72.16699999998683, 85.83699999998058, 86.5089999999819, 69.1929999999871, 9.397000000004372], 0.4: [94.4779999999906, 85.22299999998295, 41.16300000001529, 66.87299999998386, -30.11099999997181, 90.06599999998294, 93.4709999999884, 73.9429999999815, 86.84999999997588, 88.9059999999782, 64.95399999998541, -9.495000000005277], 0.45: [93.2209999999885, 84.72299999997841, 36.1600000000155, 63.86999999999032, -31.672999999975815, 86.73799999998049, 90.67599999998812, 73.48699999997108, 81.39399999998437, 86.75899999997851, 67.64599999998713, -0.3420000000052488], 0.5: [93.3639999999871, 87.57999999997985, 33.21600000001821, 62.77999999998112, -32.05299999997566, 87.57699999998249, 91.96599999998854, 72.3249999999844, 79.9699999999718, 88.35599999998053, 59.47999999998657, 2.598999999997897]}, 'both_same': {0.0: [94.86599999999088, 88.42999999999132, 94.05499999999216, 94.34899999998899, 97.28899999999513, 75.28599999998217, 88.77599999998525, 91.79199999999584, 86.1019999999826, 94.92499999999083, 93.61199999998652, 97.68699999999573], 0.05: [85.45999999997899, 64.01299999999186, 77.066999999975, 80.32099999998096, 30.06500000001823, 87.8869999999856, 90.35099999999068, 71.52199999998112, 80.46799999998541, 83.48799999998411, 75.53299999998693, 74.29399999998867], 0.1: [88.85499999997973, 46.54200000000913, 77.43799999997488, 85.39499999998716, 9.493999999995733, 94.53799999999053, 93.44199999998895, 71.6459999999795, 87.16399999998605, 85.40099999998509, 72.93299999998575, 30.2130000000189], 0.15: [91.8309999999903, 32.0420000000137, 60.34999999998712, 64.54599999998977, -3.8930000000047005, 94.01899999999077, 87.08299999998991, 60.22299999998219, 86.16499999998166, 85.61599999998134, 79.7009999999785, 1.8079999999971406], 0.2: [92.06999999998621, 52.93399999999815, 60.574999999971105, 51.264999999998416, -37.13299999997127, 93.05399999998843, 89.38499999998452, 61.09599999997557, 80.55299999997959, 80.77199999998125, 70.41099999998478, -13.23099999998593], 0.25: [90.61999999998358, 51.983999999986466, 47.43300000000938, 34.243000000016856, -44.364999999989315, 91.26099999998533, 82.14399999997903, 66.58799999996405, 70.42599999997377, 76.39099999996927, 60.93099999997749, -23.049999999981807], 0.3: [87.54499999998313, 64.14899999997544, 57.64599999998872, 34.95800000001026, -48.58300000000052, 87.7549999999835, 78.9419999999761, 67.02499999997987, 62.42899999998299, 78.71899999998294, 50.17099999999985, -34.46299999997074], 0.35: [83.58699999998163, 52.50199999999795, 42.66400000001366, 31.467000000020104, -47.15299999999529, 81.34599999998281, 46.255000000006184, 69.05499999997325, 56.71699999999163, 65.27599999998579, 38.97800000001917, -33.22199999997952], 0.4: [80.71099999997814, 43.918000000005385, 23.762000000022, 34.534000000020825, -50.00000000000659, 73.72399999997427, 19.656000000011506, 60.29499999997984, 31.670000000022956, 56.81699999998374, 24.302000000020918, -37.86099999996847], 0.45: [72.11299999997884, 41.84600000001449, 16.610000000006305, 33.15500000001611, -48.640000000000725, 60.43099999997137, 12.405000000002586, 57.387999999990285, 32.81100000001649, 40.789000000017154, 12.680000000007135, -35.11499999997618], 0.5: [54.265999999990555, 39.98500000002172, 6.191999999993301, 15.536000000006343, -47.278999999994866, 50.74299999999853, 7.788999999999377, 52.333999999977074, 26.60500000002136, 30.190000000022494, 10.877999999992229, -42.483999999987425]}}")

noise_hrl: {'both': {0.0: [89.3899999999815, 91.0309999999958, 95.5099999999919, 93.63099999999169, 97.28799999999512, 82.85599999998996, 88.9359999999888, 93.25999999999458, 87.13799999998855, 96.21999999999375, 93.90999999999178, 97.66199999999564], 0.05: [88.87799999998153, 56.85199999998574, 79.99199999997805, 86.5489999999835, 16.843000000012324, 85.7839999999846, 85.94599999998469, 77.93799999997303, 77.52299999998385, 77.80199999997501, 71.91599999997868, 74.05699999998856], 0.1: [85.49499999998174, 40.47200000001511, 60.43699999999265, 79.07499999997796, -3.902000000004151, 95.29299999999122, 91.99899999998851, 69.39799999996973, 86.91499999998113, 85.38299999997933, 87.42099999998932, 26.486000000008616], 0.15: [91.02399999998548, 56.24299999999046, 61.22499999998353, 62.93399999998408, -20.97199999998521, 94.9129999999905, 84.06399999998855, 49.12199999999172, 78.9939999999828, 75.06599999999108, 86.78399999998575, 0.4419999999967616], 0.2: [91.96899999998715, 61.591999999973

In [None]:
def eval_scale(agent, episode_durations):
    agent.eval()
    agent.meta_controller.eval()
    agent.controller.eval()

    max_episode_length = 500
    agent.meta_controller.is_training = False
    agent.controller.is_training = False

    num_episodes = 100

    c = 10

    for scale in np.arange(1.0,7.01,0.5):
        env = NormalizedEnv(PointMazeEnv(scale))

        overall_reward = 0
        for i_episode in range(num_episodes):
            observation = env.reset()

            state = torch.from_numpy(observation).float().unsqueeze(0).to(device)
            g_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)

            episode_steps = 0
            done = False
            while not done:
                # select a goal
                goal = agent.select_goal(g_state, False, False)

                goal_done = False
                while not done and not goal_done:
                    action = agent.select_action(state, goal, False, False)
                    observation, reward, done, info = env.step(action.detach().cpu().squeeze(0).numpy())
                    
                    next_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)
                    g_next_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)

                    next_goal = agent.h(g_state, goal, g_next_state)
                                      
                    overall_reward += reward

                    if max_episode_length and episode_steps >= max_episode_length - 1:
                        done = True
                    episode_steps += 1

                    #goal_done = agent.goal_reached(action, goal)
                    goal_reached = agent.goal_reached(g_state, goal, g_next_state)

                    if (episode_steps % c) == 0:
                        goal_done = True

                    state = next_state
                    g_state = g_next_state
                    goal = next_goal

        episode_durations[np.round(scale, 2)].append(overall_reward / num_episodes)

In [None]:
episodes = {}
for scale in np.arange(1.0,7.01,0.5):
    episodes[np.round(scale, 2)] = []

n_observations = env.observation_space.shape[0]
n_actions = env.action_space.shape[0]

i = 0
while i < 12:
    #agent = train_model()
    agent = HIRO(n_observations, n_actions).to(device)
    load_model(agent, f"hiro_{i}")

    if agent is not None:
        # goal_attack, action_attack, same_noise
        eval_scale(agent, episodes)
        print(f"{i} scale: {episodes}")
        i += 1

print("----")
print(f"scale: {episodes}")

0 scale: {1.0: [61.642999999976745], 1.5: [49.19700000000038], 2.0: [57.66299999998722], 2.5: [71.84899999997542], 3.0: [98.36299999999699], 3.5: [95.78799999999414], 4.0: [95.11199999999067], 4.5: [88.78199999998891], 5.0: [84.03299999998958], 5.5: [75.42999999997562], 6.0: [39.41200000001294], 6.5: [71.99299999998667], 7.0: [71.91799999997758]}
1 scale: {1.0: [61.642999999976745, -16.05299999999719], 1.5: [49.19700000000038, -12.458999999997713], 2.0: [57.66299999998722, 84.34099999998456], 2.5: [71.84899999997542, 75.8269999999781], 3.0: [98.36299999999699, 79.54199999998632], 3.5: [95.78799999999414, 92.30299999998806], 4.0: [95.11199999999067, 90.75999999998909], 4.5: [88.78199999998891, 81.72999999998848], 5.0: [84.03299999998958, 91.30699999998679], 5.5: [75.42999999997562, 74.05399999998431], 6.0: [39.41200000001294, 48.99099999999645], 6.5: [71.99299999998667, 47.694000000005126], 7.0: [71.91799999997758, 55.3779999999856]}
2 scale: {1.0: [61.642999999976745, -16.0529999999971

In [None]:
def eval_starting_position(agent, episode_durations):
    agent.eval()
    agent.meta_controller.eval()
    agent.controller.eval()

    max_episode_length = 500
    agent.meta_controller.is_training = False
    agent.controller.is_training = False

    num_episodes = 100

    c = 10

    for extra_range in np.arange(0.0, 0.401, 0.05):
        
        overall_reward = 0
        for i_episode in range(num_episodes):
            observation = env.reset()

            extra = np.random.uniform(-0.1 - extra_range, 0.1 + extra_range, env.starting_point.shape)
            #extra = np.random.uniform(0.1, 0.1 + extra_range, env.starting_point.shape)
            #extra = extra * (2*np.random.randint(0,2,size=env.starting_point.shape)-1)
            env.unwrapped.state = np.array(env.starting_point + extra, dtype=np.float32)
            env.unwrapped.state[2] = env.state[2] % (2 * math.pi)
            observation = env.normalised_state()

            state = torch.from_numpy(observation).float().unsqueeze(0).to(device)
            g_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)

            episode_steps = 0
            done = False
            while not done:
                # select a goal
                goal = agent.select_goal(g_state, False, False)

                goal_done = False
                while not done and not goal_done:
                    action = agent.select_action(state, goal, False, False)
                    observation, reward, done, info = env.step(action.detach().cpu().squeeze(0).numpy())
                    
                    next_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)
                    g_next_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)

                    next_goal = agent.h(g_state, goal, g_next_state)
                                      
                    overall_reward += reward

                    if max_episode_length and episode_steps >= max_episode_length - 1:
                        done = True
                    episode_steps += 1

                    #goal_done = agent.goal_reached(action, goal)
                    goal_reached = agent.goal_reached(g_state, goal, g_next_state)

                    if (episode_steps % c) == 0:
                        goal_done = True

                    state = next_state
                    g_state = g_next_state
                    goal = next_goal

        episode_durations[np.round(extra_range, 3)].append(overall_reward / num_episodes)

In [None]:
episodes = {}
for extra_range in np.arange(0.0, 0.401, 0.05):
    episodes[np.round(extra_range, 3)] = []

n_observations = env.observation_space.shape[0]
n_actions = env.action_space.shape[0]

env = NormalizedEnv(PointMazeEnv(4))
i = 0
while i < 12:
    #agent = train_model()
    agent = HIRO(n_observations, n_actions).to(device)
    load_model(agent, f"hiro_{i}")

    if agent is not None:
        # goal_attack, action_attack, same_noise
        eval_starting_position(agent, episodes)
        print(f"{i} range: {episodes}")
        i += 1

print("----")
print(f"range: {episodes}")

0 range: {0.0: [92.68199999998889], 0.05: [91.45199999998897], 0.1: [90.35899999998408], 0.15: [91.99299999998702], 0.2: [90.6069999999873], 0.25: [91.52899999998594], 0.3: [91.2769999999874], 0.35: [89.3809999999913], 0.4: [92.83999999998858]}
1 range: {0.0: [92.68199999998889, 85.62199999998724], 0.05: [91.45199999998897, 85.78699999998517], 0.1: [90.35899999998408, 85.44799999998403], 0.15: [91.99299999998702, 77.44299999998843], 0.2: [90.6069999999873, 82.67499999998213], 0.25: [91.52899999998594, 86.07499999998869], 0.3: [91.2769999999874, 86.12299999998686], 0.35: [89.3809999999913, 77.39299999998255], 0.4: [92.83999999998858, 84.80599999998846]}
2 range: {0.0: [92.68199999998889, 85.62199999998724, 93.84399999999265], 0.05: [91.45199999998897, 85.78699999998517, 93.99599999999002], 0.1: [90.35899999998408, 85.44799999998403, 94.05099999999021], 0.15: [91.99299999998702, 77.44299999998843, 94.02599999999258], 0.2: [90.6069999999873, 82.67499999998213, 95.36299999999184], 0.25: [9

In [None]:
state_max = torch.from_numpy(env.observation_space.high).to(device).float()
state_min = torch.from_numpy(env.observation_space.low).to(device).float()
state_mid = (state_max + state_min) / 2.
state_range = (state_max - state_min)
def save_trajectories(agent, episode_durations, dirty):
    agent.eval()
    agent.meta_controller.eval()
    agent.controller.eval()

    max_episode_length = 500
    agent.meta_controller.is_training = False
    agent.controller.is_training = False

    num_episodes = 10

    c = 10

    l2norm = 0.3
    episode_durations.append([])
    
    for i_episode in range(num_episodes):
        path = {"overall_reward": 0, "manager": [], "worker": []}

        observation = env.reset()

        state = torch.from_numpy(observation).float().unsqueeze(0).to(device)
        state_ = torch.from_numpy(observation).float().unsqueeze(0).to(device)
        g_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)
        g_state_ = torch.from_numpy(observation).float().unsqueeze(0).to(device)

        noise = torch.FloatTensor(state.shape).uniform_(-l2norm, l2norm).to(device)

        if dirty:
            g_state = g_state + state_range * noise
            g_state = torch.max(torch.min(g_state, state_max), state_min).float()
        if dirty:
            state = state + state_range * torch.FloatTensor(state.shape).uniform_(-l2norm, l2norm).to(device)
            state = torch.max(torch.min(state, state_max), state_min).float()

        episode_steps = 0
        overall_reward = 0
        done = False
        while not done:
            # select a goal
            goal = agent.select_goal(g_state, False, False)
            path["manager"].append((episode_steps, g_state_.detach().cpu().squeeze(0).numpy(), goal.detach().cpu().squeeze(0).numpy()))

            goal_done = False
            while not done and not goal_done:
                action = agent.select_action(state, goal, False, False)
                path["worker"].append((episode_steps, torch.cat([state_, goal], 1).detach().cpu().squeeze(0).numpy(), action.detach().cpu().squeeze(0).numpy()))
                observation, reward, done, info = env.step(action.detach().cpu().squeeze(0).numpy())
                
                next_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)
                g_next_state = torch.from_numpy(observation).float().unsqueeze(0).to(device)
                state_ = torch.from_numpy(observation).float().unsqueeze(0).to(device)
                g_state_ = torch.from_numpy(observation).float().unsqueeze(0).to(device)

                noise = torch.FloatTensor(state.shape).uniform_(-l2norm, l2norm).to(device)
                if dirty:
                    g_next_state = g_next_state + state_range * noise
                    g_next_state = torch.max(torch.min(g_next_state, state_max), state_min).float()
                if dirty:
                    next_state = next_state + state_range * torch.FloatTensor(next_state.shape).uniform_(-l2norm, l2norm).to(device)
                    next_state = torch.max(torch.min(next_state, state_max), state_min).float()

                next_goal = agent.h(g_state, goal, g_next_state)
                                  
                overall_reward += reward

                if max_episode_length and episode_steps >= max_episode_length - 1:
                    done = True
                episode_steps += 1

                #goal_done = agent.goal_reached(action, goal)
                goal_reached = agent.goal_reached(g_state, goal, g_next_state)

                if (episode_steps % c) == 0:
                    goal_done = True

                state = next_state
                g_state = g_next_state
                goal = next_goal

        path["overall_reward"] = overall_reward
        episode_durations[-1].append(path)

In [None]:
episodes = []

n_observations = env.observation_space.shape[0]
n_actions = env.action_space.shape[0]

i = 0
while i < 12:
    #agent = train_model()
    agent = HIRO(n_observations, n_actions).to(device)
    load_model(agent, f"hiro_{i}")

    if agent is not None:
        # goal_attack, action_attack, same_noise
        save_trajectories(agent, episodes, True)
        #print(f"{i} paths: {episodes}")
        i += 1

print("----")
#print(f"paths: {episodes}")

episodes.pop(5)
episodes.pop(8 - 1)
torch.save(episodes, "PointMaze_dirty_eps.pt")

----
