In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import subprocess
import numpy as np
import itertools
from itertools import count

In [8]:
class Policy(nn.Module):
    def __init__(self, state_space, action_space):
        super(Policy, self).__init__()
        self.state_space = state_space 
        self.action_space = action_space
        self.l1 = nn.Sequential(nn.Linear(self.state_space, 6*self.state_space), nn.ReLU())
        self.l2 = nn.Sequential(nn.Linear(6*self.state_space, 10*self.action_space), nn.ReLU())
        self.l3 = nn.Linear(10*self.action_space, self.action_space)
        
    def forward(self, x):    
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        return x

In [9]:
class ERenv():
    import platform
    import numpy as np
    def __init__(self, exe, init_m = 100, nrange=(100,200), mc=10):
        """
            nrange is half-open [a,b)
        """
        self.exe    = exe
        self.mc     = mc
        self.nrange = (100,200)
        self.init_m = init_m
        self.state  = [np.random.randint(*self.nrange), init_m]
        # self.pyver  = platform.python_version()[:3] == 3.6
        
    def step(self, p):
        if True: 
            tmp = subprocess.run(['./main', str(self.state[0]), str(10) ,str(p)],stdout=subprocess.PIPE) ## 3.6
        else:
            tmp = subprocess.run(['./main', str(self.state[0]), str(10) ,str(p)],capture_output=True) ## 3.7
        res = float(tmp.stdout.split()[0])
        self.state[0] = np.random.randint(*self.nrange)
        self.state[1] -= 1
        return self.state.copy(), float(tmp.stdout.split()[0]), self.state[1]==0
    
    def reset(self):
        self.state[0] = np.random.randint(*self.nrange)
        self.state[1] = self.init_m
        return self.state.copy()
    
    def run(self, state, p):
        tmp = subprocess.run(['./main', str(state[0]), str(10) ,str(p)],stdout=subprocess.PIPE) ## 3.6
        return float(tmp.stdout.split()[0])

In [10]:
def simu(env, policy_net, batch_size=8, gamma=0.99):
    state_pool, action_pool, reward_pool = [], [], []
    for b in range(batch_size):
        done  = False
        state = env.reset()
        cur_reward_pool = []
        for t in itertools.count():
            if done: break
            state_pool.append(state)
            state  = torch.tensor(state).float()
            action = policy_net(state)
            action_pool.append(action)
            state, reward, done, *_ = env.step(nn.Sigmoid()(action).item())
            cur_reward_pool.append(reward)
        
        # discount reward
        running_mean = 0
        for i in reversed(range(len(cur_reward_pool))):
            running_mean = running_mean*gamma + cur_reward_pool[i]
            cur_reward_pool[i] = running_mean
        reward_pool.extend(cur_reward_pool)
    
    # normalize reward
    avg, std = np.mean(reward_pool), np.std(reward_pool)
    reward_pool = list(map(lambda x: (x-avg)/std, reward_pool))
    
    return state_pool, action_pool, reward_pool


In [11]:
def update_policy(policy_net, optim, state_pool, action_pool, reward_pool):
    optim.zero_grad()
    for s, a, r in zip(state_pool, action_pool, reward_pool):
        m    = torch.distributions.Normal(nn.Sigmoid()(a), torch.tensor([0.0001]))
        loss = -m.log_prob(a)*r
        loss.backward()
    optim.step()

In [None]:
env = ERenv('./main', init_m=10)
policy_net = Policy(2, 1)
optim = torch.optim.SGD(policy_net.parameters(), lr=0.001)
for i in range(100):
    print(i)
    tmp_act = policy_net(torch.tensor([155, 5]).float()).sigmoid().item()
    print( tmp_act, env.run([155,5], tmp_act) )
    update_policy(policy_net, optim, *simu(env, policy_net, 4))

0
9.330861212220043e-05 0.0103226
1
0.0 0.00645161
2
nan 0.00645161
3
nan 0.00645161
4
nan 0.00645161
5
nan 0.00645161
6
nan 0.00645161
7
nan 0.00645161
8
nan 0.00645161
9
nan 0.00645161
10
nan 0.00645161
11
nan 0.00645161
12
nan 0.00645161
13
nan 0.00645161
14
nan 0.00645161
15
nan 0.00645161
16
nan 0.00645161
17
nan 0.00645161
18
nan 0.00645161
19
nan 0.00645161
20
nan 0.00645161
21
nan 0.00645161
22
nan 0.00645161
23
nan 0.00645161
24
nan 0.00645161
25
nan 0.00645161
26
nan 0.00645161
27
nan 0.00645161
28
nan 0.00645161
29
nan 0.00645161
30
nan 0.00645161
31
nan 0.00645161
32
nan 0.00645161
33
nan 0.00645161
34
nan 0.00645161
35
nan 0.00645161
36
nan 0.00645161
37
nan 0.00645161
38
nan 0.00645161
39
nan 0.00645161
40
nan 0.00645161
41
nan 0.00645161
42
nan 0.00645161
43
nan 0.00645161
44
nan 0.00645161
45
nan 0.00645161
46
nan 0.00645161
47
nan 0.00645161
48
nan 0.00645161
49
nan 0.00645161
50
nan 0.00645161
51
nan 0.00645161
52
nan 0.00645161
53
nan 0.00645161
54
nan 0.00645161
55
