In [1]:
import torch
import torch.nn as nn
from torch.distributions import Categorical
import numpy as np
from RL import ActorCriticAgent
from single_agent_env import SinglePlayerFootball, ACTION_SPACE_SIZE, STATE_SPACE_SIZE
import matplotlib.pyplot as plt
torch.manual_seed(3407)
np.random.seed(3407)

In [2]:
class AC(nn.Module):

    def __init__(self, observation_size, action_size):
        super().__init__()
        self.actor = nn.Sequential(
            nn.Linear(observation_size, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 128),
            nn.LeakyReLU(),
            nn.Linear(128, action_size),
            nn.Softmax(dim=0)
        )
        self.critic = nn.Sequential(
            nn.Linear(observation_size, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        probs = self.actor(x)
        distribution = Categorical(probs)
        action = distribution.sample()

        value = self.critic(x)
        return action.item(), distribution.log_prob(action), value

In [3]:
env = SinglePlayerFootball(title="Actor critic train")
agent = ActorCriticAgent(STATE_SPACE_SIZE, ACTION_SPACE_SIZE, device="cuda:1")
agent.create_model(AC, lr=0.00025, y=0.99)
scores = []

while env.running:
    reward = []
    s = env.reset()
    while not env.loop_once():
        a = agent.policy(s)
        s, r, d = env.step(a)
        agent.learn(r, d)
        reward.append(r)
    scores.append(sum(reward))

env.rendering = True

for _ in range(10):
    s = env.reset()
    while not env.loop_once():
        a = agent.policy(s)
        s, r, d = env.step(a)

del env

Episode: 1 | Train: 1 | loss: -34.762226
Episode: 2 | Train: 2 | loss: -40.457550
Episode: 3 | Train: 3 | loss: -50.895206
Episode: 4 | Train: 4 | loss: -58.939812
Episode: 5 | Train: 5 | loss: -38.102802
Episode: 6 | Train: 6 | loss: -71.852821
Episode: 7 | Train: 7 | loss: -81.591461
Episode: 8 | Train: 8 | loss: -88.550697
Episode: 9 | Train: 9 | loss: -81.474663
Episode: 10 | Train: 10 | loss: -92.824287
Episode: 11 | Train: 11 | loss: -92.915909
Episode: 12 | Train: 12 | loss: -102.130165
Episode: 13 | Train: 13 | loss: -98.827209
Episode: 14 | Train: 14 | loss: -79.566116
Episode: 15 | Train: 15 | loss: -48.017170
Episode: 16 | Train: 16 | loss: -110.282333
Episode: 17 | Train: 17 | loss: -94.828400
Episode: 18 | Train: 18 | loss: -107.803551
Episode: 19 | Train: 19 | loss: -79.173119
Episode: 20 | Train: 20 | loss: -54.453697
Episode: 21 | Train: 21 | loss: -86.694695
Episode: 22 | Train: 22 | loss: -104.047600
Episode: 23 | Train: 23 | loss: -81.793900
Episode: 24 | Train: 24 |

In [None]:
plt.plot(scores)
plt.show()