In [None]:
import numpy as np

# start: only 1 BS engine
# assuptions:
# Assumptions
# - 1 unique key per job
# Pattern is Linear Ops -> PBSs
# 1 param set for all jobs
# Static scheduler
# No need to schedule Noise management (PBS are inserted already as needed)
# (6. Schoolbook addition/Multiplication/Comparison Patterns)

# - BS engine can only work with one specific parameterset, so keys can be different but their sizes are known in advance
# - and therefore the key loading time is also known in advance
# - BS engines that share jobs must also use the same parameter sets, otherwise it would not make any sense
# - BS engine does not handle scalar MulAdds, as required for scaling and the keyswitch
# we assume that all jobs are dynamic, so they produce in-between results and the PBS needed for noise management were inserted by a compiler

# all PBS functions can be brought to full precision by 1 batch pause (because need to do scaling in the background), then 4 PBS in a batch, then 1 batch pause (accumulate the result)
# --> we simply model that as one PBS call

# Advanced aspects:
# - BS engine that can handle different parameter sets
# - dynamic noise management: the schduler needs to add PBS for noise management

key_freespace_indicator = -1
batch_freespace_indicator = -1
global_tic_cnt = 0 # incremented by scheduler

# TODO: model jobs as having single-file-codepieces and vectorized codepieces for batching, what that exactly looks like depends on the use case, good coding will vectorize more
# need to respect flexibilities --> need DAGs
# need to respect fairness
# scheduling is a closed problem, need literature research
# maybe go for multiple PBS engine scheduling? But perhaps too difficult in practice as in-between-results must be shared

class TFHE_sub_job:
    # a sub_job is a part of a job that leads to an in-between result that other jobs might wait for
    def __init__(self, parallelizeable_calls_array, key_idx):
        # parallelizeable_calls_array for each "step" of the job computation, this array shows how many PBS calls can de done
        # indepentendly and therefore in parallel
        self.key_idx = key_idx
        self.parallelizeable_calls_array = np.array(parallelizeable_calls_array)
        self.parall_array_temp = np.array(parallelizeable_calls_array)
        self.parall_arr_idx = 0
        self.maketime = -1
        self.age = 0
        self.done = False
        self.best_maketime_to_completion = -1
    
    def get_best_maketime(self, batchsize): # in batches, since 1 batch = 1 time unit, batches are a time unit
        return sum(np.ceil(self.parallelizeable_calls_array/batchsize))
    
    def get_best_maketime_to_completion(self, batchsize): # in batches, since 1 batch = 1 time unit, batches are a time unit
        return sum(np.ceil(self.parall_array_temp/batchsize))
    
    def inc_age(self):
        # scheduler must increase the age until the job is done
        if sum(self.parall_array_temp) != 0:
            self.age += 1

    def serve_pbs(self, cnt):
        cnts_left = self.parall_array_temp[self.parall_arr_idx]
        unused_cnts = 0
        if cnts_left > cnt:
            self.parall_array_temp[self.parall_arr_idx] -= cnt
        else:
            self.parall_array_temp[self.parall_arr_idx] = 0
            unused_cnts = cnt - cnts_left
            if self.parall_arr_idx+1 < len(self.parallelizeable_calls_array):
                self.parall_arr_idx += 1
            else:
                # list empty = job done
                self.done=True
                self.maketime = self.age
                # sanity check
                checksum = sum(self.parall_array_temp)
                if checksum != 0:
                    print("Error: job finished with " + checksum + " leftover PBS calls - this is not supposed to happen, something went wrong")
        return unused_cnts

class PBS_engine:
    def __init__(self, pbs_per_s, batchsize, key_loading_time, max_num_keys):
        self.speed = pbs_per_s
        self.batchsize = batchsize
        self.key_loading_time = key_loading_time # in batches per second
        self.max_num_keys = max_num_keys

        self.key_storage = np.ones(max_num_keys)*key_freespace_indicator # this array stores the indices of the keys that are present on the device
        self.batch = np.ones(batchsize)*batch_freespace_indicator
        self.wasted_batchslots = 0
        self.new_key_storage = np.ones(max_num_keys)
        self.key_loading_cooldown = 0

    def run(self):
        self.wasted_batchslots += len(np.argwhere(self.batch==batch_freespace_indicator))
        # empty batch
        self.batch = np.ones(self.batchsize)*batch_freespace_indicator
        # update key storage
        if self.key_loading_cooldown > 0:
            self.key_loading_cooldown -= 1
        else:
            self.key_storage = self.new_key_storage
        # TODO: fill batch anew

    def add_to_batch(self, num_ciphertexts, key_idx):
        # only allow to fill batch with ciphertexts that we have the keys for
        if len(np.argwhere(self.key_storage==key_idx)) > 0:
            # is there still space in the batch?
            indices = np.argwhere(self.batch==batch_freespace_indicator)
            if len(indices) > num_ciphertexts:
                for i in range(num_ciphertexts):
                    self.key_storage[indices[i]] = key_idx
            else:
                print("Error: could not add to batch ciphertext of key " + key_idx + " - reason: batch capacity insufficient")
        else:
            print("Error: could not add batch with key " + key_idx + " - reason: key not present - load the key first")

    def unload_key(self, key_idx):
        indice = np.argwhere(self.key_storage==key_idx)
        # we expect the key to be present on the device only once
        if len(indice) > 0:
            self.new_key_storage[indice[0]]=key_freespace_indicator
        else:
            # this is unexpected
            print("Warning: could not unload key " + key_idx + " - reason: key not present")

    def load_key(self, key_idx):
        # are we already loading a key?
        if self.key_loading_cooldown == 0:
            # is there still space?
            indices = np.argwhere(self.key_storage==key_freespace_indicator)
            if len(indices) > 0:
                # is the key already present on the device?
                if len(np.argwhere(self.key_storage==key_idx)) > 0:
                    print("Warning: key " + key_idx + " - is already present, did not load again")
                else:
                    self.new_key_storage[indices[0]] = key_idx
                    self.key_loading_cooldown = self.key_loading_time
            else:
                # need to overwrite some key
                # yes, could do complicated logic here but we trust that after a job is done the unloading function is called
                print("Error: could not load key " + key_idx + " - reason: key storage full - unload a key first")
        else:
                print("Error: could not load key " + key_idx + " - reason: a key is already being loaded")

# scheduler
# given a list of sub_jobs, sorted by priority metric, and a PBS_engine, give jobs to the engine
# strategy: focus jobs for better latency in a certain priority, only use other jobs to fill up leftover spaces

In [15]:
test = np.arange(5)
np.ceil(test/5)

array([0., 1., 1., 1., 1.])