-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_all_slices.py
75 lines (58 loc) · 2.94 KB
/
train_all_slices.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import cvxpy as cp
import matplotlib.pyplot as matplt
from utils import *
from ddpg_alg_spinup import ddpg
from t3_ddpg_alg_spinup import td3
import tensorflow as tf
from env_mra import ResourceEnv
import numpy as np
import time
import pickle
from parameters import *
if __name__ == "__main__":
with open("saved_alpha.pickle", "wb") as fileop:
pickle.dump(alpha, fileop)
with open("saved_weight.pickle", "wb") as fileop:
pickle.dump(weight, fileop)
########################################################################################################################
########################################## Main Training #############################################
########################################################################################################################
start_time = time.time()
utility = np.zeros(SliceNum)
x = np.zeros([UENum, maxTime], dtype=np.float32)
for i in range(SliceNum):
ac_kwargs = dict(hidden_sizes=hidden_sizes, activation=tf.nn.relu, output_activation=tf.nn.sigmoid)
logger_kwargs = dict(output_dir=str(RESNum) + 'slice' + str(i), exp_name=str(RESNum) + 'slice_exp' + str(i))
env = ResourceEnv(alpha=alpha[i], weight=weight[i],
num_res=RESNum, num_user=UENum,
max_time=maxTime, min_reward=minReward,
rho=rho, test_env=False)
utility[i], _ = td3(env=env, ac_kwargs=ac_kwargs,
steps_per_epoch=steps_per_epoch,
epochs=epochs, pi_lr=pi_lr, q_lr=q_lr,
start_steps=start_steps, batch_size=batch_size,
seed=seed, replay_size=replay_size, max_ep_len=maxTime,
logger_kwargs=logger_kwargs, fresh_learn_idx=True)
# utility[i], _ = ddpg(env=env, ac_kwargs=ac_kwargs,
# steps_per_epoch=steps_per_epoch,
# epochs=epochs, pi_lr=pi_lr, q_lr=q_lr,
# start_steps=start_steps, batch_size=batch_size,
# seed=seed, replay_size=replay_size, max_ep_len=maxTime,
# logger_kwargs=logger_kwargs, fresh_learn_idx=True)
print('slice' + str(i) + 'training completed.')
end_time = time.time()
print('Training Time is ' + str(end_time - start_time))
##################################### result ploting ###############################################
with open("saved_alpha.pickle", "rb") as fileop:
load_alpha = pickle.load(fileop)
print(load_alpha)
with open("saved_weight.pickle", "rb") as fileop:
load_weight = pickle.load(fileop)
print(load_weight)
# print(weight)
# matplt.subplot(2, 1, 1)
# matplt.plot(sum_utility)
# matplt.subplot(2, 1, 2)
# matplt.plot(sum_x)
matplt.show()
print('done')