In [2]:
from dataclasses import dataclass, field
from vi import Agent, Config, Simulation, Vector2, HeadlessSimulation
import pygame as pg
import os
import numpy as np
import random
from multiprocessing import Pool
import polars as pl
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import math

pygame 2.6.1 (SDL 2.28.4, Python 3.12.3)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
frame_dir = "frames"
os.makedirs(frame_dir, exist_ok=True)
frame_count = 0

def save_frame(screen):
    global frame_count
    pg.image.save(screen, os.path.join(frame_dir, f"frame_{frame_count:05d}.png"))
    frame_count += 1

In [4]:
class RecordingSimulation(Simulation):
    def __init__(self, config):
        super().__init__(config)
        self.frame_count = 0
        os.makedirs("frames", exist_ok=True)

    def after_update(self) -> None:
        # Draw everything to the screen
        self._all.draw(self._screen)

        if self.config.visualise_chunks:
            self.__visualise_chunks()

        # Save current frame as an image
        # pg.image.save(self._screen, f"frames/frame_{self.frame_count:05d}.png")

        # Update the screen with the new image
        pg.display.flip()

        self._clock.tick(self.config.fps_limit)

        current_fps = self._clock.get_fps()
        if current_fps > 0:
            self._metrics.fps._push(current_fps)

            if self.config.print_fps:
                print(f"FPS: {current_fps:.1f}")  # noqa: T201

        # Increment a frame counter (you may need to initialize it somewhere)
        self.frame_count += 1


In [5]:
@dataclass
class PPConfig(Config):
    #desired initial populations
    initial_prey: int = 100
    initial_predators: int = 50

    #parameters of the Lotka-Volterra equations
    alpha = 0.0055    # slower prey reproduction
    beta = 0.015     # more effective predators
    delta = 0.8     # predator reproduction chance
    gamma = 0.005    # predator death chance

config = PPConfig(image_rotation = True, movement_speed = 3.0, radius = 25, duration = 60*60*2, fps_limit = 0)
    

In [6]:
class Prey(Agent):
    def on_spawn(self):
        self.flee_strength = 1.5
        self.join_strength = 1.0
        self.state = "Wander"
        self.wander_direction_cooldown = 50
        return super().on_spawn()
    
    def change_position(self):
        self.there_is_no_escape()

        if random.random() < self.config.alpha:
            self.reproduce()

        flee_force = Vector2(0, 0)

        neighbors = list(self.in_proximity_accuracy())
        if any(isinstance(agent, Predator) for agent, _ in neighbors):
            self.state = "Flee"
        else:
            self.state = "Wander"


        if self.state == "Flee":
            self.wander_direction_cooldown = 0
            predators_nerby = [(agent, dist) for agent, dist in neighbors if isinstance(agent, Predator)]
            for predator, dist in predators_nerby:
                if dist == 0:
                    dist = 0.001
                direction = self.pos - predator.pos
                if dist > 0:
                    direction = direction / dist
                force_magnitude = self.flee_strength / dist
                flee_force += direction * force_magnitude
            
            self.move = flee_force
            if self.move.length() > 0:
                self.move = self.move.normalize() * self.config.movement_speed


        elif self.state == "Wander":
            if self.wander_direction_cooldown == 0:
                self.wander_direction_cooldown = 50
                wander_angle = random.uniform(0, 2 * math.pi)
                wander_force = Vector2(1, 0).rotate_rad(wander_angle)
                self.move = wander_force
                if self.move.length() > 0:
                    self.move = self.move.normalize() * self.config.movement_speed
            else:
                self.wander_direction_cooldown -= 1
            
        self.pos += self.move

In [7]:
class Predator(Agent):
    def on_spawn(self):
        self.random_move_strength = 1.0
        self.predator_speed_boost = 1.2
        self.state = "Wander"
        self.wander_direction_cooldown = 0
        return super().on_spawn()
    
    def change_position(self):
        self.there_is_no_escape()
        if random.random() < self.config.gamma:
            self.kill()
            return
        
        attract_force = Vector2(0, 0)
        avoid_force = Vector2(0, 0)

        neighbors = list(self.in_proximity_accuracy())

        if any(isinstance(agent, Prey) for agent, _ in neighbors):
            self.state = "Hunt"
        else:
            self.state = "Wander"

        if self.state == "Hunt":
            self.wander_direction_cooldown = 0
            prey_nearby = [(agent, dist) for agent, dist in neighbors if isinstance(agent, Prey)]
            the_closest_prey, closest_dist = min(prey_nearby, key=lambda x: x[1])
            if closest_dist == 0:
                    closest_dist = 0.001
            direction = the_closest_prey.pos - self.pos
                
            if closest_dist > 0:
                direction = direction / closest_dist
            # if closest_dist < 5:
            #     the_closest_prey.kill()
            #     if random.random() < self.config.delta:
            #         self.reproduce()

            force_magnitude = 1.0 / closest_dist
            attract_force += direction * force_magnitude
            if random.random() < self.config.beta and the_closest_prey.is_alive():
                the_closest_prey.kill()
                if random.random() < self.config.delta:
                    self.reproduce()
            
            self.move = attract_force

            if self.move.length() > 0:
                self.move = self.move.normalize() * self.config.movement_speed * self.predator_speed_boost
            
                

        elif self.state == "Wander":
            if self.wander_direction_cooldown == 0:
                self.wander_direction_cooldown = 50
                wander_angle = random.uniform(0, 2 * math.pi)
                wander_force = Vector2(1, 0).rotate_rad(wander_angle)
                self.move = wander_force
                if self.move.length() > 0:
                    self.move = self.move.normalize() * self.config.movement_speed * self.predator_speed_boost
            else:
                self.wander_direction_cooldown -= 1
        
        self.pos += self.move

In [None]:
(
    RecordingSimulation(
        config
    )
    .batch_spawn_agents(config.initial_predators, Predator, images=["images/Predator.png"])
    .batch_spawn_agents(config.initial_prey, Prey, images=["images/Prey.png"])
    .run()
)

: 

In [8]:
class HeadlessSimulationRefined(HeadlessSimulation):
   def __init__(self, config):
      super().__init__(config)
      self.frame_count = 0
      self.prey_predator__per_frame = []

   def after_update(self) -> None:
      # Increment a frame counter (you may need to initialize it somewhere)
      
      if self.frame_count % 10 == 0:
         num_prey = len([agent for agent in self._all if isinstance(agent, Prey)])
         num_predator = len([agent for agent in self._all if isinstance(agent, Predator)])
         self.prey_predator__per_frame.append((num_prey, num_predator))
         if self.frame_count % 30 == 0:
             print(num_prey, num_predator)
      self.frame_count += 1

   def run(self):
        """Run the simulation until it's ended by closing the window or when the `vi.config.Schema.duration` has elapsed."""
        self._running = True

        while self._running:
            self.tick()

        return self.prey_predator__per_frame

In [9]:
def run_simulation(config: Config) -> pl.DataFrame:
    return (
        HeadlessSimulationRefined(config)
        .batch_spawn_agents(config.initial_predators, Predator, images=["images/Predator.png"])
        .batch_spawn_agents(config.initial_prey, Prey, images=["images/Prey.png"])
        .run()
    )

In [15]:
agents_per_frame = run_simulation(config)
df = pd.DataFrame(agents_per_frame, columns=['Prey', 'Predator'])
df['Time'] = range(len(df)) 

df_melted = df.melt(id_vars='Time', value_vars=['Prey', 'Predator'],
                    var_name='Species', value_name='Population')

sns.set(style="whitegrid")
plt.figure(figsize=(10, 6))
sns.lineplot(data=df_melted, x='Time', y='Population', hue='Species', marker="o")
plt.title("Predator-Prey Populations Over Time")
plt.xlabel("Time Step")
plt.ylabel("Population")
plt.legend(title="Species")
plt.tight_layout()
plt.show()

101 50
105 54
108 54
112 57
111 66
127 67
130 66
134 73
141 77
144 81
147 79
155 83
149 96
157 97
148 102
153 100
139 112
127 118
130 122
122 120
129 117
127 114
122 118
109 124
107 121
92 133
71 134
50 141
48 129
40 125
33 120
26 112
22 104
16 102
21 83
18 77
20 72
16 66
17 55
12 48
9 49
9 49
9 45
5 42
2 35
1 32
1 24
2 20
2 17
2 15
2 9
3 9
2 8
3 8
3 8
3 6
3 5
5 5
8 4
12 3
13 3
15 3
16 3
16 3
17 3
18 2
22 2
24 2
25 3
27 3
29 2
37 1
44 1
50 1
60 1
75 2
88 2
105 2
125 2
150 3
179 3
208 2
243 1
294 1
342 1
414 1
481 1
579 2
681 4
805 7
927 12
1097 14
1302 13
1486 18
1744 23
2034 25
2418 25
2846 30
3341 34
3907 43
4539 52


KeyboardInterrupt: 

In [None]:
agents_per_frame = run_simulation(config)
print(len(agents_per_frame))
df = pd.DataFrame(agents_per_frame, columns=['Prey', 'Predator'])
df['Time'] = range(len(df))  # Add time steps

# Melt the DataFrame for seaborn
df_melted = df.melt(id_vars='Time', value_vars=['Prey', 'Predator'],
                    var_name='Species', value_name='Population')

# Plot
sns.set(style="whitegrid")
plt.figure(figsize=(10, 6))
sns.lineplot(data=df_melted, x='Time', y='Population', hue='Species', marker="o")
plt.title("Predator-Prey Populations Over Time")
plt.xlabel("Time Step")
plt.ylabel("Population")
plt.legend(title="Species")
plt.tight_layout()
plt.show()

100 30
95 31
87 31
82 32
78 30
73 25
74 27
70 32
73 26
71 29
72 30
69 24
69 27
77 22
81 26
68 30
66 26
66 24
67 26
72 26
77 20
82 18
90 16
86 18
85 23
86 25
84 23
85 25
89 23
79 26
75 29
70 24
66 20
69 18
74 16
88 10
96 14
108 17
114 21
114 18
126 19
125 21
124 22
123 20
121 25
130 18
143 17
144 17
165 15
168 18
187 20
183 25
182 29
174 35
173 40
163 46
158 45
140 58
129 64
101 68
73 71
54 71
40 68
29 59
23 54
18 43
16 35
15 27
11 22
8 18
8 14
9 13
11 11
12 9
13 8
10 6
10 6
14 5
19 3
24 2
24 2
26 2
32 1
35 1
44 0
49 0
56 0
60 0
68 0
79 0
92 0
103 0
120 0
131 0
155 0
173 0
200 0
221 0
256 0
287 0
324 0
360 0
415 0
478 0
546 0
636 0
730 0
851 0
995 0
1146 0
1327 0
1491 0


KeyboardInterrupt: 

: 