# EfficientRAG: Model Training (Filter & Labeler)

This notebook focuses on **Phase 2: Training**. Having obtained the synthesized data (either via the synthesis pipeline or pre-downloaded), we will now train the two core components of EfficientRAG:

1.  **The Filter:** A coarse-grained selector that removes irrelevant paragraphs from retrieved contexts.
2.  **The Labeler:** A fine-grained token classifier that highlights essential information within the remaining paragraphs.

**Pipeline Overview:**
1.  **Setup:** Install dependencies and download pre-processed datasets.
2.  **Patching:** Fix compatibility issues with newer `transformers` versions (e.g., `evaluation_strategy` deprecation).
3.  **Training:** Run the training scripts for both models using `accelerate`.
4.  **Export:** Zip and download the trained checkpoints.

## 1. Environment Setup

We begin by cloning the repository and installing the necessary requirements. We also apply specific version upgrades to `accelerate`, `peft`, and `transformers` to ensure compatibility with the training scripts.

In [None]:
! git clone https://github.com/nil-zhuang/efficientrag-official.git

Cloning into 'efficientrag-official'...
remote: Enumerating objects: 122, done.[K
remote: Counting objects: 100% (122/122), done.[K
remote: Compressing objects: 100% (111/111), done.[K
remote: Total 122 (delta 13), reused 108 (delta 7), pack-reused 0 (from 0)[K
Receiving objects: 100% (122/122), 1.81 MiB | 4.03 MiB/s, done.
Resolving deltas: 100% (13/13), done.


In [None]:
%cd efficientrag-official
! pip install -r requirements.txt

/content/efficientrag-official
Collecting faiss-cpu (from -r requirements.txt (line 1))
  Downloading faiss_cpu-1.12.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting msal (from -r requirements.txt (line 2))
  Downloading msal-1.34.0-py3-none-any.whl.metadata (11 kB)
Collecting accelerate==0.29.1 (from -r requirements.txt (line 14))
  Downloading accelerate-0.29.1-py3-none-any.whl.metadata (18 kB)
Collecting deepspeed>=0.14.1 (from -r requirements.txt (line 16))
  Downloading deepspeed-0.18.2.tar.gz (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m66.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting vllm (from -r requirements.txt (line 17))
  Downloading vllm-0.11.0-cp38-abi3-manylinux1_x86_64.whl.metadata (17 kB)
Collecting black (from -r requirements.txt (line 19))
  Downloading black-25.11.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.ma

In [None]:
! pip install -U "accelerate>=1.1.0" "peft>=0.16.0" "transformers==4.45.2"

Collecting accelerate>=1.1.0
  Downloading accelerate-1.11.0-py3-none-any.whl.metadata (19 kB)
Collecting transformers==4.45.2
  Downloading transformers-4.45.2-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.21,>=0.20 (from transformers==4.45.2)
  Downloading tokenizers-0.20.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.45.2-py3-none-any.whl (9.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m99.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading accelerate-1.11.0-py3-none-any.whl (375 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m375.8/375.8 kB[0m [31m36.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tokenizers-0.20.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
%cd efficientrag-official

/content/efficientrag-official


In [None]:
import transformers
print(transformers.__version__)

4.45.2


## 2. Data Preparation

Instead of generating data from scratch (which takes significant time and API costs), we will download the pre-synthesized training data. This includes:
* **Filter Data:** `train.jsonl` and `valid.jsonl` for the paragraph filtering task.
* **Labeler Data:** Token-level labels for the fine-grained extraction task.
* **Negative Samples:** Hard negatives extracted from the corpus.

In [None]:
# --- EfficientRAG data fetch & install (one cell) ---
import os, io, sys, zipfile, shutil
from pathlib import Path

# --- Config ---
URL = "https://box.nju.edu.cn/f/a86b512077c7489b8da3/?dl=1"
ZIP_NAME = "EfficientRAG.zip"
DATA_DIR = Path("data")          # your repo's data/ directory
SOURCE_PREFIX = "EfficientRAG/"  # inside the zip
OVERWRITE = False                # set True to overwrite existing files

# --- Optional: progress with requests + tqdm (both widely available) ---
try:
    import requests
    from tqdm import tqdm
except Exception as e:
    print("Note: tqdm/requests not found; falling back to stdlib download (no progress bar).")
    requests = None
    tqdm = None

DATA_DIR.mkdir(parents=True, exist_ok=True)
zip_path = Path(ZIP_NAME)

def download(url: str, dest: Path):
    if dest.exists():
        print(f"[skip] {dest.name} already exists. Re-using it.")
        return
    if requests is None:
        # stdlib fallback (no chunked progress)
        import urllib.request
        print(f"[downloading] {url}")
        urllib.request.urlretrieve(url, dest)
        print("[done]")
        return
    # streaming download with progress
    print(f"[downloading] {url}")
    with requests.get(url, stream=True, timeout=60) as r:
        r.raise_for_status()
        total = int(r.headers.get("content-length") or 0)
        bar = tqdm(total=total, unit="B", unit_scale=True, desc="Downloading") if tqdm else None
        with open(dest, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)
                    if bar: bar.update(len(chunk))
        if bar: bar.close()
    print("[done]")

def ensure_zip_ok(zpath: Path):
    try:
        with zipfile.ZipFile(zpath, "r") as zf:
            bad = zf.testzip()
            if bad:
                raise zipfile.BadZipFile(f"Corrupted member: {bad}")
    except zipfile.BadZipFile as e:
        raise RuntimeError(f"Zip validation failed: {e}")

def copy_from_zip(zpath: Path, src_prefix: str, dst_dir: Path, overwrite: bool = False):
    """
    Copy only files under src_prefix/ (e.g., EfficientRAG/) to dst_dir (flattening that top folder),
    ignoring MAC-specific junk. Returns (copied, skipped).
    """
    copied = 0
    skipped = 0
    with zipfile.ZipFile(zpath, "r") as zf:
        members = [m for m in zf.namelist() if m.startswith(src_prefix)]
        if not members:
            # If the top-level folder isn't exactly "EfficientRAG/", try to locate it heuristically
            candidates = [m.split("/")[0] + "/" for m in zf.namelist() if m.endswith("/")]
            guess = next((c for c in candidates if c.lower().startswith("efficientrag")), None)
            if guess:
                print(f"[info] '{src_prefix}' not found; using '{guess}' instead.")
                members = [m for m in zf.namelist() if m.startswith(guess)]
                src_prefix = guess
            else:
                raise RuntimeError(f"Could not find '{src_prefix}' in the zip. Found top-level dirs: {set(candidates)}")

        for m in members:
            # Skip directories and macOS cruft
            if m.endswith("/") or m.startswith("__MACOSX/") or m.split("/")[0].upper() == "MACOS":
                continue
            rel = m[len(src_prefix):]  # path inside EfficientRAG/
            if not rel:   # safety: in case it's exactly the folder
                continue
            target = dst_dir / rel
            target.parent.mkdir(parents=True, exist_ok=True)

            if target.exists() and not overwrite:
                skipped += 1
                continue

            with zf.open(m, "r") as src, open(target, "wb") as out:
                shutil.copyfileobj(src, out)
            copied += 1
    return copied, skipped

# --- Run ---
download(URL, zip_path)
ensure_zip_ok(zip_path)
copied, skipped = copy_from_zip(zip_path, SOURCE_PREFIX, DATA_DIR, overwrite=OVERWRITE)

print(f"\n[summary]")
print(f"  -> Copied: {copied} file(s)")
print(f"  -> Skipped (already existed): {skipped} file(s)")
print(f"  -> Data root: {DATA_DIR.resolve()}")

# Show a quick peek at what landed (first 20 files)
from itertools import islice
all_files = sorted(p for p in DATA_DIR.rglob("*") if p.is_file())
print("\n[sample files]")
for p in islice(all_files, 20):
    print("  ", p.relative_to(DATA_DIR))

# (Optional) sanity checks for common files
jsonls = list(DATA_DIR.rglob("*.jsonl"))
print(f"\nFound {len(jsonls)} *.jsonl file(s).")

[downloading] https://box.nju.edu.cn/f/a86b512077c7489b8da3/?dl=1


Downloading: 100%|██████████| 472M/472M [00:26<00:00, 17.7MB/s]


[done]

[summary]
  -> Copied: 22 file(s)
  -> Skipped (already existed): 0 file(s)
  -> Data root: /content/efficientrag-official/data

[sample files]
   efficient_rag/filter/2WikiMQA/train.jsonl
   efficient_rag/filter/2WikiMQA/valid.jsonl
   efficient_rag/filter/hotpotQA/train.jsonl
   efficient_rag/filter/hotpotQA/valid.jsonl
   efficient_rag/filter/musique/demo.jsonl
   efficient_rag/filter/musique/train.jsonl
   efficient_rag/filter/musique/valid.jsonl
   efficient_rag/labeler/2WikiMQA/test_demo.jsonl
   efficient_rag/labeler/2WikiMQA/train.jsonl
   efficient_rag/labeler/2WikiMQA/valid.jsonl
   efficient_rag/labeler/hotpotQA/train.jsonl
   efficient_rag/labeler/hotpotQA/valid.jsonl
   efficient_rag/labeler/musique/demo.jsonl
   efficient_rag/labeler/musique/train.jsonl
   efficient_rag/labeler/musique/valid.jsonl
   negative_sampling_extracted/2WikiMQA/train.jsonl
   negative_sampling_extracted/2WikiMQA/valid.jsonl
   negative_sampling_extracted/hotpotQA/train.jsonl
   negative_s

In [None]:
import os, json, random
from pathlib import Path

# Project root (adjust if your %cd differs)
ROOT = Path.cwd()
DATA = ROOT / "data"
DATASET_DIR = DATA / "dataset" / "hotpotQA"
MODEL_CACHE = ROOT / "model_cache"

for p in [DATASET_DIR, MODEL_CACHE]:
    p.mkdir(parents=True, exist_ok=True)

# Make HF/Transformers use our local cache folder
os.environ["HF_HOME"] = str(MODEL_CACHE)
os.environ["TRANSFORMERS_CACHE"] = str(MODEL_CACHE)
os.environ["HF_DATASETS_CACHE"] = str(MODEL_CACHE)

## 3. Model & Cache Configuration

We configure the environment to cache models locally (to prevent re-downloading if the runtime restarts) and authenticate with Hugging Face. This is required to download the base models like `DeBERTa-v3-large`.

In [None]:
import os
token = ... # put your own Hugging Face access token

if token:
    # Non-interactive login (recommended: set the env var in Colab “Secrets” or in a cell)
    from huggingface_hub import login
    login(token=token, add_to_git_credential=False)
else:
    # Fallback: interactive login widget
    from huggingface_hub import notebook_login
    notebook_login()


In [None]:
import torch
from transformers import AutoTokenizer, AutoModel

# Contriever retriever (MS MARCO variant)
ctr_name = "facebook/contriever-msmarco"
ctr_tok = AutoTokenizer.from_pretrained(ctr_name, cache_dir=str(MODEL_CACHE))
ctr = AutoModel.from_pretrained(ctr_name, cache_dir=str(MODEL_CACHE))

# DeBERTa-v3-large encoder (used by EfficientRAG Labeler/Filter)
deb_name = "microsoft/deberta-v3-large"
deb = AutoModel.from_pretrained(deb_name, torch_dtype="auto", cache_dir=str(MODEL_CACHE))

print("Models cached under:", MODEL_CACHE)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/580 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/874M [00:00<?, ?B/s]

Models cached under: /content/efficientrag-official/model_cache


In [None]:
from huggingface_hub import snapshot_download

# 2) Make the folders expected by the repo
import os, pathlib
pathlib.Path("model_cache/contriever-msmarco").mkdir(parents=True, exist_ok=True)
pathlib.Path("model_cache/deberta-v3-large").mkdir(parents=True, exist_ok=True)

# 3) Download model snapshots there (no symlinks to avoid path surprises)
snapshot_download(
    repo_id="facebook/contriever-msmarco",
    local_dir="model_cache/contriever-msmarco",
    local_dir_use_symlinks=False
)
snapshot_download(
    repo_id="microsoft/deberta-v3-large",
    local_dir="model_cache/deberta-v3-large",
    local_dir_use_symlinks=False
)

# 4) (Optional but reduces warnings) Prefer HF_HOME over TRANSFORMERS_CACHE
import os
os.environ["HF_HOME"] = os.path.abspath("model_cache")
os.environ["TRANSFORMERS_CACHE"] = os.environ["HF_HOME"]  # still set for older code

For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

README.md: 0.00B [00:00, ?B/s]

generator_config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/580 [00:00<?, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

tf_model.h5:   0%|          | 0.00/1.74G [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/874M [00:00<?, ?B/s]

pytorch_model.generator.bin:   0%|          | 0.00/571M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

## 4. Codebase Adaptation (Hot-Fixes)

The official codebase requires several patches to run smoothly in the current Google Colab environment with newer library versions.

**Modifications applied below:**
1.  **Deprecation Fixes:** Replacing `evaluation_strategy` with `eval_strategy` in `TrainingArguments`.
2.  **Saving Logic:** Forcing the `Labeler` trainer to save the final model state and tokenizer explicitly at the end of training.
3.  **Model Architecture:** Patching `DeBERTa` imports to handle `StableDropout` location changes in `transformers>=4.46`.
4.  **Config:** Updating the model dictionary.

In [None]:
import pathlib, re
p = pathlib.Path("src/efficient_rag/filter_training.py")
s = p.read_text(encoding="utf-8")
s = re.sub(r'\bevaluation_strategy\s*=\s*["\']steps["\']', 'eval_strategy="steps"', s)
p.write_text(s, encoding="utf-8")

4174

In [None]:
import pathlib, re

p = pathlib.Path("src/efficient_rag/labeler_training.py")
code = p.read_text(encoding="utf-8")
total_edits = 0

# --- A) Force save_strategy="epoch" unconditionally ---
code_new, n = re.subn(
    r'save_strategy\s*=\s*"(?:epoch)"\s*if\s*not\s*opt\.test\s*else\s*"no"\s*,',
    'save_strategy="epoch",',
    code,
    flags=re.MULTILINE,
)
total_edits += n
code = code_new

if n == 0:
    # Fallback: replace the first save_strategy kwarg (any form) with constant "epoch"
    code_new, n = re.subn(
        r'save_strategy\s*=\s*[^,]+,',
        'save_strategy="epoch",',
        code,
        count=1,
        flags=re.MULTILINE,
    )
    total_edits += n
    code = code_new

    # --- B) Ensure a final save after trainer.train() ---
    # --- B) Ensure a final save after trainer.train() ---
    # Insert the three lines with +4 spaces relative to the indent of trainer.train()
    code_new, n = re.subn(
        r'(^[ \t]*)trainer\.train\(\)\s*$',
        (
            r'\g<1>trainer.train()\n'
            r'\g<1># --- force a final save (always write a clean copy) ---\n'
            r'\g<1>trainer.save_model(save_dir)\n'
            r'\g<1>tokenizer.save_pretrained(save_dir)\n'
        ),
        code,
        count=1,
        flags=re.MULTILINE,
    )
    total_edits += n
    code = code_new

# Replace the entire build_dataset(...) call for valid_dataset iff it lacks test_mode=
pattern = re.compile(
    r'valid_dataset\s*=\s*build_dataset\([^)]*\)',  # grab the whole call up to the next ')'
    flags=re.DOTALL
)

def _valid_replacer(m: re.Match) -> str:
    block = m.group(0)
    # If already patched, keep it as-is (prevents duplicate ",\n)")
    if "test_mode=" in block:
        return block
    return (
        'valid_dataset = build_dataset(\n'
        '        opt.dataset,\n'
        '        "valid",\n'
        '        opt.max_length,\n'
        '        tokenizer,\n'
        '        test_mode=opt.test,\n'
        '        test_sample_cnt=int(opt.test_samples/10),\n'
        '    )'
    )

code_new, n = pattern.subn(_valid_replacer, code, count=1)

# Safety: if some editor previously left a dangling ",\n)" inside this call, fix just that case.
code_new = re.sub(
    r'(valid_dataset\s*=\s*build_dataset\([^)]*),\s*\)',
    r'\1\n    )',
    code_new,
    flags=re.DOTALL
)

code = code_new
total_edits += n
# --- D) Your extra substitution: evaluation_strategy -> eval_strategy ---
code_new, n = re.subn(
    r'\bevaluation_strategy\s*=\s*["\']steps["\']',
    'eval_strategy="steps"',
    code,
    flags=re.MULTILINE,
)
total_edits += n
code = code_new

# Write back
p.write_text(code, encoding="utf-8")
print(f"Patched labeler_training.py. Total edits applied: {total_edits}")


Patched labeler_training.py. Total edits applied: 3


In [None]:
import pathlib, re

p = pathlib.Path("src/efficient_rag/model/model.py")
code = p.read_text()

# 1) Remove StableDropout from the v2 import list (robust to spacing/newlines)
code2 = re.sub(
    r"(?m)^(from\s+transformers\.models\.deberta_v2\.modeling_deberta_v2\s+import\s*\([\s\S]*?\))",
    lambda m: re.sub(r"(?m)^[ \t]*StableDropout,\s*\n", "", m.group(0)),
    code,
    count=1,
)

# 2) Insert the fallback try/except block right after the v2 import,
#    but only if we haven't already added it.
try_block = """\
try:
    # Older layouts where v2 exported it
    from transformers.models.deberta_v2.modeling_deberta_v2 import StableDropout  # noqa: F401
except Exception:
    try:
        # Common/current location
        from transformers.models.deberta.modeling_deberta import StableDropout  # noqa: F401
    except Exception:
        # Last resort: use vanilla dropout (slightly different behavior but unblocks you)
        from torch.nn import Dropout as StableDropout  # type: ignore
"""

if "from transformers.models.deberta.modeling_deberta import StableDropout" not in code2 and "Dropout as StableDropout" not in code2:
    code2 = re.sub(
        r"(?m)^(from\s+transformers\.models\.deberta_v2\.modeling_deberta_v2\s+import\s*\([\s\S]*?\))\s*\n",
        r"\1\n" + try_block + "\n",
        code2,
        count=1,
    )

p.write_text(code2)

3793

In [None]:
import pathlib, re
p = pathlib.Path("src/conf/config.py")
s = p.read_text(encoding="utf-8")

s = re.sub(
    r'MODEL_DICT\s*=\s*\{[^}]+\}',
    '''MODEL_DICT = {
    "gpt35": "gpt-35-turbo-1106",
    "gpt4": "gpt-4-0125-preview",
    "llama": "llama-3-70b-gptq-int4",
    "llama-8B": "meta-llama/Meta-Llama-3-8B-Instruct",
    "deepseek": "deepseek-chat",
}''',
    s,
    flags=re.S,
)
p.write_text(s, encoding="utf-8")

1988

## 5. Experiment Tracking

We use **Weights & Biases (W&B)** to track training loss, accuracy, and F1 scores.

In [None]:
import wandb
wandb.login(key=...) # put your own wandb login key

## 6. Train the Filter Model

The **Filter** is a binary classifier (based on DeBERTa-v3-large) trained to determine if a retrieved paragraph is relevant to the query.

* **Input:** Question + Paragraph
* **Output:** Relevant / Irrelevant

In [None]:
# helps allocator fragmentation (harmless to try)
! export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

! python src/efficient_rag/filter_training.py \
  --dataset hotpotQA \
  --save_path saved_models/filter \
  --batch_size 64

## 7. Train the Labeler Model

The **Labeler** is a token classification model (also DeBERTa-based). For paragraphs deemed relevant by the Filter, the Labeler identifies the exact tokens that answer the sub-question or provide the necessary context for the next hop.

* **Input:** Question + Paragraph
* **Output:** Token-level binary masks (Keep / Discard)

In [None]:
! python src/efficient_rag/labeler_training.py \
    --dataset hotpotQA \
    --tags 2 \
    --test \
    --test_samples 50000 \
    --lr 1e-5 \
    --warmup_steps 100

## 8. Export Trained Models

Training is complete. The cells below zip the model checkpoints (from the W&B run directories or local save paths) and trigger a download so you can use them for inference or evaluation in the next stage.

In [None]:
# 1) Point to your run folder
RUN_DIR = "/content/efficientrag-official/saved_models/labeler_two/labeler_20251112_002639/checkpoint-3126"

# 2) Create a zip (wandb_run.zip) next to it
import shutil, os
zip_path = "/content/wandb_run_labeler"
shutil.make_archive(zip_path, "zip", RUN_DIR)

# 3) Trigger a browser download
from google.colab import files
files.download(zip_path + ".zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# 1) Point to your run folder
RUN_DIR = "/content/efficientrag-official/saved_models/filter/filter_20251025_233133/checkpoint-2270"

# 2) Create a zip (wandb_run.zip) next to it
import shutil, os
zip_path = "/content/wandb_run"
shutil.make_archive(zip_path, "zip", RUN_DIR)

# 3) Trigger a browser download
from google.colab import files
files.download(zip_path + ".zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>