In [11]:
import pygame
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import os

In [12]:
class Linear_QNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size,hidden_size)
        self.linear2 = nn.Linear(hidden_size,output_size)

    def forward(self,x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [13]:
class Agent:
    def __init__(self):
        self.DQ_Network = Linear_QNet(5,128,3)
        self.Target_Network = Linear_QNet(5,128,3)
        self.lr = 0.001
        self.optimizer = optim.Adam(self.DQ_Network.parameters(), lr = self.lr)
        self.loss = nn.MSELoss()
        self.memory_size = 100000
        self.memory = deque(maxlen=self.memory_size)
        self.batch_size = 100
        self.gamma = 0.9
        self.n_games = 0
        self.epsilon = 100

    def remember(self,state,action,reward,next_state,done):
        self.memory.append((state,action,reward,next_state,done))

    def train_DQ_Network(self):
        train_size = min(len(self.memory),self.batch_size)

        batch = random.sample(self.memory,train_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.tensor(states,dtype=torch.float)
        actions = torch.tensor(actions,dtype=torch.long)
        rewards = torch.tensor(rewards,dtype=torch.float)
        next_states = torch.tensor(next_states,dtype=torch.float)

        pred = self.DQ_Network(states)

        next_pred = self.Target_Network(next_states).detach()
        max_next_pred = torch.max(next_pred, dim=1)[0]

        targets = pred.clone()
        for i in range(train_size):
            Q_value = rewards[i].item()
            if not dones[i]:
                Q_value = rewards[i].item()+self.gamma*max_next_pred[i].item()
            targets[i][torch.argmax(actions[i]).item()] = Q_value

        self.optimizer.zero_grad()
        loss = self.loss(targets,pred)
        loss.backward()
        self.optimizer.step()

        

    def train_Target_Network(self):
        self.Target_Network.load_state_dict(self.DQ_Network.state_dict())

    def get_action(self,state):
        action = [0,0,0]
        self.epsilon = max(0,100-self.n_games)
        if random.randint(1,100) <= self.epsilon:
            #Exploration
            idx = random.randint(0,2)
            action[idx] = 1
        else:
            #Exploitation   
            state = torch.tensor(state,dtype=torch.float)
            pred = self.DQ_Network(state)
            idx = torch.argmax(pred).item()
            action[idx] = 1

        return action

In [14]:
def save_model(model,hs):
    model_path = './models'
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    file_name = os.path.join(model_path, "DQL_model_"+str(hs)+".pth")
    torch.save(model.state_dict(), file_name)

In [15]:
pygame.init()

width, height = 500, 600
window = pygame.display.set_mode((width,height))
pygame.display.set_caption("Catch the Eggs!!")

#Colors
yellow = (225,225,0)
black = (0,0,0)
white = (255,255,255)
green = (0,255,0)
red = (255,0,0)
blue = (0,0,255)

#Basket
basket_width = 30
basket_height = 50
basket_x = width//2-basket_width//2
basket_y = height-basket_height-10

basket_speed = 10

#Eggs
egg_width = 23
egg_height = 35
egg_speed = 5

def create_egg():
    egg_x = random.randint(10,width-egg_width-10)
    egg_y = 0
    return pygame.Rect(egg_x,egg_y,egg_width,egg_height)

eggs = [create_egg()]

basket = pygame.Rect(basket_x, basket_y, basket_width, basket_height)

score = 0
total_reward = 0
highest_score = 0
action = [0,1,0]
font = pygame.font.Font(None, 36)
over_font = pygame.font.Font(None, 72)

def reset_game():
    global basket, eggs, score, total_reward
    score = 0
    total_reward = 0
    basket = pygame.Rect(basket_x, basket_y, basket_width, basket_height)
    eggs = [create_egg()]

def get_state(action):
    global basket, eggs
    egg = eggs[0]
    state = [
        action[0], action[1], action[2],
        egg.left<=basket.left,
        egg.left>basket.left
    ]
    return np.array(state,dtype=int)

def perform_action(action):
    global basket, eggs, score, basket_speed, egg_speed

    if action[0]==1 and basket.left > 0:
        basket.x -= basket_speed
    if action[2]==1 and basket.right < width:
        basket.x += basket_speed

    reward = 0
    game_over = False

    for egg in eggs:
        egg.y += egg_speed
        if basket.colliderect(egg):
            eggs[0] = create_egg()
            score += 1
            reward = 10
        if egg.y > height:
            reward = -10
            game_over = True

    return reward, game_over
    
agent = Agent()
clock = pygame.time.Clock()
running  = True
while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
            
    state = get_state(action)
    action = agent.get_action(state)
    reward, game_over = perform_action(action)

    total_reward+=reward

    if game_over:
        agent.n_games+=1
        if score > highest_score:
            highest_score = score
            save_model(agent.DQ_Network,highest_score)

        print("In Game:",agent.n_games,", Score is:",score,", Total Reward is:", total_reward,", Highest Score till now:",highest_score)
        
        reset_game()

    if score >= 100:
        highest_score = score
        save_model(agent.DQ_Network,highest_score)
        running = False

    next_state = get_state(action)

    agent.remember(state,action,reward,next_state,game_over)

    agent.train_DQ_Network()

    if agent.n_games%10==0:
        agent.train_Target_Network()

    window.fill(black)
    pygame.draw.rect(window, yellow, basket)

    for egg in eggs:
        pygame.draw.ellipse(window, white, egg)

    score_text = font.render(f"Score: {score}", True, green)
    window.blit(score_text, (width-150, 10))

    pygame.display.flip()
    clock.tick(30)

window.fill(black)
game_over_text = over_font.render(f'Game Over', True, red) 
final_score_text = over_font.render(f'Final Score: {score}', True, blue)
window.blit(game_over_text, (width//2-game_over_text.get_width()//2, height//2-game_over_text.get_height()//2-50))
window.blit(final_score_text, (width//2-final_score_text.get_width()//2, height//2-final_score_text.get_height()//2))
pygame.display.flip()
pygame.time.wait(5000)

pygame.quit()

In Game: 1 , Score is: 1 , Total Reward is: 0 , Highest Score till now: 1
In Game: 2 , Score is: 0 , Total Reward is: -10 , Highest Score till now: 1
In Game: 3 , Score is: 0 , Total Reward is: -10 , Highest Score till now: 1
In Game: 4 , Score is: 0 , Total Reward is: -10 , Highest Score till now: 1
In Game: 5 , Score is: 0 , Total Reward is: -10 , Highest Score till now: 1
In Game: 6 , Score is: 0 , Total Reward is: -10 , Highest Score till now: 1
In Game: 7 , Score is: 1 , Total Reward is: 0 , Highest Score till now: 1
In Game: 8 , Score is: 0 , Total Reward is: -10 , Highest Score till now: 1
In Game: 9 , Score is: 2 , Total Reward is: 10 , Highest Score till now: 2
In Game: 10 , Score is: 0 , Total Reward is: -10 , Highest Score till now: 2
In Game: 11 , Score is: 0 , Total Reward is: -10 , Highest Score till now: 2
In Game: 12 , Score is: 0 , Total Reward is: -10 , Highest Score till now: 2
In Game: 13 , Score is: 0 , Total Reward is: -10 , Highest Score till now: 2
In Game: 14 ,