In [None]:
import os
import yaml
import pickle
import argparse
import datetime
import scipy.signal

import numpy as np
import torch as T
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm 
from collections import namedtuple

from models.a2c_lstm_my import A2C_LSTM
from tasks.two_step_my import TwoStepTask
import matplotlib.pyplot as plt
import pandas as pd
#from scipy.special import softmax
from scipy.optimize import minimize
from scipy.stats import sem

In [None]:
def softmax(tau,x):
    return np.exp(tau*x)/sum(np.exp(tau*x))
    


In [None]:
strategy = 0 # 0 is MF, 1 is MB

In [None]:
Rollout = namedtuple('Rollout',
                        ('state', 'action', 'reward', 'timestep', 'done', 'policy', 'value'))




class Trainer: 
    def __init__(self, config, plot_param):
        print(config['run-title'])

        self.plot_param = plot_param

        self.device = 'cpu'
        self.seed = config["seed"]

        T.manual_seed(config["seed"])
        np.random.seed(config["seed"])
        T.random.manual_seed(config["seed"])

        self.env = TwoStepTask(config["task"])  
   

        
        self.max_grad_norm = config["a2c"]["max-grad-norm"]
        self.switch_p = config["task"]["switch-prob"]
        self.start_episode = 0

        self.writer = SummaryWriter(log_dir=os.path.join("logs", config["run-title"],config["run-title"]+str(self.plot_param)))
        self.save_path = os.path.join(config["save-path"], config["run-title"], config["run-title"]+str(self.plot_param))
        


    def run_episode(self, episode, det=False, params = [0.1,0.1]):

        tau = 1
        epsilon = params[0]
        lrate = params[1]
        eta = 0.1

        done = False
        total_reward = 0
        total_entropy = 0


        p_action, p_reward, timestep = [0,0], 0, 0

        state = self.env.reset()
        

        Q_s0 = np.zeros(2)
        Q_stage2 = np.zeros(2)
        p_hat = np.zeros([2,2])

        

        counter = 0

        while not done:

            counter += 1
            # switch reward contingencies at the beginning of each episode with probability p
            #self.env.possible_switch(switch_p=self.switch_p)

            # sample action using model

            epsilon = epsilon * 0.99
           
            action_dist = softmax(tau,Q_s0)

            action_dist = T.from_numpy(action_dist)

            action_cat = T.distributions.Categorical(action_dist)

            if counter >300:
                action = np.argmax(action_dist).item()
            else:
                # use epsilon greedy
                if np.random.rand() < epsilon:
                    action =  np.random.choice([0, 1])
                else:
                    action = np.argmax(action_dist).item()    


           
             
            action_onehot = np.eye(2)[action]

            # take action and observe result
            new_state, reward, done, timestep = self.env.step(int(action))
       
            # true planet arrived
            planet_arrived = np.where(new_state) 
            planet_arrived = int(planet_arrived[0]) # 1 or 2
            planet_arrived = planet_arrived - 1

            
            gamma = 1


            # update Q value

            if strategy == 0: # MF  
                Q_stage2[planet_arrived] = Q_stage2[planet_arrived] + lrate*(reward-Q_stage2[planet_arrived])
            
                Q_s0[action] = Q_s0[action] + lrate*(gamma*Q_stage2[planet_arrived] - Q_s0[action])

            else: #MB
                p_hat[action,planet_arrived] = p_hat[action,planet_arrived]+eta*(1-p_hat[action,planet_arrived])
                p_hat_temp = p_hat
                p_hat[action,:]= p_hat_temp[action,:]/sum(p_hat_temp[action,:])

                Q_stage2[planet_arrived] = Q_stage2[planet_arrived] + lrate*(reward-Q_stage2[planet_arrived])
            
            
                Q_s0[action ] = gamma*np.sum(Q_stage2*p_hat[action,:])
                Q_s0[1-action] = gamma*np.sum(Q_stage2*p_hat[1-action,:])
                #Q_s0[action] = Q_s0[action] + lrate*(gamma*np.sum(Q_stage2*p_hat[action,:]) - Q_s0[action]) 
                #Q_s0[1-action] = Q_s0[1-action] + lrate*(gamma*np.sum(Q_stage2*p_hat[1-action,:]) - Q_s0[1-action]) 

            
                
            #prev_state = state
            state = new_state
            p_reward = reward
            p_action = action_onehot

            
            if counter > 300:
                total_reward += reward
                px = softmax(tau,Q_s0)
                logpx = np.log(action_dist)
                total_entropy = total_entropy + px[0]*logpx[0] + px[1]*logpx[1]
                
  
     
                
        total_entropy = -total_entropy

        return total_reward, total_entropy
    

    def train(self,params, max_episodes, save_interval):
       

        total_rewards = np.zeros(max_episodes)
        progress = tqdm(range(self.start_episode, max_episodes))

        for episode in progress:

            reward = self.run_episode(episode=episode,det=False,params=params)
       
          
            total_rewards[episode] = reward

            avg_reward_10 = total_rewards[max(0, episode-10):(episode+1)].mean()
            avg_reward_100 = total_rewards[max(0, episode-100):(episode+1)].mean()
            self.writer.add_scalar("perf/reward_t", reward, episode)
            self.writer.add_scalar("perf/avg_reward_10", avg_reward_10, episode)
            self.writer.add_scalar("perf/avg_reward_100", avg_reward_100, episode)
            
            
            #progress.set_description(f"Episode {episode}/{max_episodes} | Reward: {reward} | Last 10: {avg_reward_10:.4f} ")
        return -sum(total_rewards)
    



    def test(self, num_episodes,optimal_params):
        progress = tqdm(range(num_episodes))
        self.env.reset_transition_count()
        

        total_rewards = np.zeros(num_episodes)
        total_entropies = np.zeros(num_episodes)
 
        for episode in progress:
            reward, entropy = self.run_episode(episode, det=False, params = optimal_params)

            total_rewards[episode] = reward
            total_entropies[episode] = entropy

            avg_reward = total_rewards[max(0, episode-10):(episode+1)].mean()            
            #progress.set_description(f"Episode {episode}/{num_episodes} | Reward: {reward} | Last 10: {avg_reward:.4f}")
   
        
        reward_fraction = sum(total_rewards)/(0.5*(len(total_rewards)*config['task']["trials-per-epi"]))
        print("fraction of rewarded trials:", reward_fraction)
        #self.env.plot(self.save_path)


        avg_entropy = sum(total_entropies)/(0.5*(len(total_entropies)*config['task']["trials-per-epi"]))
        return reward_fraction, avg_entropy

if __name__ == "__main__":
    
    if not os.path.isdir('./temp_results'): 
        os.makedirs('./temp_results')


    yaml_path = "./configs/two_step_my.yaml"
    with open(yaml_path, 'r', encoding="utf-8") as fin:
        config = yaml.load(fin, Loader=yaml.FullLoader)

    n_seeds = 1
    base_seed = config["seed"]
    base_run_title = config["run-title"]
    buffer_results = {}
    test_common_probs = [config['task']['common-prob']]
    for common_prob in test_common_probs:

        buffer_results[common_prob]={}
    
        config['task']['common-prob']=common_prob
        for seed_idx in range(1, n_seeds + 1):
            
            config["seed"] = base_seed * seed_idx
            config["run-title"] = base_run_title + '_seed='+str(config['seed'])    

            exp_path = os.path.join(config["save-path"], config["run-title"])
            if not os.path.isdir(exp_path): 
                os.makedirs(exp_path)
            
            out_path = os.path.join(exp_path, os.path.basename(yaml_path))
            with open(out_path, 'w') as fout:
                yaml.dump(config, fout)

            plot_param = {'seed':config['seed'],'common_prob':common_prob,'safebet':config["task"]['safebet-reward']}

            print(f"> Running {config['run-title']}")
            trainer = Trainer(config,plot_param)
            if config["train"]:
                #result = trainer.train(params= [0.1,0.9],args=(config["task"]["train-episodes"],  config["save-interval"]))
                result = minimize(fun=trainer.train, x0 = [0.1,0.8],args=(config["task"]["train-episodes"],  config["save-interval"]),bounds = ((0,1),(0,1)))
                optimal_params = result.x

            optimal_params = [0.1,0.8]
                
            if config["test"]:
                reward_fraction,avg_entropy = trainer.test(config["task"]["test-episodes"],optimal_params)
                # total_buffer=pd.DataFrame(total_buffer, columns=Rollout._fields)
                # buffer_results[common_prob][seed_idx-1] = total_buffer
   
    

In [None]:
def run(ambiguity,seed):
    
    if not os.path.isdir('./temp_results'): 
        os.makedirs('./temp_results')


    yaml_path = "./configs/two_step_my.yaml"
    with open(yaml_path, 'r', encoding="utf-8") as fin:
        config = yaml.load(fin, Loader=yaml.FullLoader)

    config['task']['ambiguity'] = ambiguity
    config['seed'] = seed

    n_seeds = 1
    base_seed = config["seed"]
    base_run_title = config["run-title"]
    buffer_results = {}
    test_common_probs = [config['task']['common-prob']]
    for common_prob in test_common_probs:

        buffer_results[common_prob]={}
    
        config['task']['common-prob']=common_prob
        for seed_idx in range(1, n_seeds + 1):
            
            config["seed"] = base_seed * seed_idx
            config["run-title"] = base_run_title + '_seed='+str(config['seed'])    

            exp_path = os.path.join(config["save-path"], config["run-title"])
            if not os.path.isdir(exp_path): 
                os.makedirs(exp_path)
            
            out_path = os.path.join(exp_path, os.path.basename(yaml_path))
            with open(out_path, 'w') as fout:
                yaml.dump(config, fout)

            plot_param = {'seed':config['seed'],'common_prob':common_prob,'safebet':config["task"]['safebet-reward']}

            print(f"> Running {config['run-title']}")
            trainer = Trainer(config,plot_param)
            if config["train"]:
                #result = trainer.train(params= [0.1,0.9],args=(config["task"]["train-episodes"],  config["save-interval"]))
                result = minimize(fun=trainer.train, x0 = [0.1,0.8],args=(config["task"]["train-episodes"],  config["save-interval"]),bounds = ((0,1),(0,1)))
                optimal_params = result.x

            optimal_params = [0.1,0.8]
                
            if config["test"]:
                reward_fraction, entropy = trainer.test(config["task"]["test-episodes"],optimal_params)
                # total_buffer=pd.DataFrame(total_buffer, columns=Rollout._fields)
                # buffer_results[common_prob][seed_idx-1] = total_buffer

    return reward_fraction, entropy
   

# performance comparison of Q1 and MB across ambiguities

In [None]:
strategy = 0
n_rounds = 4
MF_results = np.zeros([n_rounds,len(np.arange(0,1.1,0.1))])
ambiguities = np.arange(0,1.1,0.1)
for iter in range(n_rounds):
    seed = 100*iter
    for ambiguity_idx in range(len(ambiguities)):
        ambiguity = ambiguities[ambiguity_idx]
        reward_fraction,_ = run(ambiguity,seed)
        MF_results[iter,ambiguity_idx] = reward_fraction
MF_avg = np.mean(MF_results,axis=0)
sems_MF = sem(MF_results)

strategy = 1

MB_results = np.zeros([n_rounds,len(np.arange(0,1.1,0.1))])
ambiguities = np.arange(0,1.1,0.1)
for iter in range(n_rounds):
    seed = 100*iter
    for ambiguity_idx in range(len(ambiguities)):
        ambiguity = ambiguities[ambiguity_idx]
        reward_fraction, _ = run(ambiguity,seed)
        MB_results[iter,ambiguity_idx] = reward_fraction
MB_avg = np.mean(MB_results,axis=0)
sems_MB = sem(MB_results)


plt.plot(MF_avg,color = '#1f77b4',alpha=0.8)

plt.errorbar(y=MF_avg,x= np.arange(0,1.1,0.1),yerr=sems_MF,color='#1f77b4',alpha=0.1,capsize=3)
      

plt.plot(MB_avg,color = '#1f77b4',alpha=0.8)

plt.errorbar(y=MB_avg,x= np.arange(0,1.1,0.1),yerr=sems_MB,color='#ff7f0e',alpha=0.1,capsize=3)

plt.legend(["MF","MB"])

In [None]:
plt.plot(np.arange(0,1.1,0.1),MF_avg,color = '#AA041F',alpha=1)

plt.errorbar(y=MF_avg,x= np.arange(0,1.1,0.1),yerr=sems_MF,color='#AA041F',alpha=0.8,capsize=3)
      

plt.plot(np.arange(0,1.1,0.1),MB_avg,color = '#0B224A',alpha=1)

plt.errorbar(y=MB_avg,x= np.arange(0,1.1,0.1),yerr=sems_MB,color='#0B224A',alpha=0.8,capsize=3)

plt.legend(["MF","MB"])
plt.xlabel("ambiguity")
plt.ylabel("fraction of rewarded trials")
plt.title('agent performance across different ambiguities')

# performance comparison of Q1 and MB across common transitions

In [None]:
def run(p_common,seed):
    
    if not os.path.isdir('./temp_results'): 
        os.makedirs('./temp_results')


    yaml_path = "./configs/two_step_my.yaml"
    with open(yaml_path, 'r', encoding="utf-8") as fin:
        config = yaml.load(fin, Loader=yaml.FullLoader)

    config['task']['common-prob'] = p_common
    config['seed'] = seed

    n_seeds = 1
    base_seed = config["seed"]
    base_run_title = config["run-title"]
    buffer_results = {}
    test_common_probs = [config['task']['common-prob']]
    for common_prob in test_common_probs:

        buffer_results[common_prob]={}
    
        config['task']['common-prob']=common_prob
        for seed_idx in range(1, n_seeds + 1):
            
            config["seed"] = base_seed * seed_idx
            config["run-title"] = base_run_title + '_seed='+str(config['seed'])    

            exp_path = os.path.join(config["save-path"], config["run-title"])
            if not os.path.isdir(exp_path): 
                os.makedirs(exp_path)
            
            out_path = os.path.join(exp_path, os.path.basename(yaml_path))
            with open(out_path, 'w') as fout:
                yaml.dump(config, fout)

            plot_param = {'seed':config['seed'],'common_prob':common_prob,'safebet':config["task"]['safebet-reward']}

            print(f"> Running {config['run-title']}")
            trainer = Trainer(config,plot_param)
            if config["train"]:
                #result = trainer.train(params= [0.1,0.9],args=(config["task"]["train-episodes"],  config["save-interval"]))
                result = minimize(fun=trainer.train, x0 = [0.1,0.8],args=(config["task"]["train-episodes"],  config["save-interval"]),bounds = ((0,1),(0,1)))
                optimal_params = result.x

            optimal_params = [0.1,0.8]
                
            if config["test"]:
                reward_fraction, entropy = trainer.test(config["task"]["test-episodes"],optimal_params)
                # total_buffer=pd.DataFrame(total_buffer, columns=Rollout._fields)
                # buffer_results[common_prob][seed_idx-1] = total_buffer

    return reward_fraction, entropy
   

In [None]:
strategy = 0
n_rounds = 4
MF_results = np.zeros([n_rounds,len([0.5,0.6,0.7,0.8,0.9,1])])
p_commons = [0.5,0.6,0.7,0.8,0.9,1]
for iter in range(n_rounds):
    seed = 100*iter
    for p_common_idx in range(len(p_commons)):
        p_common = p_commons[p_common_idx]
        reward_fraction,_ = run(p_common,seed)
        MF_results[iter,p_common_idx] = reward_fraction
MF_avg = np.mean(MF_results,axis=0)
sems_MF = sem(MF_results)

strategy = 1

MB_results = np.zeros([n_rounds,len([0.5,0.6,0.7,0.8,0.9,1])])

for iter in range(n_rounds):
    seed = 100*iter
    for p_common_idx in range(len(p_commons)):
        p_common = p_commons[p_common_idx]
        reward_fraction, _ = run(p_common,seed)
        MB_results[iter,p_common_idx] = reward_fraction
MB_avg = np.mean(MB_results,axis=0)
sems_MB = sem(MB_results)


plt.plot([0.5,0.6,0.7,0.8,0.9,1],MF_avg,color = '#1f77b4',alpha=0.8)

plt.errorbar(y=MF_avg,x= [0.5,0.6,0.7,0.8,0.9,1],yerr=sems_MF,color='#1f77b4',alpha=0.1,capsize=3)
      

plt.plot([0.5,0.6,0.7,0.8,0.9,1],MB_avg,color = '#1f77b4',alpha=0.8)

plt.errorbar(y=MB_avg,x= [0.5,0.6,0.7,0.8,0.9,1],yerr=sems_MB,color='#ff7f0e',alpha=0.1,capsize=3)

plt.legend(["MF","MB"])

In [None]:
plt.plot([0.5,0.6,0.7,0.8,0.9,1],MF_avg,color = '#AA041F',alpha=1)

plt.errorbar(y=MF_avg,x= [0.5,0.6,0.7,0.8,0.9,1],yerr=sems_MF,color='#AA041F',alpha=0.8,capsize=3)
      

plt.plot([0.5,0.6,0.7,0.8,0.9,1],MB_avg,color = '#0B224A',alpha=1)

plt.errorbar(y=MB_avg,x= [0.5,0.6,0.7,0.8,0.9,1],yerr=sems_MB,color='#0B224A',alpha=0.8,capsize=3)

plt.legend(["MF","MB"])
plt.xlabel("p_common")
plt.ylabel("fraction of rewarded trials")
plt.title('agent performance across different common transitions')
plt.xticks([0.5,0.6,0.7,0.8,0.9,1])

# subjective confidence comparison between Q1 and MB

In [None]:
strategy = 0
n_rounds = 4
MF_results = np.zeros([n_rounds,len(np.arange(0,1.1,0.1))])
ambiguities = np.arange(0,1.1,0.1)
for iter in range(n_rounds):
    seed = 100*iter
    for ambiguity_idx in range(len(ambiguities)):
        ambiguity = ambiguities[ambiguity_idx]
        _, entropy = run(ambiguity,seed)
        MF_results[iter,ambiguity_idx] = entropy
MF_avg = np.mean(MF_results,axis=0)
sems_MF = sem(MF_results)

strategy = 1

MB_results = np.zeros([n_rounds,len(np.arange(0,1.1,0.1))])
ambiguities = np.arange(0,1.1,0.1)
for iter in range(n_rounds):
    seed = 100*iter
    for ambiguity_idx in range(len(ambiguities)):
        ambiguity = ambiguities[ambiguity_idx]
        _,entropy = run(ambiguity,seed)
        MB_results[iter,ambiguity_idx] = entropy
MB_avg = np.mean(MB_results,axis=0)
sems_MB = sem(MB_results)


plt.plot(MF_avg,color = '#1f77b4',alpha=1)

plt.errorbar(y=MF_avg,x= np.arange(0,1.1,0.1),yerr=sems_MF,color='#1f77b4',alpha=0.8,capsize=3)
      

plt.plot(MB_avg,color = '#1f77b4',alpha=1)

plt.errorbar(y=MB_avg,x= np.arange(0,1.1,0.1),yerr=sems_MB,color='#ff7f0e',alpha=0.8,capsize=3)

plt.legend(["MF","MB"])

In [None]:
plt.plot(np.arange(0,1.1,0.1),MF_avg,color = '#1f77b4',alpha=1)

plt.errorbar(y=MF_avg,x= np.arange(0,1.1,0.1),yerr=sems_MF,color='#1f77b4',alpha=0.8,capsize=3)
      

plt.plot(np.arange(0,1.1,0.1),MB_avg,color = '#ff7f0e',alpha=1)

plt.errorbar(y=MB_avg,x= np.arange(0,1.1,0.1),yerr=sems_MB,color='#ff7f0e',alpha=0.8,capsize=3)

plt.legend(["MF","MB"])
plt.xlabel("ambiguity")
plt.ylabel("entropy")
plt.title('agent entropy')


In [None]:
plt.plot(np.arange(0,1.1,0.1),MF_avg,color = '#AA041F',alpha=0.8)

plt.errorbar(y=MF_avg,x= np.arange(0,1.1,0.1),yerr=sems_MF,color='#AA041F',alpha=0.1,capsize=3)
      

plt.plot(np.arange(0,1.1,0.1),MB_avg,color = '#0B224A',alpha=0.8)

plt.errorbar(y=MB_avg,x= np.arange(0,1.1,0.1),yerr=sems_MB,color='#0B224A',alpha=0.1,capsize=3)

plt.legend(["MF","MB"])

In [None]:
Rollout = namedtuple('Rollout',
                        ('state', 'action', 'reward', 'timestep', 'done', 'policy', 'value'))


class Trainer: 
    def __init__(self, config, plot_param):
        print(config['run-title'])

        self.plot_param = plot_param

        self.device = 'cpu'
        self.seed = config["seed"]

        T.manual_seed(config["seed"])
        np.random.seed(config["seed"])
        T.random.manual_seed(config["seed"])

        self.env = TwoStepTask(config["task"])  
   

        
        self.max_grad_norm = config["a2c"]["max-grad-norm"]
        self.switch_p = config["task"]["switch-prob"]
        self.start_episode = 0

        self.writer = SummaryWriter(log_dir=os.path.join("logs", config["run-title"],config["run-title"]+str(self.plot_param)))
        self.save_path = os.path.join(config["save-path"], config["run-title"], config["run-title"]+str(self.plot_param))
        


    def run_episode(self, episode, det=False , params = None):

        epsilon = params[0]
        #tau = params[0]
        eta = params[1]
        lrate = params[2]


        done = False
        total_reward = 0


        p_action, p_reward, timestep = [0,0], 0, 0

        state = self.env.reset()
        

        Q_s0 = np.zeros(2)
        Q_stage2 = np.zeros(2)
        avg_reward_stage2 = np.zeros(2)
        p_hat = np.zeros([2,2]) # first row is prob of state 1 and 2 for action 1, second row is for action 2

        counter = 0
        while not done:

            counter = counter+1

            # switch reward contingencies at the beginning of each episode with probability p
            #self.env.possible_switch(switch_p=self.switch_p)

            # sample action using model
           
            action_dist = softmax(1,Q_s0)

            action_dist = T.from_numpy(action_dist)

            
      

            # use epsilon greedy
            if counter >300:
                action = np.argmax(action_dist).item()
            else:
                # use epsilon greedy
                if np.random.rand() < epsilon:
                    action =  np.random.choice([0, 1])
                else:
                    action = np.argmax(action_dist).item()   

            

            
             
            action_onehot = np.eye(2)[action]

            # take action and observe result
            new_state, reward, done, timestep = self.env.step(int(action))
       
            # true planet arrived
            planet_arrived = np.where(new_state) 
            planet_arrived = int(planet_arrived[0]) # 1 or 2
            planet_arrived = planet_arrived - 1

          
            # eta = 0.5
            # lrate = 0.6
            gamma = 1


            p_hat[action,planet_arrived] = p_hat[action,planet_arrived]+eta*(1-p_hat[action,planet_arrived])
            p_hat_temp = p_hat
            p_hat[action,:]= p_hat_temp[action,:]/sum(p_hat_temp[action,:])

            Q_stage2[planet_arrived] = Q_stage2[planet_arrived] + lrate*(reward-Q_stage2[planet_arrived])
           
           
            Q_s0[action ] = gamma*np.sum(Q_stage2*p_hat[action,:])
            Q_s0[1-action] = gamma*np.sum(Q_stage2*p_hat[1-action,:])
            # Q_s0[action] = Q_s0[action] + lrate*(gamma*np.sum(Q_stage2*p_hat[action,:]) - Q_s0[action]) 
            # Q_s0[1-action] = Q_s0[1-action] + lrate*(gamma*np.sum(Q_stage2*p_hat[1-action,:]) - Q_s0[1-action])   



                
            #prev_state = state
            state = new_state
            p_reward = reward
            p_action = action_onehot
            
            if counter > 300:
                total_reward += reward


        return total_reward
    

    def train(self, params, max_episodes, save_interval):

        total_rewards = np.zeros(max_episodes)
        progress = tqdm(range(self.start_episode, max_episodes))

        for episode in progress:

            reward = self.run_episode(episode=episode, det=False,params=params)
       
          
            total_rewards[episode] = reward

            avg_reward_10 = total_rewards[max(0, episode-10):(episode+1)].mean()
            avg_reward_100 = total_rewards[max(0, episode-100):(episode+1)].mean()
            self.writer.add_scalar("perf/reward_t", reward, episode)
            self.writer.add_scalar("perf/avg_reward_10", avg_reward_10, episode)
            self.writer.add_scalar("perf/avg_reward_100", avg_reward_100, episode)
            
            
           # progress.set_description(f"Episode {episode}/{max_episodes} | Reward: {reward} | Last 10: {avg_reward_10:.4f} ")
        
       
        return -sum(total_rewards)

        



    def test(self, num_episodes,optimal_params):
        progress = tqdm(range(num_episodes))
        self.env.reset_transition_count()
        
        total_rewards = np.zeros(num_episodes)

        for episode in progress:
            reward = self.run_episode(episode, det=False, params=optimal_params)

            total_rewards[episode] = reward
            avg_reward = total_rewards[max(0, episode-10):(episode+1)].mean()            
            progress.set_description(f"Episode {episode}/{num_episodes} | Reward: {reward} | Last 10: {avg_reward:.4f}")

        reward_fraction = sum(total_rewards)/(0.5*(len(total_rewards)*config['task']["trials-per-epi"]))

        print("fraction of rewarded trials:", reward_fraction)
        self.env.plot(self.save_path)
        
              

if __name__ == "__main__":
    
    if not os.path.isdir('./temp_results'): 
        os.makedirs('./temp_results')


    yaml_path = "./configs/two_step_my.yaml"
    with open(yaml_path, 'r', encoding="utf-8") as fin:
        config = yaml.load(fin, Loader=yaml.FullLoader)

    n_seeds = 1
    base_seed = config["seed"]
    base_run_title = config["run-title"]
    buffer_results = {}
    test_common_probs = [config['task']['common-prob']]
    for common_prob in test_common_probs:

        buffer_results[common_prob]={}
    
        config['task']['common-prob']=common_prob
        for seed_idx in range(1, n_seeds + 1):
            
            config["seed"] = base_seed * seed_idx
            config["run-title"] = base_run_title + '_seed='+str(config['seed'])    

            exp_path = os.path.join(config["save-path"], config["run-title"])
            if not os.path.isdir(exp_path): 
                os.makedirs(exp_path)
            
            out_path = os.path.join(exp_path, os.path.basename(yaml_path))
            with open(out_path, 'w') as fout:
                yaml.dump(config, fout)

            plot_param = {'seed':config['seed'],'common_prob':common_prob,'safebet':config["task"]['safebet-reward']}

            print(f"> Running {config['run-title']}")
            trainer = Trainer(config,plot_param)
            if config["train"]:
                result = minimize(fun=trainer.train, x0 = [0.1,0.2,0.2],args=(config["task"]["train-episodes"],  config["save-interval"]),bounds=((0,1),(0,1),(0,1)))
                optimal_params = result.x

            optimal_params = [0.1,0.2,0.2]

            if config["test"]:
                total_buffer = trainer.test(config["task"]["test-episodes"],optimal_params)
                total_buffer=pd.DataFrame(total_buffer, columns=Rollout._fields)
                buffer_results[common_prob][seed_idx-1] = total_buffer
            print("optimal_params",optimal_params)
    