In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import time
import math
import pygame
from collections import deque

# Initialize device properly
# device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
device = 'cpu'
print(f"Using device: {device}")

num_episodes = 3000
d_state = 4
action_size = 4
discount_rate = 0.995
learning_rate = 3e-3
eps_start = 1
eps_end = 0.1
eps_decay = 500
time_step_reward = -0.5
dropout = 0.5

# Epsilon decay function
def epsilon_by_episode(episode):
    return eps_end + (eps_start - eps_end) * math.exp(-1. * episode / eps_decay)

# Neural Network for Q-Learning
class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, 384),
            nn.ReLU(),
            nn.Linear(384, output_size)
        )
        self.to(device)  # Move the model to the specified device
        self.dropout = nn.Dropout(p=dropout)
        # Initialize weights using Kaiming He initialization for ReLU
        nn.init.kaiming_uniform_(self.net[0].weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.net[2].weight, nonlinearity='relu')

    def forward(self, x):
        # print(f'input shape: {x.shape}') # [1, 16]
        out = F.softmax(self.net(x), dim=1)
        # print(f'output shape: {out.shape}') # [1, 4]
        return self.dropout(out)

# class DQN(nn.Module):
#     def __init__(self, input_size, output_size, seq_length):
#         super(DQN, self).__init__()
#         self.seq_length = seq_length
#         self.lstm = nn.LSTM(input_size, 128, batch_first=True)  # LSTM layer
#         self.fc1 = nn.Linear(128, 384)  # Adjusted for LSTM output
#         self.fc2 = nn.Linear(384, output_size)
#         self.dropout = nn.Dropout(p=dropout)
#         self.to(device)
        
#         # Initialize weights
#         nn.init.kaiming_uniform_(self.fc1.weight, nonlinearity='relu')
#         nn.init.kaiming_uniform_(self.fc2.weight, nonlinearity='relu')

#     def forward(self, x):
#         x = x.view(-1, self.seq_length, d_state ** 2)  # Reshape to [batch, seq_length, input_size]
#         lstm_out, _ = self.lstm(x)
#         x = F.relu(self.fc1(lstm_out[:, -1, :]))  # Use the last output of the sequence
#         x = self.fc2(x)
#         return self.dropout(F.softmax(x, dim=1))


class GridGame:
    def __init__(self, model):
        self.state_size = d_state ** 2
        self.action_size = action_size
        self.model = model
        self.reset()
        self.optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
        self.criterion = nn.MSELoss()

    def reset(self):
        self.player_pos = (random.randint(0, d_state - 1), random.randint(0, d_state - 1))
        self.goal_pos = (random.randint(0, d_state - 1), random.randint(0, d_state - 1))
        while self.goal_pos == self.player_pos:
            self.goal_pos = (random.randint(0, d_state - 1), random.randint(0, d_state - 1))
        self.done = False
        self.state = self.get_state()

    def get_state(self):
        state = torch.zeros((d_state, d_state), device=device)
        state[self.player_pos[0], self.player_pos[1]] = 1
        state[self.goal_pos[0], self.goal_pos[1]] = 2
        return state.flatten().unsqueeze(0)

    def calculate_distance(self):
        # Convert the differences to tensors before calculating the distance
        a = torch.tensor((self.player_pos[0] - self.goal_pos[0])**2, device=device, dtype=torch.float)
        b = torch.tensor((self.player_pos[1] - self.goal_pos[1])**2, device=device, dtype=torch.float)
        return torch.sqrt(a + b)

    def step(self, action):
        moves = [(-1, 0), (1, 0), (0, -1), (0, 1)]  # Up, Down, Left, Right
        move = moves[action]
        prev_distance = self.calculate_distance()
        self.player_pos = ((self.player_pos[0] + move[0]) % d_state, (self.player_pos[1] + move[1]) % d_state)
        new_distance = self.calculate_distance()
        # reward = time_step_reward  # Penalize each step to encourage efficiency
        # delta_distance = prev_distance - new_distance
        # if delta_distance > 0:
        #     reward += delta_distance/d_state
        # else:
        #     reward -= delta_distance/d_state
        # if self.player_pos == self.goal_pos:
        # #     reward += 100  # Large reward for reaching the goal
        #     self.done = True
        reward = time_step_reward  # Base penalty for each step
        if self.player_pos == self.goal_pos:
            # reward += 100  # Large reward for reaching the goal
            self.done = True
        else:
            delta_distance = prev_distance - new_distance
            if delta_distance > 0:
                reward += 1  # Encourage moving closer to the goal
            else:
                reward -= 0.5  # Penalize moving away from the goal

        new_state = self.get_state()
        return new_state, reward, self.done

    def train_step(self, state, action, reward, next_state, done):
        action = action.view(1, -1)
        reward = torch.tensor([reward], device=device, dtype=torch.float)
        done = torch.tensor([done], device=device, dtype=torch.float)
        
        state_action_values = self.model(state).gather(1, action)
        next_state_values = self.model(next_state).max(1)[0].detach()
        # print((next_state_values * discount_rate) * (1 - done))
        # time.sleep(0.2)
        expected_state_action_values = (next_state_values * discount_rate) * (1 - done) + torch.sigmoid(torch.tensor(reward))*2
        
        loss = self.criterion(state_action_values, expected_state_action_values.unsqueeze(1))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

def select_action(state, policy_net, episode):
    eps_threshold = epsilon_by_episode(episode)
    if random.random() > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(action_size)]], device=device, dtype=torch.long)

# N = 4  # Sequence length
policy_net = DQN(d_state ** 2, action_size)
game = GridGame(policy_net)


start_time = time.time()
policy_net.train()

try:
    for episode in range(num_episodes):
        game.reset()
        total_reward = 0
        while not game.done:
            state = game.state
            action = select_action(state, policy_net, episode)
            next_state, reward, done = game.step(action.item())
            game.train_step(state, action, reward, next_state, done)
            game.state = next_state
            total_reward += reward
                
        if episode % 100 == 0:
            print(f"Episode {episode}: Total Reward: {total_reward}, Epsilon: {epsilon_by_episode(episode)}")

except KeyboardInterrupt:
    print("Training stopped")

finally:
    print(f'Training took {time.time() - start_time} seconds')
    torch.save(policy_net.state_dict(), 'weights/model-v0.pth')
    print('Model saved')


Using device: cpu
Episode 0: Total Reward: 0.0, Epsilon: 1.0


  expected_state_action_values = (next_state_values * discount_rate) * (1 - done) + torch.sigmoid(torch.tensor(reward))*2


Episode 100: Total Reward: -1.5, Epsilon: 0.8368576777701836
Episode 200: Total Reward: 0.0, Epsilon: 0.7032880414320754
Episode 300: Total Reward: -5.0, Epsilon: 0.5939304724846237
Episode 400: Total Reward: -0.5, Epsilon: 0.5043960677054994
Episode 500: Total Reward: -3.0, Epsilon: 0.43109149705429817
Episode 600: Total Reward: -0.5, Epsilon: 0.37107479072098193
Episode 700: Total Reward: -0.5, Epsilon: 0.32193726754744584
Episode 800: Total Reward: -0.5, Epsilon: 0.28170686619518986
Episode 900: Total Reward: -20.0, Epsilon: 0.2487689993994279
Episode 1000: Total Reward: -13.5, Epsilon: 0.22180175491295145
Episode 1100: Total Reward: -6.5, Epsilon: 0.1997228425261005
Episode 1200: Total Reward: -2.5, Epsilon: 0.1816461579604713
Episode 1300: Total Reward: -1.0, Epsilon: 0.16684622039290048
Episode 1400: Total Reward: -2.5, Epsilon: 0.15472905636269618
Episode 1500: Total Reward: -12.0, Epsilon: 0.14480836153107757
Episode 1600: Total Reward: -15.0, Epsilon: 0.13668598358052958
Episo

In [5]:
# run a game with the current weights. use a pygame window to visualize the game
# game.reset()
# while not game.done:
#     state = game.state
#     human_readable_state = state.view(d_state, d_state).cpu().numpy()
#     print(human_readable_state)
#     action = select_action(state, policy_net, 0, device)
#     next_state, reward, done = game.step(action.item())
#     game.state = next_state
#     print(reward)
#     time.sleep(0.5)
# print("Game over")

policy_net.eval()
# give me a pygame window to visualize the game
episode_moves = []
for x in range(100):
    game.reset()
    screen = pygame.display.set_mode((d_state*100, d_state*100))
    total_reward = 0
    episode_moves.append(0)
    while not game.done:

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                quit()
        state = game.state
        human_readable_state = state.view(d_state, d_state).cpu().numpy()
        # print(human_readable_state)
        action = select_action(state, policy_net, 0)
        # action = select_action(state, policy_net, episode, game.player_pos)
        next_state, reward, done = game.step(action.item())
        game.state = next_state
        total_reward += reward
        episode_moves[x] += 1
        # print(reward)
        time.sleep(0.05)
        screen.fill((0, 0, 0))
        for i in range(d_state):
            for j in range(d_state):
                if human_readable_state[i][j] == 1:
                    pygame.draw.rect(screen, (0, 0, 255), (i*100, j*100, 100, 100))
                elif human_readable_state[i][j] == 2:
                    pygame.draw.rect(screen, (0, 255, 0), (i*100, j*100, 100, 100))
        pygame.display.update()
    print(f"Total Reward: {total_reward}")
    # print("Game over")
pygame.quit()
print(f"Average moves: {sum(episode_moves)/len(episode_moves)}")

Total Reward: 0.0
Total Reward: -2.5
Total Reward: -2.0
Total Reward: -2.0
Total Reward: -3.5
Total Reward: 0.0
Total Reward: -4.5
Total Reward: -6.5


KeyboardInterrupt: 