# REINFORCE

## Import the Necessary Packages

In [None]:
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
%matplotlib inline

import torch
torch.manual_seed(0) # set random seed
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

import world
from helpers import *

## Clear all previous data

In [None]:
import os
import shutil

if os.path.exists('Players_Data'):
    shutil.rmtree('Players_Data/')
    os.mkdir('Players_Data')
else:
    os.mkdir('Players_Data')

## Initial states of the agents

In [None]:
STATE_SIZE = 21
ACTION_SIZE = 13

In [None]:
def pad_state(state, maxlen):
    if len(state) > maxlen:
        return state[:maxlen]
    elif len(state) < maxlen:
        new_state = np.zeros((maxlen,))
        new_state[:len(state)] = state
        return new_state
    elif len(state) == maxlen:
        return state

In [None]:
def get_state_single(players, my_particles, killed, i):
    global STATE_SIZE
    
    initial_state = []
    if type(players[i]) != int:
        env_particles,env_particle_distance = food_in_env(players[i], my_particles)
        env_food_vector = getFoodVector(players[i],env_particles, my_particles)
        env_food_vector = sum(env_food_vector, [])

        env_players, env_player_distance = players_in_env(players[i],players)
        env_player_vector = getPlayerVector(players[i],env_players, players)
        env_player_vector = sum(env_player_vector, [])

        temp_state = [env_food_vector, env_player_vector]
        temp_state = sum(temp_state, [])
        initial_state.append(np.array(temp_state))
    else:
        initial_state.append(np.array([0]))

    initial_state = [np.append(initial_state[i], players[i].energy) if type(players[i]) != int else np.append(initial_state[i], -100) for i in range(1)]

    return np.array(initial_state)

In [None]:
def get_state(players, my_particles, killed):
    global STATE_SIZE
    
    initial_state = []
    for i in range(len(players)):
        if type(players[i]) != int:
            env_particles,env_particle_distance = food_in_env(players[i], my_particles)
            env_food_vector = getFoodVector(players[i],env_particles, my_particles)
            env_food_vector = sum(env_food_vector, [])

            env_players, env_player_distance = players_in_env(players[i],players)
            env_player_vector = getPlayerVector(players[i],env_players, players)
            env_player_vector = sum(env_player_vector, [])

            temp_state = [env_food_vector, env_player_vector]
            temp_state = sum(temp_state, [])
            initial_state.append(np.array(temp_state))
        else:
            initial_state.append(np.array([0]))

    initial_state = [pad_state(state, STATE_SIZE-1) for state in initial_state]
    initial_state = [np.append(initial_state[i], players[i].energy) if type(players[i]) != int else np.append(initial_state[i], -100) for i in range(len(players))]

    return np.array(initial_state)

## Define the Architecture of the Policy

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Agent(nn.Module):
    def __init__(self, s_size=STATE_SIZE, h_size=30, a_size=ACTION_SIZE):
        super(Agent, self).__init__()
        self.fc1 = nn.Linear(s_size, h_size)
        self.fc2 = nn.Linear(h_size, a_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim=1)
    
    def act(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        probs = self.forward(state).cpu()
        m = Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)

## Train the Agent with REINFORCE

In [None]:
def reinforce(n_episodes=1000, max_t=1000, gamma=1.0, print_every=100):
    agents = [Agent().to(device) for _ in range(world.INITIAL_POPULATION)]
    optimizers = [optim.Adam(agent.parameters(), lr=1e-2) for agent in agents]
    
    TIME = -1
    regenerate_times = 0
    MAX_REGENERATIONS = 100
    allow_regenerate = True
    FOOD_REGEN_CONDITION_IS_MET = False

    
    players, killed, my_particles = world.init()
    
    states = get_state(players, my_particles, killed)

    scores = [0 for _ in range(len(players))]
    saved_log_probs = {i:[] for i in range(len(players))}
    rewards = {i:[] for i in range(len(players))}
    
    while True:
        if(len(killed) == len(players)):
            print(killed)
            break
        TIME += 1
        for i, agent in enumerate(agents):
            if type(players[i]) != int:
                action, log_prob = agents[i].act(states[i])
                saved_log_probs[i].append(log_prob)
                reward, done, players, my_particles, killed, mate_idx, TIME = world.take_action(players, my_particles, killed, i, action, TIME)
                rewards[i].append(reward)
                
                if(action == 10 and reward == 0):
                    print("Asexual reproduction")
                    offsprings = len(players) - len(agents)
                    for j in range(len(agents), len(agents) + offsprings):
                        agents.append(Agent().to(device))
                        agents[-1].load_state_dict(agents[i].state_dict())
                        optimizers.append(optim.Adam(agents[-1].parameters(), lr=1e-2))
                        states[i] = 0
                        scores.append(0)
                        saved_log_probs[j] = []
                        rewards[j] = []
                elif(action == 11 and reward == 4):
                    print("Sexual reproduction")
                    dominant_percent = random.randint(0, 10) * 10
                    recessive_percent = 100 - dominant_percent
                    offsprings = len(players) - len(agents)
                    num_dominant = round(offsprings * (dominant_percent / 100))
                    num_recessive = offsprings - num_dominant
                    
                    for j in range(len(agents), len(agents) + num_dominant):
                        agents.append(Agent().to(device))
                        agents[-1].load_state_dict(agents[i].state_dict())
                        optimizers.append(optim.Adam(agents[-1].parameters(), lr=1e-2))
                        scores.append(0)
                        saved_log_probs[j] = []
                        rewards[j] = []
                    for j in range(len(agents) + num_dominant, len(agents) + num_dominant + num_recessive):
                        agents.append(Agent().to(device))
                        agents[-1].load_state_dict(agents[mate_idx].state_dict())
                        optimizers.append(optim.Adam(agents[-1].parameters(), lr=1e-2))
                        scores.append(0)
                        saved_log_probs[j] = []
                        rewards[j] = []
                
                if(type(players[i]) == int):
                    agents[i] = 0
                
                if(TIME % 10):
                    optimizers[i].zero_grad()
                    policy_loss = 0
                    for j in range(len(saved_log_probs[i])):
                        policy_loss += (-saved_log_probs[i][j] * rewards[i][j])
                    policy_loss.backward(retain_graph=True)
                    optimizers[i].step()

                next_states = get_state(players, my_particles, killed)
                rewards[i].append(reward)
                scores[i] += reward
                states = next_states
            
#         if(len(killed) == len(players) and allow_regenerate):
#             discounts = {j:[gamma**i for i in range(len(rewards[j])+1)] for j in range(len(players))}
#             R = {j:sum([a*b for a,b in zip(discounts[j], rewards[j])]) for j in range(len(players))}
            
#             policy_loss = {i:[] for i in range(len(players))}
#             for i, saved_log_prob in saved_log_probs.items():
#                 for log_prob in saved_log_prob:
#                     policy_loss[i].append(-log_prob * R[i])
#                 policy_loss[i] = torch.cat(policy_loss[i]).sum()
            
#                 optimizers[i].zero_grad()
#                 policy_loss[i].backward(retain_graph=True)
#                 optimizers[i].step()
            
#             killed = []
#             players = regenerate_species(TIME)
#             print("GENERATION:", regenerate_times, ", score:", scores)
#             regenerate_times += 1
#         elif(len(killed) == INITIAL_POPULATION and not allow_regenerate):
#             running = False

#         if(regenerate_times == MAX_REGENERATIONS):
#             allow_regenerate = False
#             break
        
        
#         if i_episode % print_every == 0:
#             print('Episode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_deque)))
#         if np.mean(scores_deque)>=200.0:
#             print('Environment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(i_episode-100, np.mean(scores_deque)))
#             break
        
#     return scores
    
scores = reinforce()

# 4. Plot the Scores

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)
plt.plot(np.arange(1, len(scores)+1), scores)
plt.ylabel('Score')
plt.xlabel('Episode #')
plt.show()