In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from models.ann import DeepQNet
from gridworld.agent import Agent
from gridworld.world import World

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device: ', device)

num_episodes_per_epoch: int = 10
num_epochs: int = 1000
num_batches_per_epoch: int = 100

batch_size: int = 64
episode_duration: int = 1000

agent = Agent()

state_action_shape = len(agent.actions) + agent.visual_field.shape[0]
policy = DeepQNet(state_action_shape, 1).to(device)

class ReplayMemory:
    def __init__(self) -> None:
        self.observations = torch.zeros((0, state_action_shape)).to(device)
        self.rewards = torch.zeros(0).to(device)
        
    def add_experience(self, observation: torch.Tensor, reward: torch.Tensor) -> None:
        self.observations = torch.cat((self.observations, observation.to(device)))
        self.rewards = torch.cat((self.rewards, reward.to(device)))
        
    def make_batch(self, batch_size: int):
        indices = np.random.randint(0, len(self.observations), batch_size)
        memory = ReplayMemory()
        memory.add_experience(self.observations[indices], self.rewards[indices])
        return memory
        
memories = ReplayMemory()

def simulate_episode(agent: Agent) -> tuple[torch.Tensor, torch.Tensor, int]:
    world = World()
    observations = torch.zeros((episode_duration, state_action_shape))
    rewards = torch.zeros(episode_duration)
    
    t = 0
    for t in range(episode_duration):
        action: int = np.random.randint(0, len(agent.actions))
        action_vector = torch.zeros(len(agent.actions))
        action_vector[action] = 1.0
        
        observations[t,:agent.visual_field.shape[0]] = agent.update_visual_field(world)
        observations[t,agent.visual_field.shape[0]:] = action_vector
        rewards[t] = agent.step(action, world)
        
    return observations[:t], rewards[:t], t

loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(policy.parameters(), lr=0.01)
loss_history = []

for epoch in tqdm(range(num_epochs), desc="Epochs"):
    total_loss = 0.0
    for _ in range(num_episodes_per_epoch):
        observations, rewards, duration = simulate_episode(agent)
        memories.add_experience(observations, rewards)

    for _ in range(num_batches_per_epoch):
        batch = memories.make_batch(batch_size)
        optimizer.zero_grad()
        predicted_rewards = policy(batch.observations)
        loss = loss_fn(predicted_rewards, batch.rewards.unsqueeze(1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    loss_history.append(total_loss / num_batches_per_epoch)

Device:  cpu


Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]

: 

In [None]:
import datetime
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = f"policy_{timestamp}.pth"
torch.save(policy.state_dict(), filename)

In [None]:
fig = plt.figure()
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss [MSE]')
print('Final loss:', loss_history[-1])

NameError: name 'plt' is not defined