In [None]:
from ipycanvas import Canvas, hold_canvas
from IPython.display import display
import numpy as np
import math
from scipy.ndimage import gaussian_filter
import random



def softmax(values):
    # Improved numerical stability
    max_val = max(values)
    exps = [math.exp(v - max_val) for v in values]
    s = sum(exps)
    return [e / s for e in exps]

class Ant:
    def __init__(self, pos=np.array([0, 0])):
        self.pos = pos.astype(np.float32)
        self.direction = np.random.rand() * 2 * np.pi
        self.carrying_food = False
        self.last_choice = 0
        self.last_feeding = 0
        self.is_alive = True

class Simulation:
    def __init__(self, shape, n_ants):
        self.shape = shape
        self.n_ants = n_ants
        self.ants = [Ant() for _ in range(n_ants)]

        self.food_map = np.zeros((shape[0], shape[1]), dtype=np.uint8)
        self.nest_map = np.zeros((shape[0], shape[1]), dtype=np.uint8)

        # Layer 0: Pheromone leading to FOOD (dropped by ants carrying food?)
        # Let's standardize:
        # Layer 0: "To Home" trail (Dropped by Food->Home ants)        
        # Layer 1: "To Food" trail (Dropped by Home->Food ants)
        self.pheromone = np.zeros((shape[0], shape[1], 2), dtype=np.float32)
        
        self.canvas = Canvas(width=shape[1], height=shape[0])
        display(self.canvas)
        
        self.step_count = 0
        self.food_in_home = 0
        self.killed = 0

    def ready(self):
        self.nest_positions = np.argwhere(self.nest_map > 0)
        for ant in self.ants:
            idx = np.random.randint(len(self.nest_positions))
            ant.pos = self.nest_positions[idx].astype(np.float32)

    def remove_dead_ants(self):
        self.ants = [ant for ant in self.ants if ant.is_alive]
        diff = self.n_ants - len(self.ants)
        self.n_ants = len(self.ants)
        self.killed += diff

    def add_new_ants(self):
        # Batch add ants to save performance
        if self.killed > 0 and self.step_count % 10 == 0:
            count = min(self.killed, 5) # Add max 5 at a time
            for _ in range(count):
                new_ant = Ant()
                idx = np.random.randint(len(self.nest_positions))
                new_ant.pos = self.nest_positions[idx].astype(np.float32)
                new_ant.last_feeding = self.step_count
                self.ants.append(new_ant)
            self.killed -= count
            self.n_ants += len(self.ants)

    def step(self):

        # make the home base emit pheromone
        for pos in self.nest_positions:
            x, y = pos
            self.pheromone[x, y, 0] += 0.5  # "To Home" pheromone
        
        # make the food source emit pheromone
        food_positions = np.argwhere(self.food_map > 0)
        self.pheromone[food_positions[:,0], food_positions[:,1], 1] += 0.5  # "To Food" pheromone


        # 1. DIFFUSE (Blur) the pheromone map directly
        # This prevents 1-pixel wide trails that cause loops
        decay = 0.99
        for i in range(2):
            self.pheromone[:, :, i] = gaussian_filter(self.pheromone[:, :, i], sigma=0.6)
        self.pheromone *= decay
        # Cap pheromone
        self.pheromone = np.clip(self.pheromone, 0, 50)
            
        # Move ants
        for ant in self.ants:
            self.update_ant(ant)
            
        self.remove_dead_ants()
        self.add_new_ants()
        self.step_count += 1

    def inside(self, pos):
        return 0 <= pos[0] < self.shape[0] and 0 <= pos[1] < self.shape[1]
    
    def update_ant(self, ant):
        # --- TUNING ---
        SENSE_DISTANCE = 10     # Look ahead
        SENSE_ANGLE = 0.5       # Radians (approx 30 deg)
        TURN_ANGLE = 0.4        # Radians
        STEP_SIZE = 1.5
        RANDOM_WIGGLE = 0.15    # Noise to prevent perfect cycles

        MAX_STEPS_WITHOUT_FEEDING = 800
        
        # # Death logic
        # if self.step_count - ant.last_feeding > MAX_STEPS_WITHOUT_FEEDING:
        #     if random.random() < 0.02:
        #         ant.is_alive = False
        #         return

        last_pos = ant.pos.copy()

        # --- FOOD / NEST LOGIC ---
        rx, ry = int(ant.pos[0]), int(ant.pos[1])
        
        rx = rx % self.shape[0]
        ry = ry % self.shape[1]
        
        if ant.carrying_food:
            # Drop "Food" trail (Layer 1) so others can find food
            self.pheromone[rx, ry, 1] += 2.0 
            
            if self.nest_map[rx, ry] > 0:
                ant.carrying_food = False
                ant.direction += np.pi 
                self.food_in_home += 1
        else:
            # Drop "Home" trail (Layer 0) so others can find home
            self.pheromone[rx, ry, 0] += 2.0
            
            if self.food_map[rx, ry] > 0:
                ant.carrying_food = True
                ant.direction += np.pi
                self.food_map[rx, ry] = max(0, self.food_map[rx, ry] - 1)
                ant.last_feeding = self.step_count

        # --- SENSING ---
        def get_sensor(angle_offset):
            angle = ant.direction + angle_offset
            offset = np.array([np.cos(angle), np.sin(angle)]) * SENSE_DISTANCE
            pos = ant.pos + offset
            
    
            
            x, y = int(pos[0]), int(pos[1])
            x = x % self.shape[0]
            y = y % self.shape[1]

            target_layer = 0 if ant.carrying_food else 1
            return self.pheromone[x, y, target_layer]

        front = get_sensor(0)
        left = get_sensor(SENSE_ANGLE)
        right = get_sensor(-SENSE_ANGLE)

        # --- DECISION ---
        # If signals are very weak, behave more randomly
        if front + left + right < 0.1:
            # Mainly momentum + random
            probs = [0.8, 0.1, 0.1] 
        else:
            # Softmax based on sensors
            values = [front, left, right]
            # Amplify differences
            values = [v * 4.0 for v in values]
            probs_from_data = softmax(values)
            
            # Combine with momentum
            prior_props = [0.6, 0.2, 0.2] # Bias forward
            probs = [d * p for d, p in zip(probs_from_data, prior_props)]
            s = sum(probs)
            probs = [p/s for p in probs]

        choice = np.random.choice([0, 1, 2], p=probs)

        if choice == 1:
            ant.direction += TURN_ANGLE
        elif choice == 2:
            ant.direction -= TURN_ANGLE
            
        # Add wiggle
        ant.direction += np.random.uniform(-RANDOM_WIGGLE, RANDOM_WIGGLE)

        # --- MOVE ---
        offset = np.array([np.cos(ant.direction), np.sin(ant.direction)]) * STEP_SIZE
        new_pos = ant.pos + offset
        
        ant.pos[0] = new_pos[0] % self.shape[0]
        ant.pos[1] = new_pos[1] % self.shape[1]




    def draw(self):
        with hold_canvas(self.canvas):
            # Create image surface
            # Red channel = Food Pheromone
            # Blue channel = Home Pheromone
            img = np.zeros((self.shape[0], self.shape[1], 3), dtype=np.uint8)
            
            # Scale for visibility
            p0 = np.clip(self.pheromone[:, :, 0] * 50, 0, 255)
            p1 = np.clip(self.pheromone[:, :, 1] * 50, 0, 255)
            
            img[:, :, 1] = p0 # Green for food path
            img[:, :, 2] = p1 # Blue for home path
            
            img[self.food_map > 0] = [255, 255, 0] # Yellow food
            img[self.nest_map > 0] = [255, 255, 255] # White nest

            self.canvas.put_image_data(img, 0, 0)

            # draw ants
            for ant in self.ants:
                x, y = int(ant.pos[0]), int(ant.pos[1])
                color = (255, 0, 0) if ant.carrying_food else (244, 55, 0)
                self.canvas.fill_style = f'rgb{color}'
                self.canvas.fill_rect(x, y, 2, 2)
            

            self.canvas.fill_style = 'white'
            self.canvas.font = '16px sans-serif'
            self.canvas.fill_text(f'Food Home: {self.food_in_home}', 10, 20)


# --- SETUP ---
n_ants = 1000
sim = Simulation((600, 600), n_ants=n_ants)

# Circular blobs for nest and food
cx, cy = 300, 300
y, x = np.ogrid[-40:40, -40:40]
mask = x**2 + y**2 <= 40**2
sim.food_map[cx-40:cx+40, cy-40:cy+40][mask] = 1

cx, cy = 100, 100
sim.nest_map[cx-40:cx+40, cy-40:cy+40][mask] = 1

sim.ready()

while True:
    sim.step()
    sim.draw()


Canvas(height=600, width=600)