In [1]:
import numpy as np

# Generic replay buffer for standard gym tasks
class TaskBuffer(object):
    def __init__(self, state_dim, action_dim, sf_dim, buffer_size, device):
        self.max_size = int(buffer_size)
        self.device = device

        self.ptr = 0
        self.crt_size = 0

        self.state = np.zeros((self.max_size, state_dim))
        self.action = np.zeros((self.max_size, action_dim))
        self.next_state = np.zeros((self.max_size, state_dim))
        self.reward = np.zeros((self.max_size, 1))
        self.not_done = np.zeros((self.max_size, 1))
        self.task = np.zeros((self.max_size, sf_dim))


    def add(self, state, action, next_state, reward, done, task):
        self.state[self.ptr] = state
        self.action[self.ptr] = action
        self.next_state[self.ptr] = next_state
        self.reward[self.ptr] = reward
        self.not_done[self.ptr] = 1. - done
        self.task[self.ptr] = task

        self.ptr = (self.ptr + 1) % self.max_size
        self.crt_size = min(self.crt_size + 1, self.max_size)

    def sample(self, batch_size=32):
        ind = np.random.randint(0, self.crt_size, size=batch_size)
        return (
            torch.FloatTensor(self.state[ind]).to(self.device),
            torch.LongTensor(self.action[ind]).to(self.device),
            torch.FloatTensor(self.next_state[ind]).to(self.device),
            torch.FloatTensor(self.reward[ind]).to(self.device),
            torch.FloatTensor(self.not_done[ind]).to(self.device),
            torch.FloatTensor(self.task[ind]).to(self.device),
        )


    def save(self, save_folder):
        np.save(f"{save_folder}_state.npy", self.state[:self.crt_size])
        np.save(f"{save_folder}_action.npy", self.action[:self.crt_size])
        np.save(f"{save_folder}_next_state.npy", self.next_state[:self.crt_size])
        np.save(f"{save_folder}_reward.npy", self.reward[:self.crt_size])
        np.save(f"{save_folder}_not_done.npy", self.not_done[:self.crt_size])
        np.save(f"{save_folder}_task.npy", self.task[:self.crt_size])
        np.save(f"{save_folder}_ptr.npy", self.ptr)


    def load(self, save_folder, size=-1):
        reward_buffer = np.load(f"{save_folder}_reward.npy")

        # Adjust crt_size if we're using a custom size
        size = min(int(size), self.max_size) if size > 0 else self.max_size
        self.crt_size = min(reward_buffer.shape[0], size)

        self.state[:self.crt_size] = np.load(f"{save_folder}_state.npy")[:self.crt_size]
        self.action[:self.crt_size] = np.load(f"{save_folder}_action.npy")[:self.crt_size]
        self.next_state[:self.crt_size] = np.load(f"{save_folder}_next_state.npy")[:self.crt_size]
        self.reward[:self.crt_size] = reward_buffer[:self.crt_size]
        self.not_done[:self.crt_size] = np.load(f"{save_folder}_not_done.npy")[:self.crt_size]
        self.task[:self.crt_size] = np.load(f"{save_folder}_task.npy")[:self.crt_size]
        print(f"Replay Buffer loaded with {self.crt_size} elements.")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

import numpy as np
from tqdm import tqdm
import copy

from utils import RMS, PBE

def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)

class GaussianPolicy(nn.Module):
    def __init__(
        self,
        state_dim,
        action_dim,
        sf_dim=5,
        hidden_dim=256,
        action_scale=1.0,
        action_bias=0.0,
        max_log_std=2,
        min_log_std=-20,
        repr_noise=1e-6,
    ):
        super().__init__()
        self.fc1 = nn.Linear(state_dim+sf_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mu = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)

        self.action_scale = action_scale
        self.action_bias = action_bias
        self.max_log_std = max_log_std
        self.min_log_std = min_log_std
        self.repr_noise = repr_noise

        self.apply(weights_init_)

    def forward(self, state, task):
        cat = torch.cat([state, task], dim=1)
        h = F.relu(self.fc1(cat))
        h = F.relu(self.fc2(h))
        mu = self.mu(h)
        log_std = self.log_std(h)
        log_std = torch.clamp(log_std, min=self.min_log_std, max=self.max_log_std)

        return mu, log_std

    def sample(self, state, task):
        mean, log_std = self.forward(state, task)
        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)

        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + self.repr_noise)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean

    def act(self, state, task):
        # print(state.shape, self.fc1)
        mean, _ = self.forward(state, task)
        return torch.tanh(mean) * self.action_scale + self.action_bias


class Critic(nn.Module):
    def __init__(
        self,
        state_dim,
        action_dim,
        sf_dim=5,
        hidden_dim=256,
    ):
        super().__init__()
        def make_critic():
            critic = nn.Sequential(
                nn.Linear(state_dim + action_dim + sf_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, sf_dim),
            )

            return critic
        
        self.critic1 = make_critic()
        self.critic2 = make_critic()
        self.apply(weights_init_)
    
    def forward(self, state, action, task):
        cat = torch.cat([state, action, task], dim=1)
        SF1, SF2 = self.get_SF(state, action, task)
        return (
            torch.einsum('bi,bi->b', SF1, task).unsqueeze(-1),
            torch.einsum('bi,bi->b', SF2, task).unsqueeze(-1),
        )
    
    def get_SF(self, state, action, task):
        cat = torch.cat([state, action, task], dim=1)
        return (
            self.critic1(cat).squeeze(-1),
            self.critic2(cat).squeeze(-1)
        )
    
class Phi(nn.Module):
    def __init__(
        self,
        state_dim,
        sf_dim=5,
        hidden_dim=256,
    ):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, sf_dim)
        
    def forward(self, state, norm=True):
        out = F.relu(self.fc1(state))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        if norm:
            return F.normalize(out, dim=-1)
        return out
    
class IdentityPhi(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()
        self.fc1 = nn.Linear(1, 1)
        
    def forward(self, state, norm=True):
        if norm:
            return F.normalize(state, dim=-1)
        return state 
    
    
class SAC_APS(object):
    def __init__(
        self,
        env,
        # essential args for RL
        rollout_model=None,
        discount=0.99,
        tau=5e-3,
        # the following args are for networks
        actor_lr=3e-4,
        critic_lr=3e-4,
        learn_alpha=True,
        # flag for using aps or not
        aps=False,
        # the following args for aps
        sf_dim=5,
        phi_lr=3e-4,
        task_lr=3e-4,
        update_task_frequency=5,
        knn_k=12,
        knn_rms=True,
        knn_avg=True,
        knn_clip=1e-4,
    ):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.env = env
        self.discount = discount
        self.tau = tau
        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.shape[0]
        self.max_action = float(self.env.action_space.high[0])
        
        if not aps:
            sf_dim = env.observation_space.shape[0]
            self.sf_dim = sf_dim
            self.phi = IdentityPhi()
        else:
            self.sf_dim = sf_dim
            self.phi = Phi(
                state_dim=self.state_dim,
                sf_dim=self.sf_dim
            ).to(self.device)
        self.phi_opt = torch.optim.Adam(self.phi.parameters(), lr=phi_lr)

        self.rollout_model = rollout_model
        if rollout_model is not None:
            self.model_based = True
        else:
            self.model_based = False

        # set networks 
        self.actor = GaussianPolicy(
            state_dim=self.state_dim,
            action_dim=self.action_dim,
            sf_dim=self.sf_dim,
            action_scale=self.max_action,
        ).to(self.device)
        
        self.critic = Critic(
            state_dim=self.state_dim,
            action_dim=self.action_dim,
            sf_dim=self.sf_dim,
        ).to(self.device)
        self.critic_target = copy.deepcopy(self.critic)

        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)

        self.learn_alpha = learn_alpha
        if learn_alpha:
            self.entropy_target = -np.prod(
                    self.env.action_space.shape).item()
            self.entropy_target = torch.tensor(self.entropy_target).to(self.device) # convert the numpy.ndarray to torch.tensor
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
            self.alpha_opt = torch.optim.Adam(
                [self.log_alpha],
                lr=actor_lr,
            )
            
        # the following is for aps 
        self.sf_dim = sf_dim
        self.update_task_frequency = update_task_frequency
        self.knn_k = knn_k
        self.knn_rms = knn_rms
        self.knn_avg = knn_avg
        self.knn_clip = knn_clip
        self.temp_task = None # temp tasks are set through function : set_task
        
        self.task_learned = False
        self.task = torch.ones((1, self.sf_dim), requires_grad=True, device=self.device) # task is only for RL phase (supervised learning phase)
        self.task_opt = torch.optim.Adam([self.task], lr=task_lr)
        
        self.RMS = RMS(self.device)
        self.PBE = PBE(
            rms=self.RMS,
            knn_clip=self.knn_clip, 
            knn_k=self.knn_k, 
            knn_avg=self.knn_avg, 
            knn_rms=self.knn_rms,
            device=self.device
        )
        
        self.iterations = 0
            
    def create_empty_replay_buffer(self, buffer_size=int(1e6)):
        return TaskBuffer(
            state_dim=self.state_dim,
            action_dim=self.action_dim,
            sf_dim=self.sf_dim,
            buffer_size=buffer_size,
            device=self.device,
        )

    def set_rollout_model(self, rollout_model):
        assert rollout_model is not None, 'Cannot set a none-type rollout model'
        self.rollout_model = rollout_model
        self.model_based = True
        
    def intrinsic_reward(self, task, next_state):
        # maxent reward
        with torch.no_grad():
            state_representation = self.phi(next_state, norm=False)
        bonus = self.PBE(state_representation)
        ent_bonus = bonus.view(-1, 1)

        # successor feature reward
        state_representation = state_representation / torch.norm(state_representation, dim=1, keepdim=True)
        sf_reward = torch.einsum("bi,bi->b", task, state_representation).reshape(-1, 1)

        return ent_bonus, sf_reward
    
    def maybe_set_task(self, fine_tune):
        if fine_tune and self.task_learned:
            return self.task
        elif (self.iterations + 1) % self.update_task_frequency:
            return self.set_task()
        
        return self.temp_task
    
    def set_task(self):
        task = torch.randn(self.sf_dim).to(self.device)
        task = task / torch.norm(task)
        self.temp_task = task
        return task

    def select_action(self, state, task, deterministic=False):
        with torch.no_grad():
            state = torch.from_numpy(state.reshape(1, -1)).float().to(self.device)
            task = task.float().view(1, task.shape[-1])
            if deterministic:
                action = self.actor.act(state, task)
            else:
                action, _, _ = self.actor.sample(state, task)
        return action.squeeze(0).data.cpu().numpy()

    def train(self, replay_buffer, batch_size, fine_tune):
        state, action, next_state, reward, not_done, task = replay_buffer.sample(batch_size)
        
        if fine_tune:
            reward = reward
            # task = self.task.repeat(batch_size, 1) # use the regressed task
        else:
            with torch.no_grad():
                ent_bonus, sf_reward = self.intrinsic_reward(task, next_state)
                aps_reward = ent_bonus + sf_reward
                reward = aps_reward
            # calculte Phi-loss
            phi_loss = - torch.einsum("bi,bi->b", self.phi(next_state), task).mean()
            # update Phi and Encoder
            self.phi_opt.zero_grad()
            phi_loss.backward()
            self.phi_opt.step()
        
        # first train the alpha
        if self.learn_alpha:
            new_action, log_prob, mean = self.actor.sample(state, task)
            log_prob = log_prob.unsqueeze(-1)
            alpha_loss = -(self.log_alpha * (log_prob + self.entropy_target).detach()).mean()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1
        
        
        q_new_action = torch.min(
            *self.critic(state, new_action, task)
        )
        actor_loss = (alpha*log_prob - q_new_action).mean()

        with torch.no_grad():
            next_action, new_log_prob, next_mean = self.actor.sample(next_state, task)
            new_log_prob = new_log_prob.unsqueeze(-1)
            target_Q = torch.min(
                *self.critic_target(next_state, next_action, task)
            ) - alpha * new_log_prob
            target_Q = reward + not_done * self.discount * target_Q
        
        Q1, Q2 = self.critic(state, action, task)
        critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)

        # update networks
        if self.learn_alpha:
            self.alpha_opt.zero_grad()
            alpha_loss.backward()
            self.alpha_opt.step()

        self.actor_opt.zero_grad()
        actor_loss.backward()
        self.actor_opt.step()

        self.critic_opt.zero_grad()
        critic_loss.backward()
        self.critic_opt.step()
        
        self.maybe_regress_task(replay_buffer, batch_size, epoch_size=1)

        self.update_target_network()
        
        self.iterations += 1
    
    def maybe_regree_task(replay_buffer, batch_size, epoch_size=1):
        if replay_buffer.size > 4096 and self.iterations % self.update_task_frequency == 0:
            self.regress_task(replay_buffer, batch_size, epoch_size)

    def regress_task(self, replay_buffer, batch_size=32, epoch_size=1):
        # there are two options, one is using gradient descent, the other is using torch.linalg.lstsq
        self.task_learned = True
        for epoch in range(epoch_size):
            state, action, next_state, reward, *_ = replay_buffer.sample(batch_size=batch_size)
            with torch.no_grad():
                representation = self.phi(next_state)
                
            # estimate the reward
            estimated_reward = torch.einsum("bi,bi->b", representation, self.task.repeat(batch_size, 1)) # 256 by 169, 1 by 169
            
            # update the task
            task_loss = F.mse_loss(reward, estimated_reward)
            self.task_opt.zero_grad()
            task_loss.backward()
            self.task_opt.step()

    def learn(self, replay_buffer, step_size, expl_step_size, fine_tune=False, batch_size=32, eval_freq=int(5e3)):
        # make aliases
        env = self.env
        state = env.reset()
        evaluations = []
        episode_reward = 0
        episode_timesteps = 0
        episode_num = 0
        expl_step_size = max(batch_size, expl_step_size)
        task = self.set_task()
        
        if not fine_tune:
            eval_freq *= 10

        for step in tqdm(range(step_size)):
            task = self.maybe_set_task(fine_tune)
            episode_timesteps += 1  
            # pick the action
            if step < expl_step_size:
                action = env.action_space.sample()
            else:
                action = self.select_action(np.array(state), task, deterministic=False)
            
            # perform the action
            next_state, reward, done, _ = env.step(action)
            done_bool = float(done) if episode_timesteps < env._max_episode_steps else 0
            
            # add the transition to the replay buffer
            replay_buffer.add(state, action, next_state, reward, done_bool, task.cpu().data.numpy())
            
            state = next_state
            episode_reward += reward
            
            if step >= expl_step_size:
                self.train(replay_buffer, batch_size, fine_tune)
            
            if done: 
                # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True
                # print(f"Total T: {step+1} Episode Num: {episode_num+1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f}")
                # Reset environment
                state, done = env.reset(), False
                episode_reward = 0
                episode_timesteps = 0
                episode_num += 1
            
            if (step + 1) % eval_freq == 0:
                evaluations.append(self.evaluate())
                # np.save(f"./results/{file_name}", evaluations)
                # if args.save_model: policy.save(f"./models/{file_name}")
                
        return evaluations

    def evaluate(self, seed=0, eval_episodes=10):
        eval_env = self.env

        eval_env.seed(seed + 100)

        avg_reward = 0.
        for _ in range(eval_episodes):
            state, done = eval_env.reset(), False
            while not done:
                action = self.select_action(np.array(state), self.task, deterministic=True)
                # print(state.shape, np.array(state).shape)
                next_state, reward, done, _ = eval_env.step(action)
                # print(state.shape, next_state.shape)
                state = next_state
                avg_reward += reward

        avg_reward /= eval_episodes

        print("---------------------------------------")
        print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
        print("---------------------------------------")
        return avg_reward
    
    def update_target_network(self):
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

In [3]:
import gym
env = gym.make('Hopper-v2')
# env = gym.make('Hopper-v2')

In [6]:
aps = SAC_APS(
    env,
    sf_dim=5,
    aps=True
)
replay_buffer = aps.create_empty_replay_buffer()
evaluations = aps.learn(
    replay_buffer,
    step_size=int(1e6),
    batch_size=256,
    expl_step_size=int(1e4),
    fine_tune=False
)

  5%|▌         | 50008/1000000 [11:44<6:56:25, 38.02it/s] 

---------------------------------------
Evaluation over 10 episodes: 6.107
---------------------------------------


 10%|█         | 100008/1000000 [26:30<8:19:57, 30.00it/s] 

---------------------------------------
Evaluation over 10 episodes: 44.564
---------------------------------------


 15%|█▌        | 150006/1000000 [41:24<4:48:00, 49.19it/s]

---------------------------------------
Evaluation over 10 episodes: 6.784
---------------------------------------


 20%|██        | 200005/1000000 [56:19<5:28:43, 40.56it/s]

---------------------------------------
Evaluation over 10 episodes: 5.707
---------------------------------------


 25%|██▌       | 250009/1000000 [1:11:16<4:44:11, 43.98it/s]

---------------------------------------
Evaluation over 10 episodes: 8.647
---------------------------------------


 30%|███       | 300010/1000000 [1:26:01<3:44:41, 51.92it/s]

---------------------------------------
Evaluation over 10 episodes: 8.567
---------------------------------------


 35%|███▌      | 350011/1000000 [1:40:42<3:21:47, 53.69it/s]

---------------------------------------
Evaluation over 10 episodes: 6.414
---------------------------------------


 40%|████      | 400007/1000000 [1:55:37<3:30:36, 47.48it/s]

---------------------------------------
Evaluation over 10 episodes: 5.835
---------------------------------------


 45%|████▌     | 450009/1000000 [2:10:27<3:02:59, 50.09it/s]

---------------------------------------
Evaluation over 10 episodes: 7.134
---------------------------------------


 50%|█████     | 500009/1000000 [2:25:51<3:41:50, 37.56it/s]

---------------------------------------
Evaluation over 10 episodes: 10.172
---------------------------------------


 55%|█████▌    | 550005/1000000 [2:41:21<4:28:57, 27.88it/s]

---------------------------------------
Evaluation over 10 episodes: 24.204
---------------------------------------


 60%|██████    | 600010/1000000 [2:56:17<3:07:24, 35.57it/s]

---------------------------------------
Evaluation over 10 episodes: 19.531
---------------------------------------


 65%|██████▌   | 650006/1000000 [3:11:22<1:54:48, 50.81it/s]

---------------------------------------
Evaluation over 10 episodes: 8.071
---------------------------------------


 70%|███████   | 700005/1000000 [3:26:23<2:32:34, 32.77it/s]

---------------------------------------
Evaluation over 10 episodes: 10.590
---------------------------------------


 75%|███████▌  | 750007/1000000 [3:42:27<2:12:54, 31.35it/s]

---------------------------------------
Evaluation over 10 episodes: 101.086
---------------------------------------


 80%|████████  | 800008/1000000 [3:58:45<1:22:16, 40.51it/s]

---------------------------------------
Evaluation over 10 episodes: 14.581
---------------------------------------


 85%|████████▌ | 850005/1000000 [4:14:26<1:06:24, 37.64it/s]

---------------------------------------
Evaluation over 10 episodes: 41.755
---------------------------------------


 90%|█████████ | 900010/1000000 [4:30:00<39:29, 42.20it/s]  

---------------------------------------
Evaluation over 10 episodes: 40.436
---------------------------------------


 95%|█████████▌| 950007/1000000 [4:45:10<20:50, 39.99it/s]

---------------------------------------
Evaluation over 10 episodes: 41.726
---------------------------------------


100%|██████████| 1000000/1000000 [5:00:15<00:00, 55.51it/s]

---------------------------------------
Evaluation over 10 episodes: 44.102
---------------------------------------





In [7]:
ft_replay_buffer = aps.create_empty_replay_buffer()
ft_evaluations = aps.learn(
    ft_replay_buffer,
    step_size=int(1e5),
    batch_size=256,
    expl_step_size=int(1e4),
    fine_tune=True,
)

  5%|▌         | 5413/100000 [00:01<00:32, 2871.11it/s]

---------------------------------------
Evaluation over 10 episodes: 44.102
---------------------------------------


 10%|█         | 10007/100000 [00:03<00:41, 2158.16it/s]

---------------------------------------
Evaluation over 10 episodes: 44.102
---------------------------------------


 15%|█▌        | 15007/100000 [01:17<25:42, 55.11it/s]  

---------------------------------------
Evaluation over 10 episodes: 39.430
---------------------------------------


 20%|██        | 20013/100000 [02:26<25:22, 52.53it/s]

---------------------------------------
Evaluation over 10 episodes: 39.397
---------------------------------------


 25%|██▌       | 25012/100000 [03:37<22:43, 54.98it/s]

---------------------------------------
Evaluation over 10 episodes: 39.351
---------------------------------------


 30%|███       | 30013/100000 [04:51<20:17, 57.48it/s]

---------------------------------------
Evaluation over 10 episodes: 39.573
---------------------------------------


 35%|███▌      | 35008/100000 [06:06<20:12, 53.59it/s]

---------------------------------------
Evaluation over 10 episodes: 39.921
---------------------------------------


 40%|████      | 40012/100000 [07:17<18:23, 54.36it/s]

---------------------------------------
Evaluation over 10 episodes: 39.693
---------------------------------------


 45%|████▌     | 45013/100000 [08:27<15:31, 59.05it/s]

---------------------------------------
Evaluation over 10 episodes: 40.097
---------------------------------------


 50%|█████     | 50009/100000 [09:45<17:14, 48.30it/s]

---------------------------------------
Evaluation over 10 episodes: 39.906
---------------------------------------


 55%|█████▌    | 55008/100000 [11:03<15:16, 49.12it/s]

---------------------------------------
Evaluation over 10 episodes: 39.890
---------------------------------------


 60%|██████    | 60011/100000 [12:18<13:36, 48.95it/s]

---------------------------------------
Evaluation over 10 episodes: 39.882
---------------------------------------


 65%|██████▌   | 65007/100000 [13:39<13:30, 43.17it/s]

---------------------------------------
Evaluation over 10 episodes: 39.862
---------------------------------------


 70%|███████   | 70012/100000 [15:00<09:48, 50.96it/s]

---------------------------------------
Evaluation over 10 episodes: 40.047
---------------------------------------


 75%|███████▌  | 75011/100000 [16:17<07:55, 52.56it/s]

---------------------------------------
Evaluation over 10 episodes: 40.219
---------------------------------------


 80%|████████  | 80006/100000 [17:34<07:50, 42.46it/s]

---------------------------------------
Evaluation over 10 episodes: 40.236
---------------------------------------


 85%|████████▌ | 85011/100000 [18:54<04:58, 50.29it/s]

---------------------------------------
Evaluation over 10 episodes: 40.246
---------------------------------------


 90%|█████████ | 90011/100000 [20:12<03:25, 48.69it/s]

---------------------------------------
Evaluation over 10 episodes: 40.237
---------------------------------------


 95%|█████████▌| 95008/100000 [21:31<01:35, 52.17it/s]

---------------------------------------
Evaluation over 10 episodes: 40.234
---------------------------------------


100%|██████████| 100000/100000 [22:51<00:00, 72.91it/s]

---------------------------------------
Evaluation over 10 episodes: 40.054
---------------------------------------





In [8]:
ft_evaluations_2 = aps.learn(
    ft_replay_buffer,
    step_size=int(1e5),
    batch_size=256,
    expl_step_size=int(1e4),
    fine_tune=True,
)

  6%|▌         | 5755/100000 [00:01<00:32, 2938.05it/s]

---------------------------------------
Evaluation over 10 episodes: 40.054
---------------------------------------




---------------------------------------
Evaluation over 10 episodes: 40.054
---------------------------------------


 15%|█▌        | 15007/100000 [01:11<26:52, 52.71it/s] 

---------------------------------------
Evaluation over 10 episodes: 40.046
---------------------------------------


 20%|██        | 20007/100000 [02:26<26:16, 50.75it/s]

---------------------------------------
Evaluation over 10 episodes: 40.044
---------------------------------------


 20%|██        | 20153/100000 [02:28<09:50, 135.33it/s]


KeyboardInterrupt: 