In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import sys
from tqdm import tqdm
import pickle
import os
import random
# Add the project root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../..')))
import random

from src.ParObsSnakeEnv import ParObsSnakeEnv
from src.FullObsSnakeEnv import FullObsSnakeEnv
from src.utils import compute_metrics
import matplotlib.pyplot as plt
from torch.distributions import Categorical

In [14]:
class RandomAgent:
    def __init__(self, action_dim):
        self.action_dim = action_dim

    def choose_action(self, state):
        """Randomly choose an action."""
        return random.randint(0, self.action_dim - 1), None

    def train(self, env, episodes=1000, save_plots=False, plots_path='random_training_plots.png'):
        self.episode_rewards = []

        for episode in tqdm(range(episodes), desc="Training", unit="episode"):
            state = env.reset()
            total_reward = 0
            done = False

            while not done:
                action, _ = self.choose_action(state)
                next_state, reward, done, _ = env.step(action)
                total_reward += reward
                state = next_state

            self.episode_rewards.append(total_reward)

        if save_plots:
            self.save_plots(self.episode_rewards, plots_path)

    def save_plots(self, plots_path):
        """Save the rewards plot for the random agent."""
        plots_dir = os.path.dirname(plots_path)
        os.makedirs(plots_dir, exist_ok=True)

        plt.plot(self.episode_rewards)
        plt.title("Episode Rewards")
        plt.xlabel("Episode")
        plt.ylabel("Total Reward")
        plt.tight_layout()
        plt.savefig(plots_path)
        plt.close()


In [15]:
grid_size = 10
# env = FullObsSnakeEnv(grid_size=grid_size, interact=False)
env = ParObsSnakeEnv(grid_size=grid_size, interact=False)
if isinstance(env, FullObsSnakeEnv):
    state_dim = env.observation_space.shape[0] * env.observation_space.shape[1] * env.observation_space.shape[2]
else:
    state_dim = env.observation_space.shape[0]

action_dim = env.action_space.n
agent = RandomAgent(action_dim)

In [16]:
num_episodes = 10000
agent.train(env, episodes=num_episodes)

Training:   0%|          | 0/10000 [00:00<?, ?episode/s]

Training: 100%|██████████| 10000/10000 [00:00<00:00, 17128.08episode/s]


In [17]:
environment = 'full 'if isinstance(env, FullObsSnakeEnv) else 'par'

agent_name = f'random_agent_{environment}_{num_episodes}_{grid_size}.pkl'
model_weights_dir = os.path.join('../..', 'models', 'random')
os.makedirs(model_weights_dir, exist_ok=True)
agent_path = os.path.join(model_weights_dir, agent_name)

In [18]:
if isinstance(env, ParObsSnakeEnv):
    env = ParObsSnakeEnv(grid_size=2*grid_size, interact=False)

model_metrics_dir = os.path.join('../..', 'artifacts', 'models_stats', 'random')
os.makedirs(model_metrics_dir, exist_ok=True)

train_metrics_name = f'random_train_metrics_{environment}_{num_episodes}_{grid_size}.png'
train_metrics_path = os.path.join(model_metrics_dir, train_metrics_name)
agent.save_plots(train_metrics_path)

num_simulations = 100
sim_metrics_name = f'random_sim_metrics_{environment}_{num_episodes}_{env.grid_size}_{num_simulations}.json'
sim_metrics_path = os.path.join(model_metrics_dir, sim_metrics_name)
compute_metrics(agent, env, sim_metrics_path, num_simulations=num_simulations)

100%|██████████| 100/100 [00:00<00:00, 4440.77it/s]

Snake length: 1, Episode reward: -72
Snake length: 1, Episode reward: -76
Snake length: 1, Episode reward: -72
Snake length: 1, Episode reward: -78
Snake length: 1, Episode reward: -82
Snake length: 1, Episode reward: -72
Snake length: 1, Episode reward: -73
Snake length: 1, Episode reward: -77
Snake length: 1, Episode reward: -85
Snake length: 1, Episode reward: -75
Snake length: 1, Episode reward: -80
Snake length: 1, Episode reward: -75
Snake length: 1, Episode reward: -78
Snake length: 1, Episode reward: -78
Snake length: 1, Episode reward: -72
Snake length: 1, Episode reward: -82
Snake length: 1, Episode reward: -74
Snake length: 1, Episode reward: -82
Snake length: 1, Episode reward: -75
Snake length: 1, Episode reward: -75
Snake length: 1, Episode reward: -65
Snake length: 1, Episode reward: -75
Snake length: 1, Episode reward: -81
Snake length: 1, Episode reward: -74
Snake length: 3, Episode reward: 77
Snake length: 2, Episode reward: 10
Snake length: 1, Episode reward: -62
Sna




{'snake_lengths': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  3,
  2,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  2,
  1,
  1,
  1,
  1,
  1,
  1,
  2,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  2],
 'episode_rewards': [-72,
  -76,
  -72,
  -78,
  -82,
  -72,
  -73,
  -77,
  -85,
  -75,
  -80,
  -75,
  -78,
  -78,
  -72,
  -82,
  -74,
  -82,
  -75,
  -75,
  -65,
  -75,
  -81,
  -74,
  77,
  10,
  -62,
  -68,
  -75,
  -81,
  -72,
  -79,
  -77,
  13,
  -87,
  -81,
  -83,
  -76,
  -82,
  -92,
  7,
  -75,
  -64,
  -77,
  -80,
  -100,
  -74,
  -64,
  -78,
  -63,
  -75,
  -75,
  -76,
  -87,
  -78,
  -74,
  -77,
  -80,
  -76,
  -73,
  -90,
  -74,
  -75,
  -74,
  -89,
  -66,
  -72

In [27]:
if isinstance(env, FullObsSnakeEnv):
    env.interact = True
else:
    env = ParObsSnakeEnv(grid_size=2*grid_size)
    
state = env.reset()
done = False
while not done:
    action, _ = agent.choose_action(state)
    state, reward, done, _ = env.step(action)
    env.render()
    print(f"Reward: {reward}")

Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: 76
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: 1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: -75
