In [11]:
from CombinerJob import CombinerJob
from qiskit.providers.fake_provider import GenericBackendV2
from qiskit.providers import BackendV2
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister, transpile
from vm_executable import *
from qiskit_ibm_runtime import SamplerV2 as Sampler


from qiskit.transpiler import PassManager
from qiskit.transpiler.passes import GateDirection
from qiskit.circuit.library import *
from qiskit.converters import circuit_to_dag
from qiskit.transpiler import TransformationPass
import numpy as np
import random

cuit
class GateDirectionTranslator(TransformationPass):
    def run(self, dag):


        for node in dag.op_nodes():
            if node.op.name == 'sdg':
                replacement = QuantumCircuit(1)
                replacement.rz(-np.pi/2, 0)
                dag.substitute_node_with_dag(node, circuit_to_dag(replacement))
            
            if node.op.name == 's':
                replacement = QuantumCircuit(1)
                replacement.rz(np.pi/2, 0)
                dag.substitute_node_with_dag(node, circuit_to_dag(replacement))
                
            if node.op.name == 'h':
                replacement = QuantumCircuit(1)
                replacement.rz(np.pi/2, 0)
                replacement.sx(0)
                replacement.rz(np.pi/2, 0)
                dag.substitute_node_with_dag(node, circuit_to_dag(replacement))

        return dag

class HypervisorBackend(BackendV2):

    def __init__(self, backend, vms, hc, vc, **fields):
        super().__init__(**fields)
        self.backend = backend
        self.sampler = Sampler(mode=backend)
        self.vms = vms
        self.hc = hc
        self.vc = vc
        self.translate = PassManager([GateDirection(backend.coupling_map, backend.target), GateDirectionTranslator()])

    @property
    def target(self):
        return self.backend.target

    @property
    def max_circuits(self):
        return self.backend.max_circuits

    def run(self, executables, selection = None, time_sched = False, intra_vm_sched = False, noise_aware = False, **kwargs) -> CombinerJob:
        QVM_INTERNAL_MAX_PARTITIONS = 2

        if selection == None:
            selection = self.schedule(executables, time_sched, intra_vm_sched, noise_aware)
        mappings = []
        clbit_cnt = []
        compiled_circuits = []
        for i, r, c, n, m, v in selection:
            mappings.append(self.get_mapping(r, c, n, m))

            if(len(i) > 1):
                internal_circuit = self.combine_internal(list(executables[j] for j in i), [[0, 1, 2], [4, 5, 6]])

                compiled_circuits.append(internal_circuit)
                for j in i:    
                    clbit_cnt.append(executables[j].clbits)
            else:
                compiled_circuits.append(executables[i[0]].qc[v])
                clbit_cnt.append(executables[i[0]].clbits)


        combined_circ = self.combine(compiled_circuits, mappings, self.backend.num_qubits, 'vm')
        direction_corrected_circ = self.translate.run(combined_circ)


        dummy_creg = ClassicalRegister(1, 'dummy')
        direction_corrected_circ.add_register(dummy_creg)
        with direction_corrected_circ.if_test((dummy_creg, 1)):
            direction_corrected_circ.x(0)

        delete_indexes = sorted((j for i in selection for j in i[0]), reverse=True)
        for i in delete_indexes:
            executables.pop(i)


        return CombinerJob(self.sampler.run([direction_corrected_circ]), mappings, clbit_cnt, backend=self)
    

    def dryrun(self, executables, selection = None, time_sched = False, intra_vm_sched = False, noise_aware = False, **kwargs):
        QVM_INTERNAL_MAX_PARTITIONS = 2
        if selection == None:
            selection = self.schedule(executables, time_sched, intra_vm_sched, noise_aware)
        mappings = []
        clbit_cnt = []
        compiled_circuits = []
        for i, r, c, n, m, v in selection:
            mappings.append(self.get_mapping(r, c, n, m))

            if(len(i) > 1):
                internal_circuit = self.combine_internal(list(executables[j] for j in i), [[0, 1, 2], [4, 5, 6]])

                compiled_circuits.append(internal_circuit)
                for j in i:    
                    clbit_cnt.append(executables[j].clbits)
            else:
                compiled_circuits.append(executables[i[0]].qc[v])
                clbit_cnt.append(executables[i[0]].clbits)


        combined_circ = self.combine(compiled_circuits, mappings, self.backend.num_qubits, 'vm')
        direction_corrected_circ = self.translate.run(combined_circ)

        dummy_creg = ClassicalRegister(1, 'dummy')
        direction_corrected_circ.add_register(dummy_creg)
        with direction_corrected_circ.if_test((dummy_creg, 1)):
            direction_corrected_circ.x(0)
        
        return direction_corrected_circ


    def get_mapping(self, r, c, n, m): 
        ret = []
        hc_set = set()

        for i in range(r, r+n, 1):
            for j in range(c, c+m, 1):
                ret += self.vms[i][j]
            

        for i in range(r, r+n, 1):
            for j in range(c, c+m-1, 1):
                ret += self.hc[i][j]
                hc_set = hc_set.union(set(self.hc[i][j]))

        for i in range(r, r+n-1, 1):
            for j in range(c, c+m, 1):
                for k in self.vc[i][j]:
                    if k not in hc_set:
                        ret.append(k)
        return ret

    def get_qvm_ranking(self):
        def score(qubits: list, tgt_cm, backend) -> float:
            link_err = []
            readout_err = []
            for q1, q2 in tgt_cm:
                if (q1, q2) not in backend.coupling_map: 
                    continue
                link_err.append(backend.properties().gate_error('ecr', (q1, q2))) 
            for q in qubits:
                readout_err.append(backend.properties().readout_error(q))
            return np.mean(link_err)

        scores = []
        for i in range(3):
            for j in range(3):
                mapping = self.get_mapping(i, j, 1, 1)
                vm_coupling_map = [[1, 0], [0, 1], [1, 2], [2, 1], [1, 3], [3, 1], [3, 5], [5, 3], [4, 5], [5, 4], [5, 6], [6, 5]]
                cm = [(mapping[q1], mapping[q2]) for q1, q2 in vm_coupling_map]
                scores.append((score(mapping, cm, self.backend), i*3+j))

        scores.sort() 

        ranking = [0]*9
        for rank, (score, index) in enumerate(scores):
            ranking[rank] = index
        return ranking


    def schedule(self, executables, time_sched = False, intra_vm_sched = False, noise_aware = False):        

        def fit1(i, j, n, m, region_status) -> bool:
            for a in range(n):
                for b in range(m):
                    if region_status[i+a][j+b] >= 1:
                        return False
            return True


        def fit(region_status, exe, bad_qvm_mark):
            if is_sensitive(exe):
                for v in range(exe.versions):
                    n, m = exe.dimensions[v][0], exe.dimensions[v][1]
                    for i in range(3-n+1):
                        for j in range(3-m+1):

                            if fit1(i, j, n, m, region_status) and fit_bad_cnt(i, j, n, m, bad_qvm_mark) == 0:
                                return i, j, v
                return None, None, None
            else:
                max_bad_qvm_used = -1
                ret_i, ret_j, ret_v = None, None, None
                for v in range(exe.versions):
                    n, m = exe.dimensions[v][0], exe.dimensions[v][1]
                    for i in range(3-n+1):
                        for j in range(3-m+1):

                            bad_qvm_used = fit_bad_cnt(i, j, n, m, bad_qvm_mark)
                            if fit1(i, j, n, m, region_status) and bad_qvm_used > max_bad_qvm_used:
                                max_bad_qvm_used = bad_qvm_used
                                ret_i, ret_j, ret_v = i, j, v

                                if max_bad_qvm_used == n*m:
                                    return ret_i, ret_j, ret_v
                                
                return ret_i, ret_j, ret_v


        def fit_bad_cnt(i, j, n, m, bad_qvm_mark) -> int:
            ret = 0
            for a in range(n):
                for b in range(m):
                    ret += bad_qvm_mark[i+a][j+b]
            return ret


        def timefit(n, m, region_status, region_height, circ_depth, cur_volume, cur_max_height, max_reuse):
            for i in range(3-n+1):
                for j in range(3-m+1):
                    if max_reuse_check(i, j, n, m, region_status, max_reuse) == False:
                        continue
                    region_max_height = max_pool(i, j, n, m, region_height)
                    if region_max_height + circ_depth < cur_max_height:
                        return i, j
            return None, None

        def max_reuse_check(i, j, n, m, region_status, max_reuse) -> bool:
            for a in range(n):
                for b in range(m):
                    if region_status[i+a][j+b] >= max_reuse:
                        return False
            return True


        def update_region_status(i, j, n, m, region_status, region_height, circ_depth, cur_max_height):
            pooled_height = max_pool(i, j, n, m, region_height)

            for a in range(n):
                for b in range(m):
                    region_status[i+a][j+b] += 1
                    region_height[i+a][j+b] = pooled_height + circ_depth

            if cur_max_height < pooled_height + circ_depth:
                print('exceeding max height') 
            return max(cur_max_height, pooled_height + circ_depth)

        MAX_REUSE = 2

        def mark_bad_qvm(n):
            mark = [[0]*3, [0]*3, [0]*3]
            ranking = self.get_qvm_ranking()

            for i in range(1, n+1, 1):
                qvm_index = ranking[-i]
                r, c = qvm_index//3, qvm_index%3
                mark[r][c] = 1

            return mark

        def is_sensitive(exe) -> bool:
            if noise_aware == False:
                return True

            sensitivity_threshold = 340
            for qc in exe.qc:
                op_cnt = sum(qc.count_ops().values())
                if op_cnt < sensitivity_threshold:
                    return True
            return False


        region_status = [[0]*3, [0]*3, [0]*3]
        region_height = [[0]*3, [0]*3, [0]*3] 
        remaining_region = len(self.vms) * len(self.vms[0])
        selection = []
        selected = set() 


        bad_qvm_mark = [[0]*3, [0]*3, [0]*3]

        if noise_aware == True:
            good_qvm_cnt = 6
            bad_qvm_cnt = 3
            bad_qvm_mark = mark_bad_qvm(bad_qvm_cnt)

        for i in range(len(executables)):
            if remaining_region == 0:
                break
            if remaining_region < executables[i].dimensions[0][0] * executables[i].dimensions[0][1]:
                continue
            r, c, v = fit(region_status, executables[i], bad_qvm_mark)
            if r != None:

                n, m = executables[i].dimensions[v][0], executables[i].dimensions[v][1]
                selection.append(([i], r, c, n, m, v))
                selected.add(i)
                
                for a in range(n):
                    for b in range(m):
                        region_status[r+a][c+b] = 1
                        region_height[r+a][c+b] += executables[i].qc[v].depth()

                remaining_region -= n*m



        if intra_vm_sched:
            self.intra_schedule(executables, selection, selected, region_height, time_sched = False)



        if noise_aware or not time_sched:
            return selection

        
        max_height = max(max(i) for i in region_height)
        min_height = min(min(i) for i in region_height)
        if max_height - min_height < 50:
            return selection



        util_volume = sum(sum(i) for i in region_height)
        

        remaining_reuse = 0
        for i in range(len(region_height)):
            for j in range(len(region_height[i])):
                if region_height[i][j] < max_height:
                    remaining_reuse += MAX_REUSE - region_status[i][j]
                else:
                    region_status[i][j] = MAX_REUSE
                    

        selection2 = []
        for i in range(len(executables)):
            if i in selected:
                continue

            for j in range(executables[i].versions):
                qc, n, m = executables[i].qc[j], executables[i].dimensions[j][0], executables[i].dimensions[j][1]
                r, c = timefit(n, m, region_status, region_height, qc.depth()+50, util_volume, max_height, MAX_REUSE) 
                if r != None:
                    update_params = (r, c, n, m, region_status, region_height, qc.depth()+50, max_height)
                    selection2.append(([i], r, c, n, m, j))
                    selected.add(i)
                    new_max_height = update_region_status(*update_params)
                    assert(new_max_height == max_height)
                    remaining_reuse -= n*m
                    break
            if remaining_reuse == 0:
                break

        if intra_vm_sched:
            self.intra_schedule(executables, selection2, selected, region_height, time_sched = False)

        return selection+selection2
    

    def intra_schedule(self, executables, selection: list, selected: {int}, region_height, time_sched = False):

        QVM_MAX_ALLOWED_PERCENTAGE = 1

        QVM_INTERNAL_MAX_PARTITIONS = 2

        QVM_INTERNAL_PARTITION_MAX_REUSE = 2

        def timefit_internal(qvm_status, circ_depth, max_reuse):
            for i, qvm in enumerate(qvm_status):
                max_depth = max(part[0] for part in qvm)
                for j, part in enumerate(qvm):

                    if part[0] + circ_depth <= max_depth and part[1] < max_reuse:
                        return i, j
            return None, None

        def all_part_usedup(qvm, qvm_status, max_reuse) -> bool:
            for part in qvm_status[qvm]:
                if part[1] < max_reuse:
                    return False
            return True
        
        remaining_reusable_qvm = 0
        remaining_partition_cnt = []
        qvm_status = []  
        max_height = max(max(i) for i in region_height)


        for i in selection:
            exe_index = i[0][0] 
            exe_ver = i[5]
            exe = executables[exe_index]

            if exe.half_qc != None:
                remaining_partition_cnt.append(1)
                qvm_status.append([[exe.half_qc.depth(), 1]])
                remaining_reusable_qvm += 1
            else:
                remaining_partition_cnt.append(0)
                qvm_status.append([[exe.qc[exe_ver].depth(), 1]])


        for i, exe in enumerate(executables):
            if i in selected or exe.half_qc == None:
                continue

            for j in range(len(selection)):
                if remaining_partition_cnt[j] > 0:
                    selection[j][0].append(i)
                    selected.add(i)
                    remaining_partition_cnt[j] -= 1
                    remaining_reusable_qvm -= 1

                    y, x = selection[j][1], selection[j][2]
                    region_height[y][x] = max(region_height[y][x], exe.half_qc.depth())

                    qvm_status[j].append([exe.half_qc.depth(), 1])
                    break

            if remaining_reusable_qvm == 0:
                break


        if not time_sched:
            return

        remaining_reusable_qvm = 0
        for i in range(len(selection)):

            if len(selection[i][0]) == 1:
                continue
            

            max_height = max(part[0] for part in qvm_status[i])
            min_height = min(part[0] for part in qvm_status[i])
            if max_height - min_height > 50:
                remaining_reusable_qvm += 1
            

            for part in qvm_status[i]:
                if part[0] == max_height:
                    part[1] = QVM_INTERNAL_PARTITION_MAX_REUSE

        for i, exe in enumerate(executables):
            if i in selected or exe.half_qc == None:
                continue
            circ_depth = exe.half_qc.depth()
            qvm, part = timefit_internal(qvm_status, circ_depth, QVM_INTERNAL_PARTITION_MAX_REUSE)
            if qvm != None:

                qvm_status[qvm][part][0] += circ_depth
                qvm_status[qvm][part][1] += 1
                selection[qvm][0].append(i)
                selected.add(i)


                max_height = max(part[0] for part in qvm_status[qvm])
                min_height = min(part[0] for part in qvm_status[qvm])

                if max_height - min_height <= 50 or all_part_usedup(qvm, qvm_status, QVM_INTERNAL_PARTITION_MAX_REUSE):
                    remaining_reusable_qvm -= 1
                if remaining_reusable_qvm == 0:
                    break


    def combine(self, vcs, mappings, num_qubits, clreg_prefix: str) -> QuantumCircuit:
        assert(len(vcs) == len(mappings))
        combined_qc_param = [QuantumRegister(num_qubits, 'q')]
        qubit_used = [False]*num_qubits


        creg_list = []
        for i, vc in enumerate(vcs):
            for creg in vc.cregs: 

                creg_list.append(ClassicalRegister(creg.size, clreg_prefix+f'{i}_'+creg.name))
        combined_qc_param += creg_list

        res = QuantumCircuit(*combined_qc_param)
        clbit_offset = 0
        for i in range(len(vcs)):
            reuse = False

            for j in mappings[i]:
                if qubit_used[j] == True:
                    reuse = True
                    res.reset(j)
            if reuse:
                res.barrier(mappings[i])



            res.compose(vcs[i], qubits = mappings[i], clbits = list(i for i in range(clbit_offset, clbit_offset+vcs[i].num_clbits)), inplace = True)

            for j in mappings[i]:
                qubit_used[j] = True
            clbit_offset += vcs[i].num_clbits

        return res


    def combine1(self, vcs, mappings, num_qubits) -> QuantumCircuit:
        assert(len(vcs) == len(mappings))
        tot_clbit = sum(vc.num_clbits for vc in vcs)
        res = QuantumCircuit(num_qubits, tot_clbit) 
        qubit_used = [False]*num_qubits

        clbit_offset = 0
        for i in range(len(vcs)):
            reuse = False

            for j in mappings[i]:
                if qubit_used[j] == True:
                    reuse = True
                    res.reset(j)

            if reuse:
                res.barrier(mappings[i])


            res.compose(vcs[i], qubits = mappings[i], clbits = list(i for i in range(clbit_offset, clbit_offset+vcs[i].num_clbits)), inplace = True)
            clbit_offset += vcs[i].num_clbits

            for j in mappings[i]:
                qubit_used[j] = True
            
        return res

    def combine_internal(self, exes, partition_mapping) -> QuantumCircuit:
        vcs = list(exe.half_qc for exe in exes)

        partition_table = [] 
        for i, vc in enumerate(vcs):
            depth = vc.depth()
            if len(partition_table) < len(partition_mapping):
                partition_table.append([[i], depth])
            else:

                target_part = 0
                min_depth = partition_table[0][1]
                for j, part in enumerate(partition_table):
                    if part[1] < min_depth:
                        target_part = j
                        min_depth = part[1]
                partition_table[target_part][0].append(i)
                partition_table[target_part][1] += depth
        
        mappings = [None]*len(exes)
        for i, part in enumerate(partition_table):
            for circ_num in part[0]:
                mappings[circ_num] = partition_mapping[i]


        return self.combine(vcs, mappings, 7, 'circ')


    @classmethod
    def _default_options(cls):
        return None



def vmbackend(num_qubits, basis_gates, coupling_map):
    return GenericBackendV2(num_qubits, basis_gates = basis_gates, coupling_map = coupling_map)


def elastic_vm(num_qubits: int, basis_gates, hc, vc, shared_up: dict, shared_down: dict,
                vm_coupling_map, allowed_dimensions):
    allowed_dimensions.sort(key=lambda d: d[0]*d[1]) 
    single_vm_size = max(max(i) for i in vm_coupling_map)+1
    hc_num_qubit = -min(min(i) for i in hc)
    vc_num_qubit = -min(min(i) for i in vc)
    hv_shared_num_qubit = len(shared_up) + len(shared_down)
    ret = []
    for n, m in allowed_dimensions:

        elastic_vm_size = n*m*single_vm_size + n*(m-1)*hc_num_qubit + (n-1)*m*vc_num_qubit - (n-1)*(m-1)*hv_shared_num_qubit 
        if elastic_vm_size >= num_qubits:

            combined_coupling_map = combine_coupling_map(vm_coupling_map, hc, vc, shared_up, shared_down, n, m)

            combined_vm = GenericBackendV2(elastic_vm_size, basis_gates = basis_gates, coupling_map = combined_coupling_map, control_flow = True)
            ret.append((combined_vm, n, m))


            rotated_vm_size = n*m*single_vm_size + m*(n-1)*hc_num_qubit + (m-1)*n*vc_num_qubit - (n-1)*(m-1)*hv_shared_num_qubit
            if n != m and (m, n) in allowed_dimensions and rotated_vm_size >= num_qubits:

                combined_coupling_map = combine_coupling_map(vm_coupling_map, hc, vc, shared_up, shared_down, m, n)

                combined_vm = GenericBackendV2(rotated_vm_size, basis_gates = basis_gates, coupling_map = combined_coupling_map, control_flow = True)
                ret.append((combined_vm, m, n))
            break
    return ret



def combine_coupling_map(vm_coupling_map, hc, vc, shared_up: dict, shared_down: dict, n, m):
    single_vm_size = max(max(i) for i in vm_coupling_map)+1
    hc_num_qubit = -min(min(i) for i in hc)
    vc_num_qubit = -min(min(i) for i in vc)


    vc.sort(key = lambda d: min(d))
    for i in range(len(vc)):
        if vc[i][0] < 0 and vc[i][1] < 0 and vc[i][0] > vc[i][1]:
            vc[i] = (vc[i][1], vc[i][0])

    ret = []
    edge_set = set()
    offset = 0


    for i in range(n):
        for j in range(m):

            for k in vm_coupling_map:
                edge = (k[0] + offset, k[1] + offset)
                ret.append(edge)
                edge_set.add(edge)
            offset += single_vm_size
    


    for i in range(n):
        for j in range(m-1):
            vm_offset_l = (i*m+j)*single_vm_size 
            vm_offset_r = vm_offset_l + single_vm_size

            for l, r in hc:



                if l < 0:
                    l = l+hc_num_qubit+offset
                else:
                    l = l+vm_offset_l
                if r < 0:
                    r = r+hc_num_qubit+offset
                else:
                    r = r+vm_offset_r

                ret.append((l, r))
                ret.append((r, l))
                edge_set.add((l, r))
                edge_set.add((r, l))
            offset += hc_num_qubit


    for i in range(n-1):
        for j in range(m):
            vm_offset_u = (i*m+j)*single_vm_size 
            vm_offset_d = vm_offset_u + m*single_vm_size
            unshared_bit = {} 
            for u, d in vc:

                if u < 0:
                    u = check_shared(u, i, j, n, m, shared_up, shared_down, single_vm_size, hc_num_qubit)
                    if u < 0:
                        if u in unshared_bit:
                            u = unshared_bit[u]
                        else:
                            unshared_bit[u] = offset
                            u = offset
                            offset += 1
                else:
                    u = u+vm_offset_u

                if d < 0:
                    d = check_shared(d, i, j, n, m, shared_up, shared_down, single_vm_size, hc_num_qubit)
                    if d < 0: 
                        if d in unshared_bit:
                            d = unshared_bit[d]
                        else:
                            unshared_bit[d] = offset
                            d = offset
                            offset += 1
                else:
                    d = d+vm_offset_d

                if (u, d) not in edge_set:
                    edge_set.add((u, d))
                    edge_set.add((d, u))
                    ret.append((u, d))
                    ret.append((d, u))

    return ret


def hc_offset(single_vm_size: int, n: int, m: int, r: int, c: int, hc_num_qubit: int) -> int:

    return n*m*single_vm_size + (r*(m-1)+c)*hc_num_qubit

def check_shared(q: int, i: int, j: int, n: int, m: int, shared_up, shared_down, single_vm_size, hc_num_qubit) -> int:
    if j-1 >= 0: 
        hc_up_offset = hc_offset(single_vm_size, n, m, i, j-1, hc_num_qubit)
        hc_down_offset = hc_offset(single_vm_size, n, m, i+1, j-1, hc_num_qubit)
        if q in shared_up:
            return shared_up[q] + hc_num_qubit + hc_up_offset
        elif q in shared_down:
            return shared_down[q] + hc_num_qubit + hc_down_offset

    return q 
    
def max_pool(i, j, n, m, region_height) -> float:
    res = 0
    for a in range(n):
        for b in range(m):
            res = max(res, region_height[i+a][j+b])
    return res

