In [1]:
#!/usr/bin/python
# -*- coding: utf-8 -*-

import sys
import os
import math
import functools

from random import randint, random, choice
from collections import namedtuple, deque
from copy import copy, deepcopy

from time import sleep, time
import tkinter as tk
import numpy as np
import datetime
import pickle
import glob

from thread_manager import spawnthread
from couple import Couple
from form import Circle

from map_menu_struct import *
from fps_manager import fps, fps_manager

from renforcement_learning_neural_network import Renforcement_learning_neural_network

import matplotlib
# While GTK isn't avail everywhere, we use TkAgg backend to generate png
if sys.platform != "win32" and os.getenv("DISPLAY") is None :
    matplotlib.use("Agg")
else :
    matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
import matplotlib.cm as cm

# TODO display a graph of rewards per minutes
# TODO button to remove leroy jenkins
# TODO stats actions, score, plot

SHIPS_NUMBER = 7
SHIPS_SPEED = 8
LASER_SPEED = 10
SHOOTING_ANGLES = 4
MAP_PRECISION = 4
NETWORK = True
PLAYER_FOLLOWS_NETWORK_RULES = True
MAX_TIME = 300
ANTICIPATION = 100
LEROY_RATE = 0.01

RECORD_FOLDER = "ofighter_records"
NETWORKS_FOLDER = "ofighter_networks"


# TODO load network that match required layers/token/unique string

# TODO sleep less each frame is compute time is high (but set minimum sleep)
# TODO add laser sight to player ship
# TODO statistic object containing all kills, etc..
# TODO different type of ships
# TODO different type of thrusters
# TODO evolving ships design
# TODO evolving ships ia
# TODO scripted evolving ships ia
# TODO add a button to see what the neural network see
# TODO increase the max processor usage (multitheading ?)
# TODO use JSON instead of pickle for saving records ?
# TODO separate IA trainer from Tkinter graphics manager
# TODO separate files by class


def now():
    return datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")


class Observation():
    # TODO improve flexibility of actions
    # maps are squares
    # TODO assert must be < to pixel dimensions
    ship_map_side_size = MAP_PRECISION
    laser_map_side_size = MAP_PRECISION
    player_map_side_size = MAP_PRECISION

    observations = [
        "can_shot",
        "ship", 
        "laser", 
        "player",
    ]
    observations_size = {
        "can_shot" : 1, 
        # maps are squares
        "ship" : ship_map_side_size**2, 
        "laser" : laser_map_side_size**2,
        "player" : player_map_side_size**2,
    }

    size = 0
    for label in observations:
        size += observations_size[label]
    print("size observations", size)

    def __init__(self, battleground=None):
        if battleground:
            self.analyse_battleground(battleground)

    def analyse_battleground(self, battleground):
        self.battleground = battleground
        self.ship_map = np.zeros((Observation.ship_map_side_size, Observation.ship_map_side_size))
        self._read_ship_map()
        self.laser_map = np.zeros((Observation.laser_map_side_size, Observation.laser_map_side_size))
        self._read_laser_map()
    
    def _read_ship_map(self):
        # should be the same for y if it wasn't squared
        cellSize = self.battleground.dim.x // Observation.ship_map_side_size
        for ship in self.battleground.ships:
            self.ship_map = ship.body.binary_draw(self.ship_map, cellSize)
        # .T to transpose to print it the right way 
        # print("ship_map\n", self.ship_map.T)

    def _read_laser_map(self):
        # should be the same for y if it wasn't squared
        cellSize = self.battleground.dim.x // Observation.laser_map_side_size
        for laser in self.battleground.lasers:
            self.laser_map = laser.body.binary_draw(self.laser_map, cellSize)
        # .T to transpose to print it the right way 
        # print("laser_map\n", self.laser_map.T)


    def analyse_ship(self, ship):
        ship_observations = copy(self)
        ship_observations.ship = ship
        ship_observations.player_map = np.zeros((Observation.player_map_side_size, Observation.player_map_side_size))
        ship_observations._read_player_map()
        ship_observations.can_shot = ship.can_shot
        return ship_observations

    def _read_player_map(self):
        # should be the same for y if it wasn't squared
        cellSize = self.battleground.dim.x // Observation.player_map_side_size
        self.player_map = self.ship.body.binary_draw(self.player_map, cellSize)
        # .T to transpose to print it the right way 
        # print("player_map\n", self.player_map.T)


    def toVector(self):
        # TODO assert all defined
        vector = np.array(self.can_shot)
        vector = np.append(vector, self.ship_map)
        vector = np.append(vector, self.laser_map)
        vector = np.append(vector, self.player_map)
        # putting the vector vertically instead of horizontally
        vector = vector.reshape(vector.size, 1)
        # print("observations vector\n", vector)
        # print("vector size\n", vector.size)
        # print("expected size\n", Observation.size)
        return vector



    def fromVector(self, vector):
        raise Exception("Not implemented.")



class Action():
    # TODO improve flexibility of actions
    angles = SHOOTING_ANGLES
    angle_size = 2 * math.pi / angles

    actions = [
        "wait", 
        "shoot", 
        "thrust",
        "turn",
    ]
    actions_values = {
        "wait" : "nothing", 
        "shoot" : "nothing", 
        "thrust" : "nothing", 
        "turn" : "radians", 
    }
    actions_size = {
        "nothing" : 1, 
        "radians" : angles,
    }
    actions_container = {
        "radians" : namedtuple("radians", ['direction'])
    }

    size = 0
    for label in actions:
        size += actions_size[actions_values[label]]
    print("size actions", size)

    def __init__(self, vector=None):
        self.label = None
        self.turn = Action.actions_container["radians"]
        if vector:
            self.fromVector(vector)

    @classmethod
    def restrict_angle(cls, angle):
        """angle must be between -pi and pi"""
        restricted_angle = Action.angle_size * ((angle % (2*math.pi)) // Action.angle_size)
        if restricted_angle > math.pi:
            restricted_angle = restricted_angle - (2*math.pi)
        return restricted_angle

    def toVector(self):
        vector = np.zeros((Action.size, 1))
        # TODO assert label not none
        # TODO automatize of indexes
        if self.label == "wait":
            vector[0] = 1
        elif self.label == "shoot":
            vector[1] = 1
        elif self.label == "thrust":
            vector[2] = 1
        elif self.label == "turn":
            # number of indexes already used in the vector
            # TODO assert turn defined
            # TODO test in human recored turn rate correspond well
            shift = 3
            # print("direction", self.turn.direction)
            # angle = math.degrees(self.turn.direction) % 360
            # print("angle", angle)
            # iAngle = int(angle // (360 // Action.angles)) + shift
            iAngle = int((self.turn.direction % (2*math.pi)) / Action.angle_size) + shift
            # print(iAngle){}
            vector[iAngle] = 1
        return vector


    def fromVector(self, vector):
        # print(vector)
        if Action.size != vector.size:
            raise Exception("Invalid vector : expected size {} but got size {}.".format(Action.size, vector.size))
        # TODO automatize of indexes
        i, j = np.nonzero(vector == 1)
        # print("i", i)
        if i.size == 0 :
            raise Exception("Invalid vector : no one found.")
        elif i.size > 1 :
            raise Exception("Invalid vector : {} one found.".format(i.size))
        else :
            action_index = i[0]

        if action_index == 0 :
            self.label = "wait"
        elif action_index == 1 :
            self.label = "shoot"
        elif action_index == 2 :
            self.label = "thrust"
        elif action_index < 3+Action.angles+1:
            self.label = "turn"
            # number of indexes already used in the vector
            shift = 3

            angle_degrees = (action_index - shift) * (360 // Action.angles)
            if angle_degrees > 180:
                angle_degrees -= 360
            self.turn.direction = math.radians(angle_degrees)
            # print("direction", self.turn.direction)
        else:
            raise Exception("Action index not implemented {}. (max index = {})".format(action_index, Action.size-1))




class Ship():
    id_max = 1

    def __init__(self, x, y, battleground, network=None):

        # id of the fighter
        self.id = Ship.id_max
        Ship.id_max += 1
        self.time = 0
        # the ship know is own position
        # the ship is not invulnerable. the ship can be hit
        self.body = Circle(x, y, 8)
        # the ship resilience
        self.hull = 1       
        # the ship knows in which battleground he is
        self.battleground = battleground
        # the ship know is speed and direction
        self.speed = 0
        self.max_speed = SHIPS_SPEED
        # between -pi and pi
        self.direction = 0
        # the ship know how much he is on a rampage
        self.score = 0
        self.reward = 0
        # current state of the ship
        self.state = "flying"
        # there is no cooldown on the laser
        self.can_shot = 1
        # the human player contolling it. None if no player controls it
        self.player = None
        # if the ia is in leroy mode time > 0
        self.leroy_time = 0
        # if the ship need to be actualised graphicaly
        self.actualise = False

        # is the ship on top of the map
        self.on_bottom = False

        # the neural network that can controls the ship actions
        self.network = network
        self.obs_vector = np.array([])
        self.act_vector = np.array([])

        # the ship is colorfull
        # self.color = '#AA3300'
        self.color = "Yellow"
        self.laser_color = "Red"


    def is_playable(self):
        return self.state not in ["destroyed", "wreckage"]


    def assign_player(self, player):
        self.player = player
        self.color = "Grey"
        self.laser_color = "Green" # light green

    def unassign_player(self):
        self.player = None
        self.color = "Yellow"
        self.laser_color = "Red"


    def hit(self, object):
        # TODO : damage by objects, armor, shield etc
        self.hull -= 1
        if self.hull <= 0:
            self.explode()

    def shoot(self):
        edge = self.body.edge(self.direction, 2)
        laser = Laser(
            edge[0], edge[1],
            LASER_SPEED, 
            self.direction, self.battleground,
            owner=self, color=self.laser_color
            )
        self.battleground.lasers.append(laser)


    def thrust(self):
        self.body.x += (self.max_speed * math.cos(self.direction))
        self.body.x = min(self.battleground.dim.x - 1, max(0, self.body.x))
        self.body.y += (self.max_speed * math.sin(self.direction))
        self.body.y = min(self.battleground.dim.y - 1, max(0, self.body.y))

    def explode(self):
        # print("baaouummm")
        # dying id bad, remember
        self.reward -= 10
        self.battleground.last_x_time_rewards.append((10, self.battleground.time))
        self.state = "destroyed"


    def possible_actions(self):
        """The ship has a very good computer.
        The ship know what are his possibilities."""
        # TODO filter impossible
        return Action.actions


    def random_play(self, observations):
        label = choice(self.possible_actions())
        action = Action()
        action.label = label
        if Action.actions_values[action.label] == "radians" :
            action.turn.direction = randint(0, Action.angles-1) / Action.angles * 2 * math.pi - math.pi
        return action

    def crazy_turret(self, observations):
        possibles = ["turn", "shoot"]
        weights = [0.5, 0.5]
        label = np.random.choice(possibles, p=weights)
        action = Action()
        action.label = label
        if Action.actions_values[action.label] == "radians" :
            action.turn.direction = randint(0, Action.angles-1) / Action.angles * 2 * math.pi - math.pi
        return action


    def crazy_runner(self, observations):
        possibles = ["turn", "thrust"]
        weights = [0.1, 0.9]
        label = np.random.choice(possibles, p=weights)
        action = Action()
        action.label = label
        if Action.actions_values[action.label] == "radians" :
            action.turn.direction = randint(0, Action.angles-1) / Action.angles * 2 * math.pi - math.pi
        return action


    def network_play(self, observations):
        self.obs_vector = observations.toVector()
        # print("obs_vector shape\n", obs_vector.shape)
        # print("obs_vector\n", obs_vector)
        
        # reward the networks with the last reward get
        self.network.update_network_with_reward(self.reward)
        # print(self.network.weights_layers[-1])
        
        # renforcement method
        # print("layers", self.network.layers)
        self.act_vector = self.network.take_action(self.obs_vector)
        # print(self.act_vector)
        
        # print("act_vector shape\n", self.act_vector.shape)
        # print("act_vector\n", self.act_vector)
        action = Action()
        action.fromVector(self.act_vector)
        # print("I want to ", action.label)
        # if action.label == "turn":
        #     print(action.turn.direction)
        return action


    def read_keys(self):
        action = Action()

        # if the player move the mouse we do not always have to consider it's a turn
        # if the player follows networks rules he turn only if the angle differ enough
        if "turn" in self.player.actions_set:
            if self.body.distance(self.player.cursor) < self.body.radius:
                self.player.actions_set.remove("turn")
            else:
                angle = self.body.angle_with(self.player.cursor)
                if PLAYER_FOLLOWS_NETWORK_RULES :
                    restricted_angle = Action.restrict_angle(angle)
                    if restricted_angle == self.direction:
                        # turning is futile
                        self.player.actions_set.remove("turn")
                    else:
                        # we aplly restriction
                        angle = restricted_angle

        
        if self.player.actions_set:
            action.label = choice(list(self.player.actions_set))
        else:
            action.label = "wait"
        
        if action.label == "turn":
            action.turn.direction = angle
            # print("turning", action.turn.direction)

        self.player.clear_keys()
        return action

    # TODO for player : shoot the nearest valid angle instead of the next one

    def go_bottom_reward(self):
        # print(self.body.y)
        # print(self.battleground.dim.y // 2)
        reach_top = not self.on_bottom and self.body.y == self.battleground.dim.y-1
        self.on_bottom = self.body.y == self.battleground.dim.y-1
        return int(reach_top)


    def get_action(self, observations):
        """The ship is smart.
        The ship know what to do to be the best."""

        if self.player:
            action = self.read_keys()
            self.obs_vector = observations.toVector()
            self.act_vector = action.toVector()
        else:
            if NETWORK and self.network:
                # in leroy mode ship feels reckless
                if self.leroy_time > 0:
                    # action = self.random_play(observations)
                    action = self.crazy_runner(observations)
                else:
                    action = self.network_play(observations)
            else:
                # action = self.random_play(observations)
                action = self.crazy_turret(observations)
                # action = self.crazy_runner(observations)

        # vector = action.toVector()
        # print("vector", vector)
        # action.fromVector(vector)
        # print("action", action.label)
        
        if self.reward != 0:
            print(self.reward)
        self.score += self.reward
        self.reward = 0

        return action


    def leroy_jenkins(self):
        self.leroy_time = randint(10, 100)
        self.color = "Red"
        self.laser_color = "Grey"
        print("Leeeeeeeeeeeroy Jeeeeeenkins !")
        self.actualise = True

    def back_to_normal(self):
        self.color = "Yellow"
        self.laser_color = "Red"
        self.actualise = True


    def move(self):
        """The ship can think. The ship can act. The ship is."""
        self.time += 1
        self.actualise = False

        if self.leroy_time == 1:
            self.back_to_normal()
        if self.leroy_time > 0:
            self.leroy_time -= 1

        # it is a complete information game so we can compute observations once per frame
        # we use observations of the battleground of this frame
        observations = self.battleground.observations.analyse_ship(self)


        # there is a chance that the ia enter in leroy mode
        # the ia goes mad for some time, acting randomly
        # added to allow the ships to explore the possible actions and not stay passive
        if not self.player and self.leroy_time == 0 and random() < LEROY_RATE:
            self.leroy_jenkins()

        action = self.get_action(observations)

        # training reward depending on position
        self.reward = self.go_bottom_reward()

        if action.label == "shoot":
            self.shoot()
        elif action.label == "thrust":
            self.thrust()
        elif action.label == "turn":
            self.direction = action.turn.direction
            # print("turn ", self.direction)



# coordinates conventions
#   | 0 | 1 | 2 | 3 | 4 -> x
# 0 |
# 1 |          pi/2
# 2 |           |
# 3 |   (-)pi _   _ 0
# 4 |           |
# y            -pi/2

class Laser():
    id_max = 1

    def __init__(self, x, y, speed, direction, battleground, owner=None, color="Red"):
        # print("piouu")
        self.id = Laser.id_max
        Laser.id_max += 1
        self.body = Circle(x, y, radius=2)  
        self.speed = speed
        self.direction = direction
        self.battleground = battleground
        self.color = color
        self.time = 0
        self.owner = owner
        self.state = "flying"

    def move(self):
        self.time += 1
        # MAYBE if we keep it int() it will be cheaper for computing
        self.body.x += self.speed * math.cos(self.direction)
        self.body.y += self.speed * math.sin(self.direction)
        # laser potential collision with ships
        explode = False
        for ship in self.battleground.ships:
            # colliding a ship
            if self.body.collide(ship.body):
                self.owner.reward += 1
                self.owner.battleground.last_x_time_rewards.append((1, self.battleground.time))
                ship.hit(self)
                explode = True
        if explode or self.battleground.outside(self.body.x, self.body.y):
            self.explode()

    def explode(self):
        # print("prrrr", self.id)
        self.state = "destroyed"

    def __str__(self):
        return str(self.id)
        # return functools.reduce(lambda x : x + self. , self.)

    def __repr__(self):
        return str(self.id)



class Player():
    max_id = 1
    keysets = {
        "mouse" : {
            "<Button-1>" : "shoot", 
            "a" : "thrust", 
            "<Motion>" : "turn", 
        }
    }

    def __init__(self, keyset, master, carte):
        # TODO automatized named tuple
        self.id = Player.max_id
        Player.max_id += 1
        self.keyset = Player.keysets[keyset]
        # for event, tAction in self.keyset.items():
            # carte.bind(event, tAction[1])
        carte.bind("<Button-1>", lambda e : self.press_shoot(e))
        carte.bind("<ButtonRelease-1>", lambda e : self.unpress_shoot(e))
        master.bind("a", lambda e : self.request_thrust(e))
        master.bind("<Motion>", lambda e : self.request_turn(e))
        self.shoot = False
        self.clear_keys()


    def press_shoot(self, event):
        self.actions_set.add("shoot")
        # print(self.actions_set)
        self.shoot = True
        # print("request shoot")

    def unpress_shoot(self, event):
        # print(self.actions_set)
        if "shoot" in self.actions_set:
            self.actions_set.remove("shoot")
        self.shoot = False
        # print("request cease fire")

    def request_thrust(self, event):
        self.actions_set.add("thrust")
        self.thrust = True
        # print("request thrust")

    def request_turn(self, event):
        self.actions_set.add("turn")
        self.turn = True
        self.cursor = Couple(event.x, event.y)
        # print("request turn", self.cursor)

    def clear_keys(self):
        self.actions_set = set()
        if self.shoot:
            self.actions_set.add("shoot")
        self.thrust = False
        self.turn = False
        self.cursor = None



class Battleground():
    # TODO parameter to pass nb of IA of feach type
    def __init__(self, largeur=400, hauteur=400, ship_number=2, networks=[]):
        """Create a battleground with ships
        networks must be a list of size ship_number containing neural networks
        if not set the networks will be generated randomly"""
        self.background = '#000000'
        self.ships_number = ship_number
        self.time = 0
        self.dim = Couple(largeur, hauteur)
        self.ships = []
        self.lasers = []
        self.observations = None
        self.networks = networks

        # list of all the rewards get and there obtention time
        # ex : (1, 120)
        self.last_x_time_rewards = []
        # expiration time of rewards in this list
        self.reward_list_len = 100

        # for the moment all the ships share the same network.
        # it's faster to train
        if self.networks == []:
            network = Renforcement_learning_neural_network(
                        [Observation.size, 56, Action.size], 
                        history_size=ANTICIPATION
                        )
            print("layers", network.layers)
            # represent an unique string that change with the network caracterisis
            # can be used to check if a network is compatible with a observation/solution vector
            print(network.network_caracterisis_string)
            network.network_caracterisis_string = "{}a_{}p_{}l".format(SHOOTING_ANGLES, MAP_PRECISION , "-".join(str(x) for x in network.layers))
            print(network.network_caracterisis_string)
            # we force a bit the shoot and thrust actions
            # because spinning ships aren't fun
            network.biases_layers[-1][1][0] = 12 * abs(network.biases_layers[-1][1][0]) + 0.2
            network.biases_layers[-1][2][0] = 12 * abs(network.biases_layers[-1][2][0]) + 0.2
            # print(network.biases_layers[-1])
            for i in range(len(network.biases_layers[-1])):
                network.biases_layers[-1][i][0] = network.biases_layers[-1][i][0] / 8
            self.networks = [network]

        # if not networks:
        #     networks = []
        #     for i in range(self.ships_number):
        #         network = Renforcement_learning_neural_network(
        #                     [Observation.size, 9, Action.size], 
        #                     history_size=ANTICIPATION)
        #         # we force a bit the shoot and thrust actions
        #         # because spinning ships aren't fun
        #         network.biases_layers[-1][1][0] = 3 * abs(network.biases_layers[-1][1][0])
        #         network.biases_layers[-1][2][0] = 3 * abs(network.biases_layers[-1][2][0])
        #         for i in range(len(network.biases_layers[-1])):
        #             network.biases_layers[-1][i][0] = network.biases_layers[-1][i][0] / 4
        #         networks.append(network)

        # copy by reference so at the end
        # when the ships will be deleted the networks will stay
        # self.networks = networks
        
        if len(self.networks) == 1:
            for x in range(self.ships_number):
                self.ships.append(Ship(randint(0, largeur), randint(0, hauteur), self, self.networks[0]))
        else:
            for x in range(self.ships_number):
                self.ships.append(Ship(randint(0, largeur), randint(0, hauteur), self, self.networks[x]))

        # print("there")
        # self.time_list = [self.time]
        # # list of average rewards
        # self.acc_rewards = [0]
        # # Creation of the graph of avg reward
        # fig, ax = plt.subplots()
        # # Size of the graph
        # plt.axis([0, 1, 0, 50])
        # # real time modification handeling
        # plt.ion()
        # plt.title("Summed rewards of last frames")
        # plt.xlabel("time")
        # plt.ylabel("avg rewards last {} frames".format(self.reward_list_len))
        # plt.xlim(32, 212)
        # plt.grid(True)
        # self.points, = ax.plot(self.time_list, self.acc_rewards, marker='o', linestyle='-')



    def set_ia(self, network):
        for ship in self.ships:
            ship.network = network


    def outside(self, x, y):
        return (x < 0) or (y < 0) or (x >= self.dim.x) or (y >= self.dim.y)


    # def plot_last_time_rewards(self):
    #     # temps.append(self.time)
    #     self.acc_rewards.append(sum(map(lambda x:x[0], self.last_x_time_rewards)))
    #     self.acc_rewards.append(self.time)

    #     # On recentre le graphique
    #     plt.axis([0, self.time_list[-1] - self.time_list[0] + 1, 0, max(self.acc_rewards)+1])
        
    #     # on place le nouveau point
    #     # plt.scatter(self.time, data)
    #     self.points.set_data(self.time_list, self.acc_rewards)

    #     # if the oldest recorded reward passed his expiration date we remove it
    #     if self.last_x_time_rewards and self.time - self.last_x_time_rewards[0][0] > self.reward_list_len :
    #         self.last_x_time_rewards.pop()


    def frame(self):
        # self.plot_last_time_rewards()
        self.time += 1
        self.observations = Observation(self)
        # TODO compute simultaneously so first ships don't have avantage on last ones
        for laser in self.lasers:
            laser.move()
        for ship in self.ships:
            ship.move()



    def run(self):
        while 1:
            self.frame()



class Ofighters(MapMenuStruct):
    fps_manager = fps_manager()

    def __init__(self, size=20, largeur=800, hauteur=800):
        # we call the basic graphic interface for the game
        super().__init__(largeur, hauteur)

        self.session_name = ""
        # if we record the player ingame
        self.recording = False
        # self.training_mode = False
        self.records = {"obs" : [], "sol" : []}
        # contains all the objects we will need to delete
        self.todelete = {}

        self.battleground = Battleground(self.dim.x, self.dim.y, SHIPS_NUMBER)
        self.switch_session("default")

        Images = namedtuple("Images", ['ships', 'lasers', ])
        self.images = Images([], [], )
        self.threads = {}

        # main window
        self.master = tk.Tk()

        # fill the main window
        self.create_main_window()

        # linking Ofighters functions to the main graphical structure
        self.link_functionnalities()

        # self.create_widgets()
        self.player_ship = True
        self.transfer_player_ship()

        # other windows
        self.threads["run"] = spawnthread(self.run)
        
        # run the main thread/loop
        self.master.mainloop()



    def restart(self):
        self.save_records(os.path.join(RECORD_FOLDER, "ofighter_record_" + now()))
        self.records = {"obs" : [], "sol" : []}
        networks = self.clear_battleground()
        # lobotomising the ships brains so they don't remember
        # the atrocities they have just seen
        for network in networks:
            network.clear_history()
        # regenerate a battleground
        self.battleground = Battleground(self.dim.x, self.dim.y, SHIPS_NUMBER, networks)
        # we put back the player ship on the field if required
        if self.ihm["string_transfer_player"].get() == "yes" and self.ihm["string_training_mode"].get() == "no":
            self.transfer_player_ship()
        self.temps = 0


    def clear_battleground(self):

        # remove image, reference to the laser and the corresponding object from the battleground
        while self.images.lasers != []:
            self.ihm["carte"].delete(self.images.lasers[0])
            del self.images.lasers[0]
        self.battleground.lasers = []
        
        # remove image, reference to the ship and the corresponding object from the battleground
        while self.images.ships != []:
            self.ihm["carte"].delete(self.images.ships[0])
            del self.images.ships[0]
        self.battleground.ships = []
        
        # the work is done
        for key, value in self.todelete.items():
            self.todelete[key] = []

        return self.battleground.networks


    def link_functionnalities(self):
        self.ihm["check_recording"].configure(command=self.swap_recording)
        self.ihm["train"].configure(command=self.analyse_records)
        self.ihm["save_ia"].configure(command=self.save_ia)
        self.ihm["check_transfer_player"].configure(command=self.transfer_player)
        self.ihm["switch_session"].configure(command=self.create_switch_session)
    

    def create_switch_session(self):
        Alert("New session", "Create", callback=lambda x : self.switch_session(x))


    def switch_session(self, name):
        path = os.path.join(NETWORKS_FOLDER, name)
        self.session_name = name
        os.makedirs(path, exist_ok=True)
        network = self.load_ia()
        # TODO when they will be more than one IA
        # if there is already an existing network we load it
        if network:
            self.battleground.set_ia(network)
            print("IA loaded")
        else:
            # else we create one
            if self.battleground.networks:
                self.save_ia()


    def transfer_player_ship(self):
        print("transfer")
        i = 0
        assigned = False
        while not assigned and i < len(self.battleground.ships):
            if self.battleground.ships[i].is_playable():
                player1 = Player("mouse", self.master, self.ihm["carte"])
                self.battleground.ships[i].assign_player(player1)
                if len(self.images.ships) > i:
                    self.ihm["carte"].itemconfig(self.images.ships[i], fill=self.battleground.ships[i].color)
                assigned = True
        if not assigned:
            print("Impossible to assign a ship to player. None of them is playable.")


    def untransfer_player_ship(self):
        print("untransfer")
        for i, ship in enumerate(self.battleground.ships):
            if ship.player:
                print("unassign")
                ship.unassign_player()
                if len(self.images.ships) > i:
                    self.ihm["carte"].itemconfig(self.images.ships[i], fill=ship.color)


    # TODO pause
    # TODO click on ship to save IA (Save this IA)
    def save_ia(self):
        """There must be an IA to save in the battleground"""
        file = os.path.join(NETWORKS_FOLDER, self.session_name, "ofighter_network_" + now())
        self.battleground.networks[0].save(file)
        print("saving as ", file)


    def load_ia(self):
        """Return None if there is no network available."""
        network = None
        path = os.path.join(NETWORKS_FOLDER, self.session_name)
        files = glob.glob(os.path.join(path, "ofighter_network_*"))
        if files:
            # sorting by modification date : the more recent the firsts
            files.sort(key=os.path.getmtime, reverse=True)
            # for file in files:
            #     print("file", os.path.getmtime(file))
            network = Renforcement_learning_neural_network.load(files[0])
            print("loading", files[0])
        return network


    def save_all_ias(self):
        # TODO
        pass
        # path = os.path.join(NETWORKS_FOLDER, self.session_name, "ofighter_network_" + now())
        # print("saving as ", path)
        # nn.save(path)


    def swap_recording(self):
        if self.ihm["string_recording"].get() == "yes":
            self.recording = True
        elif self.ihm["string_recording"].get() == "no":
            self.save_records(os.path.join(RECORD_FOLDER, "ofighter_record_" + now()))
            self.records = {"obs" : [], "sol" : []}
            self.recording = False


    def transfer_player(self):
        if self.ihm["string_transfer_player"].get() == "yes":
            if self.ihm["string_training_mode"].get() == "no":
                self.transfer_player_ship()
        elif self.ihm["string_transfer_player"].get() == "no":
            if self.ihm["string_training_mode"].get() == "no":
                self.untransfer_player_ship()

    # TODO max size on obs/act vectors un renforcement nn
    # TODO multiplayer ? :)

    def save_records(self, name):
        if self.records["obs"]:
            if not name.endswith(".orec"):
                name += ".orec"
            with open(name, "wb") as file:
                pickle.dump(self.records, file)


    def load_records(self, paths):
        records = {"obs" : [], "sol" : []}
        if not isinstance(paths, list):
            paths = [paths]
        for path in paths:
            if not path.endswith(".orec"):
                path += ".orec"
            with open(path , "rb") as file:
                content = pickle.load(file)
                if isinstance(content, dict):
                    # print(content)
                    records["obs"].extend(content["obs"])
                    records["sol"].extend(content["sol"])
                else:
                    raise Exception("Bad file content : dict expected")
        return records


    def swap_training_mode(self):
        # TODO replace all the things at the right place (use grid ?)
        if self.ihm["string_training_mode"].get() == "yes":
            if self.ihm["string_transfer_player"].get() == "yes":
                # the player do not need to play while the training mode is on
                self.untransfer_player_ship()
            self.expand_map()
        elif self.ihm["string_training_mode"].get() == "no":
            self.hide_map()
            if self.ihm["string_transfer_player"].get() == "yes":
                # replace the player on the map if it was before
                self.transfer_player_ship()


    def analyse_records(self):
        if "loading" in self.threads:
            print("Training already in progress !")
        else:
            # TODO be able to stop while in progress
            # TODO progress bar
            files = glob.glob(os.path.join(RECORD_FOLDER, "ofighter_record_*.orec"))
            # print("files len", len(files))
            # print("files", files)
            self.records = self.load_records(files)
            print("records len", len(self.records))
            print("analysing...")
            # make the loading bar appear
            self.ihm["progress"]["value"] = 0
            self.loading_i = 0
            self.ihm["progress"].grid(row=self.ihm["grid_raw_progress"], column=0)
            # the loading bar will load independantly
            self.read_loading()
            self.threads["loading"] = spawnthread(self.train_networks)


    def train_networks(self):
        # update the neural network for all ships IAs
        for i, ship in enumerate(self.battleground.ships):
            ship.network.update_network(self.records["obs"], self.records["sol"])
            # print("{} of {}".format(i+1, len(self.battleground.ships)))
            self.loading_i = i
            sleep(0.2)
            # print("job done at", self.loading_i, "/", len(self.battleground.ships))
        # make the loading bar disappear
        self.loading_i = len(self.battleground.ships)
        self.ihm["progress"].grid_forget()
        self.threads["loading"].stop()
        del self.threads["loading"]


    def read_loading(self):
        print(self.battleground.ships[self.loading_i].network.progression)
        if self.loading_i == len(self.battleground.ships):
            self.ihm["progress"]["value"] = 100
        else:
            self.ihm["progress"]["value"] = round(
                100 / len(self.battleground.ships) * 
                (self.loading_i + self.battleground.ships[self.loading_i].network.progression)
                )

        print(self.ihm["progress"]["value"])

        if self.loading_i < len(self.battleground.ships)-1:
            self.master.after(100, self.read_loading)



    # TODO not adapted to launch multiple instances of Ofighters
    # musn't use decorators here
    @fps(fps_manager)
    def frame(self):

        for i, ship in enumerate(self.battleground.ships):

            if ship.actualise:
                self.ihm["carte"].itemconfig(self.images.ships[i], fill=self.battleground.ships[i].color)

            if self.recording and ship.player and ship.obs_vector.size > 0 and ship.act_vector.size > 0:
                self.records["obs"].append(ship.obs_vector)
                self.records["sol"].append(ship.act_vector)
            # print("len records", len(self.records))

            if ship.state == "destroyed":
                # explosion animation
                # self.todelete.append(lambda : self.ihm["carte"].delete(self.images.ships[i]) )
                # self.todelete["images"].append(self.images.ships[i])
                # self.todelete["objects"].append(self.battleground.ships[i])
                # self.todelete["ships"].append(self.images.ships[i])
                self.todelete["ships"].append((self.images.ships[i], self.battleground.ships[i]))
                if ship.player:
                    self.ihm["check_transfer_player"].deselect()
            elif ship.time == 0:
                # we create images for new ships
                self.images.ships.append(
                    self.ihm["carte"].create_oval(
                        ship.body.x - ship.body.radius, ship.body.y - ship.body.radius, 
                        ship.body.x + ship.body.radius, ship.body.y + ship.body.radius, 
                        fill=ship.color, outline="Black", width="1"
                    )
                )
            else:
                # and move already existing image of existing ships
                self.ihm["carte"].coords(
                    self.images.ships[i], 
                    ship.body.x - ship.body.radius, ship.body.y - ship.body.radius, 
                    ship.body.x + ship.body.radius, ship.body.y + ship.body.radius
                )


        for i, laser in enumerate(self.battleground.lasers):
            if laser.state == "destroyed":
                # explosion animation
                # self.todelete.append(lambda : self.ihm["carte"].delete(self.images.lasers[i]) )
                # self.todelete["images"].append(self.images.lasers[i])
                # self.todelete["objects"].append(self.battleground.lasers[i])
                # self.todelete["lasers"].append(self.images.lasers[i])
                self.todelete["lasers"].append((self.images.lasers[i], self.battleground.lasers[i]))
            elif laser.time == 0:
                # we create images for new lasers
                self.images.lasers.append(
                    self.ihm["carte"].create_oval(
                        laser.body.x - laser.body.radius, laser.body.y - laser.body.radius, 
                        laser.body.x + laser.body.radius, laser.body.y + laser.body.radius, 
                        fill=laser.color
                    )
                )
            else:
                # and move already existing image of existing lasers
                self.ihm["carte"].coords(
                    self.images.lasers[i], 
                    laser.body.x - laser.body.radius, laser.body.y - laser.body.radius, 
                    laser.body.x + laser.body.radius, laser.body.y + laser.body.radius
                )
        # print("{} - {}".format(len(self.battleground.lasers), len(self.images.lasers)))


        # self.ihm["carte"].coords(self.images.mass_center[0], self.battleground.center_of_mass.x-ship.body.radius, self.battleground.center_of_mass.y-ship.body.radius, self.battleground.center_of_mass.x+ship.body.radius, self.battleground.center_of_mass.y+ship.body.radius)

        # self.ihm["carte"].coords(self.images.best_pos[0], self.battleground.best_pos.x-ship.body.radius, self.battleground.best_pos.y-ship.body.radius, self.battleground.best_pos.x+ship.body.radius, self.battleground.best_pos.y+ship.body.radius)

        self.clear_wreckage()

        if not self.training_mode:
            self.sleep_time = 1 / self.ihm["vitesse"].get()

        sleep(self.sleep_time)

        self.battleground.frame()



    def clear_wreckage(self):
        # remove image, reference to the laser and the corresponding object from the battleground
        for image, obj in self.todelete["lasers"]:
            self.ihm["carte"].delete(image)
            self.images.lasers.remove(image)
            self.battleground.lasers.remove(obj)

        # remove image, reference to the ship and the corresponding object from the battleground
        for image, obj in self.todelete["ships"]:
            self.ihm["carte"].delete(image)
            self.images.ships.remove(image)
            self.battleground.ships.remove(obj)
        
        # the work is done
        for key, value in self.todelete.items():
            self.todelete[key] = []
        

    def quit(self):
        self.quitter = True
        self.save_ia()


    def run(self):
        # images contains tkinter graphical objects
        # lasers contains (index, object) indexes of ships images 
        # and corresponding object in the battleground
        # lasers contains indexes of ships images
        self.todelete = {"images" : [], "lasers" : [], "ships" : []}

        while not self.quitter:

            self.temps += 1
            self.ihm["temps"]["text"] = "Temps : "+str(self.temps)

            self.frame()

            if self.fps_manager.active:
                self.ihm["fps"]["text"] = "FPS " + str(self.fps_manager.fps)

            if self.continuous_training and self.temps > MAX_TIME:
                self.restart()

        self.master.destroy()



# TODO benchmark desactivate print
# def blockPrint():
#     sys.stdout = open(os.devnull, 'w')
# def enablePrint():
#     sys.stdout = sys.__stdout__



# Battleground().run()

Ofighters()




size observations 49
size actions 7
layers [49, 56, 7]
49-56-7
4a_4p_49-56-7l
loading ofighter_networks/default/ofighter_network_19-05-20-00-05-09.dat
IA loaded
transfer
resuming...
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
1
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
untransfer
unassign
Leeeeeeeeeeeroy Jeeeeeenkins !
1
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
1
-10
-10
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
1
Leeeeeeeeeeeroy Jeeeeeenkins !
1
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
1
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
1
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
Leeeeeeeeeeeroy Jeeeeeenkins !
L

<__main__.Ofighters at 0x7f0eed6846d8>

Exception in thread Thread-4:
Traceback (most recent call last):
  File "/home/mondher/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/home/mondher/Documents/Thibault/GPU Tests/python/machine_learning/thread_manager.py", line 25, in run
    self.qe.put(self.fcn())
  File "<ipython-input-1-2b07133dc23d>", line 1197, in run
    self.frame()
  File "/home/mondher/Documents/Thibault/GPU Tests/python/machine_learning/fps_manager.py", line 37, in with_fps
    result = method(*args, **keyargs)
  File "<ipython-input-1-2b07133dc23d>", line 1116, in frame
    ship.body.x + ship.body.radius, ship.body.y + ship.body.radius
  File "/home/mondher/anaconda3/lib/python3.6/tkinter/__init__.py", line 2469, in coords
    self.tk.call((self._w, 'coords') + args))]
_tkinter.TclError: invalid command name ".!canvas"

