In [41]:
from srea import srea
from stn import STN
from dispatcher import Dispatcher
from utils import trunc
from temporal_constraint import TemporalConstraint
from scipy.stats import norm

def live(event, exec_windows, time):
    return time+0.001 >= trunc(exec_windows[event][0],1) and time-0.001 <= trunc(exec_windows[event][1],1)

def check_schedule_consistency(schedule, stn):
    for u, v, tc in stn.stn.edges(data='tc'):
        if not tc.contingent:
            if (schedule[v] - schedule[u] < tc.constraint[0]-.001) or (schedule[v] - schedule[u] > tc.constraint[1]+.001):
                print(f'Invalid: ({u}, {schedule[u]}) and ({v}, {schedule[v]}) violate {tc.constraint}.')
                return False
    print('Valid!')
    return True

def dreamier(pstn : STN, dispatcher : Dispatcher):
    srea_out = srea(pstn)
    guide_stn : STN = srea_out['stnu']
    execution_windows = srea_out['execution_windows']
    alpha = srea_out['alpha']
    nodes = set( guide_stn.stn.nodes )

    schedule = {'START' : 0.0}
    contingent_dispatch_arrivals = {}
    req_enabled = set()

    predecessors = {event : guide_stn.find_predecessors(event) for event in nodes}
    predecessors['START'] = set()

    contingent_map = guide_stn.contingent_map()
    required_events = nodes - set(contingent_map.keys()) - set('START')
    run_srea = False

    k = 0
    acc_k = 0 # CHANGE ADDED
    mAR = -1
    mSC = 0
    mBR = 1 # CHANGE ADDED
    
    dispatcher.start()
    while len(schedule) < guide_stn.stn.number_of_nodes():
        t = trunc(dispatcher.time(),1)
        # print(f"{'Time:':<25}{t}")
        # print(f"{'Dispatched:':<25}{schedule}")
        # print(f"{'Exec windows:':<25}{execution_windows}")
        # print(f"{'(Hidden) arrivals:':<25}{contingent_dispatch_arrivals}")
        # print(t, schedule, execution_windows, contingent_dispatch_arrivals, )
        # print()
        
            
        for con in contingent_dispatch_arrivals:
            end = contingent_dispatch_arrivals[con][1]
            if t < trunc(end,1):
                run_srea = True
                k += 1
            elif t == trunc(end,1):
                schedule[con] = t
                dispatcher.receive(con)
                run_srea = True
                k += 1
        
        for req in required_events:
            if guide_stn.enabled(req, schedule, predecessors):
                req_enabled.add(req)
        
        
        contingent_map = guide_stn.contingent_map()
        total = len(contingent_map)
        print("total ", total)
        if (acc_k/total) <= mBR:
            print("rescheduling")
            if run_srea and (1 - alpha) ** k <= mAR:
                updated_stn = pstn.execution_update(t, schedule, contingent_dispatch_arrivals)
                updated_srea = srea(updated_stn)
                if updated_srea['stnu'] is None:
                    # print("No LP solution, controllability not guaranteed")
                    pass
                else:
                    if abs( updated_srea['alpha'] - alpha ) >= mSC:
                        guide_stn, execution_windows, alpha = updated_srea['stnu'], updated_srea['execution_windows'], updated_srea['alpha']
                        k = 0

                contingent_map = guide_stn.contingent_map()
                
        for req in required_events:
            if live(req, execution_windows, t) and (req in req_enabled):
                schedule[req] = t
                dispatcher.dispatch(req)
                req_enabled.remove(req)
        
        for con in contingent_map:
            if guide_stn.enabled(con, schedule) and con not in contingent_dispatch_arrivals:
                # if con == 'Aet':
                #     contingent_dispatch_arrivals[con] = (t, t + 1.1 )
                # else:
                contingent_dispatch_arrivals[con] = (t, t + max(0.1, pstn.stn.edges[contingent_map[con]]['tc'].sample()) )
                acc_k += 1 # CHANGE ADDED to track how many total uncertain events have been executed
            print(acc_k) # CHANGE ADDED
        
        run_srea = False
        
        dispatcher.sleep(0.10000001)
    
    if len(schedule) == guide_stn.stn.number_of_nodes():
        print("~~~~~~Dispatching complete~~~~~~")
        return check_schedule_consistency(schedule, pstn)
        # print(schedule)
    else:
        print("Execution terminating early, STN inconsistent")
        return False
            
                


Add acc_k as a variable to track the number of uncertain events

In [42]:
stn = STN()
stn.add_edge('START', 'START', TemporalConstraint([0.0, 0.0]))
stn.add_edge('START', 'Ast', TemporalConstraint([0.0, 10.0]))
stn.add_edge('START', 'Aet', TemporalConstraint([0.0, 10.0]))
stn.add_edge('START', 'Bst', TemporalConstraint([0.0, 10.0]))
stn.add_edge('START', 'Bet', TemporalConstraint([0.0, 10.0]))

stn.add_edge('Ast', 'Aet', TemporalConstraint(norm(loc=6, scale=2), contingent=True))
stn.add_edge('Bst', 'Bet', TemporalConstraint(norm(loc=6, scale=1), contingent=True))
stn.add_edge('Aet', 'Bet', TemporalConstraint([-2.0, 2.0]))

In [43]:
dispatcher=Dispatcher()

In [44]:
dreamier(stn, dispatcher)

[33m000.0000[0m: Starting
total  2
rescheduling
[33m000.0003[0m: Dispatched Ast 
[33m000.0004[0m: Dispatched Bst 
1
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2
total  2
rescheduling
2
2


KeyboardInterrupt: 

In [None]:
success = 0
attempts = 500
for i in range(attempts):
    print(i)
    dispatcher = Dispatcher(sim_time=True, quiet=True)
    success += drea(stn, dispatcher)

print('Robustness:', success/attempts * 100)