### notebook for initial model training

obs layers: 

    o: borders, 

    1: apples, 

    2: heads_other, 

    3: bodies_other, 

    4: tails_other,

    5: head_self,

    6: body_self,

    7: tail_self
    
"""

In [21]:
# !git clone https://github.com/1eeGit/marlenv.git
# !cd marlenv
# !git checkout 1eeGit-patch-1
# !pip install -e .

### install additional packages
# !pip install pygame==2.6.0
# !pip install matplotlib
# !pip install torch

### Test environment

In [None]:
!pytest

platform linux -- Python 3.12.5, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/unix
plugins: anyio-4.4.0
collected 5 items                                                              [0m[1m

marlenv/tests/test_snake.py [32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m                                        [100%][0m



In [None]:
### test game window
# !python video_recorder_and_input_handling.py 

In [None]:
### reload the project if needed
# import importlib
# import marlenv
# importlib.reload(marlenv)

# DQN Model

## prepare env

In [37]:

import gym
import marlenv
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

## check gpu availability
torch.cuda.is_available()


True

In [26]:
custom_reward_dict = {
    'fruit': 1.0,
    'kill': 1.5,
    'lose': -10.0,
    'time': 0.1,
    'win': 10.0
}

In [27]:
### create the environment

env = gym.make(
    'Snake-v1',
    height=20,       # Height of the grid map
    width=20,        # Width of the grid map
    num_snakes=4,    # Number of snakes to spawn on grid
    snake_length=3,  # Initial length of the snake at spawn time
    vision_range=5,  # Vision range (both width height), map returned if None
    frame_stack=1,   # Number of observations to stack on return
    reward_func=custom_reward_dict
)

## define CNN

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNQNetwork(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(CNNQNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=input_shape[0] * input_shape[-1], out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        
        conv_out_size = self._get_conv_output((input_shape[0] * input_shape[-1], *input_shape[1:3]))  # Adjust for new input size
        
        self.fc1 = nn.Linear(conv_out_size, 512)
        self.fc2 = nn.Linear(512, num_actions)

    def _get_conv_output(self, shape):
        o = torch.zeros(1, *shape)
        o = self.conv1(o)
        o = self.conv2(o)
        o = self.conv3(o)
        return int(np.prod(o.size()))

    def forward(self, x, device="cpu"):
        x = torch.Tensor(x).to(device)
        # Flatten the last two dimensions into the channel dimension
        x = x.view(x.size(0), -1, *x.size()[2:])
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)


## define training function

In [45]:
def train_dqn(env, q_network, replay_buffer, optimizer, num_episodes=500, batch_size=64, gamma=0.99, epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    q_network.to(device)
    epsilon = epsilon_start

    for episode in range(num_episodes):
        state = env.reset()
        print("Original state shape:", state.shape)

        # Reshape the state to combine the first and last dimensions into the channels dimension
        state = state.reshape(state.shape[0] * state.shape[3], state.shape[1], state.shape[2])

        # Add the batch dimension
        state = state.reshape(1, *state.shape)

        # Print the new shape
        print("New state shape:", state.shape)

        done = False
        total_reward = 0

        while not done:
            # Epsilon-greedy action selection
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    action = q_network(state).argmax().item()

            next_state, reward, done, _ = env.step(action)
            
            # Reshape next_state for CNN input
            next_state = next_state.transpose(2, 0, 1)
            next_state = next_state.reshape(1, *next_state.shape)

            total_reward += reward

            # Store transition in replay buffer
            replay_buffer.put((state, action, reward, next_state, done))

            state = next_state

        # Training logic here (sampling from replay buffer, etc.)


            # Training the Q-Network
            if len(replay_buffer) > batch_size:
                s_batch, a_batch, r_batch, s_prime_batch, done_batch = replay_buffer.sample(batch_size)

                s_batch = torch.FloatTensor(s_batch).to(device)
                a_batch = torch.LongTensor(a_batch).to(device)
                r_batch = torch.FloatTensor(r_batch).to(device)
                s_prime_batch = torch.FloatTensor(s_prime_batch).to(device)
                done_batch = torch.FloatTensor(done_batch).to(device)

                q_values = q_network(s_batch, device).gather(1, a_batch.unsqueeze(1)).squeeze(1)
                next_q_values = q_network(s_prime_batch, device).max(1)[0]
                target_q_values = r_batch + gamma * next_q_values * (1 - done_batch)

                loss = F.mse_loss(q_values, target_q_values)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        epsilon = max(epsilon * epsilon_decay, epsilon_end)
        print(f"Episode {episode + 1}: Total Reward: {total_reward}")



In [40]:
import collections
import random
import numpy as np

class ReplayBuffer:
    def __init__(self, buffer_limit):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append(a)
            r_lst.append(r)
            s_prime_lst.append(s_prime)
            done_mask_lst.append(done_mask)

        return np.array(s_lst), np.array(a_lst), \
            np.array(r_lst), np.array(s_prime_lst), \
            np.array(done_mask_lst)

    def size(self):
        return len(self.buffer)


### training loop

In [46]:
input_shape = env.observation_space.shape  # Should be (channels, width, height)
num_actions = env.action_space.n  # Number of possible actions

q_network = CNNQNetwork(input_shape, num_actions)
replay_buffer = ReplayBuffer(buffer_limit=10000)
optimizer = optim.Adam(q_network.parameters(), lr=0.001)

train_dqn(env, q_network, replay_buffer, optimizer)

Original state shape: (4, 11, 11, 8)
New state shape: (1, 32, 11, 11)


AssertionError: 

## evaluate

In [None]:
state = env.reset()
done = False
while not done:
    action = q_network(state, device="cuda").argmax().item()
    state, reward, done, _ = env.step(action)
    env.render()  # Assuming you want to see the game
