# Qiskit Machine Learning + Serverless

In [None]:
# assume some given dataset
train_dataset = ...

In [None]:
# specify required classical compute resources
num_cpus = 10
num_gpus = 20
ram = 100

# specify type of required quantum compute resources
# could be done via criteria for a filter, a list of acceptable device names, etc.
# can be actual devices or "virtual" devices (partitioning)
qpu_a = {min_num_qubits: 10, min_qv: 16}
qpu_b = {min_qv: 32, coupling_map: {...}}
qpu_c = {names: ['ibm_sherbrooke']}

# specify required quantum compute resources
qpus = {qpu_a: 10, qpu_b: 20, qpu_c: ...}

In [None]:
# request compute resources
serverless = QuantumServerless({
    ...,                                     
    'num_cpus': num_cpus,
    'num_gpus': num_gpus,
    'ram': ram,
    'qpus': qpus
})

In [None]:
def create_model(sampler):
    # create a QNN using the sampler and return it
    # this defines the actual hybrid quantum/classical model
    return TorchConnector(SamplerQNN(..., sampler))    

In [None]:
@run_qiskit_remote(target={'cpu': 2, 'gpus': 4, 'qpus': {qpu_a: 2, qpu_b: 4}, 'ram': 20})
def train_function(target):

    # get assigned quantum resources
    qpus_a = target['qpus'][qpu_a]
    qpus_b = target['qpus'][qpu_b]
    
    # get fraction of training data
    dataset_shard = session.get_dataset_shard("train")
    
    # setup model
    sampler = CuttingSampler(     # these don't exist yet, just to illustrate how this may be used
        ThreadedSampler(qpus_a),  # exact design and what will be passed along to be decided...
        ThreadedSampler(qpus_b)
    )    
    model = create_model(sampler) 
    model = train.torch.prepare_model(model)

    # run training epochs for given data
    # ...
    for e in range(num_epochs):
        # train...

In [None]:
from ray.air import ScalingConfig
from ray.train.torch import TorchTrainer

torch_trainer = TorchTrainer(
                train_function,
                scaling_config=ScalingConfig(use_gpu=True, use_qpu=True, num_workers=5),
                datasets={"train": train_dataset},
            )

results = torch_trainer.fit()