In [5]:
import random
import numpy as np
from collections import deque
from itertools import count

from simple_custom_taxi_env import SimpleTaxiEnv
from utils import select_action, get_state_tensor
from TaxiMemory import TaxiMemory

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [7]:
import torch.nn as nn


class PolicyNet(nn.Module):
    def __init__(self, n_observations, n_actions):
        super(PolicyNet, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(n_observations, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, n_actions),
        )

    def forward(self, x):
        return F.softmax(self.layers(x), 1)

In [8]:
class PGAgent():
    def __init__(self, n_observations, n_actions, device) -> None:
        self.policy_net = PolicyNet(n_observations, n_actions).to(device)
    
    def get_action(self, state: torch.tensor) -> tuple[int, float]:
        probs = self.policy_net(state)
        m = Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)

In [19]:
num_episodes = 300
MAX_FUEL = 50
LR = 0.0001
gamma = 0.99
EPS = np.finfo(np.float32).eps.item()

log_freq = 20
save_freq = 250
partial_prob = 0

In [20]:
# environment
n_observations = 23
n_actions = 6
# NN
agent = PGAgent(n_observations, n_actions, device=device)
agent.policy_net.load_state_dict(torch.load("pg1000.pth"))
optimizer = optim.AdamW(agent.policy_net.parameters(), lr=LR, amsgrad=True)
taxi_memory = TaxiMemory()

In [21]:
def env_jump_before_pick(env, taxi_memory, k=1):
    '''
    k=0-3
    '''
    if k == 0:
        return env.get_state()
    dest_idx = env.stations.index(env.destination)
    locations_idx = list(range(4)) # all
    locations_idx.remove(env.stations.index(env.passenger_loc)) # remove goal
    locations_idx = random.sample(locations_idx, k) # shrink to k

    env.taxi_pos = env.stations[random.choice(locations_idx)]

    for idx in locations_idx:
        taxi_memory.visit_mask[idx] = 0
    if dest_idx in locations_idx:
        taxi_memory.destination_mask[dest_idx] = 1
    return env.get_state()

def env_jump_after_pick(env, taxi_memory, k=1):
    '''
    k=0-3
    '''
    if k == 0:
        env.get_state()
    dest_idx = env.stations.index(env.destination)
    passenger_idx = env.stations.index(env.passenger_loc)
    locations_idx = list(range(4)) # all
    locations_idx.remove(passenger_idx) # remove passenger and add it back later
    locations_idx = random.sample(locations_idx, k-1) # shrink to k-1
    locations_idx.append(passenger_idx)

    env.taxi_pos = env.passenger_loc
    env.step(4)

    taxi_memory.passenger_picked_up = True
    for idx in locations_idx:
        taxi_memory.visit_mask[idx] = 0
    if dest_idx in locations_idx:
        taxi_memory.destination_mask[dest_idx] = 1
    return env.get_state()


In [22]:
episodice_rewards = []
scores_deque = deque(maxlen=100)

for i_episode in range(num_episodes):
    use_partial = random.random() < partial_prob
    if use_partial:
        env = SimpleTaxiEnv(fuel_limit=MAX_FUEL, partial=True)
    else:
        env = SimpleTaxiEnv(fuel_limit=MAX_FUEL, partial=False)
    state, info = env.reset()
    taxi_memory.reset(state)
    if use_partial:
        twist_level = np.random.randint(0, 4)
        state = env_jump_before_pick(env, taxi_memory, twist_level) # spectial setting
    else:
        twist_level = np.random.randint(1, 4)
        state = env_jump_after_pick(env, taxi_memory, twist_level) # spectial setting
    state = get_state_tensor(state, taxi_memory, device)
    log_probs = []
    rewards = []
    terminated = False

    for step in count():
        action, log_prob = agent.get_action(state)
        next_state, reward, done, _ = env.step(action)
        next_state = get_state_tensor(next_state, taxi_memory, device, action)
        log_probs.append(log_prob)
        rewards.append(reward)

        if done:
            terminated = step < MAX_FUEL - 1
            break
        state = next_state
    total_reward = sum(rewards)
    episodice_rewards.append(total_reward)
    scores_deque.append(total_reward)
    
    returns = []
    G = 0 if terminated else 0.1 * MAX_FUEL
    for reward in rewards[::-1]:
        G = reward + gamma * G
        returns.append(G)
    returns = torch.tensor(returns[::-1])
    returns = (returns - returns.mean()) / (returns.std() + EPS)

    policy_loss = []
    for log_prob, disc_return in zip(log_probs, returns):
        policy_loss.append(-log_prob * disc_return)
    policy_loss = torch.cat(policy_loss).sum()

    optimizer.zero_grad()
    policy_loss.backward()
    optimizer.step()

    if i_episode % log_freq == 0:
        print(f'Episode {i_episode}\t, Average Score: {np.mean(scores_deque):.2f}')
    if i_episode % save_freq == 0:
        torch.save(agent.policy_net.state_dict(), f"pg_ckpt-{i_episode}.pth")

Episode 0	, Average Score: -4.90
Episode 20	, Average Score: -7.04
Episode 40	, Average Score: -6.97
Episode 60	, Average Score: -6.95
Episode 80	, Average Score: -6.63
Episode 100	, Average Score: -6.65
Episode 120	, Average Score: -6.40
Episode 140	, Average Score: -6.25
Episode 160	, Average Score: -6.30
Episode 180	, Average Score: -6.40
Episode 200	, Average Score: -6.55
Episode 220	, Average Score: -6.45
Episode 240	, Average Score: -6.85
Episode 260	, Average Score: -6.75
Episode 280	, Average Score: -7.40


In [23]:
torch.save(agent.policy_net.state_dict(), f"pg1000.pth")