In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LSTMSnakeNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMSnakeNet, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.lstm(x, hidden)
        out = self.fc(out[:, -1, :])
        return out, hidden

    def init_hidden(self, batch_size):
        return (torch.zeros(1, batch_size, self.hidden_size).to(device),
                torch.zeros(1, batch_size, self.hidden_size).to(device))

class RNNSnakeNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNSnakeNet, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out[:, -1, :])
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size).to(device)

class GRUSnakeNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUSnakeNet, self).__init__()
        self.hidden_size = hidden_size
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.gru(x, hidden)
        out = self.fc(out[:, -1, :])
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size).to(device)


In [4]:
class DQNAgent:
    def __init__(self, state_size, action_size, hidden_size, model_type='LSTM'):
        self.state_size = state_size
        self.action_size = action_size
        self.hidden_size = hidden_size
        self.memory = deque(maxlen=2000)
        self.gamma = 0.95
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01
        self.learning_rate = 0.001
        
        if model_type == 'RNN':
            self.model = RNNSnakeNet(state_size, hidden_size, action_size).to(device)
        elif model_type == 'GRU':
            self.model = GRUSnakeNet(state_size, hidden_size, action_size).to(device)
        else:  # Default to LSTM
            self.model = LSTMSnakeNet(state_size, hidden_size, action_size).to(device)

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.criterion = nn.MSELoss()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state, hidden):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size), hidden
        state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(device)
        action_values, hidden = self.model(state, hidden)
        return torch.argmax(action_values[0]).item(), hidden

    def replay(self, batch_size):
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(device)
            next_state = torch.FloatTensor(next_state).unsqueeze(0).unsqueeze(0).to(device)
            target = reward
            if not done:
                next_hidden = self.model.init_hidden(1)
                target = (reward + self.gamma * torch.max(self.model(next_state, next_hidden)[0][0])).item()
            target_f = self.model(state, self.model.init_hidden(1))[0]
            target_f[0][action] = target
            self.optimizer.zero_grad()
            loss = self.criterion(target_f, self.model(state, self.model.init_hidden(1))[0])
            loss.backward()
            self.optimizer.step()
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay


In [3]:
import pygame
import random
import numpy as np

class SnakeGameAI:
    def __init__(self, width=640, height=480, snake_size=20):
        self.width = width
        self.height = height
        self.snake_size = snake_size
        self.reset()

    def reset(self):
        self.x1 = self.width / 2
        self.y1 = self.height / 2
        self.x1_change = 0
        self.y1_change = 0
        self.snake_List = []
        self.Length_of_snake = 1
        self.foodx = round(random.randrange(0, self.width - self.snake_size) / self.snake_size) * self.snake_size
        self.foody = round(random.randrange(0, self.height - self.snake_size) / self.snake_size) * self.snake_size
        self.score = 0
        return self.get_state()

    def step(self, action):
        if action == 0:  # LEFT
            self.x1_change = -self.snake_size
            self.y1_change = 0
        elif action == 1:  # RIGHT
            self.x1_change = self.snake_size
            self.y1_change = 0
        elif action == 2:  # UP
            self.y1_change = -self.snake_size
            self.x1_change = 0
        elif action == 3:  # DOWN
            self.y1_change = self.snake_size
            self.x1_change = 0

        self.x1 += self.x1_change
        self.y1 += self.y1_change

        done = self.x1 >= self.width or self.x1 < 0 or self.y1 >= self.height or self.y1 < 0
        snake_Head = [self.x1, self.y1]
        self.snake_List.append(snake_Head)
        if len(self.snake_List) > self.Length_of_snake:
            del self.snake_List[0]

        for x in self.snake_List[:-1]:
            if x == snake_Head:
                done = True

        reward = 0
        if self.x1 == self.foodx and self.y1 == self.foody:
            self.foodx = round(random.randrange(0, self.width - self.snake_size) / self.snake_size) * self.snake_size
            self.foody = round(random.randrange(0, self.height - self.snake_size) / self.snake_size) * self.snake_size
            self.Length_of_snake += 1
            reward = 10
            self.score += 1

        if done:
            reward = -10

        state = self.get_state()
        return state, reward, done

    def get_state(self):
        state = [
            self.x1, self.y1, self.foodx, self.foody,
            self.x1_change, self.y1_change,
        ]
        return np.array(state, dtype=np.float32)

    def render(self):
        window = pygame.display.set_mode((self.width, self.height))
        window.fill((0, 0, 0))
        for segment in self.snake_List:
            pygame.draw.rect(window, (0, 255, 0), [segment[0], segment[1], self.snake_size, self.snake_size])
        pygame.draw.rect(window, (255, 255, 255), [self.foodx, self.foody, self.snake_size, self.snake_size])
        pygame.display.update()


pygame 2.5.2 (SDL 2.28.3, Python 3.10.13)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [1]:
def train_until_achieve_score(model_type, target_score=10):
    state_size = 6
    hidden_size = 64
    action_size = 4
    batch_size = 32

    agent = DQNAgent(state_size, action_size, hidden_size, model_type=model_type)
    game = SnakeGameAI()
    episode = 0

    while True:
        episode += 1
        state = game.reset()
        hidden = agent.model.init_hidden(1)
        for time in range(500):
            action, hidden = agent.act(state, hidden)
            next_state, reward, done = game.step(action)
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            if done:
                print(f"Episode: {episode}, Score: {game.score}, Epsilon: {agent.epsilon:.2}")
                break
            if len(agent.memory) > batch_size:
                agent.replay(batch_size)
        if game.score >= target_score:
            print(f"Model {model_type} achieved a score of {target_score} in episode {episode}")
            break

    torch.save(agent.model.state_dict(), f"{model_type.lower()}_snake_model.pth")

if __name__ == "__main__":
    for model_type in ['RNN', 'GRU', 'LSTM']:
        train_until_achieve_score(model_type)


In [7]:
def load_model(model_type, model_path, state_size, hidden_size, action_size):
    if model_type == 'RNN':
        model = RNNSnakeNet(state_size, hidden_size, action_size).to(device)
    elif model_type == 'GRU':
        model = GRUSnakeNet(state_size, hidden_size, action_size).to(device)
    else:
        model = LSTMSnakeNet(state_size, hidden_size, action_size).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval() 
    return model

def play_game(model, game, display=True):
    state = game.reset()
    hidden = model.init_hidden(1)
    done = False
    clock = pygame.time.Clock()
    
    while not done:
        state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(device)
        with torch.no_grad():
            action_values, hidden = model(state, hidden)
        action = torch.argmax(action_values[0]).item()
        
        next_state, _, done = game.step(action)
        state = next_state
        
        if display:
            game.render()
            clock.tick(10)

    print(f"Final Score: {game.score}")

if __name__ == "__main__":
    model_type = 'RNN'
    model_path = "rnn/rnn_snake_model_10500.pth"

    state_size = 6
    hidden_size = 64
    action_size = 4
    model = load_model(model_type, model_path, state_size, hidden_size, action_size)

    game = SnakeGameAI()

    play_game(model, game)


Final Score: 0
