In [1]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torch.nn.init import kaiming_uniform_
from torch.distributions import Normal

In [2]:
torch.normal(mean=torch.arange(1., 6.), std=torch.arange(1., 6.))

tensor([ 2.6046, -1.6558,  1.5283,  6.9046,  5.3011])

In [3]:
import numpy as np
import gym
from tqdm import tqdm
import random as rand
from itertools import count

In [4]:
class ReplayMemory():
    def __init__(self,capacity):   
        self.capacity = capacity
        self.memory = []
        self.push_count = 0
        
    def push(self, experience):
        if len(self.memory) < self.capacity:
            self.memory.append(experience)
        else:
            self.memory[self.push_count%self.capacity] = experience
        self.push_count+=1
    
    def sample(self, batch_size):
        return rand.sample(self.memory,batch_size)
    
    def can_provide_sample(self, batch_size):
        return len(self.memory)>=batch_size
    
    def update_td_error(self, sampled_experiences):
        for sampled_idx,sampled_exp in enumerate(sampled_experiences):
            for mem_idx, mem_exp in enumerate(self.memory):
                if mem_exp.timestep == sampled_exp.timestep:
                    self.memory[mem_idx] = sampled_exp #update memory
                    break
        
    def get_memory_values(self):
        return self.memory    

In [5]:
def extract_tensors(experiences):
    #print(".....................................................")
    #print(experiences)
    batch = Xp(*zip(*experiences))
    state = np.stack(batch.state) #stack
    action = np.stack(batch.action)
    next_state = np.stack(batch.next_state)
    reward = np.stack(batch.reward)
    done = np.stack(batch.done)
    abs_td_error = np.stack(batch.abs_td_error)
    timestep = np.stack(batch.timestep)
    return state,action,next_state,reward,done,abs_td_error,timestep

In [6]:
def rebuild_experiences(state, action, next_state, reward, done, abs_error, timestep):
    exp_list = []
    for idx_ in range(len(state)):
        exp_list.append(\
                        Xp(state[idx_], action[idx_], next_state[idx_], reward[idx_],\
                           done[idx_], abs_error[idx_], timestep[idx_]))
    return exp_list

In [7]:
from collections import namedtuple
Xp = namedtuple('Experience',
                        ('state', 'action', 'next_state', 'reward', 'done', 'abs_td_error','timestep'))
Xp_points = Xp(5,6,7,8,9,10,11)
Xp_points

Experience(state=5, action=6, next_state=7, reward=8, done=9, abs_td_error=10, timestep=11)

In [8]:
def prioritize_samples(experience_samples, alpha, beta):
    state,action,next_state,reward,done,abs_td_error,timesteps = extract_tensors(experience_samples)
    #rank based
    #('state', 'action', 'next_state', 'reward', 'done', 'abs_td_error','timestep')
    abs_td_error  = np.expand_dims(abs_td_error, axis=1)
    abs_td_error = torch.tensor(abs_td_error)
    abs_td_error, indices_ = abs_td_error.sort(0, descending=True)#big to small
    indices = np.arange(1, len(abs_td_error)+1)
    priorities = 1.0/indices
    priorities = priorities**alpha#scale by alpha
    priorities = np.expand_dims(priorities, axis=1)
    probabilities = priorities/np.sum(priorities, axis=0)#sums up to 1(or 0.999999)
    assert np.isclose(probabilities.sum(), 1.0)#ensures probs add up to 1
    
    number_of_samples  = len(probabilities)
    weight_importance_ = number_of_samples*probabilities
    weight_importance_ = weight_importance_**-beta
    weight_importance_max = np.max(weight_importance_)
    weight_importance_scaled = weight_importance_/weight_importance_max
    return weight_importance_scaled, indices_ #return weight important samples, return indices for re_arranging sampled experiences

In [9]:
class linearApproximator_FCGSAP(nn.Module):
    def __init__(self,state_shape,outputs,hidden_dims=(32,32), log_entropy_lr =0.0001,\
                log_std_dev_min=-20, log_std_dev_max= 2, epsilon = 1e-6):
        super(linearApproximator_FCGSAP, self).__init__()
        self.input_size = state_shape
        self.out = outputs
        self.log_std_dev_min = log_std_dev_min
        self.log_std_dev_max = log_std_dev_max
        self.epsilon = epsilon
        self.device = torch.device("cuda" if torch.cuda.is_available()\
                                   else "cpu")
        
        self.fc1  = nn.Linear(self.input_size,hidden_dims[0])
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_dims)-1):
            hidden_layer = nn.Linear(\
                                hidden_dims[i], hidden_dims[i+1])
            self.hidden_layers.append(hidden_layer)
        
        self.output_layer_distribution  = nn.Linear(hidden_dims[-1],self.out)
        self.output_layer_mean = nn.Linear(hidden_dims[-1],self.out)
        
        self.target_entropy = -float(self.out)
        #self.target_entropy = self.target_entropy.float() 
        #according to the eq, log alpha is a learnable parameter
        self.log_alpha = torch.zeros(1,\
                                     requires_grad=True,\
                                     device = self.device)
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],\
                                                    lr=log_entropy_lr)
                                     
        self.to(self.device)
        
    def forward(self, state_shape):
        if not isinstance(state_shape, torch.Tensor):
            state_shape = torch.tensor(state_shape, dtype=torch.float32)
        state_shape = state_shape.to(self.device)
        x = self.fc1(state_shape)
        x = F.relu(x)
        
        for hidden_layer in self.hidden_layers:
            x = F.relu(hidden_layer(x))
        
        distribution = self.output_layer_distribution(x)#logits, preferences of actions
        mean   = self.output_layer_mean(x)
        distribution = torch.clamp(distribution, self.log_std_dev_min, self.log_std_dev_max)
        return mean, distribution
        
    def full_pass(self, state):
        mean, distribution = self.forward(state)
        pi_s = Normal(mean, distribution.exp())
        pre_sampled_actions = pi_s.rsample()
        sampled_actions = torch.tanh(pre_sampled_actions)#scale actions between -1 and 1
        #we also rescale our logprobs to match action space
        log_probs = pi_s.log_prob(pre_sampled_actions) - \
                                    torch.log((1 - sampled_actions.pow(2)).clamp(0,1) + \
                                                self.epsilon)
        log_probs = log_probs.sum(dim=1, keepdim=True)
        return sampled_actions, log_probs, mean

In [10]:
class linearApproximator_FCQV(nn.Module):#Q value of state action pair
    def __init__(self,state_shape,action_outputs_size,hidden_dims=(32,32)):
        super(linearApproximator_FCQV, self).__init__()
        self.input_size = state_shape
        self.action_outputs_size = action_outputs_size
        self.device = torch.device("cuda" if torch.cuda.is_available()\
                                   else "cpu")
        
        self.fc1  = nn.Linear(self.input_size,hidden_dims[0])
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_dims)-1):
            hidden_input_layer = hidden_dims[i]
            if i == 0:
                hidden_input_layer += self.action_outputs_size #increased to account for size/number of actions
            hidden_layer = nn.Linear(\
                                hidden_input_layer, hidden_dims[i+1])
            self.hidden_layers.append(hidden_layer)
        
        self.output_layer  = nn.Linear(hidden_dims[-1],1)
        self.to(self.device)
        
    def forward(self, state_shape, action_shape):
        if not isinstance(state_shape, torch.Tensor):
            state_shape = torch.tensor(state_shape, dtype=torch.float32).to(self.device)
        if not isinstance(action_shape, torch.Tensor):
            action_shape = torch.tensor(action_shape, dtype=torch.float32).to(self.device)
                    
        x = self.fc1(state_shape)
        x = F.relu(x)
        
        for idx, hidden_layer in enumerate(self.hidden_layers):
            if idx == 0:
                x = torch.cat((x, action_shape), dim=1)
            x = F.relu(hidden_layer(x))
        
        q_value = self.output_layer(x)#logits, preferences of actions
        return q_value

In [11]:
def update_networks(online_q_network_a, online_q_network_b,\
                    offline_q_network_a, offline_q_network_b, tau):
        
    for target_weights, online_weights in zip(offline_q_network_a.parameters(), online_q_network_a.parameters()):
        target_weight_update = (1.0 - tau)*target_weights.data
        online_weight_update = tau*online_weights.data
        sum_up = target_weight_update + online_weight_update
        target_weights.data.copy_(sum_up)
        
    for target_weights, online_weights in zip(offline_q_network_b.parameters(), online_q_network_b.parameters()):
        target_weight_update = (1.0 - tau)*target_weights.data
        online_weight_update = tau*online_weights.data
        sum_up = target_weight_update + online_weight_update
        target_weights.data.copy_(sum_up)

    return offline_q_network_a, offline_q_network_b

In [12]:
def update_online_model(experience_samples,\
                        online_policy_network, online_q_network_a, online_q_network_b,\
                        online_policy_optimizer, online_q_optimizer_a, online_q_optimizer_b,\
                        offline_q_network_a, offline_q_network_b,\
                        gamma, weighted_importance, indices):
    
    states, actions, next_states, rewards, done, td_errors, timesteps = extract_tensors(experience_samples)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #rearrange first
    indices = np.array(indices)#have to be same size as batch size(already taken care of)
    indices = np.squeeze(indices)
    states = states[indices,:]
    states = np.squeeze(states)
    actions = actions[indices]
    next_states = next_states[indices,:]
    next_states = np.squeeze(next_states)
    rewards = rewards[indices]
    done = done[indices]
    td_errors = td_errors[indices]
    timesteps = timesteps[indices]    
    
    states = torch.tensor(states).float().to(device)
    actions = torch.tensor(actions)
    actions = actions.float().to(device)
    next_states=torch.tensor(next_states).float().to(device)
    rewards = torch.tensor(rewards).float().to(device)
    rewards = rewards.unsqueeze(1)
    done = torch.tensor(done).float().to(device)
    done = done.unsqueeze(1)
    weighted_importance = torch.tensor(weighted_importance).float().to(device)
    
    
    current_actions,log_pi, _ = online_policy_network.full_pass(states)
    target_alpha = (log_pi +\
                    online_policy_network.target_entropy).detach()
    target_alpha_loss = -(online_policy_network.log_alpha *\
                         target_alpha).mean()
    online_policy_network.log_alpha_optimizer.zero_grad()
    target_alpha_loss.backward()
    online_policy_network.log_alpha_optimizer.step()
    optimized_alpha = online_policy_network.log_alpha.exp()
    
    
    max_q_sa_online_a = online_q_network_a(states, current_actions.detach())
    max_q_sa_online_b = online_q_network_b(states, current_actions.detach())
    max_q_sa_online_a*=weighted_importance
    max_q_sa_online_b*=weighted_importance
    max_q_online__ = torch.min(max_q_sa_online_a, max_q_sa_online_b)
    #max_q_online__*=(1 - done)
    #print("max_q_online__", max_q_online__)
    predicted_online_action_policy,\
                log_pi_ns, _ = online_policy_network.full_pass(next_states)

    policy_loss = (-optimized_alpha.detach() *(max_q_online__.detach() - log_pi)).mean()#policy loss
    
    max_q_sa_offline_a = offline_q_network_a(next_states, predicted_online_action_policy)
    max_q_sa_offline_b = offline_q_network_b(next_states, predicted_online_action_policy)
    max_q_sa_offline = torch.min(max_q_sa_offline_a, max_q_sa_offline_b)
    TWIN_target = max_q_sa_offline
    TWIN_target*=weighted_importance
    
    TWIN_target = TWIN_target.detach()
    TWIN_target *=(1 - done)
    TWIN_target = TWIN_target - optimized_alpha * log_pi_ns
    TWIN_target = rewards + gamma*TWIN_target
    
    #abs_a = abs(TWIN_target.detach().cpu().numpy() - max_q_sa_online_a.detach().cpu().numpy())
    #abs_b = abs(TWIN_target.detach().cpu().numpy() - max_q_sa_online_b.detach().cpu().numpy())
    #ovr_abs_update = (abs_a + abs_b)/2 #we get the mean(not done)
    
    TWIN_target = TWIN_target.detach()
    loss_func = torch.nn.SmoothL1Loss()
    
    
    q_sa_online_a = online_q_network_a(states, actions)
    q_sa_online_b = online_q_network_b(states, actions)
    
    abs_a = abs(TWIN_target.detach().cpu().numpy() - max_q_sa_online_a.detach().cpu().numpy())
    abs_b = abs(TWIN_target.detach().cpu().numpy() - max_q_sa_online_b.detach().cpu().numpy())
    ovr_abs_update = (abs_a + abs_b)/2 #we get the mean(not done)
    
    
    
    q_online_value_loss_a = loss_func(q_sa_online_a,\
                                             TWIN_target)
    q_online_value_loss_b = loss_func(q_sa_online_b,\
                                             TWIN_target)
    online_q_optimizer_a.zero_grad()
    q_online_value_loss_a.backward()
    online_q_optimizer_a.step()
    online_q_optimizer_b.zero_grad()
    q_online_value_loss_b.backward()
    online_q_optimizer_b.step()
    
    #One of the difference between SAC and TD3 is SAC only uses one online policy
    #there is also no delay in policy updates in SAC
    #policy_loss = -(pre_optimized_alpha * log_pi - max_q_online__).mean()
    online_policy_optimizer.zero_grad()
    policy_loss.backward()
    online_policy_optimizer.step()
        
    states, actions, next_states, rewards, done, td_errors, timesteps = extract_tensors(experience_samples)
    experiences_rebuilded = rebuild_experiences(states, actions, next_states, rewards, done, ovr_abs_update, timesteps)
    return experiences_rebuilded

In [13]:
def query_error(online_policy_network, offline_q_network_a, offline_q_network_b,\
                online_q_network_a, online_q_network_b, state, action, next_state, reward, gamma):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    state = torch.tensor(state).float().to(device)
    state = state.unsqueeze(0)
    next_state = torch.tensor(next_state).float().to(device)
    next_state = next_state.unsqueeze(0)
    
    alpha = online_policy_network.log_alpha.exp()
    
    ns_actions,log_pi_ns, _ = online_policy_network.full_pass(next_state)
    q_target_next_states_action_a = offline_q_network_a(next_state,\
                                                    ns_actions.detach())
    q_target_next_states_action_b = offline_q_network_b(next_state,\
                                                    ns_actions.detach())
    
    
    TWIN_target = torch.min(q_target_next_states_action_a, q_target_next_states_action_b)
    TWIN_target = TWIN_target - alpha * log_pi_ns
    TWIN_target = reward + (gamma*TWIN_target.detach())
    
    
    current_action,_, _ = online_policy_network.full_pass(state)
    #print("state: ", state.shape)
    #print("action: ", action.shape)
    action = np.expand_dims(action, axis=0)
    #print("action: ", action.shape)
    q_online_state_action_val_a = online_q_network_a(state, action)
    q_online_state_action_val_b = online_q_network_b(state, action)
    
    abs_a = abs(TWIN_target - q_online_state_action_val_a)
    abs_b = abs(TWIN_target - q_online_state_action_val_b)
    abs_stack = (abs_a + abs_b)/2
    #print("abs querry error stacked: ", abs_stack.shape)
    ovr_abs_update = abs_stack
    #print("abs querry error mean: ", ovr_abs_update.shape)
    return ovr_abs_update.detach().cpu().numpy()

In [14]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
    return model

In [15]:
def select_action(state, online_policy_network):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    state = torch.tensor(state).float().to(device)
    state = state.unsqueeze(0)
    with torch.no_grad():
        actions,log_pi_, _ = online_policy_network.full_pass(state)
        actions = actions.cpu().detach()
        actions = actions.data.numpy().squeeze()
    return actions

In [21]:
def SAC_PER(env,
         gamma=0.99,
         alpha_pr=0.6,
         beta_pr=0.3,
         memory_size = 50000,
         tau = 0.1,
         offline_update = 1000,
         min_sample_size=200,
         batch_size = 64,
         n_ep=2000,
         max_steps = 100000
         ):
    
    
    observation_space = len(env.reset())
    action_space_high, action_space_low = env.action_space.high, env.action_space.low
    n_actions = len(action_space_high)
    online_policy_network = linearApproximator_FCGSAP(observation_space,n_actions,\
                                     hidden_dims=(128,64))
    online_q_network_a = linearApproximator_FCQV(observation_space,\
                                     n_actions,hidden_dims=(128,64))
    online_q_network_b = linearApproximator_FCQV(observation_space,\
                                     n_actions,hidden_dims=(128,64))
    
    offline_q_network_a = linearApproximator_FCQV(observation_space,\
                                     n_actions,hidden_dims=(128,64))
    offline_q_network_b = linearApproximator_FCQV(observation_space,\
                                     n_actions,hidden_dims=(128,64))
    

    offline_q_network_a.eval()
    offline_q_network_a = freeze_model(offline_q_network_a)
    offline_q_network_b.eval()
    offline_q_network_b = freeze_model(offline_q_network_b)
    
    online_policy_optimizer    = torch.optim.Adam(online_policy_network.parameters(),lr=0.0008)
    online_q_optimizer_a = torch.optim.Adam(online_q_network_a.parameters(),lr=0.0008)
    online_q_optimizer_b = torch.optim.Adam(online_q_network_b.parameters(),lr=0.0008)
    
    memory = ReplayMemory(memory_size)
    
    t_step = 0 #important
    reward_per_ep = []
    
    
    for e in tqdm(range(n_ep)):
        state = env.reset()
        reward_accumulated = 0
        
        while True:
            env.render()
            action = select_action(state, online_policy_network)
            if memory.can_provide_sample(min_sample_size):
                print(action)
            next_state, reward, done, info = env.step(action)
            td_error = query_error(online_policy_network, offline_q_network_a, offline_q_network_b,\
                online_q_network_a, online_q_network_b, state, action, next_state, reward, gamma)
            #print("td error: ", td_error.shape)
            td_error = np.squeeze(td_error, axis = 0)
            #print("td error: ", td_error.shape)
            reward_accumulated+=reward
            is_truncated = 'TimeLimit.truncated' in info and\
                                info['TimeLimit.truncated']
            is_failure = done and not is_truncated
           
            memory.push(Xp(state, action, next_state, reward, is_failure, td_error, t_step))
            state = next_state
            t_step+=1
            if memory.can_provide_sample(min_sample_size):
                experience_samples = memory.sample(batch_size)
                weighted_importance, indices = prioritize_samples(experience_samples, alpha_pr, beta_pr)
                rebuilded_exp = update_online_model(experience_samples,\
                        online_policy_network, online_q_network_a, online_q_network_b,\
                        online_policy_optimizer, online_q_optimizer_a, online_q_optimizer_b,\
                        offline_q_network_a, offline_q_network_b,\
                        gamma, weighted_importance, indices)
                memory.update_td_error(rebuilded_exp)
                
            if t_step%offline_update == 0:
                offline_q_network_a, offline_q_network_b = update_networks(online_q_network_a, online_q_network_b,\
                                                                    offline_q_network_a, offline_q_network_b, tau)
            if done == True:
                reward_per_ep.append(reward_accumulated)
                #print(reward_accumulated)
                break
            if t_step > max_steps:
                env.close()
                return reward_per_ep
    env.close()           
    return reward_per_ep

In [22]:
import gym
env = gym.make('BipedalWalker-v3')

In [23]:
env.close()

In [None]:
rewards = SAC_PER(env)  

  0%|          | 2/2000 [00:01<17:35,  1.89it/s]

[-0.91166925  0.93979627  0.61329836  0.77086055]
[-0.20432463 -0.55195177 -0.5575151   0.02042458]
[-0.6556064   0.04860764  0.1192078   0.2332958 ]
[-0.6746763  -0.59897494 -0.55672646 -0.02301397]
[-0.9088265  -0.6487699  -0.2491401   0.97005314]
[-0.3964877   0.24614315  0.6312369   0.916448  ]
[ 0.3859649   0.23407508 -0.3061111   0.30990085]
[ 0.38699603  0.92000836  0.66908383 -0.27561423]
[ 0.6929807  -0.8136302  -0.2627812  -0.67524666]
[-0.42197585 -0.46928757 -0.14425729  0.8553563 ]
[-0.6435398   0.9486105  -0.5290562  -0.56782293]
[ 0.43598565  0.8003263  -0.83929884 -0.06605521]
[-0.5341318  -0.50468004 -0.99684066 -0.72267795]
[ 0.17672722 -0.1936523   0.90643644  0.9746543 ]
[ 0.46122405 -0.39175123  0.6966756  -0.5401637 ]
[ 0.7864884   0.8539301  -0.63390136  0.80754167]
[ 0.5646744 -0.965259   0.5555403  0.7077081]
[-0.3845712   0.75924957 -0.7033428   0.42812592]
[-0.3039445  -0.7051641  -0.05400967 -0.760298  ]
[-0.975991   0.0574446  0.7207432 -0.8646346]
[-0.2640

[ 0.39093345  0.45821247  0.26224166 -0.2752669 ]
[-0.8583696   0.17472805  0.76655555 -0.7686304 ]
[ 0.44114172 -0.81227654  0.83190453 -0.87917936]
[-0.58524567  0.25148252 -0.01041776  0.24781227]
[-0.3228092  -0.35129845  0.1412118   0.1808319 ]
[-0.641502   -0.08689647 -0.14849858 -0.5469134 ]
[ 0.22441703  0.24619113 -0.9519896  -0.9006433 ]
[-0.17518628  0.34878543  0.8340902   0.74889493]
[-0.36521104  0.8438494   0.7987218  -0.49486864]
[-0.48321578 -0.08072473  0.256328    0.77715814]
[-0.29321218 -0.5189447  -0.01753221  0.35817352]
[ 0.58172655  0.62542117 -0.80734664 -0.9078607 ]
[ 0.8525702  -0.47895387 -0.39903435 -0.64495814]
[ 0.98295635 -0.42522275 -0.61130595 -0.5719111 ]
[-0.11138701 -0.67985916  0.32145444  0.6336141 ]
[-0.7273824   0.7502688   0.03182378 -0.05845452]
[-0.8537618  -0.70404875  0.88679713 -0.4366827 ]
[ 0.941065   -0.2987885   0.7796695   0.41307962]
[0.59284604 0.31322655 0.725133   0.46856853]
[-0.09242254  0.6679096   0.85991347 -0.11825197]
[ 0.

[ 0.62159187 -0.35574302 -0.23401822 -0.67418766]
[ 0.15612654  0.9601529   0.82314205 -0.8357015 ]
[-0.40805328 -0.01029735 -0.48916733 -0.7460004 ]
[ 0.00832377  0.39746302 -0.9001654   0.9513525 ]
[-0.5148049 -0.5253884  0.8556755 -0.9570905]
[ 0.417919   -0.30732524  0.94678843 -0.20086375]
[-0.13593908  0.06800964  0.69921696  0.08020048]
[ 0.75887734 -0.86278063 -0.4913163  -0.31035393]
[-0.58619225  0.25302282 -0.48885062  0.8164327 ]
[ 0.00547632 -0.97077304  0.23667344 -0.40674514]
[-0.13365349 -0.78833574  0.08368549  0.86081743]
[ 0.21505022  0.6063552   0.56784266 -0.67635846]
[-0.52746904  0.9586462   0.11503773 -0.96331763]
[-0.37988076 -0.36280566  0.86582655 -0.8153599 ]
[ 0.3825068  -0.36736053  0.50912803 -0.8413812 ]
[-0.52373433 -0.4492433  -0.01713264  0.4928223 ]
[-0.567372   -0.76733774  0.5620124  -0.59583986]
[-0.88635933 -0.7631738  -0.12845741 -0.04295449]
[-0.40061656 -0.749377    0.06982562  0.41914576]
[-0.6495831   0.23568258 -0.57231086  0.77354544]
[ 0.

[ 0.52710724 -0.09658521  0.4422952   0.07787554]
[ 0.3987049  -0.7306122   0.19517587  0.75556296]
[ 0.3925113  -0.82193995  0.7906402   0.9359995 ]
[-0.389451    0.76926506  0.8965705  -0.15552244]
[-0.7269839   0.03348924 -0.7725289  -0.33400205]
[0.3416659  0.6783482  0.45935014 0.18407063]
[-0.17364568 -0.85959876  0.58572245 -0.08932294]
[ 0.78372574  0.5566704   0.02562244 -0.9191948 ]
[-0.20833984  0.12152426  0.11707384 -0.2046282 ]
[ 0.22301513 -0.4012456  -0.40786463  0.22348556]
[-0.41852766  0.46686345 -0.9519      0.9641547 ]
[ 0.2439337   0.07825533 -0.07789404 -0.05560871]
[-0.6757938   0.3778882  -0.4466419  -0.25882202]
[ 0.7823576  -0.6350664   0.73440194 -0.11537593]
[-0.60367095 -0.55443704  0.11202522  0.4610478 ]
[-0.4770768  -0.65088296  0.44922668 -0.6706731 ]
[ 0.08335508  0.5567889   0.87940854 -0.76710725]
[-0.92223525 -0.5278658  -0.518319   -0.56856537]
[-0.5858559  -0.85208863 -0.35888803 -0.40087608]
[-0.84330213  0.21467847  0.80876416  0.32462996]
[-0.

[-0.6434218  -0.5529636   0.61891973  0.5539092 ]
[-0.6604112   0.16498742 -0.7390653   0.29898605]
[-0.6148608   0.7002729  -0.83083063  0.13676707]
[-0.15347253 -0.5145482  -0.4250073  -0.7098603 ]
[-0.42472926  0.5711999   0.8725988   0.7503992 ]
[ 0.6492611   0.45303187 -0.33718088  0.22541599]
[-0.6075901  -0.36344257 -0.8421606   0.8625814 ]
[-0.88733315  0.6491762  -0.24631678 -0.4412122 ]
[ 0.5182671  -0.9028664   0.6506257   0.13199966]
[ 0.5921693  -0.14970033  0.4210745   0.25548598]
[0.03678806 0.8320005  0.9223184  0.11332414]
[ 0.9948474  -0.43778214 -0.16745132 -0.58908427]
[-0.4264327   0.23156767  0.2817143   0.21399212]
[0.03833398 0.14240617 0.09264004 0.78250754]
[-0.43238693  0.5465207   0.71568954 -0.30921316]
[-0.17342241 -0.10605213  0.65985596 -0.29231533]
[-0.34785286 -0.29856673  0.1676062   0.550977  ]
[ 0.8775305   0.8546158  -0.17121388  0.14325821]
[0.15929303 0.01708985 0.8583506  0.5282393 ]
[ 0.64041615 -0.31383997 -0.810433    0.8274249 ]
[-0.9231357 

[ 0.76293266 -0.5898042   0.11308305  0.83183295]
[ 0.6242544   0.00170168  0.08884811 -0.13299711]
[-0.707233   -0.09740467  0.7252485  -0.8637786 ]
[ 0.8952074  -0.43250358 -0.5496429   0.65800357]
[-0.31037506  0.61494076 -0.99329406 -0.04724197]
[ 0.730548   -0.7776041  -0.923189   -0.35835084]
[ 0.40436015  0.9246815  -0.13952659  0.92542505]
[-0.74970126 -0.92989    -0.08676258  0.22376081]
[ 0.7136351  -0.42314053  0.93859464  0.23078297]
[-0.44485465  0.37442422  0.85709834  0.69037676]
[-0.47782844 -0.3938879   0.7967794  -0.03426064]
[ 0.07172801 -0.01740345  0.11252756 -0.47819844]
[-0.722597  -0.5445416 -0.9644511  0.7963006]
[-0.9630929   0.771892    0.79786134  0.985614  ]
[-0.77596694  0.84248376 -0.7279418  -0.6692629 ]
[ 0.13944604  0.6322011  -0.35215196 -0.5233659 ]
[ 0.62097204  0.3831072  -0.7950352  -0.57585764]
[ 0.1719172   0.0562198   0.89882475 -0.8628    ]
[ 0.8216446   0.8629613  -0.97177845 -0.42663425]
[-0.39248404  0.87171507  0.82556605 -0.41181737]
[-0.

[ 0.9422528  -0.6066174  -0.9033377   0.94616574]
[-0.895772    0.22704624  0.9051233   0.54800403]
[ 0.4602636  -0.44504115 -0.7049831  -0.13341498]
[0.14741202 0.6607892  0.54019153 0.6265135 ]
[ 0.5650766  -0.42297223  0.6333002  -0.6837722 ]
[-0.5295725   0.7177343   0.8323263   0.22982328]
[-0.07606649  0.8787794   0.8101089   0.8892004 ]
[ 0.41960806 -0.67102265 -0.30649772  0.7468491 ]
[-0.96881014 -0.8620298  -0.5001378  -0.6849147 ]
[-0.19539279  0.8443196  -0.5726604   0.87059724]
[ 0.03247457  0.26148984 -0.5561949  -0.79859966]
[-0.5787108   0.03071108  0.7334373  -0.4543252 ]
[-0.65873766  0.99162346 -0.2910086   0.37967768]
[-0.77407837  0.457902   -0.02546768 -0.43226925]
[-0.85451907  0.70277035  0.20272219  0.02313252]
[ 0.9434894 -0.0333207 -0.9663181  0.4934984]
[-0.53453505  0.22030544  0.6992167  -0.5709802 ]
[-0.07449206 -0.02132026 -0.8763229   0.00767772]
[ 0.20507674 -0.3507921  -0.3596546   0.8712858 ]
[-0.8983857  -0.6915234  -0.18888147 -0.8841208 ]
[ 0.3784

[ 0.5079714  -0.96914655  0.5944382  -0.8670944 ]
[-0.80192804 -0.32949144 -0.5866543  -0.6424855 ]
[ 0.8969261   0.49868163 -0.97677016  0.7421819 ]
[-0.27313146  0.52827126 -0.98873645 -0.9581109 ]
[ 0.52190846  0.5396869   0.43535382 -0.77534175]
[-0.47278026 -0.8609249  -0.5737544  -0.83349127]
[-0.933007   -0.6115562  -0.07990608  0.08210008]
[-0.61512285  0.2818449   0.1235286  -0.05195956]
[-0.48726615 -0.70781803  0.38191152 -0.82090294]
[-0.72988355  0.97413516  0.78989345  0.52673495]
[-0.76885986  0.28215277 -0.31603718 -0.02350621]
[-0.5415689 -0.8547136  0.4670716  0.7281251]
[-0.9258583  -0.14247021 -0.56375074  0.9620595 ]
[-0.8053564   0.59580266 -0.8440873  -0.7717006 ]
[ 0.21926473  0.11573207 -0.7738996   0.49035576]
[-0.90091795 -0.5398369  -0.689088   -0.74520963]
[-0.96290237  0.8628502  -0.8343723  -0.23166296]
[-0.2772921  -0.42352033 -0.5843762   0.35334632]
[ 0.4854691   0.0174617  -0.8151809   0.84224695]
[-0.7274711  -0.30701464  0.02866296  0.6528503 ]
[-0.

[-0.8408609   0.20279929 -0.29370564  0.69654393]
[-0.21425451  0.30641004  0.71899873  0.5047038 ]
[ 0.19557215  0.54049313 -0.15137799 -0.66344327]
[ 0.5441265   0.07007904 -0.50629616 -0.481378  ]
[-0.12810677 -0.64406407  0.02256301  0.19212818]
[ 0.5849347  -0.07496188  0.8198443  -0.00587566]
[ 0.93482137  0.83313423 -0.3271695  -0.59452283]
[0.12553349 0.5079069  0.2222763  0.59676814]
[-0.44602376 -0.08303186 -0.5159952  -0.6584468 ]
[ 0.44453028  0.52670956  0.96116805 -0.08726139]
[ 0.26202404  0.3858925  -0.8433552  -0.6826295 ]
[ 0.93079454 -0.9117392   0.30182633  0.60410535]
[ 0.42888466 -0.8295982   0.7295145   0.79703414]
[-0.4512528   0.09188963  0.33409837  0.6590941 ]
[-0.5379435   0.72622645  0.43842348 -0.32507637]
[ 0.45115155 -0.7242647  -0.7107992  -0.1058371 ]
[-0.81619334 -0.77805126 -0.6132245   0.9305666 ]
[ 0.42673403 -0.95336604 -0.76152456  0.9247319 ]
[ 0.6851578  -0.2182496  -0.8467077  -0.59388685]
[-0.7511297   0.46869543  0.9133742  -0.9193262 ]
[-0.

  0%|          | 3/2000 [00:31<7:53:44, 14.23s/it]

[-0.9571455   0.96571434 -0.30456895 -0.25016257]
[ 0.90858847  0.16375168 -0.815189    0.46904406]
[0.20766503 0.25834903 0.75591683 0.15135168]
[-0.833639    0.7568457  -0.75606066 -0.2942987 ]
[ 0.14327763 -0.7050513  -0.69150305 -0.2981366 ]
[0.9184709  0.8901091  0.67606306 0.51004136]
[ 0.89078     0.6655476  -0.50507987 -0.14558749]
[ 0.5849354  -0.6958628  -0.789183    0.01815679]
[0.8783927  0.01402611 0.57968557 0.62350196]
[ 0.13416581  0.9153232  -0.84865785 -0.95243496]
[-0.2239746  -0.08819161 -0.00396547 -0.26819247]
[ 0.31404042 -0.41704893  0.560431   -0.31722975]
[ 0.9016797   0.8140892  -0.5988308   0.91848487]
[-0.956423    0.11564554 -0.226927   -0.6921243 ]
[-0.09072658 -0.5501831  -0.49266165  0.89282113]
[-0.23979503  0.6509256   0.55155325  0.40835515]
[-0.5253205  -0.37991112 -0.44679907 -0.26377928]
[-0.87338847 -0.29805404  0.02830685 -0.00884406]
[-0.43611676 -0.94427335 -0.25293127 -0.23791975]
[ 0.51362664  0.98736805  0.53309876 -0.36609542]
[-0.07646498

  0%|          | 4/2000 [00:32<5:02:06,  9.08s/it]

[-0.00674489  0.70458853  0.5793292  -0.19382772]
[ 0.70701295  0.80075985 -0.20401432 -0.35021782]
[-0.61853755  0.84085554 -0.6158184   0.07404478]
[ 0.10378581 -0.0186209  -0.9787831  -0.70169353]
[-0.5911262  -0.22829801 -0.77117556 -0.5098363 ]
[0.12839596 0.7701648  0.32354045 0.07141677]
[ 0.84665525  0.98107886  0.37086365 -0.6830269 ]
[-0.7661874   0.18860961  0.7021582  -0.45021036]
[ 0.45607516  0.8838697   0.75927734 -0.08881989]
[ 0.38478348 -0.3773661  -0.27748114  0.36737925]
[-0.61171794 -0.47560337  0.5666298  -0.02645866]
[-0.33954567  0.1637736   0.16717772  0.78193283]
[-0.98547024 -0.84582114 -0.0127104   0.35691428]
[ 0.744766    0.07364326 -0.91381514 -0.93212426]
[-0.03438297  0.8794341  -0.00254273 -0.4949247 ]
[ 0.5149685   0.718596    0.57618177 -0.36987472]
[ 0.8038546  -0.26555228  0.8099961   0.59329045]
[-0.1320583  -0.03311972 -0.11581571 -0.90865225]
[-0.74039155  0.74990904  0.38084272 -0.66262364]
[ 0.93585587 -0.31654274 -0.8943903   0.43336403]
[ 0.

[-0.67737174  0.6496875  -0.9873893   0.25491032]
[ 0.9412779  -0.45476785 -0.3962904  -0.4606123 ]
[-0.3657757 -0.8302025  0.0679465  0.3386594]
[-0.43584335 -0.49477306  0.921336   -0.7199408 ]
[-0.36230648 -0.89318174 -0.7934383   0.7700286 ]
[ 0.65243447  0.27564982 -0.92143035 -0.4418176 ]
[-0.03854063 -0.9280705  -0.50951874 -0.8642959 ]
[-0.6190271  -0.50418514  0.45489165  0.17220907]
[-0.16001104  0.14490388 -0.97272676  0.23197123]
[-0.29749528 -0.12596016 -0.19417107  0.13061349]
[-0.641194   0.6537932 -0.3040923 -0.6369337]
[-0.62556446  0.8424637  -0.979272    0.7285142 ]
[-0.32096764 -0.77489424 -0.75932246  0.19784033]
[-0.21309195  0.2959033   0.38995638 -0.18271624]
[-0.63269424 -0.9123517   0.7580263   0.13943355]
[-0.22492602  0.4299031   0.72779167  0.5795641 ]
[-0.32986623 -0.8123205  -0.36687917  0.06447456]
[ 0.3128252  -0.4323916   0.3612708   0.44845805]
[-0.6521863  -0.3886329  -0.7270777  -0.44389063]
[-0.23183087  0.3032787  -0.00526125 -0.7544545 ]
[-0.3027

[-0.2343336   0.09536292 -0.12689449 -0.34730098]
[-0.8077285  -0.29200187  0.59509695 -0.8159681 ]
[ 0.37757197  0.48006138 -0.96383613  0.76200867]
[-0.7067083  -0.4479213  -0.04770039  0.30987293]
[ 0.38966846 -0.6806536   0.20206784  0.03190993]
[ 0.7557496  -0.54633063 -0.087915   -0.76980346]
[ 0.78208435 -0.22806019 -0.83844125 -0.02397453]
[-0.36402577  0.7263795  -0.38999873 -0.7275115 ]
[0.14098054 0.10443576 0.61383086 0.89880294]
[-0.5568168  -0.44375274  0.8303246  -0.24774167]
[ 0.393273    0.19780633 -0.7612058   0.28888384]
[0.7954409  0.32402757 0.30914694 0.7132972 ]
[-0.43185905 -0.28595927 -0.80240476  0.15767905]
[-0.5686363   0.7795228   0.22961402 -0.3047675 ]
[-0.3942143   0.95571035 -0.5042629   0.92923915]
[-0.5158548 -0.5453266  0.6286221  0.4274731]
[-0.279505    0.9004471   0.2725694   0.20349205]
[-0.52383214 -0.23191953  0.83627295  0.6707666 ]
[ 0.1347568  -0.8058523  -0.3265768  -0.48121005]
[-0.12868482 -0.8979583  -0.7126009   0.7789741 ]
[-0.74507314

[-0.48489    0.9099921 -0.886978   0.4266652]
[ 0.7502071  -0.13987282 -0.96354246  0.38611367]
[ 0.793797   -0.34427714  0.9801147   0.05154689]
[ 0.07333118 -0.7750232  -0.9359881   0.5401496 ]
[-0.790933   -0.93097144 -0.30992094 -0.24006228]
[ 0.41101992 -0.31365803 -0.20237848  0.05411365]
[-0.8821041   0.3865363   0.46775198 -0.04529255]
[-0.48563454 -0.01217057 -0.07969525 -0.49823686]
[ 0.23035131  0.9468992   0.28652996 -0.84997606]
[-0.9286667  -0.46125042  0.5545135   0.57821906]
[ 0.1838522   0.20137735 -0.27009055 -0.92334586]
[-0.7403275   0.7386216   0.24695437  0.00455775]
[-0.33714604  0.53896713 -0.7113024   0.5757936 ]
[-0.4080202  -0.07780039  0.52485013  0.7987419 ]
[-0.96889824 -0.5487869  -0.26500526 -0.37680984]
[0.24553426 0.9566026  0.00797474 0.2802897 ]
[-0.4537274  -0.9499692   0.25037333 -0.23144588]
[-0.6050179  -0.44996396  0.72664666 -0.41160378]
[ 0.46930525 -0.4531003   0.62481296  0.1362614 ]
[-0.09813476 -0.71367943  0.9402277   0.9716494 ]
[-0.9681

[ 0.7433313  -0.32421294 -0.38433233  0.8798371 ]
[ 0.84744406 -0.91246104 -0.650932    0.49455923]
[ 0.73259944 -0.884447    0.9196011  -0.5666152 ]
[-0.173601    0.10190181  0.16628647  0.37716383]
[ 0.08180997  0.9451561  -0.52744585  0.7880212 ]
[-0.16079642 -0.83842194 -0.77875656 -0.5882949 ]
[-0.4233752   0.83300906 -0.65959835  0.5713943 ]
[ 0.4735028  -0.9746351   0.7994487  -0.43574873]
[-0.4414488   0.44357118 -0.1497883   0.88290215]
[ 0.4158129 -0.9745113 -0.5184055  0.5575237]
[ 0.2823539   0.5911533  -0.7793523  -0.82465225]
[ 0.34178331  0.3467581   0.7060255  -0.2010552 ]
[ 0.34588534  0.8540999  -0.8047495  -0.037773  ]
[ 0.7197679   0.64548856 -0.9687508   0.2947963 ]
[ 0.2816348  -0.49130476 -0.03559188  0.6496723 ]
[0.20336704 0.56623936 0.37734807 0.41606444]
[ 0.5877166  0.2177314 -0.7164005 -0.0869619]
[ 0.5472071  -0.3869317   0.9305691  -0.45224738]
[ 0.04552133 -0.08548594 -0.7307849  -0.18747477]
[-0.8819874   0.6993513  -0.23882322  0.5296794 ]
[ 0.02601928

[ 0.06165981 -0.8561733  -0.4552221   0.61187035]
[ 0.28949282  0.5510372  -0.29917043 -0.73005545]
[0.2896357  0.14557396 0.07502991 0.4334408 ]
[-0.46357575  0.6597823   0.7529609  -0.40063205]
[-0.29792333  0.02977966  0.40299487  0.9379107 ]
[-0.9715596   0.13125265  0.68512547  0.98704416]
[-0.97342134 -0.15138532 -0.05784688  0.62834126]
[ 0.57323974  0.62295127 -0.56345856 -0.8260744 ]
[ 0.5023951  -0.24293007  0.43451172 -0.29946062]
[-0.7052351   0.29151785  0.9101281  -0.94170606]
[-0.33496344  0.9326324  -0.6612308  -0.4717519 ]
[-0.04918443 -0.15694655  0.04497277  0.8810566 ]
[-0.83351314 -0.6021944   0.28107116  0.87148684]
[ 0.31066102 -0.04470846  0.8084238  -0.47950786]
[ 0.9598604  0.5807221 -0.6633048  0.9580263]
[ 0.85819966  0.37206116  0.14633673 -0.47730666]
[-0.62640965 -0.46932554  0.58541346  0.6953717 ]
[-0.46435407 -0.04572364 -0.7472725  -0.70060027]
[-0.68662775  0.47980896 -0.651881    0.32437366]
[ 0.76957864  0.8911119  -0.15906073  0.6484022 ]
[ 0.3525

[ 0.45647353  0.36765143 -0.5391704  -0.06545098]
[-0.12753195 -0.98180544 -0.5586693  -0.26275516]
[-0.7551142   0.04622458  0.80659235 -0.8130457 ]
[-0.7454523  -0.32415923 -0.29580078 -0.07129626]
[ 0.45424756  0.6163324   0.7421062  -0.12623149]
[ 0.56852245  0.9747266  -0.6621523  -0.212072  ]
[ 0.506171    0.91601914  0.89190143 -0.04346471]
[0.3283855  0.11078557 0.68811893 0.23596431]
[-0.04044418  0.26211956  0.9723943  -0.48396674]
[ 0.12086771  0.58765525 -0.3276825   0.09345868]
[ 0.5365718   0.04342137 -0.21783112  0.94280887]
[-0.9302849   0.43519336  0.7139704   0.5846832 ]
[-0.8401313   0.53018516 -0.1844441  -0.7343117 ]
[-0.7233156  -0.9358747  -0.32284892  0.8925694 ]
[-0.70726305 -0.08274335  0.25814974 -0.8197858 ]
[-0.447164   -0.19461562 -0.7086575   0.95350224]
[ 0.3046813  -0.45535815  0.2959249  -0.04049119]
[-0.60873735 -0.10936515  0.36403736 -0.73759043]
[0.15194921 0.28237295 0.9525169  0.37190163]
[0.50793815 0.5226451  0.86759555 0.73135185]
[0.7766727  

[-0.51383144 -0.80904746 -0.79245603  0.52353525]
[-0.89944386 -0.7832755  -0.24145482 -0.5627146 ]
[ 0.6946707  -0.76861465  0.06745904 -0.98081374]
[ 0.3713586   0.28610384 -0.9581938  -0.88398933]
[ 0.87245667 -0.3797329  -0.13496217 -0.31597894]
[-0.47786087 -0.8551032  -0.02849361 -0.52301246]
[ 0.79032296 -0.37294665  0.81300795 -0.11601159]
[-0.6545559  -0.0812149   0.79062855  0.9488463 ]
[ 0.43344533  0.6144175   0.8380489  -0.3348988 ]
[ 0.28102043  0.5791369   0.43674088 -0.681232  ]
[ 0.48828074 -0.6953174   0.40619963  0.6014148 ]
[-0.11299845 -0.52483237 -0.28899914  0.42301944]
[-0.9335732   0.430059   -0.13748217  0.7249313 ]
[ 0.51154655  0.72310567 -0.7023611   0.6946229 ]
[ 0.17238083 -0.18289603  0.5939609   0.9728539 ]
[ 0.11232287 -0.79047894 -0.47986743  0.4805868 ]
[-0.9343468   0.49647483  0.5798439  -0.504214  ]
[ 0.94862974 -0.4919093   0.80801994 -0.3624903 ]
[-0.93071955 -0.569754   -0.04489779 -0.5508649 ]
[ 0.730041   -0.87429976 -0.2526369  -0.14041737]


[ 0.0225496  -0.01952686  0.01349146  0.62763005]
[ 0.38187373  0.6935787   0.03796191 -0.2922471 ]
[ 0.18580604  0.6695173  -0.10386761  0.74646413]
[ 0.8236291  -0.6049206   0.25190428  0.55485606]
[-0.9516272   0.6566818   0.23717271  0.13319087]
[-0.7657151   0.14619641  0.769035    0.17132287]
[ 0.9850503  -0.87394786  0.29263085  0.8403573 ]
[ 0.61989796  0.70830286  0.16861811 -0.8127936 ]
[ 0.9360055   0.1588856  -0.72296226  0.74754703]
[ 0.76952934 -0.8724408  -0.7788189  -0.09324884]
[0.51837325 0.52505004 0.9647489  0.588533  ]
[ 0.5556014   0.4011054  -0.42523617  0.6862929 ]
[ 0.35781944 -0.73715794  0.00310384 -0.718676  ]
[ 0.6955056   0.69349813 -0.9212992  -0.5259062 ]
[ 0.52465665 -0.110451   -0.9149075  -0.96554226]
[0.68347716 0.27252916 0.4371755  0.47251335]
[-0.8914733   0.68200487  0.17056067  0.7498728 ]
[ 0.6138059  -0.6040983   0.2561295   0.61732984]
[-0.21269746  0.774484    0.7323346  -0.20387183]
[ 0.33812046 -0.75411767 -0.6825996  -0.8114443 ]
[ 0.5446

[-0.8772063   0.09366893  0.3153865   0.564662  ]
[ 0.5719106   0.85473704 -0.27176374 -0.63155663]
[-0.03179971  0.19972946  0.3042197  -0.5574417 ]
[ 0.07173573 -0.48260048 -0.0408862   0.3154738 ]
[-0.54460967  0.7087009  -0.6701342  -0.35251644]
[ 0.97319454  0.58687925  0.16092063 -0.37650812]
[ 0.20986845  0.823117    0.65296954 -0.73637676]
[ 0.18257894  0.7069446  -0.6384642  -0.5778768 ]
[-0.90139294 -0.7570443  -0.39317507  0.7075068 ]
[-0.19023557 -0.6946331   0.22055021  0.8505987 ]
[-0.50546557 -0.09331365  0.39721826  0.4812574 ]
[ 0.6469088   0.09257373  0.30727664 -0.93680394]
[0.04176023 0.06893045 0.6323783  0.937742  ]
[-0.6459209   0.3746513  -0.8932973   0.15672167]
[ 0.60418046  0.6584289   0.2808952  -0.21703646]
[-0.73267806  0.38229832  0.19356868 -0.45684958]
[-0.6634034   0.29352412  0.5179069   0.7773324 ]
[ 0.77181685  0.26047477  0.65038526 -0.27326778]
[ 0.47347173  0.5974541   0.5625185  -0.98796946]
[-0.22934145  0.43226936 -0.02752313  0.0105724 ]
[ 0.

  0%|          | 5/2000 [01:10<10:47:31, 19.47s/it]

[0.50638676 0.8733939  0.7870385  0.19331557]
[-0.8249917  -0.92688185  0.24302472  0.6187085 ]
[ 0.63024855 -0.8151202  -0.5895654   0.5399548 ]
[ 0.82984734 -0.59630126  0.56498206  0.69125193]
[ 0.70474756  0.7427557  -0.5491818  -0.43367606]
[ 0.23336564 -0.09606558 -0.65803957 -0.8382088 ]
[ 0.38679     0.79441655  0.4922357  -0.88317454]
[-0.43693197  0.11561732 -0.6969461  -0.27563128]
[ 0.17611729 -0.812075    0.3477934   0.11197776]
[ 0.63847965 -0.24249558 -0.8794862   0.8554606 ]
[0.8945924  0.14120153 0.09935721 0.20626727]
[ 0.7667464  -0.5119241   0.4371874   0.91924036]
[-0.01705677  0.93130124  0.79769415  0.4128197 ]
[ 0.5396791  -0.7923835  -0.08465251 -0.7434585 ]
[-0.382822   -0.8621181  -0.16260508  0.34061277]
[-0.72470737  0.46734053  0.39136267 -0.66785914]
[-0.4348025  -0.23309508 -0.04565512 -0.4069855 ]
[-0.955792   -0.2586846  -0.3053026   0.87867093]
[ 0.8628369  -0.58134395  0.06017463 -0.524134  ]
[ 0.9816282  -0.44237965 -0.358946    0.0826079 ]
[ 0.6361

  0%|          | 6/2000 [01:12<7:27:59, 13.48s/it] 

[-0.7818196  -0.14632294  0.18189378  0.9920549 ]
[-0.79254544 -0.40461594 -0.61875165 -0.94772714]
[-0.222757    0.01649418 -0.0652239  -0.32208198]
[ 0.8541383  0.6355618  0.6932583 -0.5724247]
[-0.9788372   0.48237455 -0.6926675   0.50339717]
[ 0.87922776 -0.4759134   0.12597403 -0.6441076 ]
[ 0.9128368  -0.9954501   0.27413642  0.79335344]
[ 0.21826066 -0.91212565 -0.89894915 -0.9351973 ]
[ 0.37192956 -0.00515728  0.2950194  -0.79960597]
[-0.95576763 -0.11205381 -0.49164397  0.7668636 ]
[ 0.18779802 -0.13645701  0.5056314  -0.9103453 ]
[-0.69511306  0.10302053 -0.45219645  0.18106037]
[ 0.8350971  -0.03871367  0.8119788  -0.80228215]
[ 0.7262244  -0.02244992 -0.9276585  -0.655614  ]
[-0.33262077  0.1204394   0.74430037  0.06063903]
[-0.8180846  -0.0642027   0.7487278   0.76840353]
[-0.9121907  -0.3618001  -0.8149606   0.15104784]
[-0.96330315  0.78842276 -0.0478575  -0.20956573]
[ 0.48081517 -0.38208887  0.74093103 -0.6346211 ]
[ 0.7884855   0.9132097  -0.97949874  0.14450593]
[-0.

[0.7094972  0.4462022  0.84178174 0.872525  ]
[-0.54700446 -0.7343385   0.05346313  0.23721871]
[-0.70934206  0.358046    0.8951584  -0.80775446]
[-0.9522943  -0.73812985  0.15256551 -0.67706347]
[-0.6256604  -0.1040842   0.586378    0.09928913]
[-0.79249537 -0.20989646 -0.5656842   0.05452031]
[ 0.9688382   0.9395409  -0.88696605 -0.7899068 ]
[-0.84962285 -0.36471343 -0.27495742  0.76706326]
[-0.3314646   0.0989436   0.1252008  -0.20731992]
[ 0.4139945   0.16786824 -0.91496265 -0.07932658]
[ 0.5808898   0.03056923  0.39048904 -0.76899564]
[0.64715576 0.4534396  0.6105575  0.35494244]
[-0.51357037 -0.7703069  -0.6122068   0.7662793 ]
[ 0.2848317   0.56854326 -0.81870055 -0.31305614]
[ 0.7902832  -0.26236558 -0.63726956 -0.90722454]
[-0.07550037  0.33187094  0.01445119 -0.9037402 ]
[-0.0721207   0.32623112 -0.01979722 -0.57855415]
[-0.81393087  0.8001662   0.7774593   0.22607127]
[-0.4798496  -0.8323088   0.8019352   0.21169022]
[ 0.43045574  0.79298425 -0.8520402  -0.8969202 ]
[-0.5532

[ 0.5811243   0.73970616 -0.7371266   0.2506557 ]
[-0.4843585   0.32579625  0.29829004 -0.6543954 ]
[-0.9751813   0.49567243 -0.85362214 -0.35226873]
[ 0.9051619   0.17292888  0.65638936 -0.14473046]
[-0.8567842   0.6352437  -0.31664822  0.32134765]
[-0.03555446 -0.32059878  0.13073008  0.7060335 ]
[ 0.15840699 -0.41568154 -0.8853452  -0.36628792]
[ 0.40806594 -0.9548033  -0.5057243   0.49446592]
[ 0.26821485  0.44493276  0.7103919  -0.02466725]
[-0.5182199   0.24522844  0.17597789 -0.7208358 ]
[0.85137916 0.39024153 0.40346065 0.59045637]
[ 0.46641687  0.47167677 -0.3927188   0.5501431 ]
[-0.67560154  0.52039516  0.9538908  -0.6188592 ]
[-0.42905104  0.8121833   0.00287335 -0.22716416]
[-0.30708098 -0.3605151   0.63015723  0.08579442]
[-0.62185085  0.11928415 -0.7140819   0.41014695]
[-0.9481266   0.06020305 -0.86653036  0.64471054]
[-0.8292309   0.8034549   0.45809516 -0.35791633]
[-0.44296753  0.47666746 -0.08733401  0.43393007]
[0.4135171  0.7652003  0.2260121  0.14792496]
[ 0.7399

[-0.6343554  -0.15415168  0.757949   -0.28692687]
[-0.8394457  -0.2520643   0.05173809 -0.18577705]
[-0.595508   -0.7852863   0.5412872  -0.68166506]
[-0.7193502  -0.3323147  -0.44027123  0.65360224]
[-0.49529734  0.60018003  0.04540599 -0.88680184]
[-0.09774755  0.35118392 -0.55757976  0.13525139]
[ 0.623255   -0.92212033 -0.84122837 -0.21826014]
[ 0.8920743  -0.08557297  0.25290382  0.11614071]
[ 0.38735202 -0.6098542   0.94461936  0.35075396]
[ 0.23362596  0.5122468   0.09923417 -0.34302387]
[ 0.521791   -0.6570717  -0.30444583  0.78335035]
[-0.26753697 -0.20537095  0.77384037 -0.7567148 ]
[ 0.3388966  -0.8875591  -0.61863315  0.08242695]
[-0.94611377 -0.17670874 -0.506858   -0.76421374]
[-0.34793884 -0.00507399  0.58345526 -0.6431532 ]
[-0.38498995  0.00547129 -0.0114935   0.6956597 ]
[ 0.047858   -0.48423773 -0.42249507  0.1124313 ]
[-0.13763961 -0.974523    0.84137875  0.08382007]
[-0.23654428 -0.02608693 -0.33934966 -0.361459  ]
[-0.01606843  0.825974    0.13094094  0.94189054]


[ 0.33555976  0.10670654 -0.63844943 -0.7042055 ]
[0.71968365 0.37073484 0.9758041  0.9025978 ]
[0.04232433 0.4229557  0.16462542 0.6023922 ]
[-0.5187434  -0.37063107 -0.89499646  0.34356245]
[0.651518   0.33321065 0.34782344 0.96768045]
[ 0.9182158  -0.8094278  -0.5790545   0.34353772]
[ 0.16799246 -0.8869989  -0.40751585 -0.6289021 ]
[ 0.43746525 -0.9278886  -0.285691    0.04686312]
[ 0.3764026  -0.15127955  0.835115   -0.726475  ]
[ 0.10973959  0.6355063  -0.8962214  -0.56725305]
[-0.4547998   0.02251249  0.21646051 -0.11707844]
[-0.46113366  0.01988607  0.3947042   0.17349868]
[ 0.38356113 -0.3552148  -0.8391757  -0.9677594 ]
[ 0.28975812 -0.41148636  0.9029616  -0.861554  ]
[-0.19416872 -0.2281203  -0.06862366 -0.7962758 ]
[ 0.97749496 -0.72465444 -0.15724997  0.3135494 ]
[-0.7516874   0.8483626  -0.6797346  -0.12728725]
[0.3073667  0.08064894 0.7473594  0.07322491]
[ 0.28566742 -0.3646675  -0.58185947  0.93680197]
[ 0.9621856  -0.18489845 -0.412865    0.12160161]
[-0.6273418  -0.

[ 0.48755515  0.4237509   0.39513466 -0.66194016]
[ 0.3466428   0.7420465  -0.5343783   0.38960683]
[ 0.7771838  -0.01573148  0.3641647   0.8915528 ]
[ 0.19777034 -0.9181596  -0.63873374  0.25639126]
[-0.06537593  0.7650053   0.6967262  -0.6278864 ]
[ 0.85552615  0.76762843 -0.175105   -0.16111709]
[-0.22657368  0.567809   -0.41841245  0.468418  ]
[ 0.04808698 -0.37261286 -0.01781742 -0.38897994]
[-0.69447684 -0.93081164 -0.10250522  0.7166828 ]
[0.2773059 0.5315176 0.7032672 0.932372 ]
[ 0.22271526 -0.1108245  -0.65314996  0.5275727 ]
[ 0.28848052  0.58853877 -0.01841373  0.05821244]
[ 0.22226259 -0.830593    0.54086494 -0.50837725]
[-0.1280587   0.6133444  -0.6494398  -0.01588577]
[-0.5109466  -0.9131955  -0.21446535 -0.86833894]
[-0.69997525 -0.8736533  -0.9444645   0.32318243]
[-0.92045355  0.14742978  0.21994747 -0.14773346]
[-0.4419935  -0.59214306  0.1624577  -0.47022176]
[-0.1968886   0.27559686 -0.27347544 -0.4414125 ]
[ 0.7484368  -0.9266519  -0.9312452   0.35072318]
[0.72877

[-0.46495074 -0.8175608   0.9040543  -0.35468546]
[0.30147767 0.63591135 0.98569804 0.14719181]
[ 0.32410082 -0.32012302  0.72045374  0.7259425 ]
[ 0.731725  -0.9428394 -0.9269346  0.9211248]
[-0.9357525   0.48862112 -0.12638722  0.72995096]
[0.773958   0.5766138  0.0104243  0.00917274]
[-0.0551011  -0.23279706 -0.69991744 -0.0428935 ]
[ 0.16881329  0.08787688 -0.7950503   0.80553293]
[-0.52631956 -0.6027905  -0.01674866 -0.1688922 ]
[-0.36629894 -0.72621566 -0.79207075  0.14088586]
[-0.3486519   0.13970844 -0.7369385   0.93445563]
[-0.07857103 -0.6076994  -0.575593   -0.7492786 ]
[0.21292917 0.12974656 0.39529085 0.8687042 ]
[ 0.7123796  -0.77626115 -0.9345602  -0.52639943]
[-0.32491153  0.3239646   0.70135856 -0.9106488 ]
[-0.28862908 -0.15455641  0.78779304  0.12448039]
[ 0.7509086  -0.36186776  0.9608197  -0.90154976]
[ 0.9338654  -0.54548514 -0.35392907  0.8477309 ]
[ 0.48748684  0.68585896 -0.17569502  0.2222574 ]
[ 0.31406945 -0.7681595   0.04880345  0.76127076]
[-0.68770397 -0.

[-0.7407473   0.87484026  0.01266318 -0.6403701 ]
[-0.18156448  0.9541054  -0.71286976 -0.6287445 ]
[ 0.92543447  0.20224987 -0.86881876 -0.6578053 ]
[ 0.14185426  0.62556726 -0.7779267  -0.12365444]
[ 0.27199996  0.7268536   0.17333268 -0.93673253]
[ 0.22187836 -0.970192    0.05207731 -0.5761845 ]
[0.82554865 0.51785153 0.44777578 0.70561767]
[-0.48275506  0.9639734  -0.09524891  0.33183575]
[ 0.5100962   0.7219048  -0.80891514  0.7379942 ]
[-0.33081242 -0.24325626  0.8763832  -0.37947038]
[-0.3640863  -0.22813785 -0.31671336  0.02659929]
[ 0.6512594  -0.89904004 -0.10746711 -0.845703  ]
[-0.16934359 -0.9056482   0.52216053  0.58388644]
[-0.2428135   0.9643571   0.5076422  -0.20463702]
[ 0.97261655  0.9759396  -0.1444165   0.32986513]
[ 0.43537945  0.10780881 -0.00865751 -0.7413524 ]
[-0.13767204 -0.234542    0.2320736  -0.60018486]
[-0.8126914  0.7794244 -0.9198564  0.7672444]
[-0.5556458  -0.23050286  0.01061917 -0.29465005]
[-0.41158962 -0.06304653 -0.8944669  -0.53710014]
[ 0.2703

[-0.45594412  0.91559994 -0.4357918  -0.3463082 ]
[-0.35700232  0.71192706 -0.20518582 -0.8685396 ]
[-0.6961589 -0.306744   0.0254004  0.8004842]
[-0.3748284  0.5120893 -0.7491521 -0.6611481]
[-0.20413017  0.9277711  -0.79073983  0.48874494]
[-0.7757809   0.40621564  0.01669149  0.9810529 ]
[-0.96269757  0.74797416  0.17232639 -0.09300746]
[ 0.5641773   0.74851966 -0.8259244   0.25136882]
[-0.32468233  0.29466575 -0.96885556  0.5340349 ]
[-0.6197426  -0.7873671  -0.5790229  -0.75091505]
[-0.9913657  -0.5302017   0.12998842  0.7996532 ]
[ 0.12588759 -0.01750753 -0.28827956 -0.03428469]
[-0.6181841   0.93253607  0.02102241  0.814028  ]
[-0.9871569  -0.24233519  0.3833523  -0.11319912]
[-0.02972927 -0.4429629  -0.62868845  0.23674072]
[-0.45479873 -0.8731102   0.8995198   0.81413084]
[0.03000389 0.09390774 0.4763211  0.4699982 ]
[-0.9806165  -0.4605403  -0.05888671  0.5062101 ]
[-0.6461411  -0.26125532  0.7591795   0.9800434 ]
[-0.02160722 -0.16470626 -0.77965516 -0.6621853 ]
[ 0.601265  

In [None]:
from scipy.ndimage.filters import uniform_filter1d
arr = uniform_filter1d(rewards, size=1)

In [None]:
from matplotlib import pyplot as plt
x = []
y = []
for i,e in enumerate(arr):
    y.append(i)
    x.append(e)

In [None]:
plt.plot(y, x)

In [None]:
-np.prod(env.action_space.high.shape)