In [1]:
import sys
sys.path.append('./gym_dagsched/data_generation/tpch/')

import torch
from torch_geometric.data import Batch
from torch_geometric.utils.convert import from_networkx

from gym_dagsched.envs.dagsched_sim import DagSchedSim
from gym_dagsched.policies.decima_agent import ActorNetwork
from gym_dagsched.utils.metrics import avg_job_duration, makespan
from gym_dagsched.data_generation.random_datagen import RandomDataGen
from gym_dagsched.data_generation.tpch_datagen import TPCHDataGen
from visualization import make_gantt


datagen = RandomDataGen(
    max_ops=20,
    # max_tasks=200,
    max_tasks=4,
    mean_task_duration=2000.,
    n_worker_types=1)

# datagen = TPCHDataGen()

sim = DagSchedSim()
policy = ActorNetwork(5, 8)


In [5]:
def find_op(op_idx):
    i = 0
    for job in sim.jobs:
        if op_idx < i + len(job.ops):
            op = job.ops[op_idx - i]
            break
        else:
            i += len(job.ops)
    return op

def sample_action(ops, prlvl):
    c = torch.distributions.Categorical(probs=ops)        
    next_op_idx = c.sample().item()
    next_op = find_op(next_op_idx)
    c = torch.distributions.Categorical(probs=prlvl[next_op.job_id])        
    n_workers = c.sample().item()

    return next_op_idx, next_op, n_workers


def run_episode(ep_length, initial_timeline, workers):
    sim.reset(initial_timeline, workers)
    actions = []
    obsns = []
    rewards = []
    # done = False

    while len(actions) < ep_length:
        dags = []
        op_msk = []
        for job in sim.jobs:
            job.update_feature_vectors(sim.workers)
            dags += [from_networkx(job.dag)]
            for op in job.ops:
                op_msk += [1] if op in sim.frontier_ops else [0]

        if len(dags) == 0:
            sim.step(None, 0)
            continue

        dag_batch = Batch.from_data_list(dags)
        op_msk = torch.tensor(op_msk)
        prlvl_msk = torch.ones((len(dags), len(sim.workers)))
        obsns += [(dag_batch, len(sim.workers), op_msk, prlvl_msk)]
        
        ops, prlvl = policy(dag_batch, len(sim.workers), op_msk, prlvl_msk)
        next_op_idx, next_op, n_workers = sample_action(ops, prlvl)
        actions += [(next_op_idx, n_workers)]
        
        _, reward = sim.step(next_op, n_workers)
        rewards += [reward]
    
    return actions, obsns, rewards

In [6]:
import numpy as np

N_SEQUENCES = 1
N_EP_PER_SEQ = 1
mean_ep_length = 20
optim = torch.optim.Adam(policy.parameters(), lr=.005)

for _ in range(N_SEQUENCES):
    ep_length = np.random.geometric(1/mean_ep_length)

    initial_timeline = datagen.initial_timeline(
        n_job_arrivals=20, n_init_jobs=0, mjit=1000.)
    workers = datagen.workers(n_workers=5)

    episodes = []
    for _ in range(N_EP_PER_SEQ):
        actions, obsns, rewards = run_episode(ep_length, initial_timeline, workers)
        total_rewards = []
        for i in range(len(rewards)):
            total_rewards += [sum(rewards[i:])]
        episodes += [(actions, obsns, total_rewards)]

    for k in range(ep_length):
        baseline = np.mean([episode[2][k] for episode in episodes])
        for actions, obsns, total_reward in episodes:
            action, obs, total_reward = actions[k], obsns[k], total_reward[k]
            dag_batch, len_workers, op_msk, prlvl_msk = obs
            next_op_idx, n_workers = action

            ops, prlvl = policy(dag_batch, len_workers, op_msk, prlvl_msk)

            c = torch.distributions.Categorical(probs=ops)
            next_op_lgp = c.log_prob(torch.tensor(next_op_idx))
            
            next_op = find_op(next_op_idx)
            c = torch.distributions.Categorical(probs=prlvl[next_op.job_id])  
            n_workers_lgp = c.log_prob(torch.tensor(n_workers))

            loss = -(next_op_lgp + n_workers_lgp) * (total_reward - baseline)

            optim.zero_grad()
            loss.backward()
            optim.step()
            
    mean_ep_length += 3

91.60151029415431
91.60151029415431
91.60151029415431
284.4567625895606
512.4832229539913
1479.8348003238366
2469.542329704819
2487.211906754563
2599.5098880106207
2932.2596829992403
4195.065597488427
4960.611085366908
5140.7898871052785
5520.64946434274
6644.868459636693
6779.74795637066
8330.492973746606
8342.748809291334
8608.751630616796
9129.228533501575
9645.379506806634
9645.379506806634
9670.458909713563
13152.992847729982
13498.227033722464
13586.187348616
14439.245603634014
14442.447643649008
14535.072268430125
14976.154764451943
14988.984906760736
15476.765221956904
16140.446344821014
16433.77667710228
17068.461175721084
18749.203524225133
18788.11831486273
18973.129714328195
19367.57347106927
21094.397737874988
21094.397737874988
21577.082285168355
24182.854852698358
24182.854852698358
24182.854852698358
24778.09315658224
24778.09315658224
24778.09315658224
25075.303763122265
27025.86741360771
27341.054332917716
27977.076839537727
29227.059543987194
30104.6469579167
30176.1

In [9]:
episodes[0][2]

[-44.63120231528518,
 -44.63120231528518,
 -44.63120231528518,
 -44.611916790055645,
 -44.56631149798275,
 -44.276106024771806,
 -43.8802230130194,
 -43.87138822449454,
 -43.804009435740895,
 -43.60435955874773,
 -42.84667601005421,
 -42.310794168539275,
 -42.16665112714857,
 -41.8627634653586,
 -40.85096636959405,
 -40.716086872860075,
 -39.16534185548413,
 -39.15308601993941,
 -38.88708319861394,
 -38.31455860544068,
 -37.69517743747461,
 -37.69517743747461,
 -37.66508215398629,
 -33.13778803456495,
 -32.654460174175476,
 -32.531315733324526,
 -31.251728350797507,
 -31.246605086773513,
 -31.098405687123723,
 -30.348565443886635,
 -30.325471187730813,
 -29.398688588858086,
 -28.07132634312987,
 -27.484665678567335,
 -26.21529668132973,
 -22.853811984321634,
 -22.775982403046438,
 -22.40595960411551,
 -21.61707209063336,
 -18.163423557021925,
 -18.163423557021925,
 -17.19805446243519,
 -11.986509327375186,
 -11.986509327375186,
 -11.986509327375186,
 -10.796032719607423,
 -10.796032719