In [113]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from collections import namedtuple, deque
from itertools import count

import random


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [114]:
ACTIONS = torch.tensor([[1, 1], [0, -1]]), torch.tensor([[0, 1], [1, 0]], dtype=torch.float)

def gen_start_config():
    return torch.round(torch.rand((1, 2)) * 100)   # TODO make a better stochastic thing

def terminate(s):
    k_eps = 1e-4    # for float imprecision
    # terminate if one of the coordinates is 0
    return s[0] <= k_eps or s[1] <= k_eps

In [115]:
class Node:
    def __init__(self, value):
        self.value = value
        self.visits = 0
        self.children = []
    
    def UCT_fn(self, child, C):
        return child.value + 2 * C * torch.sqrt(2 * torch.log2(self.visits) / child.visits)
    
        

In [116]:
LR = 1e-4
GAMMA = 0.99
STEP_CAP = 100
TAU = 0.005

target_net = DQN().to(device)
target_net.load_state_dict(target_net.state_dict())
optimizer = optim.AdamW(target_net.parameters(), lr=LR, amsgrad=True)

memory = ReplayMemory(STEP_CAP)

steps_done = 0

def next_states(state):
    return state @ ACTIONS[0], state @ ACTIONS[0]


0
inf
-0.70703125


In [117]:
# TODO implement batch
def optimize_model():
    transitions = memory.sample()
    batch = Transition(*zip(*transitions))
    
    print(batch.next_state)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None and terminate(s) is not True,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if not (s is None or terminate(s))])
    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(torch.cat(batch.state))
    

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(1)
    with torch.no_grad():
        next_state_values[non_final_mask] = torch.max(target_net(non_final_next_states))

    # Compute the expected Q values
    global steps_done
    expected_state_action_values = (next_state_values * GAMMA) + steps_done

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

[322.  73.]
Accuracy: 0.9


In [620]:
if torch.cuda.is_available():
    num_episodes = 600
else:
    num_episodes = 50

def env_step(action, state):
    print("action", action, ACTIONS[action])
    pos = ACTIONS[action].int() @ state.int()   # todo 
    return pos, terminate(pos) 

for i_episode in range(num_episodes):
    print("EPISODE", str(i_episode))
    # Initialize the environment and get its state
    start = gen_start_config()
    if start.size() != (2, 1):
        print("failed", start)
    
    state = start
    reward = 0
    for t in count():
        print("step", t)
        action = select_action(state)
        new_state, done_early = env_step(action, state)
        killed = steps_done > STEP_CAP
        
        # update reward
        reward -= torch.linalg.norm(new_state.float())
        if done_early:
            reward += torch.sqrt(torch.norm(start)) # pretty arbitrary rn
        elif killed:
            print(start)
            print(start.size())
            print(torch.linalg.norm(start))
            reward -= torch.norm(start) # todo change penalty magnitude to decrease as norm of start increases?
        
        done = done_early or killed
        if done:
            print("early", done_early)
            print("killed", killed)
        if done:
            next_state = None
        else:
            next_state = new_state
            
        # Store the transition in memory
        memory.push(state, action, next_state, reward)
        print(memory.sample()[0])

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            break

print('Complete')

EPISODE 0
step 0
action tensor(1) tensor([[0., 1.],
        [1., 0.]])
Transition(state=tensor([[85.],
        [53.]]), action=tensor(1), next_state=tensor([[53],
        [85]], dtype=torch.int32), reward=tensor(-100.1699))
(tensor([[53],
        [85]], dtype=torch.int32),)
step 1
action tensor(1) tensor([[0., 1.],
        [1., 0.]])
Transition(state=tensor([[85.],
        [53.]]), action=tensor(1), next_state=tensor([[53],
        [85]], dtype=torch.int32), reward=tensor(-200.3397))
(tensor([[85],
        [53]], dtype=torch.int32),)
step 2
action tensor(1) tensor([[0., 1.],
        [1., 0.]])
Transition(state=tensor([[53],
        [85]], dtype=torch.int32), action=tensor(1), next_state=tensor([[85],
        [53]], dtype=torch.int32), reward=tensor(-300.5096))
(tensor([[53],
        [85]], dtype=torch.int32),)
step 3
action tensor(1) tensor([[0., 1.],
        [1., 0.]])
Transition(state=tensor([[85],
        [53]], dtype=torch.int32), action=tensor(1), next_state=tensor([[53],
        

RuntimeError: torch.cat(): expected a non-empty list of Tensors