# Collaboration and Competition

---

This notebook uses Multi-Agent Deep Deterministic Policy Gradients (MADDPG) to solve the tennis environment.

## I. Preparation

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('ggplot')
from collections import deque
from unityagents import UnityEnvironment
from maddpg_agent import MADDPGAgents

In [None]:
def train_maddpg(n_episodes=2000, max_t=2000, print_every=100, rolling_window=100, 
         output_actor='checkpoint_actor.pth', output_critic='checkpoint_critic.pth'):
    # Initialize
    scores_pass = 0.5
    solved = False
    brain_name = env.brain_names[0]
    scores_deque = deque(maxlen=rolling_window)
    scores = []
    # Iterate through episodes
    for i_episode in range(1, n_episodes+1):
        env_info = env.reset(train_mode=True)[brain_name]
        states = env_info.vector_observations
        agents.reset()
        score = np.zeros(agents.n_agents)
        for t in range(max_t):
            actions = agents.act(states)
            env_info = env.step(actions)[brain_name]
            next_states = env_info.vector_observations
            rewards = env_info.rewards
            dones = env_info.local_done         
            agents.step(states, actions, rewards, next_states, dones)
            states = next_states
            score += rewards
            if np.any(dones):
                break 
        scores_deque.append(np.max(score))
        scores.append(np.max(score))
        # Print results
        print('\rEpisode {} - Rolling Avg. Score (Max): {:.2f}'.format(i_episode, np.mean(scores_deque)), end="")
        if i_episode % print_every == 0:
            print('')
        # Save (save only one, since the two agents are trying to learn the same strategy.)
        if np.mean(scores_deque)>=scores_pass:
            if solved == False:
                print('\n * Environment first solved in {:d} episodes! Continue training...'.format(i_episode))
                solved = True
            scores_pass = np.mean(scores_deque)
            torch.save(agents.agents[0].actor_local.state_dict(), output_actor)
            torch.save(agents.agents[0].critic_local.state_dict(), output_critic)
    return scores

def plot_score(scores, rolling_window=100):
    df_scores = pd.DataFrame(scores, columns=['EpisodeScore'])
    df_scores['MovingAvg100'] = df_scores['EpisodeScore'].rolling(rolling_window, min_periods=1).mean()
    df_scores.plot(color=['grey', 'red'])
    plt.ylabel('Score')
    plt.xlabel('Episode #')
    plt.show()

## II. Environment

In [None]:
env = UnityEnvironment(file_name="Tennis_Windows\Tennis.exe")
brain_name = env.brain_names[0]
brain = env.brains[brain_name]
env_info = env.reset(train_mode=True)[brain_name]

## III. Train the Agents

In [None]:
# Train model or load model
train_model = True

In [None]:
# Environment Info
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
output_actor = 'checkpoint_actor.pth'
output_critic = 'checkpoint_critic.pth'
n_agents = len(env_info.agents)
action_size = brain.vector_action_space_size
state_size = env_info.vector_observations.shape[1]

# Agents
agents = MADDPGAgents(state_size, action_size, n_agents)
if train_model:
    scores = train_maddpg(output_actor=output_actor, output_critic=output_critic)
    plot_score(scores)
else:
    for i in range(n_agents):
        agents.agents[i].actor_local.load_state_dict(torch.load(output_actor))
        agents.agents[i].critic_local.load_state_dict(torch.load(output_critic))

## IV. Watch the Smart Agents

In [None]:
# Watch the trained agent
env_info = env.reset(train_mode=False)[brain_name] 
states = env_info.vector_observations
scores = np.zeros(n_agents)
while True:
    actions = agents.act(states, add_noise=False)
    env_info = env.step(actions)[brain_name]
    next_states = env_info.vector_observations
    rewards = env_info.rewards 
    dones = env_info.local_done
    scores += rewards
    states = next_states 
    if np.any(dones):
        break
print('Scores of this episode: {}'.format(scores))
# env.close()