<a target="_blank" href="https://colab.research.google.com/github/VectorInstitute/fed-rag/blob/main/docs/notebooks/rag_benchmarking_hf_mmlu.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

_(NOTE: if running on Colab, you will need to supply a WandB API Key in addition to your HFToken. Also, you'll need to change the runtime to a T4.)_

In [1]:
!uv pip install bitsandbytes -q

In [1]:
!uv pip install docker -q

# Basic Federated Fine-tuning of RAG Systems

## Knowledge Store

In [1]:
import docker
import os
import time

client = docker.from_env()
image_name = "vectorinstitute/qdrant-atlas-dec-wiki-2021:latest"

# first see if we need to pull the docker image
try:
    client.images.get(image_name)
    print(f"Image '{image_name}' already exists locally")
except docker.errors.ImageNotFound:
    print(f"Image '{image_name}' not found locally. Pulling...")
    # Pull with progress information
    for line in client.api.pull(image_name, stream=True, decode=True):
        if "progress" in line:
            print(f"\r{line['status']}: {line['progress']}", end="")
        elif "status" in line:
            print(f"\r{line['status']}", end="")
    print("\nPull complete!")

# run the Qdrant container
container = client.containers.run(
    "vectorinstitute/qdrant-atlas-dec-wiki-2021:latest",
    detach=True,  # -d flag
    name="tiny-wiki-dec2021-ks",  # --name
    ports={"6333/tcp": 6333, "6334/tcp": 6334},  # -p 6333:6333  # -p 6334:6334
    volumes={
        "qdrant_data": {  # -v qdrant_data:/qdrant_storage
            "bind": "/qdrant_storage",
            "mode": "rw",
        }
    },
    environment={"SAMPLE_SIZE": "tiny"},  # -e SAMPLE_SIZE=tiny
    device_requests=[
        docker.types.DeviceRequest(
            count=-1, capabilities=[["gpu"]]
        )  # --gpus all
    ],
    remove=False,  # Don't auto-remove when stopped
)

print(f"Container started with ID: {container.id}")

# wait a moment for the container to initialize
time.sleep(3)

# Check container status
container.reload()  # Refresh container data
print(f"Container status: {container.status}")
print(f"Container logs:")
print(container.logs().decode("utf-8"))

Image 'vectorinstitute/qdrant-atlas-dec-wiki-2021:latest' already exists locally
Container started with ID: 3bbf28156093168dc6ad130fa9833c0a23804841b34897c91a476596c5daef20
Container status: running
Container logs:
Starting Qdrant Atlas Knowledge Store container
Running database initialization check...
Using tiny sample mode...
Creating tiny sample file for testing...
Using tiny sample file: tiny-sample.jsonl
Verifying sample file creation...
✅ Sample file successfully created at: /app/data/atlas/enwiki-dec2021/tiny-sample.jsonl
File details:
-rw-r--r-- 1 root root 6785 Jun  4 15:06 /app/data/atlas/enwiki-dec2021/tiny-sample.jsonl
File content (first 3 lines):
{"id": "140", "title": "History of marine biology", "section": "James Cook", "text": " James Cook is well known for his voyages of exploration for the British Navy in which he mapped out a significant amount of the world's uncharted waters. Cook's explorations took him around the world twice and led to countless descriptions of p

In [4]:
container.health

'starting'

In [2]:
print(f"Container status: {container.status}")

Container status: running


In [5]:
from fed_rag.knowledge_stores import QdrantKnowledgeStore

In [8]:
store = QdrantKnowledgeStore(
    collection_name="nthakur.dragon-plus-context-encoder"
)

In [9]:
store.count

ResponseHandlingException: [Errno 111] Connection refused

In [2]:
import time

GIST_URL = f"https://gist.githubusercontent.com/nerdai/33e8445ab8b96783f34a7e0464e0b0f0/raw?fresh={int(time.time())}"

In [3]:
import requests

response = requests.get(GIST_URL)
rag_code = response.text

In [4]:
from IPython.display import Code, display

display(Code(rag_code, language="python"))

In [5]:
# load the vars from the gist
setup_only_code = rag_code.replace(
    'if __name__ == "__main__":', "if False:  # Disabled for notebook"
)

before_exec = set(globals().keys())
exec(setup_only_code)

config.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [8]:
knowledge_store.count

13

### Centralized Training

In [None]:
result = manager.train()

In [None]:
result.loss

In [None]:
# After exec - find what was added
after_exec = set(globals().keys())
new_vars = after_exec - before_exec

# Delete everything that was added
for var_name in new_vars:
    del globals()[var_name]

print(f"🗑️ Deleted {len(new_vars)} variables: {new_vars}")

In [None]:
import torch
import gc

torch.cuda.empty_cache()

# Force garbage collection
gc.collect()

# More aggressive cleanup
torch.cuda.synchronize()
torch.cuda.empty_cache()
gc.collect()

### Federated Learning

In [None]:
# write gist to script
with open("rag_federated_learning.py", "w") as f:
    f.write(rag_code)

In [None]:
server_command = "python rag_federated_learning.py --component server"
client_command = "python rag_federated_learning.py --component {client_name}"

In [None]:
from fed_rag.utils.cookbook import ProcessMonitor

monitor = ProcessMonitor()

In [None]:
monitor.start_process("server", server_command)

In [None]:
monitor.start_process(
    "client_1", client_command.format(client_name="client_1")
)
monitor.start_process(
    "client_2", client_command.format(client_name="client_2")
)

In [None]:
print(monitor.get_logs("server"))

In [None]:
print(monitor.get_logs("client_1"))

In [None]:
print(monitor.get_logs("client_2"))

In [None]:
# monitor.stop_all()

In [None]:
monitor.list_processes()

### Cleanup

In [9]:
# stop and remove container
container.stop()
container.remove()