Skip to content
Permalink
Browse files

best shit

  • Loading branch information...
Robert Bazzocchi
Robert Bazzocchi committed Apr 2, 2018
1 parent c7cb1f7 commit d9b15451b859ea8f0cdd6fa06f7deb798e058b0f
Showing with 8 additions and 8 deletions.
  1. +8 −8 tictactoe.py
@@ -170,7 +170,6 @@ def compute_returns(rewards, gamma=1.0):

return G


def finish_episode(saved_rewards, saved_logprobs, gamma=1.0):
"""Samples an action from the policy at the state."""
policy_loss = []
@@ -189,14 +188,14 @@ def finish_episode(saved_rewards, saved_logprobs, gamma=1.0):
def get_reward(status):
"""Returns a numeric given an environment status."""
return {
Environment.STATUS_VALID_MOVE : 25, # TODO
Environment.STATUS_INVALID_MOVE: -75,
Environment.STATUS_WIN : 100,
Environment.STATUS_TIE : 0,
Environment.STATUS_LOSE : -100
Environment.STATUS_VALID_MOVE : 0, # 0
Environment.STATUS_INVALID_MOVE: -150, # -75
Environment.STATUS_WIN : 100, # 100
Environment.STATUS_TIE : 0, # 0
Environment.STATUS_LOSE : -100 # 100
}[status]

def train(policy, env, gamma=1.0, log_interval=1000):
def train(policy, env, gamma=0.75, log_interval=1000):
"""Train policy gradient."""
optimizer = optim.Adam(policy.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(
@@ -210,14 +209,15 @@ def train(policy, env, gamma=1.0, log_interval=1000):
state = env.reset()
done = False

if i_episode % log_interval == 0: print(select_action(policy,state)[0])
while not done:
action, logprob = select_action(policy, state)
state, status, done = env.play_against_random(action)
reward = get_reward(status)
saved_logprobs.append(logprob)
saved_rewards.append(reward)

if -50 in saved_rewards:
if -150 in saved_rewards:
num_invalid_moves += 1

R = compute_returns(saved_rewards)[0]

0 comments on commit d9b1545

Please sign in to comment.
You can’t perform that action at this time.