In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import numpy as np
import time
import pygame

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LSTMSnakeNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMSnakeNet, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.lstm(x, hidden)
        out = self.fc(out[:, -1, :])
        return out, hidden

    def init_hidden(self, batch_size):
        return (torch.zeros(1, batch_size, self.hidden_size).to(device),
                torch.zeros(1, batch_size, self.hidden_size).to(device))

class RNNSnakeNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNSnakeNet, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out[:, -1, :])
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size).to(device)

class GRUSnakeNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUSnakeNet, self).__init__()
        self.hidden_size = hidden_size
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.gru(x, hidden)
        out = self.fc(out[:, -1, :])
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size).to(device)

class DQNAgent:
    def __init__(self, state_size, action_size, hidden_size, model_type='LSTM'):
        self.state_size = state_size
        self.action_size = action_size
        self.hidden_size = hidden_size
        self.memory = deque(maxlen=2000)
        self.gamma = 0.95
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01
        self.learning_rate = 0.001
        
        if model_type == 'RNN':
            self.model = RNNSnakeNet(state_size, hidden_size, action_size).to(device)
        elif model_type == 'GRU':
            self.model = GRUSnakeNet(state_size, hidden_size, action_size).to(device)
        else:  # Default to LSTM
            self.model = LSTMSnakeNet(state_size, hidden_size, action_size).to(device)

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.criterion = nn.MSELoss()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state, hidden):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size), hidden
        state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(device)
        action_values, hidden = self.model(state, hidden)
        return torch.argmax(action_values[0]).item(), hidden

    def replay(self, batch_size):
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(device)
            next_state = torch.FloatTensor(next_state).unsqueeze(0).unsqueeze(0).to(device)
            target = reward
            if not done:
                next_hidden = self.model.init_hidden(1)
                target = (reward + self.gamma * torch.max(self.model(next_state, next_hidden)[0][0])).item()
            target_f = self.model(state, self.model.init_hidden(1))[0]
            target_f[0][action] = target
            self.optimizer.zero_grad()
            loss = self.criterion(target_f, self.model(state, self.model.init_hidden(1))[0])
            loss.backward()
            self.optimizer.step()
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

class SnakeGameAI:
    def __init__(self, width=640, height=480, snake_size=20):
        self.width = width
        self.height = height
        self.snake_size = snake_size
        self.reset()

    def reset(self):
        self.x1 = self.width / 2
        self.y1 = self.height / 2
        self.x1_change = 0
        self.y1_change = 0
        self.snake_List = []
        self.Length_of_snake = 1
        self.foodx = round(random.randrange(0, self.width - self.snake_size) / self.snake_size) * self.snake_size
        self.foody = round(random.randrange(0, self.height - self.snake_size) / self.snake_size) * self.snake_size
        self.score = 0
        return self.get_state()

    def step(self, action):
        if action == 0:  # LEFT
            self.x1_change = -self.snake_size
            self.y1_change = 0
        elif action == 1:  # RIGHT
            self.x1_change = self.snake_size
            self.y1_change = 0
        elif action == 2:  # UP
            self.y1_change = -self.snake_size
            self.x1_change = 0
        elif action == 3:  # DOWN
            self.y1_change = self.snake_size
            self.x1_change = 0

        self.x1 += self.x1_change
        self.y1 += self.y1_change

        done = self.x1 >= self.width or self.x1 < 0 or self.y1 >= self.height or self.y1 < 0
        snake_Head = [self.x1, self.y1]
        self.snake_List.append(snake_Head)
        if len(self.snake_List) > self.Length_of_snake:
            del self.snake_List[0]

        for x in self.snake_List[:-1]:
            if x == snake_Head:
                done = True

        reward = 0
        if self.x1 == self.foodx and self.y1 == self.foody:
            self.foodx = round(random.randrange(0, self.width - self.snake_size) / self.snake_size) * self.snake_size
            self.foody = round(random.randrange(0, self.height - self.snake_size) / self.snake_size) * self.snake_size
            self.Length_of_snake += 1
            reward = 10
            self.score += 1

        if done:
            reward = -10

        state = self.get_state()
        return state, reward, done

    def get_state(self):
        state = [
            self.x1, self.y1, self.foodx, self.foody,
            self.x1_change, self.y1_change,
        ]
        return np.array(state, dtype=np.float32)

    def render(self):
        window = pygame.display.set_mode((self.width, self.height))
        window.fill((0, 0, 0))
        for segment in self.snake_List:
            pygame.draw.rect(window, (0, 255, 0), [segment[0], segment[1], self.snake_size, self.snake_size])
        pygame.draw.rect(window, (255, 255, 255), [self.foodx, self.foody, self.snake_size, self.snake_size])
        pygame.display.update()

def train_until_achieve_score(model_type, target_score=10, load_path=None, saved_episode = 0):
    state_size = 6
    hidden_size = 64
    action_size = 4
    batch_size = 32

    agent = DQNAgent(state_size, action_size, hidden_size, model_type=model_type)

    if load_path:
        agent.model.load_state_dict(torch.load(load_path))
        print(f"Loaded model from {load_path}")

    game = SnakeGameAI()
    episode = saved_episode

    start_time = time.time()

    while True:
        episode += 1
        state = game.reset()
        hidden = agent.model.init_hidden(1)
        for time_step in range(500):
            action, hidden = agent.act(state, hidden)
            next_state, reward, done = game.step(action)
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            if done:
                print(f"Episode: {episode}, Score: {game.score}, Epsilon: {agent.epsilon:.2}")
                break
            if len(agent.memory) > batch_size:
                agent.replay(batch_size)
        if episode % 250 == 0:
            torch.save(agent.model.state_dict(), f"{model_type.lower()}_snake_model_{episode}.pth")
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"Saved model at episode {episode}")
            print(f"Training time for {model_type} after {episode} episodes: {elapsed_time:.2f} seconds")
            start_time = time.time()
        if game.score >= target_score:
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"Model {model_type} achieved a score of {target_score} in episode {episode}")
            print(f"Training time for {model_type}: {elapsed_time:.2f} seconds")
            break

    torch.save(agent.model.state_dict(), f"{model_type.lower()}_snake_model_final.pth")

if __name__ == "__main__":
    # 'RNN', 
    load_path = f"{'RNN'.lower()}_snake_model_8000.pth" if torch.cuda.is_available() else None
    train_until_achieve_score('RNN', load_path=load_path, saved_episode = 8000)
    for model_type in ['GRU', 'LSTM']:
        load_path = f"{model_type.lower()}_snake_model_final.pth" if torch.cuda.is_available() else None
        train_until_achieve_score(model_type)


pygame 2.5.2 (SDL 2.28.3, Python 3.10.13)
Hello from the pygame community. https://www.pygame.org/contribute.html
Loaded model from rnn_snake_model_8000.pth
Episode: 8001, Score: 0, Epsilon: 0.91
Episode: 8002, Score: 0, Epsilon: 0.61
Episode: 8003, Score: 0, Epsilon: 0.27
Episode: 8004, Score: 0, Epsilon: 0.18
Episode: 8008, Score: 1, Epsilon: 0.01
Episode: 8009, Score: 0, Epsilon: 0.01
Episode: 8010, Score: 0, Epsilon: 0.01
Episode: 8011, Score: 0, Epsilon: 0.01
Episode: 8012, Score: 0, Epsilon: 0.01
Episode: 8013, Score: 0, Epsilon: 0.01
Episode: 8014, Score: 0, Epsilon: 0.01
Episode: 8015, Score: 0, Epsilon: 0.01
Episode: 8016, Score: 0, Epsilon: 0.01
Episode: 8017, Score: 0, Epsilon: 0.01
Episode: 8018, Score: 0, Epsilon: 0.01
Episode: 8020, Score: 0, Epsilon: 0.01
Episode: 8021, Score: 0, Epsilon: 0.01
Episode: 8022, Score: 0, Epsilon: 0.01
Episode: 8023, Score: 0, Epsilon: 0.01
Episode: 8024, Score: 0, Epsilon: 0.01
Episode: 8025, Score: 0, Epsilon: 0.01
Episode: 8026, Score: 0,

KeyboardInterrupt: 

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import numpy as np
import time
import pygame

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LSTMSnakeNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMSnakeNet, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.lstm(x, hidden)
        out = self.fc(out[:, -1, :])
        return out, hidden

    def init_hidden(self, batch_size):
        return (torch.zeros(1, batch_size, self.hidden_size).to(device),
                torch.zeros(1, batch_size, self.hidden_size).to(device))

class RNNSnakeNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNSnakeNet, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out[:, -1, :])
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size).to(device)

class GRUSnakeNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUSnakeNet, self).__init__()
        self.hidden_size = hidden_size
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.gru(x, hidden)
        out = self.fc(out[:, -1, :])
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size).to(device)

class DQNAgent:
    def __init__(self, state_size, action_size, hidden_size, model_type='LSTM'):
        self.state_size = state_size
        self.action_size = action_size
        self.hidden_size = hidden_size
        self.memory = deque(maxlen=2000)
        self.gamma = 0.95
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01
        self.learning_rate = 0.001
        
        if model_type == 'RNN':
            self.model = RNNSnakeNet(state_size, hidden_size, action_size).to(device)
        elif model_type == 'GRU':
            self.model = GRUSnakeNet(state_size, hidden_size, action_size).to(device)
        else:  # Default to LSTM
            self.model = LSTMSnakeNet(state_size, hidden_size, action_size).to(device)

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.criterion = nn.MSELoss()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state, hidden):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size), hidden
        state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(device)
        action_values, hidden = self.model(state, hidden)
        return torch.argmax(action_values[0]).item(), hidden

    def replay(self, batch_size):
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(device)
            next_state = torch.FloatTensor(next_state).unsqueeze(0).unsqueeze(0).to(device)
            target = reward
            if not done:
                next_hidden = self.model.init_hidden(1)
                target = (reward + self.gamma * torch.max(self.model(next_state, next_hidden)[0][0])).item()
            target_f = self.model(state, self.model.init_hidden(1))[0]
            target_f[0][action] = target
            self.optimizer.zero_grad()
            loss = self.criterion(target_f, self.model(state, self.model.init_hidden(1))[0])
            loss.backward()
            self.optimizer.step()
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

class SnakeGameAI:
    def __init__(self, width=640, height=480, snake_size=20):
        self.width = width
        self.height = height
        self.snake_size = snake_size
        self.reset()

    def reset(self):
        self.x1 = self.width / 2
        self.y1 = self.height / 2
        self.x1_change = 0
        self.y1_change = 0
        self.snake_List = []
        self.Length_of_snake = 1
        self.foodx = round(random.randrange(0, self.width - self.snake_size) / self.snake_size) * self.snake_size
        self.foody = round(random.randrange(0, self.height - self.snake_size) / self.snake_size) * self.snake_size
        self.score = 0
        return self.get_state()

    def step(self, action):
        if action == 0:  # LEFT
            self.x1_change = -self.snake_size
            self.y1_change = 0
        elif action == 1:  # RIGHT
            self.x1_change = self.snake_size
            self.y1_change = 0
        elif action == 2:  # UP
            self.y1_change = -self.snake_size
            self.x1_change = 0
        elif action == 3:  # DOWN
            self.y1_change = self.snake_size
            self.x1_change = 0

        self.x1 += self.x1_change
        self.y1 += self.y1_change

        done = self.x1 >= self.width or self.x1 < 0 or self.y1 >= self.height or self.y1 < 0
        snake_Head = [self.x1, self.y1]
        self.snake_List.append(snake_Head)
        if len(self.snake_List) > self.Length_of_snake:
            del self.snake_List[0]

        for x in self.snake_List[:-1]:
            if x == snake_Head:
                done = True

        reward = 0
        if self.x1 == self.foodx and self.y1 == self.foody:
            self.foodx = round(random.randrange(0, self.width - self.snake_size) / self.snake_size) * self.snake_size
            self.foody = round(random.randrange(0, self.height - self.snake_size) / self.snake_size) * self.snake_size
            self.Length_of_snake += 1
            reward = 10
            self.score += 1

        if done:
            reward = -10

        state = self.get_state()
        return state, reward, done

    def get_state(self):
        state = [
            self.x1, self.y1, self.foodx, self.foody,
            self.x1_change, self.y1_change,
        ]
        return np.array(state, dtype=np.float32)

    def render(self):
        window = pygame.display.set_mode((self.width, self.height))
        window.fill((0, 0, 0))
        for segment in self.snake_List:
            pygame.draw.rect(window, (0, 255, 0), [segment[0], segment[1], self.snake_size, self.snake_size])
        pygame.draw.rect(window, (255, 255, 255), [self.foodx, self.foody, self.snake_size, self.snake_size])

def load_model(model_type, model_path, state_size, hidden_size, action_size):
    if model_type == 'RNN':
        model = RNNSnakeNet(state_size, hidden_size, action_size).to(device)
    elif model_type == 'GRU':
        model = GRUSnakeNet(state_size, hidden_size, action_size).to(device)
    else:  # Default to LSTM
        model = LSTMSnakeNet(state_size, hidden_size, action_size).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model

def draw_text(surface, text, font, color, x, y):
    text_obj = font.render(text, True, color)
    text_rect = text_obj.get_rect(center=(x, y))
    surface.blit(text_obj, text_rect)

def wait_for_click():
    waiting = True
    while waiting:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                return False
            if event.type == pygame.MOUSEBUTTONDOWN:
                return True
        pygame.display.flip()
        pygame.time.Clock().tick(30)
    return False

def play_game(model, game, display=True):
    pygame.init()
    window = pygame.display.set_mode((game.width, game.height)) 
    pygame.display.set_caption("Snake Game AI")
    font = pygame.font.SysFont(None, 55)
    
    window.fill((0, 0, 0))
    draw_text(window, "Click to Start", font, (255, 255, 255), game.width // 2, game.height // 2)
    pygame.display.flip()
    
    if not wait_for_click():
        return

    state = game.reset()
    hidden = model.init_hidden(1)
    done = False
    clock = pygame.time.Clock()

    while not done:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                return

        state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(device)
        with torch.no_grad():
            action_values, hidden = model(state, hidden)
        action = torch.argmax(action_values[0]).item()
        
        next_state, _, done = game.step(action)
        state = next_state
        
        if display:
            game.render()
            pygame.display.flip()
            clock.tick(10)

    print(f"Final Score: {game.score}")
    pygame.quit()

if __name__ == "__main__":
    model_type = 'RNN'  # Ensure this matches the model type used during training
    model_path = "rnn/rnn_snake_model_10500.pth"

    state_size = 6
    hidden_size = 64
    action_size = 4
    model = load_model(model_type, model_path, state_size, hidden_size, action_size)

    game = SnakeGameAI()

    for i in range(5):
        play_game(model, game)


Final Score: 0
Final Score: 0
Final Score: 0
Final Score: 0


In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import numpy as np
import time
import pygame
import matplotlib.pyplot as plt
import csv
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class ResidualCNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ResidualCNN, self).__init__()
        self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(64, 64, stride=1)
        self.layer2 = self._make_layer(64, 128, stride=2)
        self.fc1 = nn.Linear(128 * 1 * 3, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, output_dim)

    def _make_layer(self, in_channels, out_channels, stride):
        return nn.Sequential(
            ResidualBlock(in_channels, out_channels, stride),
            ResidualBlock(out_channels, out_channels),
        )

    def forward(self, x):
        x = x.unsqueeze(2)  # Add a dummy height dimension
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = torch.relu(self.fc1(out))
        out = torch.relu(self.fc2(out))
        out = self.fc3(out)
        return out

class DQNAgent:
    def __init__(self, input_dim, action_dim):
        self.input_dim = input_dim
        self.action_dim = action_dim
        self.memory = deque(maxlen=2000)
        self.gamma = 0.95
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01
        self.learning_rate = 0.001
        self.model = ResidualCNN(input_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.criterion = nn.MSELoss()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_dim)
        state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(device)  # Reshape to [1, 1, 1, 6]
        with torch.no_grad():
            action_values = self.model(state)
        return torch.argmax(action_values[0]).item()

    def replay(self, batch_size):
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(device)  # Ensure batch dimension
            next_state = torch.FloatTensor(next_state).unsqueeze(0).unsqueeze(0).to(device)  # Ensure batch dimension
            target = reward
            if not done:
                target = (reward + self.gamma * torch.max(self.model(next_state)[0])).item()
            target_f = self.model(state)[0]
            target_f[action] = target
            self.optimizer.zero_grad()
            loss = self.criterion(target_f, self.model(state)[0])
            loss.backward()
            self.optimizer.step()
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

class SnakeGameAI:
    def __init__(self, width=640, height=480, snake_size=20):
        self.width = width
        self.height = height
        self.snake_size = snake_size
        self.reset()

    def reset(self):
        self.x1 = self.width / 2
        self.y1 = self.height / 2
        self.x1_change = 0
        self.y1_change = 0
        self.snake_List = []
        self.Length_of_snake = 1
        self.foodx = round(random.randrange(0, self.width - self.snake_size) / self.snake_size) * self.snake_size
        self.foody = round(random.randrange(0, self.height - self.snake_size) / self.snake_size) * self.snake_size
        self.score = 0
        return self.get_state()

    def step(self, action):
        if action == 0:  # LEFT
            self.x1_change = -self.snake_size
            self.y1_change = 0
        elif action == 1:  # RIGHT
            self.x1_change = self.snake_size
            self.y1_change = 0
        elif action == 2:  # UP
            self.y1_change = -self.snake_size
            self.x1_change = 0
        elif action == 3:  # DOWN
            self.y1_change = self.snake_size
            self.x1_change = 0

        self.x1 += self.x1_change
        self.y1 += self.y1_change

        done = self.x1 >= self.width or self.x1 < 0 or self.y1 >= self.height or self.y1 < 0
        snake_Head = [self.x1, self.y1]
        self.snake_List.append(snake_Head)
        if len(self.snake_List) > self.Length_of_snake:
            del self.snake_List[0]

        for x in self.snake_List[:-1]:
            if x == snake_Head:
                done = True

        reward = 0
        if self.x1 == self.foodx and self.y1 == self.foody:
            self.foodx = round(random.randrange(0, self.width - self.snake_size) / self.snake_size) * self.snake_size
            self.foody = round(random.randrange(0, self.height - self.snake_size) / self.snake_size) * self.snake_size
            self.Length_of_snake += 1
            reward = 10 * self.Length_of_snake
            self.score += 1

        if done:
            reward = 0

        state = self.get_state()
        return state, reward, done

    def get_state(self):
        state = [
            self.x1, self.y1, self.foodx, self.foody,
            self.x1_change, self.y1_change,
        ]
        return np.array(state, dtype=np.float32)

    def render(self):
        window = pygame.display.set_mode((self.width, self.height))
        window.fill((0, 0, 0))
        for segment in self.snake_List:
            pygame.draw.rect(window, (0, 255, 0), [segment[0], segment[1], self.snake_size, self.snake_size])
        pygame.draw.rect(window, (255, 255, 255), [self.foodx, self.foody, self.snake_size, self.snake_size])
        pygame.display.update()

def draw_text(surface, text, font, color, x, y):
    text_obj = font.render(text, True, color)
    text_rect = text_obj.get_rect(center=(x, y))
    surface.blit(text_obj, text_rect)

def wait_for_click():
    waiting = True
    while waiting:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                return False
            if event.type == pygame.MOUSEBUTTONDOWN:
                return True
        pygame.display.flip()
        pygame.time.Clock().tick(30)
    return False

def play_game(model, game, display=True):
    pygame.init()  # Initialize Pygame
    window = pygame.display.set_mode((game.width, game.height))  # Create the window
    pygame.display.set_caption("Snake Game AI")
    font = pygame.font.SysFont(None, 55)
    
    # Start screen
    window.fill((0, 0, 0))
    draw_text(window, "Click to Start", font, (255, 255, 255), game.width // 2, game.height // 2)
    pygame.display.update()
    if not wait_for_click():
        return
    
    state = game.reset()
    total_reward = 0
    for _ in range(500):
        action = model.act(state)
        next_state, reward, done = game.step(action)
        state = next_state
        total_reward += reward
        if display:
            game.render()
        if done:
            break
        pygame.time.Clock().tick(10)
    print(f"Score: {game.score}, Total Reward: {total_reward}")

def train_until_achieve_score(target_score=10, load_path=None, episode_st=0):
    state_size = 6
    input_dim = 1
    action_size = 4
    batch_size = 32
    n_episodes = 10000

    agent = DQNAgent(input_dim, action_size)
    if load_path:
        agent.model.load_state_dict(torch.load(load_path))
        print(f"Loaded model from {load_path}")

    game = SnakeGameAI()
    episode = episode_st
    start_time = time.time()
    
    scores = []
    episode_durations = []

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    csv_file = "/kaggle/input/model-esentials/training_data.csv"
    if os.path.isfile(csv_file):
        with open(csv_file, mode='r') as file:
            reader = csv.reader(file)
            next(reader)
            for row in reader:
                episode, score, duration = int(row[0]), int(row[1]), float(row[2])
                scores.append(score)
                episode_durations.append(duration)
                episode = max(episode_st, episode)

    else:
        with open(csv_file, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["Episode", "Score", "Episode Time"])

    while True:
        episode += 1
        state = game.reset()
        episode_start = time.time()
        total_reward = 0
        for time_step in range(500):
            action = agent.act(state)
            next_state, reward, done = game.step(action)
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            total_reward += reward
            if done:
                episode_end = time.time()
                episode_time = episode_end - episode_start
                scores.append(game.score)
                episode_durations.append(episode_time)
                print(f"Episode: {episode}, Score: {game.score}, Episode time: {episode_time:.2f} seconds")
                break
            if len(agent.memory) > batch_size:
                agent.replay(batch_size)

        if episode % 100 == 0:
            torch.save(agent.model.state_dict(), f"/kaggle/working/residual_cnn_snake_model_{episode}.pth")
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"Saving after Episode: {episode}, training time since last save: {elapsed_time:.2f} seconds")
            
            with open("/kaggle/working/training_data.csv", mode='a', newline='') as file:
                writer = csv.writer(file)
                for i in range(len(scores)):
                    writer.writerow([episode - (len(scores) - i), scores[i], episode_durations[i]])

        if game.score >= target_score:
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"Model achieved a score of {target_score} in episode {episode}")
            print(f"Training time: {elapsed_time:.2f} seconds")
            break

        ax1.clear()
        ax1.plot(scores)
        ax1.set_xlabel('Episode')
        ax1.set_ylabel('Total Reward')
        ax1.set_title('Training Progress')

        ax2.clear()
        ax2.plot(episode_durations)
        ax2.set_xlabel('Episode')
        ax2.set_ylabel('Duration (seconds)')
        ax2.set_title('Episode Duration')

        fig.savefig("/kaggle/working/training_progress.png")

    torch.save(agent.model.state_dict(), f"/kaggle/working/residual_cnn_snake_model_final.pth")


In [15]:
import torch
import pygame

state_size = 6
input_dim = 1
action_size = 4

agent = DQNAgent(input_dim, action_size)
game = SnakeGameAI()

model_path = "rcnn_old/residual_cnn_snake_model_11000.pth"
agent.model.load_state_dict(torch.load(model_path, map_location=torch.device('cuda')))
agent.model.eval()

play_game(agent, game, display=True)

pygame.quit()


Score: 0, Total Reward: 0
