In [1]:
threshold_maxs = [50, 45, 40, 35, 30, 25, 20, 15, 10, 5, 0]
sigmas = [0, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
pass_prob = 1 # 1 or 0.4

In [2]:
save = True
visualize = False

In [3]:
import sys, os
sys.path.append(os.pardir)
import numpy as np
import matplotlib.pyplot as plt
import time
import random
import torch
from torch import nn, optim
from common.network import DuelingNetwork
from common.replay import PrioritizedReplayBuffer
from common.trainer import Trainer
from common.hparameter import *
from common.utils import *

if visualize==True:
    from psychopy import data, visual, core
    import psychopy.hardware.joystick

In [4]:
def get_params(condition):
    
    con = condition
    
    if con == "A1" or con == "A2" or con == "A3":
        accel_d = 1.5
        accel_b = 3.6
    elif con == "B1" or con == "B2" or con == "B3":
        accel_d = 1.75
        accel_b = 3.6
    elif con == "C1" or con == "C2" or con == "C3":
        accel_d = 2.0
        accel_b = 3.6
    elif con == "D1" or con == "D2" or con == "D3":
        accel_d = 1.5
        accel_b = 4.8
    elif con == "E1" or con == "E2" or con == "E3":
        accel_d = 1.75
        accel_b = 4.8
    elif con == "F1" or con == "F2" or con == "F3":
        accel_d = 2.0
        accel_b = 4.8
    elif con == "G1" or con == "G2" or con == "G3":
        accel_d = 1.5
        accel_b = 6.0
    elif con == "H1" or con == "H2" or con == "H3":
        accel_d = 1.75
        accel_b = 6.0
    elif con == "I1" or con == "I2" or con == "I3":
        accel_d = 2.0
        accel_b = 6.0
        
    elif con == "P1" or con == "P2" or con == "P3":
        accel_d = 1.75
        accel_b = 4.8
        
    return accel_d, accel_b
        

In [5]:

for threshold_max in threshold_maxs:
    print("====  max=", threshold_max, " ====")

    for sigma in sigmas:
        print("====  sigma=", sigma, " ====")
    
        for num in range(1, 29):

            """ seed """
            seed = num
            np.random.seed(seed)
            random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)

            ''' divice '''
            device = torch.device('cpu')

            """ Network """
            net_d = DuelingNetwork(28, 13).to(device)

            """ Epsilon """
            step = 1000000
            epsilon_end = 0
            epsilon_func = lambda step: max(epsilon_end, epsilon_begin - (epsilon_begin - epsilon_end) * (step / epsilon_decay))

            """ Environment """
            from three_on_one_agent import ThreeOnOne

            accel_a = 2
            noise_b = 0
            threshold_pass_ang2 = 30
            threshold_pass_ang3 = 30
            pass_ang_bias2 = 0
            pass_ang_bias3 = 0

            """ Load """
            net_d.load_state_dict(torch.load("../model/defender.pth", torch.device('cpu')))

            name = f"agent{num:03d}"
            conditions = ["A1", "A2", "A3", "B1", "B2", "B3", "C1", "C2", "C3", "D1", "D2", "D3", "E1", "E2", "E3", "F1", "F2", "F3", "G1", "G2", "G3", "H1", "H2", "H3", "I1", "I2", "I3"]

            screen_num = 0
            trial_num = 8
            max_pass_num = 50

            if visualize == True:

                "screen"
                win = visual.Window(size=(800, 800), pos=[350, 0], units='norm', screen = screen_num)

                "drawing"
                outer = .81
                inner = .79
                pitchVert_outer = []
                pitchVert_inner = []

                n = 72
                for i in range(n):
                    j = i/n*2*np.pi
                    x, y = np.cos(j), np.sin(j)
                    pitchVert_outer.append([x*outer, y*outer]) 
                    pitchVert_inner.append([x*inner, y*inner]) 

                bgVert = [(-1,-1),(-1,1),(1,1),(1,-1)]
                bg = visual.ShapeStim(win, vertices=bgVert, fillColor='white', lineWidth=0, size=1, pos=(0, 0))
                pitch_outer = visual.ShapeStim(win, vertices=pitchVert_outer, fillColor='darkgray', lineWidth=0, size=1, pos=(0, 0))
                pitch_inner = visual.ShapeStim(win, vertices=pitchVert_inner, fillColor='white', lineWidth=0, size=1, pos=(0, 0))

                timer = core.Clock()

            "Environment"

            condition_num = len(conditions)

            for con in range(condition_num):

                pos_list = []
                times = []  

                accel_d, accel_b = get_params(conditions[con])

                env = ThreeOnOne(accel_defender=accel_d, accel_attacker1=accel_a, accel_attacker2=accel_a, accel_attacker3=accel_a, accel_ball=accel_b, \
                                 pass_noise1=noise_b, pass_noise2=noise_b, pass_noise3=noise_b, threshold_ang2=threshold_pass_ang2, threshold_ang3=threshold_pass_ang3, \
                                 pass_bias2=pass_ang_bias2, pass_bias3=pass_ang_bias3, max_step=max_step_episode, threshold_max=threshold_max, sigma=sigma, pass_prob=pass_prob)


                if visualize == True:

                    key_on = False
                    while not key_on:

                        "wait"
                        bg.draw() 
                        text0 = visual.TextStim(win, text = conditions[con], pos = [0, 0.1], color='black', height=0.3)
                        text0.draw()
                        win.flip()

                        "start"
                        key_on = True
                        bg.draw() 
                        text0 = visual.TextStim(win, text = conditions[con], pos = [0, 0.1], color='gray', height=0.3)
                        text0.draw()
                        win.flip()
                        core.wait(1)

                for i in range(trial_num):

                    defender_pos_episode = []
                    attacker1_pos_episode = []
                    attacker2_pos_episode = []
                    attacker3_pos_episode = []
                    ball_pos_episode = []
                    pass_episode = []
                    time_episode = []
                    pass_times_episode = []

                    if visualize == True:
                        core.wait(1)

                    t = 0
                    pass_times = 0

                    obs_d, obs_a1, obs_a2, obs_a3, with_b_a1, with_b_a2, with_b_a3 = env.reset()
                    obs_d = torch.Tensor(obs_d)
                    done = False
                    step_episode = 0

                    pos_d = env.pos_d
                    pos_a1 = env.pos_a1
                    pos_a2 = env.pos_a2
                    pos_a3 = env.pos_a3
                    pos_b = env.pos_b

                    if visualize == True:
                        defender = visual.GratingStim(win, tex=None, mask='circle', size=(.1,.1), color='black', pos=pos_d)
                        attacker1 = visual.GratingStim(win, tex=None, mask='circle', size=(.1,.1), color='red', pos=pos_a1)
                        attacker2 = visual.GratingStim(win, tex=None, mask='circle', size=(.1,.1), color='green', pos=pos_a2)
                        attacker3 = visual.GratingStim(win, tex=None, mask='circle', size=(.1,.1), color='blue', pos=pos_a3)
                        ball = visual.GratingStim(win, tex=None, mask='circle', size=(.05,.05), color='gold', pos=pos_b)
                        text1 = visual.TextStim(win, text = t/10, pos = [-0.8, 0.9], color='gray', height=0.1)
                        text2 = visual.TextStim(win, text = pass_times, pos = [-0.8, 0.75], color='black', height=0.15)
                        text3 = visual.TextStim(win, text = conditions[con], pos = [0.8, 0.9], color='gray', height=0.1)
                        text4 = visual.TextStim(win, text = i+1, pos = [0.8, 0.75], color='gray', height=0.1)

                    key_on = False
                    while not key_on:
                        
                        if visualize == True:
                            timer.reset()

                        action_d, action_a1, action_a2, action_a3 = 0, 0, 0, 0
                        action_d = net_d.act(obs_d.float().to(device), epsilon_func(step))

                        next_obs_d, next_obs_a1, next_obs_a2, next_obs_a3, reward_d, reward_a1, reward_a2, reward_a3, done, obs_a, action_a, reward_a, next_obs_a, push_a, with_b_a1, with_b_a2, with_b_a3, to_a1, to_a2, to_a3 \
                        = env.step(obs_d, obs_a1, obs_a2, obs_a3, action_d, action_a1, action_a2, action_a3, step_episode)
                        next_obs_d = torch.Tensor(next_obs_d)

                        obs_d, obs_a1, obs_a2, obs_a3 = next_obs_d, next_obs_a1, next_obs_a2, next_obs_a3
                        step_episode += 1

                        pos_d = env.pos_d
                        pos_a1 = env.pos_a1
                        pos_a2 = env.pos_a2
                        pos_a3 = env.pos_a3
                        pos_b = env.pos_b

                        dist0 = get_dist(np.array(pos_d), np.array(pos_b))
                        dist1 = get_dist(np.array(pos_b))

                        if push_a == True and dist0 > 0.1:
                            pass_times +=1

                        defender_pos_episode.append(np.array(pos_d))
                        attacker1_pos_episode.append(np.array(pos_a1))
                        attacker2_pos_episode.append(np.array(pos_a2))
                        attacker3_pos_episode.append(np.array(pos_a3))
                        ball_pos_episode.append(np.array(pos_b))
                        pass_episode.append(int(push_a))
                        time_episode.append(t)
                        pass_times_episode.append(int(pass_times))

                        if visualize == True:

                            bg.draw() 
                            pitch_outer.draw() 
                            pitch_inner.draw() 
                            defender.setPos(pos_d)
                            defender.draw()
                            attacker1.setPos(pos_a1)
                            attacker1.draw()
                            attacker2.setPos(pos_a2)
                            attacker2.draw()
                            attacker3.setPos(pos_a3)
                            attacker3.draw()
                            ball.setPos(pos_b)
                            ball.draw()
                            text1.setText(t/10)
                            text1.draw()
                            text2.setText(pass_times)
                            text2.draw()
                            text3.draw()
                            text4.setText(i+1)
                            text4.draw()

                            win.flip()

                        if pass_times >= max_pass_num:       
                            key_on = True

                        if dist0 < 0.1 or dist1 > 0.8:
                            key_on = True

                            if visualize == True:
                                core.wait(0.5)

                        if visualize == True:
                            time_past = 0
                            while time_past < 0.1:
                                time_past = timer.getTime()

                        t += 1

                    pos_episode = []
                    pos_episode.append(defender_pos_episode)
                    pos_episode.append(attacker1_pos_episode)
                    pos_episode.append(attacker2_pos_episode)
                    pos_episode.append(attacker3_pos_episode)
                    pos_episode.append(ball_pos_episode)
                    pos_episode.append(pass_episode)
                    pos_episode.append(time_episode)
                    pos_episode.append(pass_times_episode)

                    pos_list.append(pos_episode)

                data = pos_list

                if save ==True:
                    if pass_prob == 1:
                        save_dir = f'../data/agent/max_sigma/max_{threshold_max}/sigma_{sigma}/{name}'
                    elif pass_prob == 0.4:
                        save_dir = f'../data/agent/max_sigma_human/max_{threshold_max}/sigma_{sigma}/{name}'

                    os.makedirs(save_dir, exist_ok=True)
                    np.save(f'{save_dir}/{name}_{conditions[con]}.npy', np.array(data, dtype=object), allow_pickle=True)

            if visualize == True:
                win.close()

            print("finish!")


====  max= 50  ====
====  sigma= 0  ====


  net_d.load_state_dict(torch.load("../model/defender.pth", torch.device('cpu')))


finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 2  ====
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 4  ====
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 8  ====
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 16  ====
finish!
finish!
fin

finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 2  ====
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 4  ====
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 8  ====
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 16  ====
finish!
finish!
finish!
finish!
fin

finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 2  ====
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 4  ====
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 8  ====
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 16  ====
finish!
finish!
finish!
finish!
finish!
finish!
fin

finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 2  ====
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 4  ====
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 8  ====
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
====  sigma= 16  ====
finish!
finish!
finish!
finish!
finish!
finish!
finish!
finish!
fin