Skip to content

Commit

Permalink
best shit
Browse files Browse the repository at this point in the history
  • Loading branch information
Robert Bazzocchi committed Apr 2, 2018
1 parent c7cb1f7 commit d9b1545
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tictactoe.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def compute_returns(rewards, gamma=1.0):


return G return G



def finish_episode(saved_rewards, saved_logprobs, gamma=1.0): def finish_episode(saved_rewards, saved_logprobs, gamma=1.0):
"""Samples an action from the policy at the state.""" """Samples an action from the policy at the state."""
policy_loss = [] policy_loss = []
Expand All @@ -189,14 +188,14 @@ def finish_episode(saved_rewards, saved_logprobs, gamma=1.0):
def get_reward(status): def get_reward(status):
"""Returns a numeric given an environment status.""" """Returns a numeric given an environment status."""
return { return {
Environment.STATUS_VALID_MOVE : 25, # TODO Environment.STATUS_VALID_MOVE : 0, # 0
Environment.STATUS_INVALID_MOVE: -75, Environment.STATUS_INVALID_MOVE: -150, # -75
Environment.STATUS_WIN : 100, Environment.STATUS_WIN : 100, # 100
Environment.STATUS_TIE : 0, Environment.STATUS_TIE : 0, # 0
Environment.STATUS_LOSE : -100 Environment.STATUS_LOSE : -100 # 100
}[status] }[status]


def train(policy, env, gamma=1.0, log_interval=1000): def train(policy, env, gamma=0.75, log_interval=1000):
"""Train policy gradient.""" """Train policy gradient."""
optimizer = optim.Adam(policy.parameters(), lr=0.001) optimizer = optim.Adam(policy.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR( scheduler = torch.optim.lr_scheduler.StepLR(
Expand All @@ -210,14 +209,15 @@ def train(policy, env, gamma=1.0, log_interval=1000):
state = env.reset() state = env.reset()
done = False done = False


if i_episode % log_interval == 0: print(select_action(policy,state)[0])
while not done: while not done:
action, logprob = select_action(policy, state) action, logprob = select_action(policy, state)
state, status, done = env.play_against_random(action) state, status, done = env.play_against_random(action)
reward = get_reward(status) reward = get_reward(status)
saved_logprobs.append(logprob) saved_logprobs.append(logprob)
saved_rewards.append(reward) saved_rewards.append(reward)


if -50 in saved_rewards: if -150 in saved_rewards:
num_invalid_moves += 1 num_invalid_moves += 1


R = compute_returns(saved_rewards)[0] R = compute_returns(saved_rewards)[0]
Expand Down

0 comments on commit d9b1545

Please sign in to comment.