In [2]:
import gym
import torch
env = gym.make('CartPole-v0')
num_state = env.observation_space.shape[0]
num_action = env.action_space.n

#for each episode, take environment, weight
#use
def run_episode(env, weight):
    state = env.reset()
    grads = []
    total_reward = 0
    is_done = False
    while not is_done:
        state = torch.from_numpy(state).float()
        z = torch.matmul(state, weight)
        probs = torch.nn.Softmax()(z)
        action = int(torch.bernoulli(probs[1]).item())
        d_softmax = torch.diag(probs) -probs.view(-1, 1) * probs
        d_log = d_softmax[action] / probs[action]
        grad = state.view(-1, 1) * d_log
        grads.append(grad)
        state, reward, is_done, _ = env.step(action)
        total_reward += reward
        if is_done:
             break
    return total_reward, grads

num_episode = 1000
weight = torch.rand(num_state, num_action)
total_rewards = []
learning = 0.001
for episode in range(num_episode):
    total_reward, gradients = run_episode(env, weight)
    print('Episode {}: {}'.format(episode + 1, total_reward))
    for i, gradient in enumerate(gradients):
        weight += learning * gradient * (total_reward - i)
    total_rewards.append(total_reward)
    
import matplotlib.pyplot as plt
plt.plot(total_rewards)
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.show()



Episode 1: 30.0
Episode 2: 10.0
Episode 3: 13.0
Episode 4: 21.0
Episode 5: 11.0
Episode 6: 20.0
Episode 7: 18.0
Episode 8: 18.0
Episode 9: 10.0
Episode 10: 11.0
Episode 11: 11.0
Episode 12: 9.0
Episode 13: 9.0
Episode 14: 27.0
Episode 15: 9.0
Episode 16: 19.0
Episode 17: 10.0
Episode 18: 11.0
Episode 19: 9.0
Episode 20: 16.0
Episode 21: 16.0
Episode 22: 12.0
Episode 23: 20.0
Episode 24: 18.0
Episode 25: 15.0
Episode 26: 11.0
Episode 27: 15.0
Episode 28: 11.0
Episode 29: 10.0
Episode 30: 16.0
Episode 31: 23.0
Episode 32: 13.0
Episode 33: 28.0
Episode 34: 18.0
Episode 35: 15.0
Episode 36: 12.0
Episode 37: 12.0
Episode 38: 22.0
Episode 39: 21.0
Episode 40: 15.0
Episode 41: 16.0
Episode 42: 23.0
Episode 43: 11.0
Episode 44: 10.0
Episode 45: 14.0
Episode 46: 18.0
Episode 47: 12.0
Episode 48: 10.0
Episode 49: 18.0
Episode 50: 36.0
Episode 51: 29.0
Episode 52: 14.0
Episode 53: 17.0
Episode 54: 51.0
Episode 55: 18.0
Episode 56: 25.0
Episode 57: 24.0
Episode 58: 30.0
Episode 59: 24.0
Episode 60

Episode 447: 200.0
Episode 448: 200.0
Episode 449: 200.0
Episode 450: 200.0
Episode 451: 200.0
Episode 452: 200.0
Episode 453: 200.0
Episode 454: 199.0
Episode 455: 200.0
Episode 456: 200.0
Episode 457: 200.0
Episode 458: 200.0
Episode 459: 200.0
Episode 460: 200.0
Episode 461: 200.0
Episode 462: 200.0
Episode 463: 200.0
Episode 464: 200.0
Episode 465: 200.0
Episode 466: 200.0
Episode 467: 200.0
Episode 468: 200.0
Episode 469: 200.0
Episode 470: 200.0
Episode 471: 200.0
Episode 472: 200.0
Episode 473: 200.0
Episode 474: 200.0
Episode 475: 200.0
Episode 476: 200.0
Episode 477: 200.0
Episode 478: 200.0
Episode 479: 80.0
Episode 480: 123.0
Episode 481: 200.0
Episode 482: 45.0
Episode 483: 68.0
Episode 484: 138.0
Episode 485: 200.0
Episode 486: 200.0
Episode 487: 200.0
Episode 488: 200.0
Episode 489: 200.0
Episode 490: 200.0
Episode 491: 200.0
Episode 492: 200.0
Episode 493: 200.0
Episode 494: 200.0
Episode 495: 200.0
Episode 496: 200.0
Episode 497: 200.0
Episode 498: 200.0
Episode 499: 20

Episode 881: 200.0
Episode 882: 200.0
Episode 883: 200.0
Episode 884: 200.0
Episode 885: 200.0
Episode 886: 200.0
Episode 887: 200.0
Episode 888: 200.0
Episode 889: 200.0
Episode 890: 200.0
Episode 891: 200.0
Episode 892: 92.0
Episode 893: 200.0
Episode 894: 200.0
Episode 895: 200.0
Episode 896: 200.0
Episode 897: 194.0
Episode 898: 68.0
Episode 899: 154.0
Episode 900: 170.0
Episode 901: 200.0
Episode 902: 200.0
Episode 903: 200.0
Episode 904: 200.0
Episode 905: 199.0
Episode 906: 200.0
Episode 907: 200.0
Episode 908: 200.0
Episode 909: 114.0
Episode 910: 200.0
Episode 911: 200.0
Episode 912: 200.0
Episode 913: 200.0
Episode 914: 200.0
Episode 915: 200.0
Episode 916: 200.0
Episode 917: 200.0
Episode 918: 140.0
Episode 919: 200.0
Episode 920: 158.0
Episode 921: 200.0
Episode 922: 80.0
Episode 923: 184.0
Episode 924: 200.0
Episode 925: 123.0
Episode 926: 177.0
Episode 927: 200.0
Episode 928: 200.0
Episode 929: 200.0
Episode 930: 200.0
Episode 931: 200.0
Episode 932: 200.0
Episode 933: 20

<Figure size 640x480 with 1 Axes>

In [32]:
print('Average total reward over {} episode: {}'.format(num_episode, sum(total_rewards) / num_episode))

Average total reward over 1000 episode: 171.818


In [40]:
num_episode_eval = 100
total_rewards_eval = []
for episode in range(num_episode_eval):
     total_reward, _ = run_episode(env, weight)
print('Episode {}: {}'.format(episode+1, total_reward))
total_rewards_eval.append(total_reward)

  


Episode 100: 200.0


<Figure size 640x480 with 1 Axes>