In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import torch
import sys, os
import pystk
import ray
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('device = ', device)
ray.init(logging_level=50)

In [None]:
from state_agent.agents.subnets.actors import SteeringActor, DriftActor, SpeedActor
from state_agent.agents.subnets.planners import PlayerPuckGoalPlannerActor
from state_agent.agents.subnets.agents import Agent, BaseTeam
from state_agent.agents.subnets.utils import Rollout, run_soccer_agent, rollout_many, show_trajectory_histogram, load_model, save_model
from state_agent.agents.subnets.rewards import SoccerBallDistanceObjective
from state_agent.agents.subnets.features import get_distance_cart_to_puck
from state_agent.trainers.train_policy_gradient import reinforce, SoccerReinforcementConfiguration

import numpy as np
import copy
import matplotlib.pyplot as plt

In [None]:
data = run_soccer_agent(Agent(SteeringActor(), train=True))

In [None]:
def get_initializations(actor_class):    
    distance_objective = SoccerBallDistanceObjective(150)
    many_actors = [actor_class() for i in range(100)]

    data = rollout_many([
        Agent(actor, accel=0.05) for actor in many_actors
    ], randomize=True, n_steps=600)

    good_initialization = many_actors[ np.argmax([distance_objective.calculate_state_score(d[-1]) for d in data]) ]
    bad_initialization = many_actors[ np.argmin([distance_objective.calculate_state_score(d[-1]) for d in data]) ]
    
    return good_initialization, bad_initialization

good_initialization, _ = get_initializations(SteeringActor)

In [None]:
#good_initialization = best_steering_net
action_net = copy.deepcopy(good_initialization.action_net)
actors = [SteeringActor(action_net)]

def gen_agent(*args, **kwargs):
    return Agent(*args, accel=0.05, target_speed=10.0, **kwargs)

# configuration
config = SoccerReinforcementConfiguration()
config.agent = gen_agent

# iterations is high relatively here to help force a good outcome from a bad initialization
best_steering_net = reinforce(actors[0], actors, config, 
                              n_epochs=5, n_iterations=500, n_trajectories=200, n_validations=100, T=1
                    )

In [None]:
data = run_soccer_agent(Agent(SteeringActor(best_steering_net), accel=0.1), randomize=True, ball_location=[-6., -60.], player_location=[-20, 0, -50])

In [None]:
# save the steering actor
save_model(best_steering_net, 'modules/steering/agent.th')

In [None]:
# load the steering actor
best_steering_net = load_model('modules/steering/agent.th', model=SteeringActor().action_net)

In [None]:
# train the speed actor
good_initialization_speed, _ = get_initializations(SpeedActor)

action_net = copy.deepcopy(good_initialization_speed.action_net)
actors = [SteeringActor(best_steering_net, train=False), SpeedActor(action_net)]

def gen_agent(*args, **kwargs):
    reverse = np.random.uniform(0, 1) < 0.1
    speed = np.random.normal(10, 5) * (-1.0 if reverse else 1.0)
    return Agent(*args, target_speed=speed, **kwargs)

# configuration
config = SoccerReinforcementConfiguration()
config.agent = gen_agent

# iterations is high relatively here to help force a good outcome from a bad initialization
best_speed_net = reinforce(actors[1], actors, config, 
                              n_epochs=5, n_iterations=500, n_trajectories=200, n_validations=100, T=1
                    )

In [None]:
data = run_soccer_agent(Agent(SteeringActor(best_steering_net), SpeedActor(best_speed_net), target_speed=-5.0), randomize=True)

In [None]:
save_model(best_speed_net, 'modules/speed/agent.th')

In [None]:
# load the speed actor
best_speed_net = load_model('modules/speed/agent.th', model=SpeedActor().action_net)

In [None]:
# train the player goal scoring planner

def create_planner_actor():
    return PlayerPuckGoalPlannerActor(
        SpeedActor(best_speed_net),
        SteeringActor(best_steering_net)
    )

def gen_agent(*args, **kwargs):
    return Agent(*args, accel=0.1, **kwargs)

def rollout_initializer(world_info, randomize, **kwargs):
        
    #wall_case = np.random.uniform(0, 1.0) < 0
    wall_case = False
    
    # generate a rollout where the player and puck are near each other
    #position = np.random.uniform(low=-10, high=10, size=(2))
    offset = [np.random.uniform(-0.2, 0.2), -6]    
    world_info.set_ball_location((position[0], 1, position[1]), (0, 0, 0))        
    
    if wall_case:
        player_location = [20, 1, 62]
    else:        
        player_location = [position[0] + offset[0], 1, position[1] + offset[1]]
    world_info.set_kart_location(0, player_location, [0, 0, 0, 1.0], 0)
        
def post_epoch(actor, context):
    # show a histogram of distances
    show_trajectory_histogram(context.trajectories, get_distance_cart_to_puck, max=60, bins=20)
    plt.hist(context.rewards) 
    plt.title("Rewards")
    plt.show()
    plt.hist(context.actions, 4, range=(0, 4)) 
    plt.title("Actions")
    plt.show()
    print(np.sum(np.array(context.actions) == 0), np.sum(np.array(context.actions) == 1), np.sum(np.array(context.actions) == 2))

good_initialization_planner, _ = get_initializations(create_planner_actor)

#action_net = copy.deepcopy(good_initialization_planner.action_net)
actors = [PlayerPuckGoalPlannerActor(SpeedActor(best_speed_net), SteeringActor(best_steering_net), action_net)]

# give it a positive random weight to make it the worst case
#action_net.net[0].weight = torch.nn.Parameter(torch.Tensor([[np.random.uniform(0, 1.0)]]))

#starting_weight = action_net.net[0].weight.clone()
#print("Starting weight", action_net.net[0].weight)

# configuration
config = SoccerReinforcementConfiguration()
config.agent = gen_agent
config.rollout_initializer = rollout_initializer

# iterations is high relatively here to help force a good outcome from a bad initialization
best_planner_net = reinforce(actors[0], actors, config, 
                              n_epochs=4, n_iterations=1000, n_trajectories=200, n_validations=20, T=1,
                              epoch_post_process=post_epoch
                    )
print(best_planner_net.net[0].weight)
print(action_net.net[0].weight)

#assert(action_net.net[0].weight != starting_weight)

In [None]:
best_planner_net = action_net
data = run_soccer_agent(Agent( 
    PlayerPuckGoalPlannerActor(
        speed_net=SpeedActor(best_speed_net),
        steering_net=SteeringActor(best_steering_net), 
        action_net=best_planner_net        
    ), accel=0.1
), randomize=True)

In [None]:
# save the planner actor
save_model(best_planner_net, 'modules/planner/agent.th')

In [None]:
# load the planner actor
best_planner_net = load_model('modules/planner/agent.th', model=PlayerPuckGoalPlannerActor(best_speed_net, best_steering_net).action_net)

In [None]:
viz_rollout_soccer = Rollout.remote(400, 300, mode="soccer", players=[(0, False, "tux")], num_karts=1)
data = run_soccer_agent(Agent( 
    PlayerPuckGoalPlannerActor(
        speed_net=SpeedActor(best_speed_net),
        steering_net=SteeringActor(best_steering_net), 
        action_net=best_planner_net        
    )
), rollout=viz_rollout_soccer)