In [None]:
import pygame
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import random
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
import matplotlib.animation as animation
from IPython.display import HTML, display
import ipywidgets as widgets

# 1. Neural Network
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# 2. Enhanced Agent
class AnimalSafeAgent:
    def __init__(self, action_dim, state_dim):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.policy_net = DQN(state_dim, action_dim).to(self.device)
        self.target_net = DQN(state_dim, action_dim).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=1e-4)  # Lower learning rate
        self.memory = deque(maxlen=10000)
        self.batch_size = 128  # Increased batch size
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.998  # Slower epsilon decay
        self.update_target_freq = 100
        self.steps = 0
        self.episode_rewards = []
        self.episode_accuracies = []

    def act(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)

        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            q_values = self.policy_net(state)
        return q_values.argmax().item()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def update(self):
        if len(self.memory) < self.batch_size:
            return

        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(np.array(states)).to(self.device)
        actions = torch.LongTensor(np.array(actions)).to(self.device)
        rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
        next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
        dones = torch.FloatTensor(np.array(dones)).to(self.device)

        current_q = self.policy_net(states).gather(1, actions.unsqueeze(1))

        with torch.no_grad():
            next_q = self.target_net(next_states).max(1)[0]

        target_q = rewards + (1 - dones) * self.gamma * next_q

        loss = F.mse_loss(current_q.squeeze(), target_q)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

        self.steps += 1
        if self.steps % self.update_target_freq == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())

    def log_episode(self, reward, accuracy):
        self.episode_rewards.append(reward)
        self.episode_accuracies.append(accuracy)

# 3. Enhanced Environment
class EnhancedDrivingEnv:
    def __init__(self):
        pygame.init()
        self.width, self.height = 800, 500
        self.screen = pygame.Surface((self.width, self.height))
        self.road_width = 300
        self.car_width, self.car_height = 40, 70
        self.animal_size = 35
        self.episode = 0
        self.max_episodes = 5
        self.episode_length = 350
        self.current_step = 0
        self.reset()

        # Actions: [hard left, left, straight, right, hard right]
        self.action_space = [
            (-0.8, 0.5),   # Hard left
            (-0.4, 0.5),    # Left
            (0.0, 0.8),     # Straight (normal speed)
            (0.4, 0.5),     # Right
            (0.8, 0.5),     # Hard right
        ]
        self.action_dim = len(self.action_space)

        # Pre-render car with black wheels
        self.car_img = pygame.Surface((self.car_width, self.car_height), pygame.SRCALPHA)
        pygame.draw.rect(self.car_img, (0, 100, 255), (5, 0, self.car_width-10, self.car_height-20))
        pygame.draw.ellipse(self.car_img, (0, 0, 0), (0, self.car_height-25, 15, 20))
        pygame.draw.ellipse(self.car_img, (0, 0, 0), (self.car_width-15, self.car_height-25, 15, 20))
        pygame.draw.rect(self.car_img, (200, 200, 250, 150), (10, 10, self.car_width-20, 20))

    def _create_animal(self, color):
        surf = pygame.Surface((self.animal_size, self.animal_size), pygame.SRCALPHA)
        body_color = (min(255, color[0]), min(255, color[1]), min(255, color[2]))
        pygame.draw.ellipse(surf, body_color, (5, 5, self.animal_size-10, self.animal_size-10))
        head_color = (max(0, color[0]-30), max(0, color[1]-30), max(0, color[2]-30))
        pygame.draw.circle(surf, head_color, (self.animal_size-10, self.animal_size//2), 8)
        leg_width = 5
        leg_height = 10
        pygame.draw.rect(surf, body_color, (10, self.animal_size-10, leg_width, leg_height))
        pygame.draw.rect(surf, body_color, (25, self.animal_size-10, leg_width, leg_height))
        pygame.draw.rect(surf, body_color, (self.animal_size-15, self.animal_size-10, leg_width, leg_height))
        pygame.draw.rect(surf, body_color, (self.animal_size-30, self.animal_size-10, leg_width, leg_height))
        return surf

    def _create_plant(self):
        surf = pygame.Surface((20, 40), pygame.SRCALPHA)
        pygame.draw.rect(surf, (34, 139, 34), (8, 15, 4, 25))
        pygame.draw.ellipse(surf, (50, 205, 50), (0, 10, 16, 12))
        pygame.draw.ellipse(surf, (50, 205, 50), (10, 0, 12, 16))
        return surf

    def reset(self, episode=None):
        if episode is not None:
            self.episode = episode
        else:
            self.episode = (self.episode % self.max_episodes) + 1

        self.current_step = 0
        self.car_pos = [self.width//2, self.height-100]
        self.car_angle = 0
        self.speed = 2.0
        self.animals = []
        self.plants = []
        self.animal_hits = 0
        self.total_animals = 0
        self.hit_animals = set()
        self.episode_reward = 0

        if self.episode == 1:
            for _ in range(3):
                self._add_animal(crossing_prob=0.3)
            self.episode_name = "Few animals"
        elif self.episode == 2:
            for _ in range(8):
                self._add_animal(crossing_prob=0.8)
            self.episode_name = "Many animals"
        elif self.episode == 3:
            for i in range(3):
                x = self.width//2 - self.road_width//4 + (i * self.road_width//3)
                self._add_animal_at_position(x, -100, crossing_prob=0.7)
            self.episode_name = "Animals in groups"
        elif self.episode == 4:
            for _ in range(5):
                animal = self._add_animal(crossing_prob=0.6)
                if animal['crossing']:
                    animal['speed'] = 1.5
            self.episode_name = "Fast moving animals"
        elif self.episode == 5:
            for i in range(6):
                crossing = i % 2 == 0
                self._add_animal(crossing_prob=1.0 if crossing else 0.0)
            self.episode_name = "Mixed scenario"

        for y in range(-100, self.height+100, 60):
            self.plants.append([self.width//2 - self.road_width//2 - 30, y])
            self.plants.append([self.width//2 + self.road_width//2 + 10, y])

        return self._get_state()

    def _add_animal(self, crossing_prob=0.5):
        animal_id = random.getrandbits(64)
        x = random.randint(self.width//2 - self.road_width//2 + 40,
                          self.width//2 + self.road_width//2 - 40)
        y = random.randint(-300, -50)
        crossing = random.random() < crossing_prob
        color = (
            random.randint(100, 150),
            random.randint(50, 100),
            random.randint(0, 50)
        )
        animal = {
            'id': animal_id,
            'x': x,
            'y': y,
            'crossing': crossing,
            'color': color,
            'img': self._create_animal(color),
            'speed': 1.0,
            'direction': random.choice([-1, 1])
        }
        if crossing:
            self.total_animals += 1
        self.animals.append(animal)
        return animal

    def _add_animal_at_position(self, x, y, crossing_prob=0.5):
        animal_id = random.getrandbits(64)
        crossing = random.random() < crossing_prob
        color = (
            random.randint(100, 150),
            random.randint(50, 100),
            random.randint(0, 50)
        )
        animal = {
            'id': animal_id,
            'x': x,
            'y': y,
            'crossing': crossing,
            'color': color,
            'img': self._create_animal(color),
            'speed': 1.0,
            'direction': random.choice([-1, 1])
        }
        if crossing:
            self.total_animals += 1
        self.animals.append(animal)
        return animal

    def _get_state(self):
        car_x_offset = (self.car_pos[0] - self.width//2) / (self.road_width//2)

        crossing_animals = [a for a in self.animals if a['crossing']]

        distances = []
        min_dist = self.width
        min_speed = 1.0
        for animal in crossing_animals:
            dist = abs(animal['x'] - self.car_pos[0]) + abs(animal['y'] - self.car_pos[1])
            distances.append((dist, animal))
            if dist < min_dist:
                min_dist = dist
                min_speed = animal.get('speed', 1.0)

        distances.sort(key=lambda x: x[0])
        nearest_animals = distances[:2]

        state = [
            car_x_offset,
            self.speed/4.0,
            1.0, 0.0, 0.0,
            1.0, 0.0, 0.0,
            min_dist / self.width,  # Closest animal distance
            min_speed / 2.0         # Closest animal speed
        ]

        if nearest_animals:
            animal1 = nearest_animals[0][1]
            state[2:5] = [
                (animal1['x'] - self.car_pos[0]) / self.width,
                (self.car_pos[1] - animal1['y']) / self.height,
                1.0
            ]

            if len(nearest_animals) > 1:
                animal2 = nearest_animals[1][1]
                state[5:8] = [
                    (animal2['x'] - self.car_pos[0]) / self.width,
                    (self.car_pos[1] - animal2['y']) / self.height,
                    1.0
                ]

        return np.array(state)

    def step(self, action_idx):
        self.current_step += 1
        steer, speed_change = self.action_space[action_idx]

        self.car_angle = steer * 0.8
        self.car_pos[0] += self.car_angle * (self.speed + 0.5)
        self.car_pos[1] -= self.speed * 0.5

        # Road boundaries for animals
        road_left = self.width//2 - self.road_width//2 + 20
        road_right = self.width//2 + self.road_width//2 - 20

        for animal in self.animals:
            if animal['crossing']:
                new_x = animal['x'] + animal['direction'] * 3 * animal.get('speed', 1.0)

                if new_x < road_left:
                    new_x = road_left
                    animal['direction'] *= -1
                elif new_x > road_right:
                    new_x = road_right
                    animal['direction'] *= -1

                animal['x'] = new_x

            animal['y'] += self.speed * 0.7 * animal.get('speed', 1.0)

        if random.random() < 0.05:
            self._add_animal()

        # Remove off-screen animals
        self.animals = [a for a in self.animals if a['y'] <= self.height + 50]

        car_rect = pygame.Rect(
            self.car_pos[0]-self.car_width//2 + 5,
            self.car_pos[1]-self.car_height//2 + 5,
            self.car_width-10,
            self.car_height-10
        )

        animal_hit = False
        for animal in [a for a in self.animals if a['crossing']]:
            animal_rect = pygame.Rect(
                animal['x']-self.animal_size//2,
                animal['y']-self.animal_size//2,
                self.animal_size,
                self.animal_size
            )

            if car_rect.colliderect(animal_rect):
                if animal['id'] not in self.hit_animals:
                    animal_hit = True
                    self.animal_hits += 1
                    self.hit_animals.add(animal['id'])
                self.animals.remove(animal)
                self._add_animal()

        # Enhanced Reward System
        reward = 0.05  # Base reward for surviving

        if animal_hit:
            reward -= 10  # Strong penalty for hitting animals
        else:
            # Reward slowing down near animals
            if any(a['crossing'] and (self.car_pos[1] - a['y']) < 150 for a in self.animals):
                reward += 0.1 * (self.speed < 1.5)
            # Small reward for every frame without hits
            reward += 0.02

        self.episode_reward += reward

        done = (
            (abs(self.car_pos[0] - self.width//2) > self.road_width//2 + 10) or
            (self.current_step >= self.episode_length)
        )

        if done:
            # Calculate accuracy when episode ends
            saved_animals = max(0, self.total_animals - self.animal_hits)
            accuracy = (saved_animals / self.total_animals) * 100 if self.total_animals > 0 else 100
            return self._get_state(), reward, done, {'accuracy': accuracy, 'reward': self.episode_reward}

        return self._get_state(), reward, done, {}

    def render(self):
        self.screen.fill((139, 69, 19))

        pygame.draw.rect(self.screen, (50, 50, 50),
                        (self.width//2 - self.road_width//2, 0,
                         self.road_width, self.height))

        for i in range(-1, 2):
            x = self.width//2 + i * self.road_width//3
            pygame.draw.line(self.screen, (255, 255, 255),
                           (x, 0), (x, self.height), 2)

        for plant in self.plants:
            plant_y = plant[1] % (self.height + 200) - 100
            if plant_y < self.height:
                plant_img = self._create_plant()
                self.screen.blit(plant_img, (plant[0], plant_y))

        for animal in [a for a in self.animals if a['crossing']]:
            pygame.draw.rect(self.screen, (255, 255, 0),
                           (animal['x']-self.animal_size//2-5,
                            animal['y']-self.animal_size//2-5,
                            self.animal_size+10,
                            self.animal_size+10), 2)
            self.screen.blit(animal['img'],
                           (animal['x']-self.animal_size//2,
                            animal['y']-self.animal_size//2))

        rotated_car = pygame.transform.rotate(self.car_img, -self.car_angle*8)
        car_rect = rotated_car.get_rect(center=self.car_pos)
        self.screen.blit(rotated_car, car_rect)

        font = pygame.font.SysFont('Arial', 20)

        saved_animals = max(0, self.total_animals - self.animal_hits)
        accuracy = (saved_animals / self.total_animals) * 100 if self.total_animals > 0 else 100

        info_text = (f"Ep {self.episode}: {self.episode_name} | "
                    f"Step: {self.current_step}/{self.episode_length} | "
                    f"Animals: {self.total_animals} | "
                    f"Hit: {self.animal_hits} | "
                    f"Accuracy: {accuracy:.1f}%")

        text_surface = font.render(info_text, True, (255, 255, 255))
        self.screen.blit(text_surface, (10, 10))

        if any(a['crossing'] and (self.car_pos[1] - a['y']) < 150 for a in self.animals):
            warning = font.render("ANIMAL CROSSING!", True, (255, 50, 50))
            self.screen.blit(warning, (self.width//2 - 80, 20))

        return np.transpose(np.array(pygame.surfarray.pixels3d(self.screen)), (1, 0, 2))

# 4. Interactive Simulation with Performance Tracking
class InteractiveSimulation:
    def __init__(self):
        self.env = EnhancedDrivingEnv()
        self.agent = AnimalSafeAgent(self.env.action_dim, self.env._get_state().shape[0])
        self.fig, (self.ax1, self.ax2) = plt.subplots(2, 1, figsize=(10, 10))
        self.img = self.ax1.imshow(self.env.render())
        self.ax1.axis('off')

        # Performance plots
        self.ax2.set_xlabel('Episode')
        self.ax2.set_ylabel('Accuracy (%)', color='tab:blue')
        self.accuracy_line, = self.ax2.plot([], [], 'b-', label='Accuracy')
        self.ax2.tick_params(axis='y', labelcolor='tab:blue')

        self.ax2_r = self.ax2.twinx()
        self.ax2_r.set_ylabel('Reward', color='tab:red')
        self.reward_line, = self.ax2_r.plot([], [], 'r-', label='Reward')
        self.ax2_r.tick_params(axis='y', labelcolor='tab:red')

        self.ax2.legend(loc='upper left')
        self.ax2_r.legend(loc='upper right')

        self.episode_count = 0
        self.accuracies = []
        self.rewards = []

        self.prev_button = widgets.Button(description="Previous Episode")
        self.next_button = widgets.Button(description="Next Episode")
        self.reset_button = widgets.Button(description="Reset Current")
        self.train_button = widgets.Button(description="Train 10 Episodes")

        self.prev_button.on_click(self.prev_episode)
        self.next_button.on_click(self.next_episode)
        self.reset_button.on_click(self.reset_current)
        self.train_button.on_click(self.train_episodes)

        display(widgets.HBox([self.prev_button, self.reset_button, self.next_button, self.train_button]))

        self.ani = animation.FuncAnimation(
            self.fig, self.update,
            frames=self.env.episode_length,
            interval=50,
            blit=False
        )
        plt.close()

    def update(self, frame):
        state = self.env._get_state()
        action = self.agent.act(state)
        next_state, reward, done, info = self.env.step(action)

        self.agent.remember(state, action, reward, next_state, done)
        self.agent.update()

        if done:
            accuracy = info.get('accuracy', 0)
            total_reward = info.get('reward', 0)
            self.accuracies.append(accuracy)
            self.rewards.append(total_reward)
            self.episode_count += 1

            # Update plots
            self.accuracy_line.set_data(range(self.episode_count), self.accuracies)
            self.reward_line.set_data(range(self.episode_count), self.rewards)
            self.ax2.relim()
            self.ax2.autoscale_view()
            self.ax2_r.relim()
            self.ax2_r.autoscale_view()

            self.env.reset()

        self.img.set_array(self.env.render())
        return [self.img, self.accuracy_line, self.reward_line]

    def prev_episode(self, b):
        new_episode = self.env.episode - 1 if self.env.episode > 1 else self.env.max_episodes
        self.env.reset(new_episode)
        self.ani.event_source.stop()
        self.ani = animation.FuncAnimation(
            self.fig, self.update,
            frames=self.env.episode_length,
            interval=50,
            blit=False
        )
        display(HTML(self.ani.to_jshtml()))

    def next_episode(self, b):
        new_episode = self.env.episode + 1 if self.env.episode < self.env.max_episodes else 1
        self.env.reset(new_episode)
        self.ani.event_source.stop()
        self.ani = animation.FuncAnimation(
            self.fig, self.update,
            frames=self.env.episode_length,
            interval=50,
            blit=False
        )
        display(HTML(self.ani.to_jshtml()))

    def reset_current(self, b):
        self.env.reset(self.env.episode)
        self.ani.event_source.stop()
        self.ani = animation.FuncAnimation(
            self.fig, self.update,
            frames=self.env.episode_length,
            interval=50,
            blit=False
        )
        display(HTML(self.ani.to_jshtml()))

    def train_episodes(self, b):
        for _ in range(10):
            state = self.env.reset()
            done = False
            while not done:
                action = self.agent.act(state)
                next_state, reward, done, info = self.env.step(action)
                self.agent.remember(state, action, reward, next_state, done)
                self.agent.update()
                state = next_state

            accuracy = info.get('accuracy', 0)
            total_reward = info.get('reward', 0)
            self.accuracies.append(accuracy)
            self.rewards.append(total_reward)
            self.episode_count += 1

            print(f"Episode {self.episode_count}: Accuracy = {accuracy:.1f}%, Reward = {total_reward:.1f}")

        # Update plots after training
        self.accuracy_line.set_data(range(self.episode_count), self.accuracies)
        self.reward_line.set_data(range(self.episode_count), self.rewards)
        self.ax2.relim()
        self.ax2.autoscale_view()
        self.ax2_r.relim()
        self.ax2_r.autoscale_view()
        plt.draw()

# Run simulation
print("Starting enhanced simulation with animal avoidance training...")
sim = InteractiveSimulation()