Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
adityasidharta
committed
Dec 9, 2018
1 parent
b161120
commit e4371db
Showing
15 changed files
with
174 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -100,5 +100,8 @@ venv.bak/ | |
# mkdocs documentation | ||
/site | ||
|
||
# idea | ||
.idea/ | ||
|
||
# mypy | ||
.mypy_cache/ |
File renamed without changes.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
class Agent(object): | ||
def __init__(self, learner, memory, policy, value_function, envs, config): | ||
self.learner = learner | ||
self.memory = memory | ||
self.policy = policy | ||
self.value_function = value_function | ||
self.envs = envs | ||
self.config = config | ||
|
||
def train_agent(self, n_iteration): | ||
pass | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
|
||
BATCH_SIZE = 128 | ||
GAMMA = 0.999 | ||
EPS_START = 0.9 | ||
EPS_END = 0.05 | ||
EPS_DECAY = 200 | ||
TARGET_UPDATE = 10 | ||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
class Config(): | ||
def __init__(self, batch_size, gamma, eps_start, eps_end, eps_decay, target_update, device): | ||
self.batch_size = batch_size | ||
self.gamma = gamma | ||
self.eps_start = eps_start | ||
self.eps_end = eps_end | ||
self.eps_decay = eps_decay | ||
self.target_update = target_update | ||
self.device = device | ||
|
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,44 @@ | ||
import torch | ||
|
||
|
||
# TODO finish Q-Function | ||
class Q_Function(object): | ||
class Optimize_Molel(object): | ||
def __init__(self, old_qnet, new_qnet, optimizer): | ||
self.new_qnet = new_qnet | ||
self.old_qnet = old_qnet | ||
self.old_qnet.load_state_dict(self.old_qnet.state_dict()) | ||
self.old_qnet = self.old_qnet.eval() | ||
self.optimizer = optimizer | ||
self.torch_device = self.config['DEVICE'] | ||
self.gamma =self.config['GAMMA'] | ||
self.batch_size = self.config['BATCH_SIZE'] | ||
|
||
def calc_q(self, state_tensor): | ||
with torch.no_grad(): | ||
return self.old_qnet(state_tensor).cpu().numpy() | ||
|
||
def optimize_new_qnet(self, batch_size, memory, config): | ||
def optimize_new_qnet(self, memory, config): | ||
torch_device = config.device | ||
gamma = config.gamma | ||
batch_size = config.batch_size | ||
|
||
|
||
if len(memory) < batch_size: | ||
pass | ||
else: | ||
state_tensor, action_tensor, reward_tensor, next_state_tensor, finish_tensor = memory.sample( | ||
batch_size, torch=True, device=config["DEVICE"] | ||
batch_size, return_tensor=True, torch_device=torch_device | ||
) | ||
|
||
finish_index = torch.nonzero(finish_tensor.view(-1)).view(-1) | ||
cur_q = self.new_qnet(state_tensor) | ||
cur_qa = cur_q.gather(action_tensor) | ||
|
||
n_actions = cur_q.shape[1] | ||
cur_qa = cur_q.gather(1, action_tensor) | ||
|
||
with torch.no_grad(): | ||
unfinished_next_state_tensor = next_state_tensor[~finish_tensor] | ||
next_q = torch.zeros_like(cur_q) | ||
next_q[~finish_tensor] = old_net(unfinished_next_state_tensor) | ||
unfinished_next_state_tensor = next_state_tensor[finish_index, :] | ||
next_q = self.old_net(unfinished_next_state_tensor) | ||
next_qa = next_q.max(1)[0] | ||
exp_qa = reward_tensor + (gamma * next_qa) | ||
|
||
exp_qa = reward_tensor + (config["GAMMA"] * next_qa) | ||
# LOSS FUNCTION | ||
# LOSS FUNCTION | ||
|
||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from src.memory import Memory | ||
import random | ||
import torch | ||
|
||
|
||
def create_transitions(state=None, | ||
action=None, | ||
reward=None, | ||
next_state=None, | ||
finish=None, | ||
n_state=None): | ||
n_state = 4 if n_state is None else n_state | ||
state = [random.random() for x in range(n_state)] if state is None else state | ||
action = random.sample([0, 1], 1)[0] if action is None else action | ||
reward = random.random() if reward is None else reward | ||
next_state = [random.random() for x in range(n_state)] if next_state is None else next_state | ||
finish = random.sample([False, True], 1)[0] if finish is None else finish | ||
return state, action, reward, next_state, finish | ||
|
||
|
||
def test_init(): | ||
memory = Memory(10, 4) | ||
assert len(memory) == 0 | ||
assert not memory.is_memory_full() | ||
assert memory.position == 0 | ||
assert memory.n_state == 4 | ||
assert memory.capacity == 10 | ||
|
||
|
||
def test_len(): | ||
memory = Memory(10,4) | ||
for idx in range(5): | ||
state, action, reward, next_state, finish = create_transitions() | ||
memory.save(state, action, reward, next_state, finish) | ||
assert len(memory) == 5 | ||
|
||
|
||
def test_is_memory_full(): | ||
memory = Memory(10, 4) | ||
assert not memory.is_memory_full() | ||
for idx in range(5): | ||
state, action, reward, next_state, finish = create_transitions() | ||
memory.save(state, action, reward, next_state, finish) | ||
assert not memory.is_memory_full() | ||
for idx in range(5): | ||
state, action, reward, next_state, finish = create_transitions() | ||
memory.save(state, action, reward, next_state, finish) | ||
assert memory.is_memory_full() | ||
for idx in range(5): | ||
state, action, reward, next_state, finish = create_transitions() | ||
memory.save(state, action, reward, next_state, finish) | ||
assert memory.is_memory_full() | ||
|
||
|
||
def test_update_postiion(): | ||
memory = Memory(10, 4) | ||
assert memory.position == 0 | ||
for idx in range(5): | ||
state, action, reward, next_state, finish = create_transitions() | ||
memory.save(state, action, reward, next_state, finish) | ||
assert memory.position == 5 | ||
for idx in range(5): | ||
state, action, reward, next_state, finish = create_transitions() | ||
memory.save(state, action, reward, next_state, finish) | ||
assert memory.position == 0 | ||
for idx in range(5): | ||
state, action, reward, next_state, finish = create_transitions() | ||
memory.save(state, action, reward, next_state, finish) | ||
assert memory.position == 5 | ||
|
||
|
||
def test_sample(): | ||
memory = Memory(10, 4) | ||
for idx in range(5): | ||
state, action, reward, next_state, finish = create_transitions( | ||
state = [idx, idx, idx, idx], | ||
reward = idx, | ||
next_state = [idx, idx, idx, idx] | ||
) | ||
memory.save(state, action, reward, next_state, finish) | ||
sampled_state, sampled_action, sampled_reward, sampled_next_state, sampled_finish = memory.sample(3) | ||
assert sampled_state.shape == (3, 4) | ||
assert sampled_next_state.shape == (3, 4) | ||
assert sampled_state[:, 0] == sampled_reward | ||
|
||
tensor_state, tensor_action, tensor_reward, tensor_next_state, tensor_finish = memory.sample(3, return_tensor=True) | ||
assert tensor_state.shape == torch.Size([3, 4]) | ||
assert tensor_next_state.shape == torch.Size([3, 4]) | ||
|
||
assert tensor_state.type() == 'torch.cuda.FloatTensor' | ||
assert tensor_action.type() == 'torch.cuda.IntTensor' | ||
assert tensor_reward.type() == 'torch.cuda.FloatTensor' | ||
assert tensor_next_state.type() == 'torch.cuda.FloatTensor' | ||
assert tensor_finish.type() == 'torch.cuda.IntTensor' | ||
|