# QML: distributed training

Ray train framework allows you to run distributed training algorithms.
Quantum Serverless provides a `QiskitTorchTrainer` which allows you to train QNNs in a distributed fashion. Every worker gets it own QPUs (and sessions) to train on. 

Let's look at example program which will show how to use trainer.


First we need to create `train.py` file and start filling it up.
Let's import everything we need
```python
# train.py
import ray
import torch
from qiskit_ibm_runtime import QiskitRuntimeService
from ray import train
from ray.air import session
from torch import nn

from quantum_serverless import QuantumServerless, Program
from quantum_serverless.library.train.trainer import (
    QiskitScalingConfig,
    QiskitTorchTrainer,
    get_runtime_sessions,
    assign_backends,
    QiskitTrainerException,
)
```

Then define our training loop
```python
# train.py

def loop(config):
    """Test training loop."""
    runtime_session_1, runtime_session_2 = get_runtime_sessions(config)
    
    print("Session for worker:", [runtime_session_1, runtime_session_2])
    print("Available backends for worker", [runtime_session_1.backend(), runtime_session_2.backend()])

    # get data, create QNN and run training loop
    dataset_shard = session.get_dataset_shard("train")
    ...
```

`get_runtime_sessions` returns qiskit runtime sessions that we can use for our algorithms

Let's finish it up with creating serverless context and running `QiskitTorchTrainer`

```python
    
    serverless = QuantumServerless()

    with serverless:
        train_dataset = ray.data.from_items(
            [{"x": x, "y": 2 * x + 1} for x in range(200)]
        )
        scaling_config = QiskitScalingConfig(
            num_workers=3,
            resource_filtering={
                "qpu1": {"name": "ibmq_qasm_simulator", "simulator": True},
                "qpu2": {"name": "ibmq_qasm_simulator", "simulator": True},
            },
            allow_overbooking=True
        )
        trainer = QiskitTorchTrainer(
            train_loop_per_worker=loop,
            qiskit_runtime_service_account=runtime_service.active_account(),
            scaling_config=scaling_config,
            datasets={"train": train_dataset},
            ...
        )
        result = trainer.fit()
```

As you can see we are using `QiskitScalingConfig` to define number or workers and specific requirements for QPUs for each of the workers. 

`allow_overbooking` flag allows use to use same backend for different workers. In this case it is simulator, that is why we can use it in parallel. In real cases you would be interested in using different QPUs for different workers as it will unlock true parallelism. 

Full version of program can be found [here](./source_files/train.py).

Finally let's run our program.

In [1]:
from quantum_serverless import QuantumServerless, Program

In [2]:
serverless = QuantumServerless({
    "providers": [{
        "name": "docker",
        "compute_resource": {
            "name": "docker",
            "host": "localhost",
        }
    }]
})
serverless

<QuantumServerless | providers [local, docker]>

In [None]:
program = Program(
    name="train_qnn",
    entrypoint="train.py",
    arguments={
        "channel": "<CHANNEL>",
        "token": "<TOKEN>"
    },
    working_dir="./source_files",
    description="Train QNN on distributed resources."
)

job = serverless.run_program(program)
job

In [7]:
job.status()

<JobStatus.SUCCEEDED: 'SUCCEEDED'>

Let's check logs that we are interested in

In [8]:
print("\n".join([log.split("0m ")[-1] for log in job.logs().split("\n")[2:-1]  if "worker" in log]))

Session for worker: [<qiskit_ibm_runtime.session.Session object at 0x7fa19709dfd0>, <qiskit_ibm_runtime.session.Session object at 0x7fa22dadb350>]
Available backends for worker ['ibmq_qasm_simulator', 'ibmq_qasm_simulator']
Session for worker: [<qiskit_ibm_runtime.session.Session object at 0x7faa0cd00f10>, <qiskit_ibm_runtime.session.Session object at 0x7faa0afded90>]
Available backends for worker ['ibmq_qasm_simulator', 'ibmq_qasm_simulator']
Session for worker: [<qiskit_ibm_runtime.session.Session object at 0x7f16524acb90>, <qiskit_ibm_runtime.session.Session object at 0x7f16524aca10>]
Available backends for worker ['ibmq_qasm_simulator', 'ibmq_qasm_simulator']
