In [None]:
import torch
import torch.nn as nn
import numpy as np
from RL import ActorCriticAgent
from single_agent_envs import SinglePlayerFootballParallel, ACTION_SPACE_SIZE, STATE_SPACE_SIZE
torch.manual_seed(3407)
torch.cuda.manual_seed(3407)
np.random.seed(3407)

In [None]:
class Actor(nn.Module):

    def __init__(self, observation_size, action_size):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(observation_size, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
            nn.Linear(32, action_size),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        return self.model(x)


class Critic(nn.Module):

    def __init__(self, observation_size):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(observation_size, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.model(x)

In [None]:
ENV_COUNT = 4
TRAIN_ID = f"AC_fixed_ball_scratch_norm_parallel_{ENV_COUNT}"

In [None]:
agent = ActorCriticAgent(STATE_SPACE_SIZE, ACTION_SPACE_SIZE, device="cuda:1")
agent.create_model(Actor, Critic, actor_lr=0.0003, critic_lr=0.0003, gamma=0.99, entropy_coef=0.0003, gae_lambda=0.9, env_count=ENV_COUNT, step_count=300, reward_norm_factor=300)

In [None]:
envs = SinglePlayerFootballParallel(env_count=ENV_COUNT, title=TRAIN_ID, random_ball=False)
best_score = 200
while envs.running:
    states = envs.reset()
    done = False
    while not done:
        actions = agent.policy(states)
        n_states, rewards, dones = envs.step(actions)
        agent.learn(states, actions, n_states, rewards, dones)
        states = n_states
        done = any(dones)
    if agent.episode_counter % 100 == 0:
        avg = np.mean(agent.reward_history[len(agent.reward_history) - 100:])
        if avg > best_score:
            best_score = avg
            model_scripted = torch.jit.script(agent.actor)
            model_scripted.save(f"best_ac_models/actor_{TRAIN_ID}_{agent.episode_counter}_{round(best_score, 6)}.pt")
            model_scripted = torch.jit.script(agent.critic)
            model_scripted.save(f"best_ac_models/critic_{TRAIN_ID}_{agent.episode_counter}_{round(best_score, 6)}.pt")
del envs

In [None]:
with open(f'{TRAIN_ID}_rewards.txt', 'w') as f:
    f.writelines([f"{round(item, 6)}\n" for item in agent.reward_history])

In [None]:
agent.training = False
env = SinglePlayerFootballParallel(title=TRAIN_ID)
for _ in range(10):
    s = env.reset()
    while not env.loop_once():
        s, _, _ = env.step(agent.policy(s))
del env

In [None]:
agent.actor = torch.jit.load("best_ac_models/actor_AC_fixed_ball_norm_transfer_it3_4300_171.18.pt")
agent.critic = torch.jit.load("best_ac_models/critic_AC_fixed_ball_norm_transfer_it3_4300_171.18.pt")
agent.actor.to(agent.device)
agent.critic.to(agent.device)

In [None]:
model_scripted = torch.jit.script(agent.actor)
model_scripted.save(f"best_ac_models/actor/{TRAIN_ID}_{agent.episode_counter}.pt")
model_scripted = torch.jit.script(agent.critic)
model_scripted.save(f"best_ac_models/critic/{TRAIN_ID}_{agent.episode_counter}.pt")