# This notebook is for hyperparameter tuning

In [56]:
import gym
import gym_tokens
import argparse
import random

import time
import datetime
import sys
import lib
import utils
import os

import numpy as np

In [57]:
utils.seed(0)
seeds = np.random.randint(0, 1000, 5)

In [58]:
num_actions = 3

In [59]:
def _sign(num):

	if num < 0:
		return -1

	elif num > 0:
		return 1

	else:
		return 0

In [60]:
class Args():
    def __init__(self, games = 1000, env = 'tokens-v0', seeds = [], log_interval = 1, algo = 'q-learning', convg = 0.00001, lr = 0.1, lr_final = 0.0001, save_interval = 2000, exp = "eps_greedy", h_start = 1.0, h_final = 0.01, h_games = 1000, gamma = 0.99, height = 15, fancy_discount = False, fast_block = False, fancy_eps = False):
        self.games = games
        self.env = env
        self.seeds = seeds
        self.log_interval = log_interval
        self.algo = algo
        self.convg = convg
        self.lr = lr
        self.lr_final = lr_final
        self.save_interval = save_interval
        self.exp = exp
        self.h_start = h_start
        self.h_final = h_final
        self.h_games = h_games
        self.gamma = gamma
        self.height = height
        self.fancy_discount = fancy_discount
        self.fast_block = fast_block
        self.fancy_eps = fancy_eps

In [61]:
def run(args: Args, trial):
    #create train dir
    # date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")

    numNT = (args.height * 2) + 1 
    numHT = (args.height * 2) + 1

    if args.fast_block:
        block_discount = 0.25

    else:
        block_discount = 0.75

    # Set seed for all randomness sources

    for seed in args.seeds:
        default_model_name = f"{args.env}_set{trial}"
        ext = f"seed{seed}"
        # default_model_name = f"{args.env}_{args.algo}_{date}"

        model_name = os.path.join(default_model_name, ext)
        model_dir = utils.get_model_dir(model_name)

        # Load loggers and Tensorboard writer

        txt_logger = utils.get_txt_logger(model_dir)
        csv_file, csv_logger = utils.get_csv_logger(model_dir)

        # Log command and all script arguments
        txt_logger.info("{}\n".format(" ".join(sys.argv)))
        txt_logger.info("{}\n".format(args))

        utils.seed(seed)

        env = gym.make('tokens-v0', gamma=block_discount, seed=seed, terminal=args.height, fancy_discount=args.fancy_discount)
        txt_logger.info("Environments loaded\n")

        # Load training status
        status = {"num_frames": 0, "update": 0, "num_games":0}
        txt_logger.info("Training status loaded\n")

        num_states = env.get_num_states()
        num_actions = env.get_num_actions()
        num_games_frames = args.height * args.games

        model = lib.Q_Table(numNT*numHT*(args.height+2), num_actions, (numNT, numHT, args.height), args.convg, args.height)

        if args.exp == "softmax":
            policy = lib.SoftmaxPolicy()
            tmp_track = lib.TemperatureTracker(args.h_start, args.h_final, args.h_games, policy)

        elif args.exp == "eps_soft":
            policy = lib.EpsilonSoftPolicy()
            eps_track = lib.EpsilonTracker(args.h_start, args.h_final, args.h_games, policy)
        elif args.exp == "eps_greedy":
            if args.fancy_eps:
            #TODO what is this
                policy = lib.EpsilonGreedyGamePolicy()
                eps_track = lib.EpsilonTracker(args.h_start,args.h_final, args.h_games, policy)
            else:
                policy = lib.EpsilonGreedyPolicy()
                eps_track = lib.EpsilonTracker(args.h_start,args.h_final, args.h_games*args.height, policy)


        if args.algo == 'sarsa': 
            monkeyAgent = lib.SarsaAgent(policy, model, args.height)
        elif args.algo == 'q-learning':
            monkeyAgent = lib.QlAgent(policy, model, args.height)

        lr_sched = lib.LRscheduler(args.lr, args.lr_final, args.games*args.height*10*0.8)

        num_frames = status["num_frames"]
        update = status["update"]
        num_games = status["num_games"]
        num_games_prevs = 0

        start_time = time.time()
        totalReturns = [] 
        totalLoss = []

        state, game_time_step  = env.reset()

        decisionTime = []
        lossPerEpisode = []
        train_info = []

        last_choice = 0

        decisionTime = np.zeros(shape=((args.height*2)+1))

        traj = []
        traj_group = []

        choice_made = []
        correct_choice = []
        finalDecisionTime = []
        finalRewardPerGame = []
        numCorrectChoice = 0
        numRecentCorrectChoice = []

        traj = []
        traj_group = []

        choice_made = []
        correct_choice = []
        finalDecisionTime = []
        finalRewardPerGame = []
        numCorrectChoice = 0
        numRecentCorrectChoice = []

        while num_frames <= num_games_frames: 

                traj.append(state[0].tolist())

                if args.fancy_eps:
                    eps_track.set_eps(num_games)

                else:
                    eps_track.set_eps(num_frames)

                action = monkeyAgent.get_actions(state, game_time_step)

                next_state, reward, is_done, game_time_step = env.step(action)

                # print('next_state: ', next_state)
                # print('action: ', action)
                # print('is done: ', is_done)
                # print('game_time_step: ', game_time_step)
                # print('state: ', state)

                lr = lr_sched.get_lr(num_frames)
                
                next_act = monkeyAgent.get_actions(next_state, game_time_step)
                loss = model.get_TDerror(state, action, next_state, next_act, reward, args.gamma, is_done, args.algo)
                converged = model.update_qVal(lr, state, action, loss)
                totalLoss.append(loss)

                if is_done:
                    num_games+=1
                    totalReturns.append(reward)

                    if reward > 0:
                        numCorrectChoice += 1
                        numRecentCorrectChoice.append(1)
                    else:
                        numRecentCorrectChoice.append(0)

                    lossPerEpisode.append(np.sum(totalLoss))
                    totalLoss = []

                    decision_step = model._augState(abs(next_state[1]))
                    decisionTime[decision_step-1] += 1

                    if abs(next_state[1]) == args.height+1:
                        last_choice += 1

                    choice_made.append(_sign(next_state[1]))
                    correct_choice.append(_sign(next_state[0]))
                    finalDecisionTime.append(abs(next_state[1]))
                    finalRewardPerGame.append(reward)

                    traj_group.append(traj)
                    traj = []
                    next_state, game_time_step = env.reset()

                else:
                    num_frames+=1 
                    update+= 1

                state = next_state


                if num_games > num_games_prevs and num_games % args.log_interval == 0:
                    duration = int(time.time() - start_time)
                    totalLoss_val = np.sum(lossPerEpisode)
                    totalReturn_val = np.sum(totalReturns)

                    avg_loss = np.mean(lossPerEpisode[-1000:])
                    avg_returns = np.mean(totalReturns[-1000:])
                    recent_correct = np.mean(numRecentCorrectChoice[-1000:])

                    header = ["update", "frames", "Games", "duration"]
                    data = [update, num_frames, num_games, duration]

                    if args.exp == "softmax":
                        header += ["tmp", "lr", "last"]
                        data += [policy.temperature, lr, last_choice]
                    else:
                        header += ["eps", "lr", "last"]
                        data += [policy.epsilon, lr, last_choice]

                    header += ["Loss", "Returns", "Avg Loss", "Avg Returns", "Correct Percentage", "Recent Correct", "decision_time"]
                    data += [totalLoss_val.item(), totalReturn_val.item(), avg_loss.item(), avg_returns.item(), numCorrectChoice/num_games, recent_correct, finalDecisionTime[num_games_prevs]]

                    if args.exp == "softmax":
                        txt_logger.info(
                            "U {} | F {} | G {} | D {} | TMP {:.3f} | LR {:.5f} | Last {} | L {:.3f} | R {:.3f} | Avg L {:.3f} | Avg R {:.3f} | Avg C {:.3f} | Rec C {:.3f} | DT {}"
                            .format(*data))
                    else:
                        txt_logger.info(
                            "U {} | F {} | G {} | D {} | EPS {:.3f} | LR {:.5f} | Last {} | L {:.3f} | R {:.3f} | Avg L {:.3f} | Avg R {:.3f} | Avg C {:.3f} | Rec C {:.3f} | DT {}"
                            .format(*data))

                    # header += ["Loss", "Returns", "Avg Loss", "Avg Returns"]
                    # data += [totalLoss_val, totalReturn_val, avg_loss, avg_returns]

                    csv_header = ["trajectory", "choice_made", "correct_choice", "decision_time", "reward_received"]
                    csv_data = [traj_group[num_games_prevs], choice_made[num_games_prevs], correct_choice[num_games_prevs], finalDecisionTime[num_games_prevs], finalRewardPerGame[num_games_prevs]]

                    # print(traj_group[num_games_prevs])
                    # print(choice_made[num_games_prevs])

                    if num_games == 1:
                        csv_logger.writerow(csv_header)
                    csv_logger.writerow(csv_data)
                    csv_file.flush()

                    num_games_prevs = num_games

                # Save status
                if args.save_interval > 0 and num_games % args.save_interval == 0:
                    # status = {"num_frames": num_frames, "update": update, "games": num_games, "totalReturns" : totalReturns}
                    model.save_q_state(model_dir, num_games)
                    np.save(model_dir+'/decisionTime_'+str(num_games)+'.npy', decisionTime)
                    # txt_logger.info("Status saved")
                    # utils.save_status(status, model_dir)

In [62]:
trial = 0
for lr in [0.1, 0.3, 0.5, 0.7, 0.9]:
        for gamma in [0.5, 0,7 , 0.9 , 0.99]:
            for exp in ["eps_greedy", "eps_soft", "softmax"]:
                for algo in ["q-learning", "sarsa"]:
                    trial += 1
                    arg = Args(lr=lr, gamma=gamma, exp=exp, algo=algo, seeds=seeds)
                    run(args=arg, trial=trial)

424 | D 5 | EPS 0.576 | LR 0.04700 | Last 0 | L -1229.459 | R 206.000 | Avg L -2.900 | Avg R 0.486 | Avg C 0.486 | Rec C 0.486 | DT 1
U 6375 | F 6375 | G 425 | D 5 | EPS 0.575 | LR 0.04688 | Last 0 | L -1231.431 | R 207.000 | Avg L -2.897 | Avg R 0.487 | Avg C 0.487 | Rec C 0.487 | DT 1
U 6390 | F 6390 | G 426 | D 5 | EPS 0.574 | LR 0.04675 | Last 0 | L -1233.473 | R 208.000 | Avg L -2.895 | Avg R 0.488 | Avg C 0.488 | Rec C 0.488 | DT 1
U 6405 | F 6405 | G 427 | D 5 | EPS 0.573 | LR 0.04663 | Last 0 | L -1236.157 | R 209.000 | Avg L -2.895 | Avg R 0.489 | Avg C 0.489 | Rec C 0.489 | DT 4
U 6420 | F 6420 | G 428 | D 5 | EPS 0.572 | LR 0.04650 | Last 0 | L -1239.321 | R 209.000 | Avg L -2.896 | Avg R 0.488 | Avg C 0.488 | Rec C 0.488 | DT 2
U 6435 | F 6435 | G 429 | D 5 | EPS 0.571 | LR 0.04638 | Last 0 | L -1241.994 | R 210.000 | Avg L -2.895 | Avg R 0.490 | Avg C 0.490 | Rec C 0.490 | DT 4
U 6450 | F 6450 | G 430 | D 5 | EPS 0.570 | LR 0.04625 | Last 0 | L -1243.877 | R 211.000 | Avg 

KeyboardInterrupt: 