In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import cv2
import time
import keyboard
from collections import deque
from image_processing.screen_capture import screen_capture
from image_processing.player_hp_detector import L_player_HP, R_player_HP
from characters.ken import Ken

In [None]:
# Define the DQN model architecture - using convolutional layers for image input
class DQN(nn.Module):
    def __init__(self, input_channels, input_height, input_width, num_actions):
        super(DQN, self).__init__()
        # Convolutional layers for processing game screen images
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        
        # Calculate the size of features after convolutions
        conv_output_size = self._get_conv_output_size(input_channels, input_height, input_width)
        
        # Fully connected layers
        self.fc1 = nn.Linear(conv_output_size, 512)
        self.fc2 = nn.Linear(512, num_actions)
        
    def _get_conv_output_size(self, input_channels, height, width):
        # Helper function to calculate output size after convolutions
        dummy_input = torch.zeros(1, input_channels, height, width)
        x = F.relu(self.conv1(dummy_input))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        return int(np.prod(x.size()))
        
    def forward(self, x):
        # x shape: (batch, channels, height, width)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        return self.fc2(x)  # No activation on output layer (for Q-values)

In [None]:

# Experience Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return np.array(states), actions, rewards, np.array(next_states), dones
    
    def __len__(self):
        return len(self.buffer)

In [None]:

# DQN Agent
class DQNAgent:
    def __init__(self, input_channels, input_height, input_width, num_actions, 
                 lr=0.0001, gamma=0.99, epsilon=1.0, epsilon_min=0.1, 
                 epsilon_decay=0.995, buffer_size=10000, batch_size=32):
        self.num_actions = num_actions
        self.gamma = gamma  # discount factor
        self.epsilon = epsilon  # exploration rate
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size
        
        # Q networks
        self.policy_net = DQN(input_channels, input_height, input_width, num_actions)
        self.target_net = DQN(input_channels, input_height, input_width, num_actions)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.memory = ReplayBuffer(buffer_size)
        
        # Set device (GPU if available)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy_net.to(self.device)
        self.target_net.to(self.device)
        
    def select_action(self, state):
        # Epsilon-greedy action selection
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.num_actions)
        
        # Convert state to PyTorch tensor
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            q_values = self.policy_net(state)
        return q_values.max(1)[1].item()
    
    def train(self):
        if len(self.memory) < self.batch_size:
            return
        
        states, actions, rewards, next_states, dones = self.memory.sample(self.batch_size)
        
        # Convert to PyTorch tensors
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)
        
        # Compute Q(s_t, a)
        q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
        
        # Compute V(s_{t+1})
        next_q_values = self.target_net(next_states).max(1)[0].detach()
        
        # Compute expected Q values
        expected_q_values = rewards + (self.gamma * next_q_values * (1 - dones))
        
        # Compute loss
        loss = F.smooth_l1_loss(q_values.squeeze(), expected_q_values)
        
        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        # Clip gradients (optional)
        for param in self.policy_net.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()
        
        # Update epsilon
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
    
    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
    def save(self, path):
        torch.save({
            'policy_net': self.policy_net.state_dict(),
            'target_net': self.target_net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'epsilon': self.epsilon
        }, path)
        
    def load(self, path):
        checkpoint = torch.load(path)
        self.policy_net.load_state_dict(checkpoint['policy_net'])
        self.target_net.load_state_dict(checkpoint['target_net'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.epsilon = checkpoint['epsilon']

In [None]:

# Street Fighter Environment wrapper
class StreetFighterEnv:
    def __init__(self, screen_region=(0, 0, 1280, 720), attack_mode='modern'):
        self.screen_region = screen_region
        self.character = Ken(attack_mode=attack_mode)
        self.last_left_hp = 100
        self.last_right_hp = 100
        self.step_count = 0
        self.max_steps = 1000  # Maximum steps per episode
        
        # Movement states
        self.moving_left = False
        self.moving_right = False
        self.crouching = False
        
        # Define actions (map integers to game actions)
        self.actions = [
            self._no_op,          # 0: Do nothing
            self.character.light,  # 1: Light attack
            self.character.medium, # 2: Medium attack
            self.character.heavy,  # 3: Heavy attack
            self.character.hadouken, # 4: Hadouken
            self.character.shoryuken, # 5: Shoryuken
            self.character.dragonlash_kick, # 6: Dragonlash kick
            self._defense,         # 7: Defense/parry
            self._burst,           # 8: Power burst
            self._move_left,       # 9: Move left
            self._move_right,      # 10: Move right
            self._crouch,          # 11: Crouch
            self._stop_movement    # 12: Stop all movement
        ]
        
    def _no_op(self):
        # Do nothing action
        time.sleep(0.1)
        
    def _defense(self):
        if self.character.attack_mode == 'modern':
            self.character.impl.drive_parry()
        else:
            self.character.impl.parry()
            
    def _burst(self):
        if self.character.attack_mode == 'modern':
            self.character.impl.drive_impact()
        else:
            self.character.impl.burst()
    
    def _move_left(self):
        # First stop any existing movement
        self._stop_movement()
        # Then start moving left
        self.character.impl.hold_left()
        self.moving_left = True
        time.sleep(0.2)  # Brief movement
        
    def _move_right(self):
        # First stop any existing movement
        self._stop_movement()
        # Then start moving right
        self.character.impl.hold_right()
        self.moving_right = True
        time.sleep(0.2)  # Brief movement
        
    def _crouch(self):
        # First stop any existing movement
        self._stop_movement()
        # Then start crouching
        self.character.impl.hold_crouch()
        self.crouching = True
        time.sleep(0.2)  # Brief crouch
    
    def _stop_movement(self):
        # Release all movement controls
        if self.moving_left:
            self.character.impl.release_left()
            self.moving_left = False
            
        if self.moving_right:
            self.character.impl.release_right()
            self.moving_right = False
            
        if self.crouching:
            self.character.impl.release_crouch()
            self.crouching = False
    
    def reset(self):
        # Reset the game (you might need to implement a way to restart the match)
        # For now, we'll just wait a bit and assume the game is reset
        self._stop_movement()  # Make sure to release all movement keys
        time.sleep(2)
        self.step_count = 0
        self.last_left_hp = 100
        self.last_right_hp = 100
        
        # Get initial state
        return self._get_state()
    
    def _get_state(self):
        # Capture game screen
        screen = screen_capture(self.screen_region)
        
        # Process the image for the neural network
        # Convert to grayscale to reduce complexity
        gray = cv2.cvtColor(screen, cv2.COLOR_BGR2GRAY)
        
        # Stack 4 frames for temporal information (optional, useful for detecting motion)
        # For simplicity, we're using just one frame here
        
        # Resize to manageable size
        resized = cv2.resize(gray, (84, 84))
        
        # Normalize pixel values
        normalized = resized / 255.0
        
        # Add channel dimension (DQN expects inputs in format [batch, channels, height, width])
        state = np.expand_dims(normalized, axis=0)
        
        return state
    
    def step(self, action_idx):
        # Execute the selected action
        if 0 <= action_idx < len(self.actions):
            self.actions[action_idx]()
        
        # Small delay to let the game process the action
        time.sleep(0.05)
        
        # Get new state
        next_state = self._get_state()
        
        # Calculate reward
        reward = self._calculate_reward()
        
        # Check if episode is done
        self.step_count += 1
        done = self._is_done()
        
        # Return step information
        return next_state, reward, done, {}
    
    def _calculate_reward(self):
        # Get current HP values
        player_HP_area = screen_capture((115, 85, 1165, 100))
        current_left_hp = L_player_HP(player_HP_area)
        current_right_hp = L_player_HP(player_HP_area)
        
        # Calculate HP changes
        left_hp_change = current_left_hp - self.last_left_hp
        right_hp_change = current_right_hp - self.last_right_hp
        
        # Assuming player is on the left
        reward = 0
        
        # Reward for dealing damage to opponent
        if right_hp_change < 0:
            reward += abs(right_hp_change) * 1.0  # More reward for dealing damage
        
        # Penalty for taking damage
        if left_hp_change < 0:
            reward -= abs(left_hp_change) * 0.8
            
        # Small penalty for time passing (encourages faster action)
        reward -= 0.1
        
        # Update last HP values
        self.last_left_hp = current_left_hp
        self.last_right_hp = current_right_hp
        
        return reward
    
    def _is_done(self):
        # Episode ends if either player reaches 0 HP or max steps reached
        player_HP_area = screen_capture((115, 85, 1165, 100))
        left_hp = L_player_HP(player_HP_area)
        right_hp = R_player_HP(player_HP_area)
        
        return left_hp <= 0 or right_hp <= 0 or self.step_count >= self.max_steps
    
    def close(self):
        # Make sure to release all keys when closing the environment
        self._stop_movement()

In [None]:

# Training function
def train_sf_agent(episodes=1000, target_update=10, save_interval=50):
    # Define environment and agent parameters
    env = StreetFighterEnv(screen_region=(0, 0, 1280, 720), attack_mode='modern')
    input_channels = 1  # Grayscale image has 1 channel
    input_height = 84
    input_width = 84
    num_actions = len(env.actions)
    
    agent = DQNAgent(
        input_channels=input_channels,
        input_height=input_height,
        input_width=input_width,
        num_actions=num_actions,
        lr=0.0001,
        gamma=0.99,
        epsilon=1.0,
        epsilon_min=0.1,
        epsilon_decay=0.995,
        buffer_size=10000,
        batch_size=32
    )
    
    # Optional: Load pretrained model if available
    try:
        agent.load("street_fighter_dqn_latest.pth")
        print("Loaded pretrained model")
    except:
        print("Training new model from scratch")
    
    try:
        # Training loop
        for episode in range(episodes):
            state = env.reset()
            total_reward = 0
            done = False
            
            # Run one episode
            while not done:
                # Check for manual interrupt
                if keyboard.is_pressed('q'):
                    print("\nTraining manually interrupted")
                    env.close()
                    return agent
                    
                # Select and perform action
                action = agent.select_action(state)
                next_state, reward, done, _ = env.step(action)
                
                # Store transition in replay buffer
                agent.memory.push(state, action, reward, next_state, done)
                
                # Train the network
                agent.train()
                
                # Move to next state
                state = next_state
                total_reward += reward
                
            # Update target network periodically
            if episode % target_update == 0:
                agent.update_target_network()
                
            # Print episode stats
            print(f"Episode: {episode}, Total Reward: {total_reward:.2f}, Epsilon: {agent.epsilon:.2f}")
            
            # Save checkpoint
            if episode % save_interval == 0:
                agent.save(f"street_fighter_dqn_episode_{episode}.pth")
                agent.save("street_fighter_dqn_latest.pth")  # Always save latest
    
    finally:
        # Make sure to release all keys when training ends
        env.close()
    
    return agent

In [None]:

# Test the trained agent
def test_agent(agent, episodes=5):
    env = StreetFighterEnv(screen_region=(0, 0, 1280, 720), attack_mode='modern')
    
    try:
        for episode in range(episodes):
            state = env.reset()
            total_reward = 0
            done = False
            
            while not done:
                # Always choose best action during testing (no exploration)
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(agent.device)
                with torch.no_grad():
                    q_values = agent.policy_net(state_tensor)
                action = q_values.max(1)[1].item()
                
                next_state, reward, done, _ = env.step(action)
                state = next_state
                total_reward += reward
                
                # Check for manual interrupt
                if keyboard.is_pressed('q'):
                    print("\nTesting manually interrupted")
                    return
            
            print(f"Test Episode {episode}, Total Reward: {total_reward:.2f}")
    
    finally:
        # Make sure to release all keys when testing ends
        env.close()

In [None]:

if __name__ == "__main__":
    print("Starting DQN training for Street Fighter...")
    print("Press 'q' at any time to stop training")
    time.sleep(3)
    # Train agent
    agent = train_sf_agent(episodes=500, target_update=10, save_interval=50)
    
    # Test trained agent
    # print("\nTesting trained agent...")
    # test_agent(agent, episodes=5)