In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import gymnasium as gym
from collections import deque
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class FootballEnv(gym.Env):
    def __init__(self, grid_rows=10, grid_cols=10):
        super(FootballEnv, self).__init__()
        self.grid_rows = grid_rows
        self.grid_cols = grid_cols

        self.action_space = gym.spaces.Discrete(10)
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(7,), dtype=np.float32)

        # Initialize field layout
        self.layout = np.zeros((grid_rows, grid_cols), dtype=str)
        self.layout[:, :] = "."
        self.layout[self.grid_rows//2, self.grid_cols//2] = "C"
        self.layout[self.grid_rows//2-8 : self.grid_rows//2+9, -6:-1] = "D"
        self.layout[self.grid_rows//2-8 : self.grid_rows//2+9, 0:5] = "d"
        self.layout[:, self.grid_cols//2] = "M"
        self.layout[:, -1] = "O"
        self.layout[:, 0] = "O"
        self.layout[0, :] = "O"
        self.layout[-1, :] = "O"
        self.layout[self.grid_rows//2-4 : self.grid_rows//2+5, -1] = "G"
        self.layout[self.grid_rows//2-4 : self.grid_rows//2+5, 0] = "g"

        # Initialize positions
        self.player_pos = (self.grid_rows//2, 5)
        self.ball_pos = (self.grid_rows//2, self.grid_cols//2)
        self.has_ball = False
        self.episode_steps = 0

        # Reward parameters
        self.goal_reward = 50
        self.step_penalty = -0.001
        self.ball_possession_bonus = 0.007
        self.near_ball_bonus = 0.00001
        self.near_goal_bonus = 0.00002

    def _get_state(self):
        # Enhanced state representation including has_ball flag
        return np.array([
            self.player_pos[0] / (self.grid_rows - 1),
            self.player_pos[1] / (self.grid_cols - 1),
            self.ball_pos[0] / (self.grid_rows - 1),
            self.ball_pos[1] / (self.grid_cols - 1),
            float(self.has_ball),  # Add has_ball as explicit state feature
            self.grid_rows // 2 / (self.grid_rows - 1),  # Goal Y position
            (self.grid_cols - 1) / (self.grid_cols - 1)  # Goal X position
        ], dtype=np.float32)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        
        # Start with player on left side of field
        player_row = random.randint(self.grid_rows//2-3, self.grid_rows//2+3)
        player_col = random.randint(1, self.grid_cols//4)
        self.player_pos = (player_row, player_col)
        
        # Place ball near player to start
        ball_row = random.randint(max(1, player_row-3), min(self.grid_rows-2, player_row+3))
        ball_col = random.randint(max(1, player_col-3), min(self.grid_cols//3, player_col+3))
        self.ball_pos = (ball_row, ball_col)
        
        self.has_ball = (self.player_pos == self.ball_pos)
        self.episode_steps = 0
        return self._get_state(), {}

    def step(self, action):
        self.episode_steps += 1
        reward = self.step_penalty
        done = False

        if action < 8:  # Movement actions
            dx, dy = [(0, -1), (1, 0), (0, 1), (-1, 0), (-1, -1), (-1, 1), (1, -1), (1, 1)][action]
            new_pos = (self.player_pos[0] + dx, self.player_pos[1] + dy)

            if 0 <= new_pos[0] < self.grid_rows and 0 <= new_pos[1] < self.grid_cols and self.layout[new_pos] != "O":
                self.player_pos = new_pos
                if self.has_ball:
                    self.ball_pos = new_pos
            else:
                reward -= 20  # Massive penalty
                done = True  # End episode if the player goes out
                return self._get_state(), reward, done, True, {}

        elif action == 8 and self.has_ball:  # Long shot
            goal_y_center = self.grid_rows // 2
            # Better aim toward goal
            target_y = min(max(goal_y_center + random.randint(-2, 2), 0), self.grid_rows-1)
            new_ball_col = min(self.ball_pos[1] + 10, self.grid_cols - 1)
            self.ball_pos = (target_y, new_ball_col)
            self.has_ball = False

        elif action == 9 and self.has_ball:  # Short pass
            new_ball_col = min(self.ball_pos[1] + 5, self.grid_cols - 1)
            self.ball_pos = (self.ball_pos[0], new_ball_col)
            self.has_ball = False

        # Check if player gets the ball
        self.has_ball = self.player_pos == self.ball_pos
        
        # Reward shaping
        if self.has_ball:
            reward += self.ball_possession_bonus

        if self.has_ball and self.player_pos[1] > self.grid_cols//2:
            reward += 0.001
        
        # Distance-based rewards
        dist_to_ball = np.sqrt((self.player_pos[0] - self.ball_pos[0])**2 + 
                              (self.player_pos[1] - self.ball_pos[1])**2)
        if dist_to_ball < 5 and not self.has_ball:
            reward += self.near_ball_bonus
        
        # Reward for moving toward goal with ball
        if self.has_ball:
            # Calculate distance to goal
            dist_to_goal = self.grid_cols - 1 - self.player_pos[1]
            if dist_to_goal < 10:
                reward += self.near_goal_bonus * (1 - dist_to_goal/10.0)
        
        # Goal reward
        if self.layout[self.ball_pos[0], self.ball_pos[1]] == 'G':
            reward += self.goal_reward
            done = True

        truncated = self.episode_steps >= 3000  # Shorter episodes
        return self._get_state(), reward, done, truncated, {}

    def render(self):
        grid = np.full((self.grid_rows, self.grid_cols), '-')
        
        # Draw field elements
        for i in range(self.grid_rows):
            for j in range(self.grid_cols):
                if self.layout[i,j] == 'O':
                    grid[i,j] = '#'
                elif self.layout[i,j] == 'G':
                    grid[i,j] = '|'
                elif self.layout[i,j] == 'M':
                    grid[i,j] = '.'
        
        # Draw player and ball
        grid[self.player_pos[0], self.player_pos[1]] = 'P'
        if not self.has_ball:
            grid[self.ball_pos[0], self.ball_pos[1]] = 'o'
        
        # Print the grid
        print('-' * (self.grid_cols + 2))
        for row in grid:
            print('|' + ''.join(row) + '|')
        print('-' * (self.grid_cols + 2))
        print(f"Has ball: {self.has_ball}, Steps: {self.episode_steps}")


class DQN(nn.Module):
    def __init__(self, input_dim, action_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, action_size)
        
        # Initialize weights with better defaults
        for layer in [self.fc1, self.fc2, self.fc3, self.fc4]:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return self.fc4(x)


class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.gamma = 0.99  # Discount factor
        self.epsilon = 1.0  # Exploration rate
        self.epsilon_min = 0.05
        self.epsilon_decay = 0.995
        self.learning_rate = 0.0003
        self.memory = deque(maxlen=100000)
        self.batch_size = 256
        self.target_update_freq = 5  # Update target network every N episodes
        
        self.device = device
        self.model = DQN(state_size, action_size).to(self.device)
        self.target_model = DQN(state_size, action_size).to(self.device)
        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        
        self.rewards_history = []
        self.episode_count = 0

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state, evaluate=False):
        if not evaluate and np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            q_values = self.model(state)
        return torch.argmax(q_values).item()

    def replay(self):
        if len(self.memory) < self.batch_size:
            return 0
            
        # Sample minibatch from memory
        minibatch = random.sample(self.memory, self.batch_size)
        
        states = torch.FloatTensor([experience[0] for experience in minibatch]).to(self.device)
        actions = torch.LongTensor([experience[1] for experience in minibatch]).to(self.device)
        rewards = torch.FloatTensor([experience[2] for experience in minibatch]).to(self.device)
        next_states = torch.FloatTensor([experience[3] for experience in minibatch]).to(self.device)
        dones = torch.FloatTensor([experience[4] for experience in minibatch]).to(self.device)
        
        # Current Q values
        curr_q_values = self.model(states).gather(1, actions.unsqueeze(1))
        
        # Target Q values
        with torch.no_grad():
            next_q_values = self.target_model(next_states).max(1)[0]
        
        # Compute target
        target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
        
        # Compute loss
        loss = F.mse_loss(curr_q_values.squeeze(), target_q_values)
        
        # Backpropagation
        self.optimizer.zero_grad()
        loss.backward()
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)  
        self.optimizer.step()
        
        return loss.item()

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def decay_epsilon(self):
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def save(self, filepath):
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epsilon': self.epsilon,
            'episode_count': self.episode_count,
            'rewards_history': self.rewards_history
        }, filepath)
        print(f"Model saved to {filepath}")

    def load(self, filepath):
        checkpoint = torch.load(filepath)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.target_model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.epsilon = checkpoint['epsilon']
        self.episode_count = checkpoint['episode_count']
        self.rewards_history = checkpoint['rewards_history']
        print(f"Model loaded from {filepath}")

    def train(self, env, episodes, max_steps=2000, save_freq=50, render_freq=20):
        for episode in range(episodes):
            state, _ = env.reset()
            total_reward = 0
            done = False
            truncated = False
            losses = []
            
            for step in range(max_steps):
                action = self.act(state)
                next_state, reward, done, truncated, _ = env.step(action)
                total_reward += reward
                
                # Store experience in memory
                self.remember(state, action, reward, next_state, done or truncated)
                
                # Train model with replay
                if len(self.memory) >= self.batch_size:
                    loss = self.replay()
                    losses.append(loss)
                
                state = next_state
                
                if done or truncated:
                    break
                    
            # Update target network periodically
            if episode % self.target_update_freq == 0:
                self.update_target_model()
                
            # Decay exploration rate
            self.decay_epsilon()
            
            # Record stats
            self.episode_count += 1
            self.rewards_history.append(total_reward)
            
            # Print episode statistics
            avg_loss = np.mean(losses) if losses else 0
            # print(f"Episode {episode}: Reward = {total_reward:.1f}, Steps = {step+1}, Epsilon = {self.epsilon:.3f}, Avg Loss = {avg_loss:.5f}")
            
            # Save the model periodically
            if episode > 0 and episode % save_freq == 0:
                self.save(f"dqn_football_ep{episode}.pth")

            # if episode % 10 == 0:
            #     avg_reward = np.mean(self.rewards_history[-10:])
            #     print(f"Last 10 episodes average reward: {avg_reward:.2f}")
                
            # Render occasionally to see progress
            # if episode % render_freq == 0:
            #     print(f"\n--- Episode {episode} Rendering ---")
            #     test_env = FootballEnv()
            #     self.evaluate(test_env, render=True)
                
    def evaluate(self, env, episodes=1, render=True):
        total_rewards = []
        
        for episode in range(episodes):
            state, _ = env.reset()
            total_reward = 0
            done = False
            truncated = False
            
            while not done and not truncated:
                action = self.act(state, evaluate=True)  # No exploration
                next_state, reward, done, truncated, _ = env.step(action)
                total_reward += reward
                
                if render:
                    print(action)
                    env.render()
                    time.sleep(0.5)  # Pause to make rendering visible
                    
                state = next_state
                
            total_rewards.append(total_reward)
            print(f"Evaluation episode {episode}: Reward = {total_reward}")
            
        return np.mean(total_rewards)

In [3]:
# Create environment and agent
env = FootballEnv()
agent = DQNAgent(state_size=7, action_size=8)  # Updated state size
# Train the agent
agent.train(env, episodes=3000, save_freq=500)
print("Done")
# Save the final model
agent.save("dqn_football_final.pth")

  and should_run_async(code)
  states = torch.FloatTensor([experience[0] for experience in minibatch]).to(self.device)


Model saved to dqn_football_ep500.pth
Model saved to dqn_football_ep1000.pth
Model saved to dqn_football_ep1500.pth
Model saved to dqn_football_ep2000.pth
Model saved to dqn_football_ep2500.pth
Done
Model saved to dqn_football_final.pth


In [13]:
agent.evaluate(env, episodes=1, render=True)

5
------------
|##########|
|-----.---||
|-oP--.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-########||
------------
Has ball: False, Steps: 1
4
------------
|##########|
|-P---.---||
|-o---.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-########||
------------
Has ball: False, Steps: 2
2
------------
|##########|
|--P--.---||
|-o---.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-########||
------------
Has ball: False, Steps: 3
6
------------
|##########|
|-----.---||
|-P---.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-########||
------------
Has ball: True, Steps: 4
2
------------
|##########|
|-----.---||
|--P--.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-----.---||
|-########||
------------
Has ball: True, Steps: 5
2
------------
|##########|
|-----.---||
|---P-.---||
|-----.---||
|-----.---||
|-

50.055138

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import gymnasium as gym
from collections import deque
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Player:
    def __init__(self, team, tackle_probability=0.5):
        """
        Initialize a player with team assignment and tackle capability
        
        Args:
            team: Either 'home' or 'away'
            tackle_probability: Chance of successful tackle when attempting
        """
        self.team = team
        self.tackle_probability = tackle_probability
        self.position = None
        self.has_ball = False
        
    def reset_position(self, row, col):
        """Set player position on the field"""
        self.position = (row, col)
        
    def move(self, dx, dy, field_layout, grid_rows, grid_cols):
        """
        Move player by the given delta if the move is valid
        
        Returns:
            bool: Whether the move was successful
        """
        new_pos = (self.position[0] + dx, self.position[1] + dy)
        
        if (0 <= new_pos[0] < grid_rows and 
            0 <= new_pos[1] < grid_cols and 
            field_layout[new_pos] != "O"):
            self.position = new_pos
            return True
        return False
        
    def attempt_tackle(self):
        """
        Attempt to tackle the opponent and get the ball
        
        Returns:
            bool: Whether the tackle was successful
        """
        return random.random() < self.tackle_probability


class FootballEnv(gym.Env):
    def __init__(self, grid_rows=10, grid_cols=10):
        super(FootballEnv, self).__init__()
        self.grid_rows = grid_rows
        self.grid_cols = grid_cols

        # Action space now includes tackle
        self.action_space = gym.spaces.Discrete(11)  # 8 movement directions + long shot + short pass + tackle
        
        # Observation space includes opponent position
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(9,), dtype=np.float32)

        # Initialize field layout
        self.layout = np.zeros((grid_rows, grid_cols), dtype=str)
        self.layout[:, :] = "."
        self.layout[self.grid_rows//2, self.grid_cols//2] = "C"
        self.layout[self.grid_rows//2-8 : self.grid_rows//2+9, -6:-1] = "D"
        self.layout[self.grid_rows//2-8 : self.grid_rows//2+9, 0:5] = "d"
        self.layout[:, self.grid_cols//2] = "M"
        self.layout[:, -1] = "O"
        self.layout[:, 0] = "O"
        self.layout[0, :] = "O"
        self.layout[-1, :] = "O"
        self.layout[self.grid_rows//2-4 : self.grid_rows//2+5, -1] = "G"
        self.layout[self.grid_rows//2-4 : self.grid_rows//2+5, 0] = "g"

        # Initialize players
        self.player = Player(team='home')
        self.opponent = Player(team='away')
        self.ball_pos = (self.grid_rows//2, self.grid_cols//2)
        self.episode_steps = 0
        
        # Score tracking
        self.home_score = 0
        self.away_score = 0

        # Reward parameters
        self.goal_reward = 50
        self.goal_against_penalty = -30
        self.step_penalty = -0.001
        self.ball_possession_bonus = 0.007
        self.near_ball_bonus = 0.00001
        self.near_goal_bonus = 0.00002
        self.successful_tackle_bonus = 0.05
        self.unsuccessful_tackle_penalty = -0.02

    def _get_state(self):
        # Enhanced state representation including opponent position
        return np.array([
            self.player.position[0] / (self.grid_rows - 1),
            self.player.position[1] / (self.grid_cols - 1),
            self.opponent.position[0] / (self.grid_rows - 1),
            self.opponent.position[1] / (self.grid_cols - 1),
            self.ball_pos[0] / (self.grid_rows - 1),
            self.ball_pos[1] / (self.grid_cols - 1),
            float(self.player.has_ball),  # Add has_ball as explicit state feature
            float(self.opponent.has_ball),  # Add opponent has_ball
            self.grid_rows // 2 / (self.grid_rows - 1)  # Goal Y position
        ], dtype=np.float32)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        
        # Start with home player on left side of field
        player_row = random.randint(self.grid_rows//2-3, self.grid_rows//2+3)
        player_col = random.randint(1, self.grid_cols//4)
        self.player.reset_position(player_row, player_col)
        
        # Start with away player on right side of field
        opponent_row = random.randint(self.grid_rows//2-3, self.grid_rows//2+3)
        opponent_col = random.randint(self.grid_cols*3//4, self.grid_cols-2)
        self.opponent.reset_position(opponent_row, opponent_col)
        
        # Place ball in center
        self.ball_pos = (self.grid_rows//2, self.grid_cols//2)
        
        # Randomize initial ball possession
        if random.random() < 0.5:
            self.player.has_ball = True
            self.opponent.has_ball = False
            self.ball_pos = self.player.position
        else:
            self.opponent.has_ball = True
            self.player.has_ball = False
            self.ball_pos = self.opponent.position
        
        self.episode_steps = 0
        self.home_score = 0
        self.away_score = 0
        
        return self._get_state(), {}

    def _opponent_action(self):
        """Simple AI for the opponent"""
        
        # If opponent has the ball, move toward home goal
        if self.opponent.has_ball:
            # Calculate possible moves toward home goal (left side)
            possible_moves = []
            
            # Try to move toward home goal (left side)
            if self.opponent.position[1] > 0:
                possible_moves.append((-1, -1))  # Diagonal up-left
                possible_moves.append((0, -1))   # Left
                possible_moves.append((1, -1))   # Diagonal down-left
            
            # Try vertical moves if in good position horizontally
            if self.opponent.position[1] < self.grid_cols // 2:
                if self.opponent.position[0] > self.grid_rows // 2:
                    possible_moves.append((-1, 0))  # Up
                else:
                    possible_moves.append((1, 0))   # Down
            
            # If in shooting range, attempt shot
            if self.opponent.position[1] < 3 and random.random() < 0.7:
                # Attempt shot at goal
                target_y = min(max(self.grid_rows // 2 + random.randint(-2, 2), 0), self.grid_rows-1)
                self.ball_pos = (target_y, 0)
                self.opponent.has_ball = False
                return
            
            # If no moves available or random choice, try random move
            if not possible_moves or random.random() < 0.1:
                possible_moves = [(-1, -1), (-1, 0), (-1, 1), (0, -1), 
                                  (0, 1), (1, -1), (1, 0), (1, 1)]
            
            # Try moves until a valid one is found
            random.shuffle(possible_moves)
            for dx, dy in possible_moves:
                if self.opponent.move(dx, dy, self.layout, self.grid_rows, self.grid_cols):
                    self.ball_pos = self.opponent.position  # Ball moves with opponent
                    break
                    
        # If opponent doesn't have the ball, try to get it
        else:
            # If close to player with ball, attempt tackle
            if (self.player.has_ball and 
                abs(self.opponent.position[0] - self.player.position[0]) <= 1 and
                abs(self.opponent.position[1] - self.player.position[1]) <= 1):
                
                if self.opponent.attempt_tackle():
                    # Successful tackle
                    self.opponent.has_ball = True
                    self.player.has_ball = False
                    self.ball_pos = self.opponent.position
                    return
                    
            # Move toward the ball
            dx = 0 if self.opponent.position[0] == self.ball_pos[0] else (
                  1 if self.opponent.position[0] < self.ball_pos[0] else -1)
            dy = 0 if self.opponent.position[1] == self.ball_pos[1] else (
                  1 if self.opponent.position[1] < self.ball_pos[1] else -1)
            
            # Try to move toward ball
            if dx != 0 or dy != 0:
                self.opponent.move(dx, dy, self.layout, self.grid_rows, self.grid_cols)
                
            # Check if opponent got the ball
            if self.opponent.position == self.ball_pos and not self.player.has_ball:
                self.opponent.has_ball = True

    def step(self, action):
        self.episode_steps += 1
        reward = self.step_penalty
        done = False
        info = {'scored': False, 'conceded': False}

        # Process player action
        if action < 8:  # Movement actions
            dx, dy = [(0, -1), (1, 0), (0, 1), (-1, 0), (-1, -1), (-1, 1), (1, -1), (1, 1)][action]
            
            if self.player.move(dx, dy, self.layout, self.grid_rows, self.grid_cols):
                if self.player.has_ball:
                    self.ball_pos = self.player.position
            else:
                reward -= 0.1  # Small penalty for invalid move attempt
        
        elif action == 8 and self.player.has_ball:  # Long shot
            goal_y_center = self.grid_rows // 2
            # Better aim toward goal
            target_y = min(max(goal_y_center + random.randint(-2, 2), 0), self.grid_rows-1)
            new_ball_col = min(self.ball_pos[1] + 10, self.grid_cols - 1)
            self.ball_pos = (target_y, new_ball_col)
            self.player.has_ball = False

        elif action == 9 and self.player.has_ball:  # Short pass
            new_ball_col = min(self.ball_pos[1] + 5, self.grid_cols - 1)
            self.ball_pos = (self.ball_pos[0], new_ball_col)
            self.player.has_ball = False
            
        elif action == 10:  # Tackle
            # Check if player is near opponent with ball
            if (self.opponent.has_ball and 
                abs(self.player.position[0] - self.opponent.position[0]) <= 1 and
                abs(self.player.position[1] - self.opponent.position[1]) <= 1):
                
                if self.player.attempt_tackle():
                    # Successful tackle
                    self.player.has_ball = True
                    self.opponent.has_ball = False
                    self.ball_pos = self.player.position
                    reward += self.successful_tackle_bonus
                else:
                    # Failed tackle
                    reward += self.unsuccessful_tackle_penalty

        # Check if player gets the ball (if it's free)
        if self.player.position == self.ball_pos and not self.opponent.has_ball:
            self.player.has_ball = True
        
        # Process opponent action
        self._opponent_action()
        
        # Reward shaping
        if self.player.has_ball:
            reward += self.ball_possession_bonus

        if self.player.has_ball and self.player.position[1] > self.grid_cols//2:
            reward += 0.001
        
        # Distance-based rewards
        if not self.player.has_ball and not self.opponent.has_ball:
            dist_to_ball = np.sqrt((self.player.position[0] - self.ball_pos[0])**2 + 
                                  (self.player.position[1] - self.ball_pos[1])**2)
            if dist_to_ball < 5:
                reward += self.near_ball_bonus
        
        # Reward for moving toward goal with ball
        if self.player.has_ball:
            # Calculate distance to goal
            dist_to_goal = self.grid_cols - 1 - self.player.position[1]
            if dist_to_goal < 10:
                reward += self.near_goal_bonus * (1 - dist_to_goal/10.0)
        
        # Check for goals
        # Home team scores (player)
        if self.layout[self.ball_pos[0], self.ball_pos[1]] == 'G':
            reward += self.goal_reward
            self.home_score += 1
            info['scored'] = True
            # Reset positions but keep score
            self.reset()
            
        # Away team scores (opponent)
        elif self.layout[self.ball_pos[0], self.ball_pos[1]] == 'g':
            reward += self.goal_against_penalty
            self.away_score += 1
            info['conceded'] = True
            # Reset positions but keep score
            self.reset()
            
        # End game if either team scores 3 goals
        if self.home_score >= 3 or self.away_score >= 3:
            done = True

        truncated = self.episode_steps >= 2000  # Shorter episodes
        
        # Add score to info
        info['home_score'] = self.home_score
        info['away_score'] = self.away_score
        
        return self._get_state(), reward, done, truncated, info

    def render(self):
        grid = np.full((self.grid_rows, self.grid_cols), '-')
        
        # Draw field elements
        for i in range(self.grid_rows):
            for j in range(self.grid_cols):
                if self.layout[i,j] == 'O':
                    grid[i,j] = '#'
                elif self.layout[i,j] == 'G':
                    grid[i,j] = '|'
                elif self.layout[i,j] == 'g':
                    grid[i,j] = '|'
                elif self.layout[i,j] == 'M':
                    grid[i,j] = '.'
        
        # Draw player, opponent and ball
        grid[self.player.position[0], self.player.position[1]] = 'P'
        grid[self.opponent.position[0], self.opponent.position[1]] = 'E'
        
        # Draw ball if not possessed
        if not self.player.has_ball and not self.opponent.has_ball:
            grid[self.ball_pos[0], self.ball_pos[1]] = 'o'
        
        # Print the grid
        print('-' * (self.grid_cols + 2))
        for row in grid:
            print('|' + ''.join(row) + '|')
        print('-' * (self.grid_cols + 2))
        print(f"Score: Home {self.home_score} - {self.away_score} Away")
        print(f"Player has ball: {self.player.has_ball}, Opponent has ball: {self.opponent.has_ball}")
        print(f"Steps: {self.episode_steps}")


class DQN(nn.Module):
    def __init__(self, input_dim, action_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, action_size)
        
        # Initialize weights with better defaults
        for layer in [self.fc1, self.fc2, self.fc3, self.fc4]:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return self.fc4(x)


class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.gamma = 0.99  # Discount factor
        self.epsilon = 1.0  # Exploration rate
        self.epsilon_min = 0.05
        self.epsilon_decay = 0.995
        self.learning_rate = 0.0003
        self.memory = deque(maxlen=100000)
        self.batch_size = 256
        self.target_update_freq = 5  # Update target network every N episodes
        
        self.device = device
        self.model = DQN(state_size, action_size).to(self.device)
        self.target_model = DQN(state_size, action_size).to(self.device)
        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        
        self.rewards_history = []
        self.episode_count = 0
        self.goals_scored = 0
        self.goals_conceded = 0

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state, evaluate=False):
        if not evaluate and np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            q_values = self.model(state)
        return torch.argmax(q_values).item()

    def replay(self):
        if len(self.memory) < self.batch_size:
            return 0
            
        # Sample minibatch from memory
        minibatch = random.sample(self.memory, self.batch_size)
        
        states = torch.FloatTensor([experience[0] for experience in minibatch]).to(self.device)
        actions = torch.LongTensor([experience[1] for experience in minibatch]).to(self.device)
        rewards = torch.FloatTensor([experience[2] for experience in minibatch]).to(self.device)
        next_states = torch.FloatTensor([experience[3] for experience in minibatch]).to(self.device)
        dones = torch.FloatTensor([experience[4] for experience in minibatch]).to(self.device)
        
        # Current Q values
        curr_q_values = self.model(states).gather(1, actions.unsqueeze(1))
        
        # Target Q values
        with torch.no_grad():
            next_q_values = self.target_model(next_states).max(1)[0]
        
        # Compute target
        target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
        
        # Compute loss
        loss = F.mse_loss(curr_q_values.squeeze(), target_q_values)
        
        # Backpropagation
        self.optimizer.zero_grad()
        loss.backward()
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)  
        self.optimizer.step()
        
        return loss.item()

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def decay_epsilon(self):
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def save(self, filepath):
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epsilon': self.epsilon,
            'episode_count': self.episode_count,
            'rewards_history': self.rewards_history,
            'goals_scored': self.goals_scored,
            'goals_conceded': self.goals_conceded
        }, filepath)
        print(f"Model saved to {filepath}")

    def load(self, filepath):
        checkpoint = torch.load(filepath)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.target_model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.epsilon = checkpoint['epsilon']
        self.episode_count = checkpoint['episode_count']
        self.rewards_history = checkpoint['rewards_history']
        self.goals_scored = checkpoint.get('goals_scored', 0)
        self.goals_conceded = checkpoint.get('goals_conceded', 0)
        print(f"Model loaded from {filepath}")

    def train(self, env, episodes, max_steps=2000, save_freq=50, render_freq=20):
        for episode in range(episodes):
            state, _ = env.reset()
            total_reward = 0
            done = False
            truncated = False
            losses = []
            
            for step in range(max_steps):
                action = self.act(state)
                next_state, reward, done, truncated, info = env.step(action)
                total_reward += reward
                
                # Store experience in memory
                self.remember(state, action, reward, next_state, done or truncated)
                
                # Track goals
                if info.get('scored', False):
                    self.goals_scored += 1
                if info.get('conceded', False):
                    self.goals_conceded += 1
                
                # Train model with replay
                if len(self.memory) >= self.batch_size:
                    loss = self.replay()
                    losses.append(loss)
                
                state = next_state
                
                if done or truncated:
                    break
                    
            # Update target network periodically
            if episode % self.target_update_freq == 0:
                self.update_target_model()
                
            # Decay exploration rate
            self.decay_epsilon()
            
            # Record stats
            self.episode_count += 1
            self.rewards_history.append(total_reward)
            
            # Print episode statistics
            avg_loss = np.mean(losses) if losses else 0
            if episode % 10 == 0:
                print(f"Episode {episode}: Reward = {total_reward:.1f}, Steps = {step+1}, Epsilon = {self.epsilon:.3f}, Avg Loss = {avg_loss:.5f}")
                print(f"Goals: {self.goals_scored} scored, {self.goals_conceded} conceded")
            
            # # Save the model periodically
            # if episode > 0 and episode % save_freq == 0:
            #     self.save(f"dqn_football_ep{episode}.pth")

            # if episode % 10 == 0:
            #     avg_reward = np.mean(self.rewards_history[-10:])
            #     print(f"Last 10 episodes average reward: {avg_reward:.2f}")
                
            # # Render occasionally to see progress
            # if episode % render_freq == 0:
            #     print(f"\n--- Episode {episode} Rendering ---")
            #     test_env = FootballEnv()
            #     self.evaluate(test_env, render=True)
                
    def evaluate(self, env, episodes=1, render=True):
        total_rewards = []
        goals_scored = 0
        goals_conceded = 0
        
        for episode in range(episodes):
            state, _ = env.reset()
            total_reward = 0
            done = False
            truncated = False
            
            while not done and not truncated:
                action = self.act(state, evaluate=True)  # No exploration
                next_state, reward, done, truncated, info = env.step(action)
                total_reward += reward
                
                # Track goals during evaluation
                if info.get('scored', False):
                    goals_scored += 1
                if info.get('conceded', False):
                    goals_conceded += 1
                
                if render:
                    print(f"Action: {action}")
                    env.render()
                    time.sleep(0.5)  # Pause to make rendering visible
                    
                state = next_state
                
            total_rewards.append(total_reward)
            print(f"Evaluation episode {episode}: Reward = {total_reward}")
            
        print(f"Evaluation goals: {goals_scored} scored, {goals_conceded} conceded")
        return np.mean(total_rewards)

# Example usage
if __name__ == "__main__":
    # Create environment and agent
    env = FootballEnv()
    agent = DQNAgent(state_size=9, action_size=11)  # Updated state size and action size
    # Train the agent
    agent.train(env, episodes=3000, save_freq=500)
    # Save the final model
    agent.save("dqn_football_final.pth")
    print("Training complete")

Episode 0: Reward = -5695.9, Steps = 2000, Epsilon = 0.995, Avg Loss = 53.67758
Goals: 43 scored, 261 conceded
Episode 10: Reward = -3764.8, Steps = 2000, Epsilon = 0.946, Avg Loss = 275.91888
Goals: 628 scored, 2630 conceded
Episode 20: Reward = -4267.5, Steps = 2000, Epsilon = 0.900, Avg Loss = 384.31072
Goals: 1296 scored, 4994 conceded
Episode 30: Reward = -3210.7, Steps = 2000, Epsilon = 0.856, Avg Loss = 440.78407
Goals: 2017 scored, 7284 conceded
Episode 40: Reward = -2308.1, Steps = 2000, Epsilon = 0.814, Avg Loss = 351.68972
Goals: 2801 scored, 9539 conceded
Episode 50: Reward = -659.3, Steps = 2000, Epsilon = 0.774, Avg Loss = 279.67012
Goals: 3950 scored, 11724 conceded
Episode 60: Reward = -156.1, Steps = 2000, Epsilon = 0.737, Avg Loss = 320.72654
Goals: 5178 scored, 13885 conceded
Episode 70: Reward = 1525.1, Steps = 2000, Epsilon = 0.701, Avg Loss = 360.78508
Goals: 6521 scored, 15968 conceded
Episode 80: Reward = 1586.8, Steps = 2000, Epsilon = 0.666, Avg Loss = 333.314

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import gymnasium as gym
from collections import deque
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Player:
    def __init__(self, steal_prob=0.3):
        self.pos = (0, 0)
        self.has_ball = False
        self.steal_prob = steal_prob
        self.team = None  # 'A' or 'B' for different teams

    def reset(self, pos, team):
        self.pos = pos
        self.has_ball = False
        self.team = team

    def attempt_steal(self, other):
        if random.random() < self.steal_prob:
            if other.has_ball:
                other.has_ball = False
                self.has_ball = True
                return True
        return False

class FootballEnv(gym.Env):
    def __init__(self, grid_rows=10, grid_cols=10):
        super(FootballEnv, self).__init__()
        self.grid_rows = grid_rows
        self.grid_cols = grid_cols

        self.action_space = gym.spaces.Discrete(10)
        self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(9,), dtype=np.float32)

        # Initialize field layout
        self.layout = np.zeros((grid_rows, grid_cols), dtype=str)
        self._initialize_layout()
        
        # Initialize players and ball
        self.player = Player(steal_prob=0.4)
        self.opponent = Player(steal_prob=0.35)
        self.ball_pos = (grid_rows//2, grid_cols//2)
        self.episode_steps = 0

        # Reward parameters
        self.goal_reward = 50
        self.step_penalty = -0.01
        self.ball_possession_bonus = 0.1
        self.near_ball_bonus = 0.001
        self.concede_penalty = -30

    def _initialize_layout(self):
        self.layout[:, :] = "."
        mid_row, mid_col = self.grid_rows//2, self.grid_cols//2
        
        # Goals and center
        self.layout[mid_row-4:mid_row+5, -1] = "G"
        self.layout[mid_row-4:mid_row+5, 0] = "g"
        self.layout[:, mid_col] = "M"
        
        # Field boundaries
        self.layout[[0, -1], :] = "O"
        self.layout[:, [0, -1]] = "O"
        self.layout[0, mid_col-4:mid_col+5] = "."
        self.layout[-1, mid_col-4:mid_col+5] = "."

    def _get_state(self, agent):
        if agent == 'player':
            player = self.player
            opponent = self.opponent
            goal_x = self.grid_cols - 1
        else:
            player = self.opponent
            opponent = self.player
            goal_x = 0

        return np.array([
            player.pos[0] / (self.grid_rows-1),
            player.pos[1] / (self.grid_cols-1),
            -opponent.pos[0]/(self.grid_rows-1),  # Negative for opponent
            -opponent.pos[1]/(self.grid_cols-1),
            self.ball_pos[0]/(self.grid_rows-1),
            self.ball_pos[1]/(self.grid_cols-1),
            float(player.has_ball),
            goal_x/(self.grid_cols-1),
            (self.grid_rows//2)/(self.grid_rows-1)
        ], dtype=np.float32)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        
        # Initialize positions
        self.player.reset(
            (random.randint(2, self.grid_rows-3), 
            random.randint(2, self.grid_cols//4)),
            'A'
        )
        self.opponent.reset(
            (random.randint(2, self.grid_rows-3), 
            random.randint(3*self.grid_cols//4, self.grid_cols-3)),
            'B'
        )
        
        # Place ball near player
        self.ball_pos = (
            random.randint(max(1, self.player.pos[0]-2), 
            min(self.grid_rows-2, self.player.pos[0]+2)),
            random.randint(max(1, self.player.pos[1]-2),
            min(self.grid_cols//3, self.player.pos[1]+2))
        )
        self.player.has_ball = (self.player.pos == self.ball_pos)
        self.opponent.has_ball = False
        self.episode_steps = 0
        
        return self._get_state('player'), {}

    def _move_agent(self, agent, action):
        dx, dy = 0, 0
        if action < 8:
            moves = [(0,-1),(1,0),(0,1),(-1,0),
                    (-1,-1),(-1,1),(1,-1),(1,1)]
            dx, dy = moves[action]
        
        new_pos = (agent.pos[0]+dx, agent.pos[1]+dy)
        if 0 <= new_pos[0] < self.grid_rows and 0 <= new_pos[1] < self.grid_cols:
            if self.layout[new_pos] != 'O':
                agent.pos = new_pos
                if agent.has_ball:
                    self.ball_pos = new_pos
                return True
        return False

    def step(self, action):
        self.episode_steps += 1
        p_reward = o_reward = self.step_penalty
        done = truncated = False

        # Player action
        player_valid = self._move_agent(self.player, action)
        if not player_valid:
            p_reward -= 2.0

        # Opponent action (random policy for initial training)
        opponent_action = random.randint(0, 9)
        opponent_valid = self._move_agent(self.opponent, opponent_action)
        if not opponent_valid:
            o_reward -= 2.0

        # Ball handling
        ball_carrier = None
        if self.player.has_ball:
            ball_carrier = self.player
        elif self.opponent.has_ball:
            ball_carrier = self.opponent

        # Check ball possession
        if self.player.pos == self.ball_pos and self.opponent.pos == self.ball_pos:
            if self.player.attempt_steal(self.opponent):
                p_reward += 0.5
                o_reward -= 0.5
            elif self.opponent.attempt_steal(self.player):
                o_reward += 0.5
                p_reward -= 0.5
        else:
            self.player.has_ball = (self.player.pos == self.ball_pos)
            self.opponent.has_ball = (self.opponent.pos == self.ball_pos)

        # Shooting/passing
        if action == 8 and self.player.has_ball:  # Long shot
            self.ball_pos = (self.ball_pos[0], min(self.ball_pos[1]+8, self.grid_cols-1))
            self.player.has_ball = False
        elif action == 9 and self.player.has_ball:  # Short pass
            self.ball_pos = (self.ball_pos[0], min(self.ball_pos[1]+4, self.grid_cols-1))
            self.player.has_ball = False

        # Goal scoring
        if self.layout[self.ball_pos] == 'G':
            p_reward += self.goal_reward
            o_reward += self.concede_penalty
            done = True
        elif self.layout[self.ball_pos] == 'g':
            o_reward += self.goal_reward
            p_reward += self.concede_penalty
            done = True

        # Additional rewards
        if self.player.has_ball:
            p_reward += self.ball_possession_bonus
            p_reward += (self.player.pos[1] / self.grid_cols) * 0.01  # Forward progress
        if self.opponent.has_ball:
            o_reward += self.ball_possession_bonus

        # Distance-based rewards
        p_dist = abs(self.player.pos[0]-self.ball_pos[0]) + abs(self.player.pos[1]-self.ball_pos[1])
        o_dist = abs(self.opponent.pos[0]-self.ball_pos[0]) + abs(self.opponent.pos[1]-self.ball_pos[1])
        p_reward += self.near_ball_bonus / (p_dist + 1)
        o_reward += self.near_ball_bonus / (o_dist + 1)

        truncated = self.episode_steps >= 1500
        return (
            self._get_state('player'), 
            {'player_reward': p_reward, 'opponent_reward': o_reward},
            done,
            truncated,
            {}
        )

    def render(self):
        grid = np.full((self.grid_rows, self.grid_cols), ' ')
        
        # Draw field
        for i in range(self.grid_rows):
            for j in range(self.grid_cols):
                if self.layout[i,j] == 'O': grid[i,j] = '#'
                elif self.layout[i,j] in ['G','g']: grid[i,j] = '|'
        
        # Draw actors
        grid[self.ball_pos] = 'o'
        grid[self.player.pos] = 'P' if self.player.has_ball else 'p'
        grid[self.opponent.pos] = 'O' if self.opponent.has_ball else 'o'
        
        # Print
        print('\n' + '='*30)
        for row in grid:
            print('|' + ''.join(row) + '|')
        print(f"Steps: {self.episode_steps}")


class DQN(nn.Module):
    def __init__(self, input_dim, action_size):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, action_size))
        
        for layer in self.net:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
                nn.init.zeros_(layer.bias)

    def forward(self, x):
        return self.net(x)


class MultiAgentDQN:
    def __init__(self, state_size, action_size):
        self.agents = {
            'player': DQNAgent(state_size, action_size),
            'opponent': DQNAgent(state_size, action_size)
        }

    def train(self, env, episodes):
        for ep in range(episodes):
            state, _ = env.reset()
            total_rewards = {'player': 0, 'opponent': 0}
            
            while True:
                # Get actions from both agents
                p_action = self.agents['player'].act(state)
                next_state, rewards, done, truncated, _ = env.step(p_action)
                
                # Opponent acts based on its own state
                o_state = env._get_state('opponent')
                o_action = self.agents['opponent'].act(o_state)
                next_o_state = env._get_state('opponent')
                
                # Store experiences
                self.agents['player'].remember(
                    state, p_action, rewards['player_reward'], 
                    next_state, done or truncated
                )
                self.agents['opponent'].remember(
                    o_state, o_action, rewards['opponent_reward'],
                    next_o_state, done or truncated
                )
                
                # Train both agents
                p_loss = self.agents['player'].replay()
                o_loss = self.agents['opponent'].replay()
                
                # Update tracking
                total_rewards['player'] += rewards['player_reward']
                total_rewards['opponent'] += rewards['opponent_reward']
                state = next_state
                
                if done or truncated:
                    break
                
            # Post-episode updates
            for agent in self.agents.values():
                agent.decay_epsilon()
                agent.update_target_model()
                
            if ep % 50 == 0:
                print(f"Episode {ep}: Player R={total_rewards['player']:.1f}, Opponent R={total_rewards['opponent']:.1f}")
                self.agents['player'].save(f"player_ep{ep}.pth")
                self.agents['opponent'].save(f"opponent_ep{ep}.pth")


class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=100000)
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.05
        self.epsilon_decay = 0.997
        self.batch_size = 128
        self.model = DQN(state_size, action_size).to(device)
        self.target_model = DQN(state_size, action_size).to(device)
        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0002)
        self.update_freq = 5

    def act(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.action_size-1)
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        return self.model(state).argmax().item()

    def remember(self, *args):
        self.memory.append(args)

    def replay(self):
        if len(self.memory) < self.batch_size:
            return 0
            
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        states = torch.FloatTensor(states).to(device)
        actions = torch.LongTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        next_states = torch.FloatTensor(next_states).to(device)
        dones = torch.FloatTensor(dones).to(device)
        
        curr_q = self.model(states).gather(1, actions.unsqueeze(1))
        next_q = self.target_model(next_states).max(1)[0].detach()
        target = rewards + (1 - dones) * self.gamma * next_q
        
        loss = F.mse_loss(curr_q.squeeze(), target)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
        self.optimizer.step()
        
        return loss.item()

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon*self.epsilon_decay)

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def save(self, path):
        torch.save(self.model.state_dict(), path)

    def load(self, path):
        self.model.load_state_dict(torch.load(path))
        self.target_model.load_state_dict(self.model.state_dict())


# Training setup
if __name__ == "__main__":
    env = FootballEnv(grid_rows=12, grid_cols=16)
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    
    multi_agent = MultiAgentDQN(state_size, action_size)
    multi_agent.train(env, episodes=1000)

  states = torch.FloatTensor(states).to(device)


Episode 0: Player R=-287.8, Opponent R=-298.5
Episode 50: Player R=-293.6, Opponent R=-322.9


KeyboardInterrupt: 

In [3]:
print("Hello")

Hello


  and should_run_async(code)


In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import gymnasium as gym
from collections import deque
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Player:
    def __init__(self, steal_prob=0.3):
        self.pos = (0, 0)
        self.has_ball = False
        self.steal_prob = steal_prob
        self.team = None  # 'A' or 'B' for different teams

    def reset(self, pos, team):
        self.pos = pos
        self.has_ball = False
        self.team = team

    def attempt_steal(self, other):
        if random.random() < self.steal_prob:
            if other.has_ball:
                other.has_ball = False
                self.has_ball = True
                return True
        return False

class FootballEnv(gym.Env):
    def __init__(self, grid_rows=10, grid_cols=10):
        super(FootballEnv, self).__init__()
        self.grid_rows = grid_rows
        self.grid_cols = grid_cols

        self.action_space = gym.spaces.Discrete(10)
        self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(9,), dtype=np.float32)

        # Initialize field layout
        self.layout = np.zeros((grid_rows, grid_cols), dtype=str)
        self._initialize_layout()
        
        # Initialize players and ball
        self.player = Player(steal_prob=0.4)
        self.opponent = Player(steal_prob=0.35)
        self.ball_pos = (grid_rows//2, grid_cols//2)
        self.episode_steps = 0

        # Reward parameters
        self.goal_reward = 50
        self.step_penalty = -0.01
        self.ball_possession_bonus = 0.1
        self.near_ball_bonus = 0.001
        self.concede_penalty = -30

    def _initialize_layout(self):
        self.layout[:, :] = "."
        mid_row, mid_col = self.grid_rows//2, self.grid_cols//2
        
        # Goals and center
        self.layout[mid_row-4:mid_row+5, -1] = "G"
        self.layout[mid_row-4:mid_row+5, 0] = "g"
        self.layout[:, mid_col] = "M"
        
        # Field boundaries
        self.layout[[0, -1], :] = "O"
        self.layout[:, [0, -1]] = "O"
        self.layout[0, mid_col-4:mid_col+5] = "."
        self.layout[-1, mid_col-4:mid_col+5] = "."

    def _get_state(self, agent):
        if agent == 'player':
            player = self.player
            opponent = self.opponent
            goal_x = self.grid_cols - 1
        else:
            player = self.opponent
            opponent = self.player
            goal_x = 0

        return np.array([
            player.pos[0] / (self.grid_rows-1),
            player.pos[1] / (self.grid_cols-1),
            -opponent.pos[0]/(self.grid_rows-1),  # Negative for opponent
            -opponent.pos[1]/(self.grid_cols-1),
            self.ball_pos[0]/(self.grid_rows-1),
            self.ball_pos[1]/(self.grid_cols-1),
            float(player.has_ball),
            goal_x/(self.grid_cols-1),
            (self.grid_rows//2)/(self.grid_rows-1)
        ], dtype=np.float32)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        
        # Initialize positions
        self.player.reset(
            (random.randint(2, self.grid_rows-3), 
            random.randint(2, self.grid_cols//4)),
            'A'
        )
        self.opponent.reset(
            (random.randint(2, self.grid_rows-3), 
            random.randint(3*self.grid_cols//4, self.grid_cols-3)),
            'B'
        )
        
        # Place ball near player
        self.ball_pos = (
            random.randint(max(1, self.player.pos[0]-2), 
            min(self.grid_rows-2, self.player.pos[0]+2)),
            random.randint(max(1, self.player.pos[1]-2),
            min(self.grid_cols//3, self.player.pos[1]+2))
        )
        self.player.has_ball = (self.player.pos == self.ball_pos)
        self.opponent.has_ball = False
        self.episode_steps = 0
        
        return self._get_state('player'), {}

    def _move_agent(self, agent, action):
        dx, dy = 0, 0
        if action < 8:
            moves = [(0,-1),(1,0),(0,1),(-1,0),
                    (-1,-1),(-1,1),(1,-1),(1,1)]
            dx, dy = moves[action]
        
        new_pos = (agent.pos[0]+dx, agent.pos[1]+dy)
        if 0 <= new_pos[0] < self.grid_rows and 0 <= new_pos[1] < self.grid_cols:
            if self.layout[new_pos] != 'O':
                agent.pos = new_pos
                if agent.has_ball:
                    self.ball_pos = new_pos
                return True
        return False

    def step(self, action):
        self.episode_steps += 1
        p_reward = o_reward = self.step_penalty
        done = truncated = False

        # Player action
        player_valid = self._move_agent(self.player, action)
        if not player_valid:
            p_reward -= 2.0

        # Opponent action (random policy for initial training)
        opponent_action = random.randint(0, 9)
        opponent_valid = self._move_agent(self.opponent, opponent_action)
        if not opponent_valid:
            o_reward -= 2.0

        # Ball handling
        ball_carrier = None
        if self.player.has_ball:
            ball_carrier = self.player
        elif self.opponent.has_ball:
            ball_carrier = self.opponent

        # Check ball possession
        if self.player.pos == self.ball_pos and self.opponent.pos == self.ball_pos:
            if self.player.attempt_steal(self.opponent):
                p_reward += 0.5
                o_reward -= 0.5
            elif self.opponent.attempt_steal(self.player):
                o_reward += 0.5
                p_reward -= 0.5
        else:
            self.player.has_ball = (self.player.pos == self.ball_pos)
            self.opponent.has_ball = (self.opponent.pos == self.ball_pos)

        # Shooting/passing
        if action == 8 and self.player.has_ball:  # Long shot
            self.ball_pos = (self.ball_pos[0], min(self.ball_pos[1]+8, self.grid_cols-1))
            self.player.has_ball = False
        elif action == 9 and self.player.has_ball:  # Short pass
            self.ball_pos = (self.ball_pos[0], min(self.ball_pos[1]+4, self.grid_cols-1))
            self.player.has_ball = False

        # Goal scoring
        if self.layout[self.ball_pos] == 'G':
            p_reward += self.goal_reward
            o_reward += self.concede_penalty
            done = True
        elif self.layout[self.ball_pos] == 'g':
            o_reward += self.goal_reward
            p_reward += self.concede_penalty
            done = True

        # Additional rewards
        if self.player.has_ball:
            p_reward += self.ball_possession_bonus
            p_reward += (self.player.pos[1] / self.grid_cols) * 0.01  # Forward progress
        if self.opponent.has_ball:
            o_reward += self.ball_possession_bonus

        # Distance-based rewards
        p_dist = abs(self.player.pos[0]-self.ball_pos[0]) + abs(self.player.pos[1]-self.ball_pos[1])
        o_dist = abs(self.opponent.pos[0]-self.ball_pos[0]) + abs(self.opponent.pos[1]-self.ball_pos[1])
        p_reward += self.near_ball_bonus / (p_dist + 1)
        o_reward += self.near_ball_bonus / (o_dist + 1)

        truncated = self.episode_steps >= 1500
        return (
            self._get_state('player'), 
            {'player_reward': p_reward, 'opponent_reward': o_reward},
            done,
            truncated,
            {}
        )

    def render(self):
        grid = np.full((self.grid_rows, self.grid_cols), ' ')
        
        # Draw field
        for i in range(self.grid_rows):
            for j in range(self.grid_cols):
                if self.layout[i,j] == 'O': grid[i,j] = '#'
                elif self.layout[i,j] in ['G','g']: grid[i,j] = '|'
        
        # Draw actors
        grid[self.ball_pos] = 'o'
        grid[self.player.pos] = 'P' if self.player.has_ball else 'p'
        grid[self.opponent.pos] = 'O' if self.opponent.has_ball else 'o'
        
        # Print
        print('\n' + '='*30)
        for row in grid:
            print('|' + ''.join(row) + '|')
        print(f"Steps: {self.episode_steps}")


class DQN(nn.Module):
    def __init__(self, input_dim, action_size):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, action_size))
        
        for layer in self.net:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
                nn.init.zeros_(layer.bias)

    def forward(self, x):
        return self.net(x)


class MultiAgentDQN:
    def __init__(self, state_size, action_size):
        self.agents = {
            'player': DQNAgent(state_size, action_size),
            'opponent': DQNAgent(state_size, action_size)
        }
        self.train_scores = {'player': [], 'opponent': []}
        
    def train(self, env, episodes, render_every=100):
        for ep in range(episodes):
            state, _ = env.reset()
            total_rewards = {'player': 0, 'opponent': 0}
            
            while True:
                # Get actions from both agents
                p_action = self.agents['player'].act(state)
                next_state, rewards, done, truncated, _ = env.step(p_action)
                
                # Opponent acts based on its own state
                o_state = env._get_state('opponent')
                o_action = self.agents['opponent'].act(o_state)
                next_o_state = env._get_state('opponent')
                
                # Store experiences
                self.agents['player'].remember(
                    state, p_action, rewards['player_reward'], 
                    next_state, done or truncated
                )
                self.agents['opponent'].remember(
                    o_state, o_action, rewards['opponent_reward'],
                    next_o_state, done or truncated
                )
                
                # Train both agents
                p_loss = self.agents['player'].replay()
                o_loss = self.agents['opponent'].replay()
                
                # Update tracking
                total_rewards['player'] += rewards['player_reward']
                total_rewards['opponent'] += rewards['opponent_reward']
                state = next_state
                
                # Render if requested
                # if ep % render_every == 0:
                #     env.render()
                #     time.sleep(0.1)  # Slow down rendering
                
                if done or truncated:
                    break
                
            # Post-episode updates
            for agent in self.agents.values():
                agent.decay_epsilon()
                agent.update_target_model()
                
            # Track scores
            self.train_scores['player'].append(total_rewards['player'])
            self.train_scores['opponent'].append(total_rewards['opponent'])
                
            if ep % 10 == 0:
                print(f"Episode {ep}: Player R={total_rewards['player']:.1f}, Opponent R={total_rewards['opponent']:.1f}")
                self.agents['player'].save(f"player_ep{ep}.pth")
                self.agents['opponent'].save(f"opponent_ep{ep}.pth")

    def evaluate(self, env, episodes=5, render=True):
        """
        Evaluate trained agents without exploration
        """
        total_rewards = {'player': 0, 'opponent': 0}
        goals_scored = {'player': 0, 'opponent': 0}
        
        for ep in range(episodes):
            state, _ = env.reset()
            ep_rewards = {'player': 0, 'opponent': 0}
            done = truncated = False
            
            print(f"\n==== EVALUATION EPISODE {ep+1} ====")
            
            step = 0
            while not (done or truncated):
                # Get action without exploration
                p_action = self.agents['player'].act_eval(state)
                next_state, rewards, done, truncated, _ = env.step(p_action)
                
                # Update rewards
                ep_rewards['player'] += rewards['player_reward']
                ep_rewards['opponent'] += rewards['opponent_reward']
                state = next_state
                
                # Render
                if render and step % 5 == 0:  # Render every 5 steps to speed up visualization
                    env.render()
                    time.sleep(0.2)  # Slow down rendering
                
                step += 1
            
            # Track goals
            if done:
                if ep_rewards['player'] > 0:
                    goals_scored['player'] += 1
                else:
                    goals_scored['opponent'] += 1
            
            # Final render
            if render:
                env.render()
            
            # Track total rewards
            total_rewards['player'] += ep_rewards['player']
            total_rewards['opponent'] += ep_rewards['opponent']
            
            print(f"Episode {ep+1} Results:")
            print(f"  Player reward: {ep_rewards['player']:.1f}")
            print(f"  Opponent reward: {ep_rewards['opponent']:.1f}")
            print(f"  Steps: {step}")
            
        # Summary
        avg_p_reward = total_rewards['player'] / episodes
        avg_o_reward = total_rewards['opponent'] / episodes
        
        print("\n==== EVALUATION SUMMARY ====")
        print(f"Episodes: {episodes}")
        print(f"Player avg reward: {avg_p_reward:.1f}")
        print(f"Opponent avg reward: {avg_o_reward:.1f}")
        print(f"Player goals: {goals_scored['player']}")
        print(f"Opponent goals: {goals_scored['opponent']}")
        
        return avg_p_reward, avg_o_reward, goals_scored


class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=100000)
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.05
        self.epsilon_decay = 0.997
        self.batch_size = 128
        self.model = DQN(state_size, action_size).to(device)
        self.target_model = DQN(state_size, action_size).to(device)
        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0002)
        self.update_freq = 5

    def act(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.action_size-1)
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        return self.model(state).argmax().item()
    
    def act_eval(self, state):
        """Act without exploration for evaluation"""
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        return self.model(state).argmax().item()

    def remember(self, *args):
        self.memory.append(args)

    def replay(self):
        if len(self.memory) < self.batch_size:
            return 0
        
        # Sample batch
        batch = random.sample(self.memory, self.batch_size)
        
        # Convert batch of tuples to tuple of lists
        states, actions, rewards, next_states, dones = zip(*batch)
        
        # Convert lists to numpy arrays first (fixing the warning)
        states = np.array(states)
        actions = np.array(actions)
        rewards = np.array(rewards)
        next_states = np.array(next_states)
        dones = np.array(dones)
        
        # Convert numpy arrays to tensors
        states = torch.FloatTensor(states).to(device)
        actions = torch.LongTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        next_states = torch.FloatTensor(next_states).to(device)
        dones = torch.FloatTensor(dones).to(device)
        
        curr_q = self.model(states).gather(1, actions.unsqueeze(1))
        next_q = self.target_model(next_states).max(1)[0].detach()
        target = rewards + (1 - dones) * self.gamma * next_q
        
        loss = F.mse_loss(curr_q.squeeze(), target)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
        self.optimizer.step()
        
        return loss.item()

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon*self.epsilon_decay)

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def save(self, path):
        torch.save(self.model.state_dict(), path)

    def load(self, path):
        self.model.load_state_dict(torch.load(path))
        self.target_model.load_state_dict(self.model.state_dict())


# Training setup with evaluation
if __name__ == "__main__":
    env = FootballEnv(grid_rows=12, grid_cols=16)
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    
    multi_agent = MultiAgentDQN(state_size, action_size)
    
    # Training with rendering every 100 episodes
    multi_agent.train(env, episodes=1000, render_every=1001)
    
    # Final evaluation after training
    print("\nRunning final evaluation...")
    multi_agent.evaluate(env, episodes=1, render=True)
    
    # Plot learning curves
    try:
        import matplotlib.pyplot as plt
        
        # Moving average to smooth the curves
        def moving_average(data, window_size=10):
            return np.convolve(data, np.ones(window_size)/window_size, mode='valid')
        
        # Plot rewards
        plt.figure(figsize=(12, 6))
        
        # Get moving averages
        window = 20
        if len(multi_agent.train_scores['player']) > window:
            p_ma = moving_average(multi_agent.train_scores['player'], window)
            o_ma = moving_average(multi_agent.train_scores['opponent'], window)
            
            plt.plot(range(window-1, len(p_ma)+window-1), p_ma, label='Player')
            plt.plot(range(window-1, len(o_ma)+window-1), o_ma, label='Opponent')
            plt.xlabel('Episode')
            plt.ylabel('Reward')
            plt.title('Training Rewards (Moving Average)')
            plt.legend()
            plt.grid(True)
            plt.savefig('football_training_rewards.png')
            plt.show()
            
    except ImportError:
        print("Matplotlib not available. Skipping reward plots.")

Episode 0: Player R=-258.2, Opponent R=-182.8
Episode 10: Player R=-269.6, Opponent R=-335.5
Episode 20: Player R=-270.5, Opponent R=-192.0
Episode 30: Player R=-255.3, Opponent R=-302.7
Episode 40: Player R=-233.0, Opponent R=-230.7
Episode 50: Player R=-179.3, Opponent R=-260.2
Episode 60: Player R=-299.2, Opponent R=-248.7
Episode 70: Player R=-328.1, Opponent R=-312.9
Episode 80: Player R=-363.1, Opponent R=-270.8
Episode 90: Player R=-236.7, Opponent R=-404.9
Episode 100: Player R=-355.9, Opponent R=-310.8
Episode 110: Player R=-280.7, Opponent R=-237.2
Episode 120: Player R=-330.7, Opponent R=-382.8
Episode 130: Player R=-402.1, Opponent R=-280.7
Episode 140: Player R=-269.7, Opponent R=-262.9
Episode 150: Player R=-303.7, Opponent R=-264.8
Episode 160: Player R=-272.6, Opponent R=-262.9
Episode 170: Player R=-440.7, Opponent R=-236.8
Episode 180: Player R=-262.9, Opponent R=-382.8
Episode 190: Player R=-269.0, Opponent R=-268.8
Episode 200: Player R=-275.5, Opponent R=-296.9
Epi

KeyboardInterrupt: 