In [133]:
import pygad
import pygad.torchga as torchga
import torch
import torch.nn as nn

from pathlib import Path
from time import sleep
import numpy as np
import math
import random
import arcade
from copy import deepcopy

In [134]:
import os
os.environ["ARCADE_HEADLESS"] = "True"

CWD = "C:\\Users\\Alex\\Desktop\\New Folder\\__UNI\\Project\\main\\main"

In [135]:
SCREEN_WIDTH = 1200
SCREEN_HEIGHT = 800

BOID_MOVEMENT_SPEED = 5
BOID_ROTATION_SPEED = 5

NUMBER_OF_BOIDS = 10

MODEL = nn.Sequential(
    nn.Linear(in_features=3, out_features=6),
    nn.Sigmoid(),
    nn.Linear(in_features=6, out_features=9),
    nn.Sigmoid(),
    nn.Linear(in_features=9, out_features=4),
    nn.Sigmoid(),
    nn.Linear(in_features=4, out_features=2),
    nn.ReLU(),
    nn.Softmax()
)

In [136]:
class Boid(arcade.Sprite):
    def __init__(self, weights):
        super(Boid, self).__init__(filename=Path(CWD) / "boid.png", scale=0.5)

        self.set_position(
            center_x=random.randint(0, SCREEN_WIDTH),
            center_y=random.randint(0, SCREEN_WIDTH),
        )
        self.speed = BOID_MOVEMENT_SPEED

        self.rotation_speed = BOID_ROTATION_SPEED
        
        self.model = deepcopy(MODEL)
        self.model.load_state_dict(torchga.model_weights_as_dict(model=self.model, weights_vector=weights))



    def update(self):
        rotation_direction = torch.argmax(self.model([self.center_x, self.center_y, self.radians]))

        self.angle += self.rotation_speed if rotation_direction else (-self.rotation_speed)

        self.set_position(
            center_x=self.center_x + (-self.speed * math.sin(self.radians)),
            center_y=self.center_y + (self.speed * math.cos(self.radians)),
        )

        self.wrap_around()
        

    # def update(self):
    #     if random.random() < 0.01:
    #         self.rotation_speed = -self.rotation_speed

    #     self.angle += self.rotation_speed

    #     self.set_position(
    #         center_x=self.center_x + (-self.speed * math.sin(self.radians)),
    #         center_y=self.center_y + (self.speed * math.cos(self.radians)),
    #     )


    def wrap_around(self):
        if self.left < 0:
            self.left += SCREEN_WIDTH
        elif self.right > SCREEN_WIDTH - 1:
            self.right -= SCREEN_WIDTH

        if self.bottom < 0:
            self.bottom += SCREEN_HEIGHT
        elif self.top > SCREEN_HEIGHT - 1:
            self.top -= SCREEN_HEIGHT

In [137]:
class Simulation(arcade.Window):
    def __init__(self, weights):
        super().__init__(width=SCREEN_WIDTH, height=SCREEN_HEIGHT, update_rate=1 / 60)  # type: ignore
        self.weights = weights
        self.boid_list = arcade.SpriteList()

    def setup(self):
        self.boid_list[:] = [Boid(self.weights) for _ in range(NUMBER_OF_BOIDS)]

        self.total_time = 0.0

        arcade.set_background_color(arcade.color.SKY_BLUE)

    def on_draw(self):
        self.clear()
        self.boid_list.draw()

    def on_update(self, delta_time):
        for boid in self.boid_list:
            for collision in self.get_collisions(boid):
                self.boid_list.remove(collision)
        self.boid_list.update()
        self.total_time += delta_time

        if len(self.boid_list) < 10 or self.total_time > 10:
            self.close()

    def get_collisions(self, boid):
        if collisions := arcade.check_for_collision_with_list(boid, self.boid_list):
            return [boid] + collisions
        return []

    def calculate_fitness(self):
        return self.total_time

In [138]:
class BoidGame:
    def __init__(self, weights):
        self.weights = weights
    def run(self):
        window = Simulation(self.weights)
        window.run()
        sleep(10)
        
        return window.calculate_fitness()

In [139]:
def fitness_func(solution, sol_idx):
    print(solution)
    return BoidGame(solution).run()

In [140]:
def callback_generation(ga_instance):
    print(f"Generation = {ga_instance.generations_completed}")
    print(f"Fitness    = {ga_instance.best_solution()[1]}")

In [141]:
NGEN = 100
NPARENTS = 50

torch_ga = torchga.TorchGA(model=deepcopy(MODEL), num_solutions=100)
initial_population = torch_ga.population_weights


In [142]:
ga_instance = pygad.GA(num_generations=NGEN,
                       num_parents_mating=NPARENTS,
                       initial_population=initial_population,
                       fitness_func=fitness_func,
                       on_generation=callback_generation)

print(initial_population[1])

ga_instance.run()

[-1.18077162 -0.27552464  0.58933137 -1.29739602 -0.16570714  0.87769209
 -0.47986238  0.56149627  0.27195092  0.41815086  0.01618799  0.44941792
  0.05063093 -0.40792489  0.53601505  0.1786442  -0.4986289  -0.73306831
 -0.35927727 -0.04461694  0.44414423  0.0794468  -0.4327605  -0.63059434
  0.29939741 -0.29593855 -0.20519694  0.38038745 -0.18374451 -0.39275489
 -0.47845601 -0.22686973  0.57833066  0.31253339  0.13225776  0.69082967
  0.57517841 -1.26343755 -0.21227191  0.21697605  0.76559965  0.9947129
  0.60992533  0.34082399  0.46101738 -1.03097429 -0.90997623  0.38940748
 -0.74127425 -0.6282234  -0.34324085  0.57041469 -0.42597918 -0.14221925
  0.72815873  0.69775276 -0.20424507  0.51816843 -0.62578945  0.97239633
 -0.65396344  0.20979844 -0.22186823  0.50652736  1.08589547 -0.93960932
  0.65540725  0.5045174  -0.74113634 -0.81182563 -0.66241748 -0.18062812
 -0.85836198 -0.08781821 -0.45421876 -0.02364241  0.92611214  0.16595166
  0.86416428 -0.46235711  0.97673435 -0.37677803  1.

TypeError: forward() takes 2 positional arguments but 4 were given

: 

In [None]:
ga_instance.population[1]


array([ 0.81783004,  1.83929294,  0.2026002 , -1.24327719, -1.26544833,
        2.80112758, -1.069781  ,  1.9883045 , -2.12510988,  5.01447091,
        1.01614581,  0.57826494, -1.60353428, -1.82088247,  4.55565227,
       -1.22029371, -2.97228826, -0.35079501,  0.6283697 , -3.01034781,
        4.67833827,  0.28363158, -2.35245686, -2.11233219, -1.4163822 ,
       -2.42032166, -0.06623994, -0.44653508,  3.52772199, -1.13495002,
       -0.09817154, -2.66704958, -4.23829941,  4.13196019,  0.27940939,
       -0.7591249 ,  1.25332424,  2.08284837,  1.4885268 , -4.65207987,
        0.45977384, -1.66483089,  0.46726655, -0.26320219,  0.7349008 ,
       -1.00544548, -1.19667168,  2.12794286, -1.1441389 , -2.78302373,
        1.6059592 , -0.55882884,  1.25896305, -1.73177682,  0.7863913 ,
        0.74618165,  1.02214433, -3.13653188,  1.41380567, -0.36514079,
        1.1486488 , -1.67938685, -1.51155437, -3.1733401 , -1.42991739,
       -1.79337443, -1.94973683,  2.76882259,  0.78285875,  0.26