In [1]:
%run pong_agent.ipynb
%run pong_env.ipynb
from itertools import count
import matplotlib.pyplot as plt

pygame 2.0.1 (SDL 2.0.14, Python 3.8.10)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
def normalize(s):
    s_norm = np.copy(s)
    s_norm[I_Y_BALL] = np.interp(s_norm[I_Y_BALL], (0, 2), (0, 1))
    s_norm[I_VX_BALL] = np.interp(s_norm[I_VX_BALL], (- V_BALL, V_BALL), (0, 1))
    s_norm[I_VY_BALL] = np.interp(s_norm[I_VY_BALL], (- V_BALL, V_BALL), (0, 1))
    return s_norm


In [3]:
def train(env, agent, n_episodes, vis=False):
    pg.init()
    steps_per_e = np.zeros(n_episodes)
    for episode in tqdm(range(n_episodes)):
        s = env.reset()
        agent.reset()
        s_norm = normalize(s)
        for t in count():
            a, policy = agent.choose_action(s_norm)
            sp, r, terminal = env.step(actions[a])
            sp_norm = normalize(sp)
            w, theta = agent.learn(s_norm, a, r, sp_norm, terminal, policy)
            print(s_norm)
            if terminal:
                break
            s = sp
            s_norm = sp_norm
        
            if vis:
                cont = env.render()
                if cont is False or terminal:
                    pg.quit()
                    vis = False

        steps_per_e[episode] = t

    return w, theta, steps_per_e
        

In [4]:
def play(env, agent, n_episodes=1):
    pg.init()
    for episode in range(n_episodes):
        s = env.reset()
        s_norm = normalize(s)
        for t in count():
            a, _ = agent.choose_action(s_norm)
            sp, r, terminal = env.step(actions[a])
            sp_norm = normalize(sp)
            if terminal:
                break
            s = s_norm
            sp = sp_norm

            cont = env.render()
            if cont is False or terminal:
                break
    pg.quit()
        

In [5]:
def load_weights(idx, chunk):
    post = f"{(idx + 1) * chunk}.npy"
    Q = np.load(f"Q_{post}")
    s_count = np.load(f"s_count_{post}")
    sa_count = np.load(f"sa_count_{post}")
    return Q, s_count, sa_count
    
def save_weights(dict_, idx, chunk):
    for name, val in dict_.items():
        np.save(f"{name}_{(idx + 1) * chunk}", val)


In [6]:
if __name__ == "__main__":
    degree = 3
    n_states = 5
    n_features = (degree + 1) ** n_states
    actions = [-2, 0, 2]
    n_actions = len(actions)

    alpha = 1e-4
    lambda_w = 0.5
    lambda_theta = 0.5
    gamma = 0.99

    offset = 0
    n_episodes = 150
    try:
        w = np.load(f"w_{offset}.npy")
        theta = np.load(f"theta_{offset}.npy")
        print("Imported Weights!")
    except FileNotFoundError:
        w, theta = np.zeros(n_features), np.zeros(n_actions * n_features)


In [8]:
agent = ActorCriticETAgent(n_features, degree, n_states, n_actions, alpha, gamma, lambda_theta, lambda_w)
env = PongSoloEnv(n_states, actions)
w, theta, steps_per_e = train(env, agent, n_episodes, vis=False)
np.save(f"w_{n_episodes + offset}", w)
np.save(f"theta_{n_episodes + offset}", theta)


  0%|          | 0/150 [00:00<?, ?it/s]

[0.66978695 0.93908353 0.5        0.5921061  0.00855675]
[0.66978695 0.94829414 0.47542784 0.5921061  0.00855675]
[0.68978695 0.95750475 0.45085568 0.5921061  0.00855675]
[0.66978695 0.96671536 0.42628351 0.5921061  0.00855675]
[0.64978695 0.97592597 0.40171135 0.5921061  0.00855675]
[0.64978695 0.98       0.37713919 0.40789493 0.00855675]
[0.64978695 0.97078949 0.35256703 0.40789493 0.00855675]
[0.66978695 0.96157899 0.32799486 0.40789493 0.00855675]
[0.66978695 0.95236848 0.3034227  0.40789493 0.00855675]
[0.66978695 0.94315797 0.27885054 0.40789493 0.00855675]
[0.64978695 0.93394746 0.25427838 0.40789493 0.00855675]
[0.66978695 0.92473696 0.22970621 0.40789493 0.00855675]
[0.64978695 0.91552645 0.20513405 0.40789493 0.00855675]
[0.62978695 0.90631594 0.18056189 0.40789493 0.00855675]
[0.60978695 0.89710543 0.15598973 0.40789493 0.00855675]


  1%|          | 1/150 [00:00<02:13,  1.12it/s]

[0.58978695 0.88789493 0.13141756 0.40789493 0.00855675]
[0.60978695 0.87868442 0.1068454  0.40789493 0.00855675]
[0.58978695 0.86947391 0.08227324 0.40789493 0.00855675]
[0.15755523 0.27089796 0.5        0.73974053 0.93877611]
[0.15755523 0.29487201 0.52193881 0.73974053 0.93877611]
[0.17755523 0.31884606 0.54387761 0.73974053 0.93877611]
[0.19755523 0.34282011 0.56581642 0.73974053 0.93877611]
[0.19755523 0.36679417 0.58775522 0.73974053 0.93877611]
[0.21755523 0.39076822 0.60969403 0.73974053 0.93877611]
[0.23755523 0.41474227 0.63163283 0.73974053 0.93877611]
[0.21755523 0.43871633 0.65357164 0.73974053 0.93877611]
[0.21755523 0.46269038 0.67551045 0.73974053 0.93877611]
[0.23755523 0.48666443 0.69744925 0.73974053 0.93877611]
[0.25755523 0.51063849 0.71938806 0.73974053 0.93877611]
[0.23755523 0.53461254 0.74132686 0.73974053 0.93877611]
[0.21755523 0.55858659 0.76326567 0.73974053 0.93877611]
[0.19755523 0.58256065 0.78520447 0.73974053 0.93877611]
[0.17755523 0.6065347  0.807143

[0.16       0.20220688 0.17826419 0.53876742 0.06122389]
[0.18       0.20608362 0.15632538 0.53876742 0.06122389]
[0.2        0.20996037 0.13438658 0.53876742 0.06122389]
[0.18       0.21383711 0.11244777 0.53876742 0.06122389]
[0.18       0.21771385 0.09050897 0.53876742 0.06122389]
[0.2        0.22159059 0.0725     0.73876742 0.93877611]
[0.18       0.24546734 0.09443881 0.73876742 0.93877611]
[0.18       0.26934408 0.11637761 0.73876742 0.93877611]
[0.18       0.29322082 0.13831642 0.73876742 0.93877611]
[0.18       0.31709756 0.16025522 0.73876742 0.93877611]


  1%|          | 1/150 [00:08<22:04,  8.89s/it]

[0.18       0.3409743  0.18219403 0.73876742 0.93877611]
[0.2        0.36485105 0.20413283 0.73876742 0.93877611]
[0.18       0.38872779 0.22607164 0.73876742 0.93877611]





KeyboardInterrupt: 

In [None]:
# plt.plot(steps_per_e)

In [None]:
# env = PongSoloEnv(n_states, actions)
# play(env, agent, n_episodes=10)
