<a target="_blank" href="https://colab.research.google.com/github/VectorInstitute/fed-rag/blob/main/docs/notebooks/basic_fl.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

__IMPORTANT NOTE__: As this notebook requires the running of a Docker image, it is not runnable from within a Google Colab.

# Basic Federated Fine-tuning of RAG Systems

In this notebook, we demonstrate how to perform federated RAG fine-tuning with FedRAG. Specifically, we'll apply federated learning to fine-tune the generator of a RAG system using a federated setting that comprises two clients.

__HARDWARE REQUIREMENTS:__ This notebook requires a setup with at least two GPUs each having at least 12GB of RAM.

### Install dependencies

In [None]:
!uv pip install fed-rag[huggingface,qdrant] docker -q

## Setup

Running this notebook requires two high-level steps.

1. Running the knowledge store Qdrant service via Docker
2. Downloading the associated example Python script, which defines the `RAGSystem` as well as the `FLTask` which we use to launch the federated learning task.

## Running the Qdrant knowledge store service

We have previously prepared a knowledge store Qdrant service that comes pre-populated with knowledge artifacts from the December 2021 Wikipedia dump (i.e., Izacard, Gautier, et al. "Few-shot learning with retrieval augmented language models." arXiv preprint arXiv:2208.03299 1.2 (2022): 4.).

Executing the below command will run this docker image on the host machine.

In [1]:
import docker
import os
import time

client = docker.from_env()
image_name = "vectorinstitute/qdrant-atlas-dec-wiki-2021:latest"

# first see if we need to pull the docker image
try:
    client.images.get(image_name)
    print(f"Image '{image_name}' already exists locally")
except docker.errors.ImageNotFound:
    print(f"Image '{image_name}' not found locally. Pulling...")
    # Pull with progress information
    for line in client.api.pull(image_name, stream=True, decode=True):
        if "progress" in line:
            print(f"\r{line['status']}: {line['progress']}", end="")
        elif "status" in line:
            print(f"\r{line['status']}", end="")
    print("\nPull complete!")

# run the Qdrant container
container = client.containers.run(
    "vectorinstitute/qdrant-atlas-dec-wiki-2021:latest",
    detach=True,  # -d flag
    name="tiny-wiki-dec2021-ks",  # --name
    ports={"6333/tcp": 6333, "6334/tcp": 6334},  # -p 6333:6333  # -p 6334:6334
    volumes={
        "qdrant_data": {  # -v qdrant_data:/qdrant_storage
            "bind": "/qdrant_storage",
            "mode": "rw",
        }
    },
    environment={"SAMPLE_SIZE": "tiny"},  # -e SAMPLE_SIZE=tiny
    device_requests=[
        docker.types.DeviceRequest(
            count=-1, capabilities=[["gpu"]]
        )  # --gpus all
    ],
    remove=False,  # Don't auto-remove when stopped
)

print(f"Container started with ID: {container.id}")

# wait a moment for the container to initialize
time.sleep(3)

# Check container status
container.reload()  # Refresh container data
print(f"Container status: {container.status}")
print(f"Container logs:")
print(container.logs().decode("utf-8"))

Image 'vectorinstitute/qdrant-atlas-dec-wiki-2021:latest' already exists locally
Container started with ID: 8e58b42f14a508109055e20dc5f0d066fce8b4775f4b3a9b98b758239ce19b6e
Container status: running
Container logs:
Starting Qdrant Atlas Knowledge Store container
Running database initialization check...
Using tiny sample mode...
Creating tiny sample file for testing...
Using tiny sample file: tiny-sample.jsonl
Verifying sample file creation...
✅ Sample file successfully created at: /app/data/atlas/enwiki-dec2021/tiny-sample.jsonl
File details:
-rw-r--r-- 1 root root 6785 Jun  8 03:35 /app/data/atlas/enwiki-dec2021/tiny-sample.jsonl
File content (first 3 lines):
{"id": "140", "title": "History of marine biology", "section": "James Cook", "text": " James Cook is well known for his voyages of exploration for the British Navy in which he mapped out a significant amount of the world's uncharted waters. Cook's explorations took him around the world twice and led to countless descriptions of p

### Check if the service is ready

To check if the knowledge store service is ready to be used, we can create a `QdrantKnowledgeStore` with the correct collection name and check if the collection exists. If it does, then we're ready to carry on with the rest of the notebook.

In [2]:
from fed_rag.knowledge_stores import QdrantKnowledgeStore

In [3]:
ks = QdrantKnowledgeStore(
    collection_name="nthakur.dragon-plus-context-encoder"
)

In [4]:
# If the collection exists, this should return an int.
# Otherwise, it will raise an error
ks.count

13

## Download the example Python script which builds the RAG System and FL Task

This script can be found in the main Github repo for fed-rag and within the `example_scripts` subdirectory. More specifically:

<https://github.com/VectorInstitute/fed-rag/blob/main/example_scripts/cookbook_script-basic_fl.py>

The commands below will download the script's text, display it here for convenience and then write it to a local file that we can execute.

In [5]:
SCRIPT_URL = "https://raw.githubusercontent.com/VectorInstitute/fed-rag/refs/heads/main/example_scripts/cookbook_script-basic_fl.py"

In [6]:
import requests

response = requests.get(SCRIPT_URL)
rag_code = response.text

In [7]:
from IPython.display import Code, display

display(Code(rag_code, language="python"))

## Federated fine-tuning

The script displayed above shows the `RAGSystem` and the generator trainer task that we will federate next. To do this we will:

1. Write the script text to a file
2. Launch the server and two clients in their own separate subprocesses

In [8]:
# write the script's code to a Python file on disk
with open("rag_federated_learning.py", "w") as f:
    f.write(rag_code)

With a file written to our local disk, we can run the script to launch the FL servers and clients. We will use a notebook utility class called `ProcessMonitor` to do so.

In [9]:
from fed_rag.utils.notebook import ProcessMonitor

monitor = ProcessMonitor()

In [10]:
# launch server command
server_command = "python rag_federated_learning.py --component server"

# launch client command template
# the two clients will use one of the two available GPUs exclusively
client_command = "export CUDA_VISIBLE_DEVICES={client_id} && python rag_federated_learning.py --component client_{client_id}"

In [11]:
# start server process
monitor.start_process("server", server_command)

# give server time to standup
time.sleep(2)

✅ Started server (PID: 85559)


In [12]:
# start client processes
monitor.start_process(
    name="client_0", command=client_command.format(client_id="0")
)
monitor.start_process(
    name="client_1", command=client_command.format(client_id="1")
)

✅ Started client_0 (PID: 85585)
✅ Started client_1 (PID: 85588)


In [13]:
# this cell will run until completion of the subprocesses or if the kernel is interrupted
monitor.monitor_live(["server", "client_0", "client_1"])

🖥️  PROCESS MONITOR

server 🔴 STOPPED
------------------------------
[23:36:29] [92mINFO [0m:      Evaluation returned no results (`None`)
[23:36:29] [92mINFO [0m:
[23:36:29] [92mINFO [0m:      [ROUND 1]
[23:36:31] [92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)
[23:37:20] [92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[23:37:26] [92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[23:37:41] [92mINFO [0m:      aggregate_evaluate: received 2 results and 0 failures
[23:37:41] [92mINFO [0m:
[23:37:41] [92mINFO [0m:      [SUMMARY]
[23:37:41] [92mINFO [0m:      Run finished 1 round(s) in 72.02s
[23:37:41] [92mINFO [0m:      	History (loss, distributed):
[23:37:41] [92mINFO [0m:      		round 1: 0.41999998688697815
[23:37:41] [92mINFO [0m:

client_0 🔴 STOPPED
------------------------------
[23:37:09] 
[23:37:09] 
[23:37:09] 100%|██████████| 3/3 [00:22<00:00,  2.18s/it]
[23:37:09] 100%|██████████| 3

### Cleanup

In [14]:
monitor.stop_all()

🛑 Stopped server
🛑 Stopped client_0
🛑 Stopped client_1
🛑 All processes stopped


In [15]:
# stop and remove container
container.stop()
container.remove()