In [3]:
import sys
sys.path.append('./gym_dagsched/data_generation/tpch/')

import torch

from gym_dagsched.envs.dagsched_env import DagSchedEnv
from gym_dagsched.policies.heuristics import fcfs, max_children, srt, lrt
from gym_dagsched.policies.decima_agent import ActorNetwork
from gym_dagsched.utils.metrics import avg_job_duration, makespan
from gym_dagsched.data_generation.random_datagen import RandomDataGen
from gym_dagsched.data_generation.tpch_datagen import TPCHDataGen
from gym_dagsched.reinforce.reinforce_utils import sample_action

# heurs = [max_children, srt]

policy = ActorNetwork(5, 8, 10)
policy.load_state_dict(torch.load('policy.pt'))
policy.eval()

def decima(env):
    obs = env._observe()
    if obs is None or env.n_active_jobs == 0:
        next_op, prlvl = None, 0
    else:
        dag_batch, op_msk, prlvl_msk = obs
        ops_probs, prlvl_probs = policy(dag_batch, op_msk, prlvl_msk)
        next_op, prlvl, _, _ = \
            sample_action(env, ops_probs, prlvl_probs)
    return next_op, prlvl

heurs = [srt, decima]

datagen = RandomDataGen(
    max_ops=8,
    max_tasks=4,
    mean_task_duration=2000.,
    n_worker_types=1)

# datagen = TPCHDataGen()

envs = [DagSchedEnv() for _ in range(len(heurs))]

In [6]:
initial_timeline = datagen.initial_timeline(
    n_job_arrivals=100, n_init_jobs=0, mjit=1000.)

workers = datagen.workers(n_workers=10)

for env, heur in zip(envs, heurs):
    env.reset(initial_timeline, workers)
    done = False
    while not done:
        next_op, n_workers = heur(env)
        _, _, done = env.step(next_op, n_workers)

In [7]:
for env in envs:
    print(int(avg_job_duration(env)/1000))

79
162


In [8]:
for env in envs:
    print(int(makespan(env)/1000))

AttributeError: 'int' object has no attribute 't_completed'

In [42]:
for job in envs[0].jobs:
    print(len(job.ops))
    print([(op.n_tasks,int(op.task_duration[0])) for op in job.ops])
    print()

14
[(2, 1759), (2, 910), (200, 37), (2, 1793), (200, 42), (58, 3711), (200, 404), (13, 2511), (200, 175), (2, 2166), (200, 155), (200, 37), (200, 32), (6, 153)]

14
[(2, 363), (2, 811), (9, 3938), (295, 2868), (200, 1102), (200, 44), (200, 216), (46, 1880), (200, 410), (66, 2342), (200, 431), (200, 49), (200, 50), (176, 74)]

8
[(474, 2453), (200, 401), (106, 2488), (474, 2222), (200, 1389), (15, 1823), (200, 109), (1, 194)]

9
[(2, 1914), (5, 2673), (200, 369), (2, 2825), (200, 209), (200, 155), (200, 72), (200, 81), (200, 110)]

9
[(106, 2881), (15, 3224), (200, 299), (2, 123), (200, 104), (474, 2306), (200, 633), (200, 174), (1, 183)]

9
[(2, 2282), (18, 2876), (200, 381), (4, 3944), (200, 239), (200, 169), (200, 80), (200, 87), (200, 102)]

6
[(593, 2437), (133, 2243), (200, 657), (200, 40), (200, 31), (3, 163)]

9
[(7, 3378), (2, 2310), (200, 208), (2, 1299), (200, 74), (29, 3343), (200, 235), (200, 92), (1, 195)]

6
[(2, 2705), (7, 3253), (200, 259), (29, 3408), (200, 234), (1, 2

In [47]:
for job in envs[0].jobs:
    print(len(job.ops))
    print([(op.n_tasks,int(op.task_duration[0])) for op in job.ops])
    print()

9
[(46, 1285), (139, 3087), (196, 307), (10, 6066), (49, 2954), (199, 1215), (147, 3625), (20, 7318), (68, 3198)]

16
[(110, 100), (93, 1252), (141, 153), (51, 13739), (199, 20972), (49, 390), (197, 197), (41, 10314), (26, 883), (55, 5575), (5, 2022), (185, 7883), (37, 1885), (54, 6193), (95, 6077), (53, 4306)]

12
[(84, 746), (28, 100), (142, 5588), (62, 7881), (165, 1666), (180, 2362), (170, 1195), (163, 2307), (109, 3959), (62, 1703), (109, 5454), (179, 128)]

13
[(198, 162), (146, 1019), (157, 6824), (184, 100), (78, 3803), (155, 2055), (89, 105), (32, 10576), (124, 2723), (61, 959), (170, 3873), (50, 1812), (109, 322)]

14
[(21, 8117), (69, 100), (134, 5452), (187, 18887), (67, 9657), (166, 238), (64, 1982), (149, 1365), (67, 5955), (145, 1612), (160, 13349), (194, 10071), (197, 1412), (7, 5864)]

17
[(135, 6180), (77, 9560), (84, 2800), (5, 11180), (91, 17040), (12, 7159), (178, 1524), (73, 6858), (35, 558), (89, 15908), (175, 26773), (36, 15993), (23, 5494), (157, 3061), (191, 6