In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import random
# import gym 
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Policy net (pi_theta)
class PolicyNet(nn.Module):
    def __init__(self, hidden_dim=512, state_dim = 4, nActions = 20):
        super().__init__()

        self.hidden = nn.Linear(state_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, nActions) # 2 * number + 1

    def forward(self, s):
        outs = self.hidden(s)
        outs = F.relu(outs)
        outs = self.output(outs)
        return outs

# pi_model = PolicyNet(nActions=20).to(device)

# Pick up action (for each step in episode)
def pick_sample(s, pi_model):
    with torch.no_grad():
        #   --> size : (1, 4)
        s_batch = np.expand_dims(s, axis=0)
        s_batch = torch.tensor(s_batch, dtype=torch.float).to(device)
        # Get logits from state
        #   --> size : (1, 2)
        logits = pi_model(s_batch)
        #   --> size : (2)
        logits = logits.squeeze(dim=0)
        # From logits to probabilities
        probs = F.softmax(logits, dim=-1)
        # Pick up action's sample
        #   --> size : (1)
        a = torch.multinomial(probs, num_samples=1)
        #   --> size : ()
        a = a.squeeze(dim=0)
        # Return
        return a.tolist()

In [4]:
class QNet(nn.Module):
    def __init__(self, state_dim = 4, hidden_dim=512, nActions = 20):
        super().__init__()

        self.hidden = nn.Linear(state_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, nActions) # 2 * number + 1

    def forward(self, s):
        outs = self.hidden(s)
        outs = F.relu(outs)
        outs = self.output(outs)
        return outs

In [5]:
alpha = 0.4
# alpha = 0.1


class categorical:
    def __init__(self, s):
        logits = pi_model(s)
        self._prob = F.softmax(logits, dim=-1)
        self._logp = torch.log(self._prob)

    # probability (sum is 1.0) : P
    def prob(self):
        return self._prob

    # log probability : log P()
    def logp(self):
        return self._logp

def optimize_theta(states, alpha):
    # Convert to tensor
    states = torch.tensor(states, dtype=torch.float).to(device)
    # Disable grad in q_origin_model1 before computation
    # (or use q_value.detach() not to include in graph)
    for p in q_origin_model1.parameters():
        p.requires_grad = False
    # Optimize
    opt_pi.zero_grad()
    dist = categorical(states)
    q_value = q_origin_model1(states)
    term1 = dist.prob()
    # alpha = log_alpha.exp().detach()
    term2 = q_value - alpha * dist.logp()
    # print(term1.shape, term2.shape)
    # return
    expectation = term1.unsqueeze(dim=1) @ term2.unsqueeze(dim=2)
    expectation = expectation.squeeze(dim=1)
    (-expectation).sum().backward()
    opt_pi.step()
    # Enable grad again
    for p in q_origin_model1.parameters():
        p.requires_grad = True
        
def optimize_alpha(states):
    states = torch.tensor(states, dtype=torch.float).to(device)
    # Disable grad in q_origin_model1 before computation
    # (or use q_value.detach() not to include in graph)
    for p in pi_model.parameters():
        p.requires_grad = False
    
    dist = categorical(states)
    alpha_optimizer.zero_grad()
    alphaLoss = - (log_alpha.exp() * ((dist.logp() * dist.prob()).sum() - targetEntropy)).mean()
    alphaLoss.backward()
    alpha_optimizer.step()
    alpha = log_alpha.exp().detach()
    for p in pi_model.parameters():
        p.requires_grad = True

In [6]:
gamma = 0.99


def optimize_phi(states, actions, rewards, next_states, dones, alpha, nActions=20):
    #
    # Convert to tensor
    #
    states = torch.tensor(states, dtype=torch.float).to(device)
    actions = torch.tensor(actions, dtype=torch.int64).to(device)
    rewards = torch.tensor(rewards, dtype=torch.float).to(device)
    rewards = rewards.unsqueeze(dim=1)
    next_states = torch.tensor(next_states, dtype=torch.float).to(device)
    dones = torch.tensor(dones, dtype=torch.float).to(device)
    dones = dones.unsqueeze(dim=1)

    #
    # Compute r + gamma * (1 - d) (min Q(s_next,a_next') + alpha * H(P))
    #
    # alpha = log_alpha.exp().detach()
    with torch.no_grad():
        # min Q(s_next,a_next')
        q1_tgt_next = q_target_model1(next_states)
        q2_tgt_next = q_target_model2(next_states)
        dist_next = categorical(next_states)
        q1_target = q1_tgt_next.unsqueeze(dim=1) @ dist_next.prob().unsqueeze(dim=2)
        q1_target = q1_target.squeeze(dim=1)
        q2_target = q2_tgt_next.unsqueeze(dim=1) @ dist_next.prob().unsqueeze(dim=2)
        q2_target = q2_target.squeeze(dim=1)
        q_target_min = torch.minimum(q1_target, q2_target)
        # alpha * H(P)
        h = dist_next.prob().unsqueeze(dim=1) @ dist_next.logp().unsqueeze(dim=2)
        h = h.squeeze(dim=1)
        h = -alpha * h
        # total
        term2 = rewards + gamma * (1.0 - dones) * (q_target_min + h)

    #
    # Optimize critic loss for Q-network1
    #
    opt_q1.zero_grad()
    one_hot_actions = F.one_hot(actions, num_classes=nActions).float()
    q_value1 = q_origin_model1(states)
    term1 = q_value1.unsqueeze(dim=1) @ one_hot_actions.unsqueeze(dim=2)
    term1 = term1.squeeze(dim=1)
    loss_q1 = F.mse_loss(
        term1,
        term2,
        reduction="none")
    loss_q1.sum().backward()
    opt_q1.step()

    #
    # Optimize critic loss for Q-network2
    #
    opt_q2.zero_grad()
    one_hot_actions = F.one_hot(actions, num_classes=nActions).float()
    q_value2 = q_origin_model2(states)
    term1 = q_value2.unsqueeze(dim=1) @ one_hot_actions.unsqueeze(dim=2)
    term1 = term1.squeeze(dim=1)
    loss_q2 = F.mse_loss(
        term1,
        term2,
        reduction="none")
    loss_q2.sum().backward()
    opt_q2.step()

In [7]:
tau = 0.002

def update_target():
    for var, var_target in zip(q_origin_model1.parameters(), q_target_model1.parameters()):
        var_target.data = tau * var.data + (1.0 - tau) * var_target.data
    for var, var_target in zip(q_origin_model2.parameters(), q_target_model2.parameters()):
        var_target.data = tau * var.data + (1.0 - tau) * var_target.data

In [8]:
class replayBuffer:
    def __init__(self, buffer_size: int):
        self.buffer_size = buffer_size
        self.buffer = []
        self._next_idx = 0

    def add(self, item):
        if len(self.buffer) > self._next_idx:
            self.buffer[self._next_idx] = item
        else:
            self.buffer.append(item)
        if self._next_idx == self.buffer_size - 1:
            self._next_idx = 0
        else:
            self._next_idx = self._next_idx + 1

    def sample(self, batch_size):
        indices = [random.randint(0, len(self.buffer) - 1) for _ in range(batch_size)]
        states   = [self.buffer[i][0] for i in indices]
        actions  = [self.buffer[i][1] for i in indices]
        rewards  = [self.buffer[i][2] for i in indices]
        n_states = [self.buffer[i][3] for i in indices]
        dones    = [self.buffer[i][4] for i in indices]
        return states, actions, rewards, n_states, dones

    def length(self):
        return len(self.buffer)

buffer = replayBuffer(20000)

In [13]:
import pickle
with open('inputTestCases/_input2ways_n=4_.pickle', 'rb') as f:
    roadDefs = pickle.load(f) # deserialize using load()


In [46]:
from junctionart.roundabout.encodingGFN.setGenerationEnv2 import SetGenerationEnv2
from tqdm import tqdm
size = 4
nActions = 30


def train(env, nIter = 6000, batch_size = 250, disableBar = False):
    for i in tqdm(range(nIter), disable = disableBar):
        # Run episode till done
        s = torch.ones(1, size)
        done = False
        cum_reward = 0
        while not done:
            a = pick_sample((s/nActions).squeeze().tolist(), pi_model)
            
            done = (a == size)
       
            if done :
                s_next = s
            else:
                s_next = env.update(s, torch.tensor([a]))
            
            if done:
                config = (s_next.squeeze() - 1).long().tolist()
            
                r = 10**env.getProxyReward(config, normalize=True) 
            else:
                r = 0
            buffer.add([(s/nActions).squeeze().tolist(), a, r, (s_next/nActions).squeeze().tolist(), float(done)])
            cum_reward += r
            if buffer.length() >= 2000:
                states, actions, rewards, n_states, dones = buffer.sample(batch_size)
                optimize_theta(states, alpha)
                optimize_alpha(states)
                
                optimize_phi(states, actions, rewards, n_states, dones, alpha, nActions=size + 1)
                update_target()
            s = s_next
            alpha = log_alpha.exp().detach()
        print("Run episode{} with rewards {} s {} ALPHA {}".format(i, cum_reward, s.squeeze().tolist(), alpha), end="\r")
    

In [47]:
roadDefinition = roadDefs[0]
env = SetGenerationEnv2(size, nActions, roadDefinition)

targetEntropy = -nActions
log_alpha = torch.tensor([0.0], requires_grad=True)
alpha = log_alpha.exp().detach()

# models
pi_model = PolicyNet(state_dim=size, nActions=size + 1).to(device)
q_origin_model1 = QNet(state_dim=size, nActions=size + 1).to(device)  # Q_phi1
q_origin_model2 = QNet(state_dim=size, nActions=size + 1).to(device)  # Q_phi2
q_target_model1 = QNet(state_dim=size, nActions=size + 1).to(device)  # Q_phi1'
q_target_model2 = QNet(state_dim=size, nActions=size + 1).to(device)  # Q_phi2'
_ = q_target_model1.requires_grad_(False)  # target model doen't need grad
_ = q_target_model2.requires_grad_(False)  # target model doen't need grad
buffer = replayBuffer(20000)

# optimizers
opt_pi = torch.optim.AdamW(pi_model.parameters(), lr=0.0005)
opt_q1 = torch.optim.AdamW(q_origin_model1.parameters(), lr=0.0005)
opt_q2 = torch.optim.AdamW(q_origin_model2.parameters(), lr=0.0005)
alpha_optimizer = torch.optim.AdamW(params=[log_alpha], lr=0.0005) 
doneTraining = False

train(env, nIter=1000, batch_size=256, disableBar=True)
        


Run episode664 with rewards 3.9810717055349722 s [8.0, 9.0, 7.0, 7.0] ALPHA tensor([0.4032])]))

RuntimeError: probability tensor contains either `inf`, `nan` or element < 0