In [7]:
import random
import numpy as np
from collections import namedtuple

import sys
sys.path.append("..")
from envs.test_env_v2 import TestEnv_v2

import torch
import torch.nn as nn
import torch.optim as optim

In [8]:
HIDDEN_SIZE = 128
BATCH_SIZE = 64
PERCENTILE = 70

GAMMA = 1.001

In [9]:
class Net(nn.Module):
    def __init__(self, obs_size, hidden_size, n_actions):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size[1], hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, int(hidden_size/2)),
            nn.ReLU()
        )
        self.out = nn.Linear(obs_size[0] * int(hidden_size/2), n_actions) 

    def forward(self, x):
        x = self.net(x)
        x = x.view(x.size(0), -1)   # to (batch_size, obs_size[0] * hidden_size/2)
        output = self.out(x)
        return output

Episode = namedtuple('Episode', field_names=['reward', 'steps', 'info'])
EpisodeStep = namedtuple('EpisodeStep', field_names=['observation', 'action'])

In [11]:
if __name__ == "__main__":
    env = TestEnv_v2()
    
    obs_size = env.observation_size
    n_actions = env.action_num

    net = Net(obs_size, HIDDEN_SIZE, n_actions)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net = nn.DataParallel(net)

    net = net.cuda()
    
    net.load_state_dict(torch.load('net_params.pkl'))


Let's use 4 GPUs!


In [12]:
    obs = env.reset()
    sm = nn.Softmax(dim=1)    
    
    while True:
        obs_v = torch.FloatTensor([obs]).cuda()
        act_probs_v = sm(net(obs_v))
        act_probs = act_probs_v.cpu().data.numpy()[0]
        action = np.random.choice(len(act_probs), p=act_probs)
        next_obs, reward, is_done, ext_info = env.step(action)
        print(reward)
        if is_done:
            break

-1.105350623105628
-0.48414873747640813
-0.5830951894845301
-0.8127115109311546
-6
-0.35846896657869837
-6
-0.16552945357246848
-0.3687817782917155
-0.06999999999999995
-0.3911521443121589
-0.8005623023850175
-0.5556977595779922
-6
-10
-0.9411163583744574
-0.296141857899217
-0.18027756377319942
-0.4801041553663121
-0.46518813398452025
-0.31764760348537185
-0.5491812087098392
-0.20124611797498107
-0.7334848328356899
-10
-0.6476109943476871
-0.6476109943476871
-0.873212459828649
-0.4949747468305833
-0.5869412236331675
-0.5514526271584895
-0.3008321791298264
-1.0781465577554843
-0.6003332407921453
-0.284253408071038
-6
-0.3780211634287159
-0.3712142238654117
-6
-6
-10
-0.1392838827718412
-0.5787054518492114
-0.7720103626247513
-0.9013878188659974
-0.7134423592694789
-0.4301162633521313
-0.7140028011149535
-6
-0.6453681120105021
-0.6453681120105021
-0.5448853090330111
-0.38639358172723315
-0.09899494936611662
-6
-0.3847076812334269
-0.36124783736376886
-0.761051903617618
-0.423792402008341