# WavesFM end-to-end finetuning demo

This notebook walks through:
1. Clone the WavesFM repo
2. Install dependencies
3. Download raw data
4. Preprocess to a .h5 cache
5. Download a pretrained checkpoint
6. Finetune on a task
7. Evaluate with a confusion matrix (classification) or an error density plot (positioning)

You can switch tasks using the config cell below.


## 0. Configuration
Adjust the task and knobs below. For positioning tasks, the evaluation cell will switch to error density plots automatically.


In [None]:
from pathlib import Path

# ---- task + paths ----
TASK = "deepmimo-los"  # examples: "deepmimo-los", "deepmimo-beam", "rml", "pos"
DATA_ROOT = Path("data")
RAW_ROOT = DATA_ROOT / "raw"
CACHE_ROOT = DATA_ROOT / "cache"
OUTPUT_DIR = Path("runs/demo")
CHECKPOINT_DIR = Path("checkpoints")

# ---- pipeline toggles ----
DOWNLOAD_RAW = True
PREPROCESS = True
DOWNLOAD_PRETRAINED = True
RUN_TRAINING = True

# ---- dataset-specific knobs ----
DEEP_MIMO_SCENARIO_IDXS = "0"  # small subset for a quick run
DEEP_MIMO_N_BEAMS = 16         # used for deepmimo-beam
POS_SCENE = "outdoor"          # "indoor" or "outdoor"
RML_VERSION = "2022"
RML_DATA_FILE = RAW_ROOT / "RML22.01A"  # update if needed

# ---- training knobs ----
MODEL_NAME = "vit_multi_small"
EPOCHS = 1
BATCH_SIZE = 16
VAL_SPLIT = 0.2
NUM_WORKERS = 4
USE_STRATIFIED_SPLIT = True
FULL_FINETUNE = False  # set True to train all weights (adds --sl-baseline)


def resolve_cache_path():
    if TASK.startswith("deepmimo"):
        return CACHE_ROOT / "deepmimo.h5"
    if TASK == "pos":
        return CACHE_ROOT / f"pos_{POS_SCENE}.h5"
    if TASK == "rml":
        return CACHE_ROOT / f"rml{RML_VERSION}.h5"
    return CACHE_ROOT / f"{TASK}.h5"


CACHE_PATH = resolve_cache_path()
print("Task:", TASK)
print("Cache path:", CACHE_PATH)


## 1. Clone the repo
If you already opened this notebook from a local clone, this will reuse it.


In [None]:
import sys
import subprocess
from pathlib import Path

REPO_URL = "https://github.com/AhmedTarek62/wavesfm.git"
REPO_DIR = Path.cwd()

if not (REPO_DIR / "main_finetune.py").exists():
    REPO_DIR = Path.cwd() / "wavesfm"
    if not REPO_DIR.exists():
        subprocess.run(["git", "clone", REPO_URL, str(REPO_DIR)], check=True)

if str(REPO_DIR) not in sys.path:
    sys.path.insert(0, str(REPO_DIR))

print("Using repo:", REPO_DIR)


## 2. Install dependencies
Install the pinned dependencies. If you have a GPU, replace the torch install line with the CUDA build from pytorch.org.


In [None]:
import importlib.util
import subprocess
import sys


def pip_install(args):
    subprocess.run([sys.executable, "-m", "pip", "install"] + args, check=True)


pip_install(["-U", "pip"])

if importlib.util.find_spec("torch") is None:
    print("Installing torch/torchvision. For GPU builds, replace this line with the command from pytorch.org.")
    pip_install(["torch", "torchvision"])

pip_install(["-r", str(REPO_DIR / "requirements.txt")])
pip_install(["matplotlib", "huggingface_hub"])


## 3. Download raw data
Most datasets require agreeing to their terms. The helper below covers a few tasks and prints guidance for others.


In [None]:
import subprocess
from pathlib import Path

RAW_ROOT.mkdir(parents=True, exist_ok=True)
CACHE_ROOT.mkdir(parents=True, exist_ok=True)

RAW_READY = False
DEEPMIMO_DIR = None
POS_DIR = None

if TASK.startswith("deepmimo"):
    DEEPMIMO_DIR = RAW_ROOT / "deepmimo_scenarios"
    if DOWNLOAD_RAW and not DEEPMIMO_DIR.exists():
        print("Cloning DeepMIMO scenarios via git-lfs (this can be large).")
        subprocess.run(["git", "lfs", "install"], check=False)
        subprocess.run(["git", "clone", "https://huggingface.co/datasets/wi-lab/lwm", str(DEEPMIMO_DIR)], check=True)
    RAW_READY = DEEPMIMO_DIR.exists()
elif TASK == "rml":
    RML_REPO = RAW_ROOT / "RML22"
    if DOWNLOAD_RAW and not RML_REPO.exists():
        subprocess.run(["git", "clone", "https://github.com/venkateshsathya/RML22", str(RML_REPO)], check=True)
    RAW_READY = RML_DATA_FILE.exists()
    if not RAW_READY:
        print("RML data file not found at:", RML_DATA_FILE)
        print("Download RML22.01A (or RML2016.10a_dict.pkl) and set RML_DATA_FILE.")
elif TASK == "pos":
    POS_DIR = RAW_ROOT / "pos"
    RAW_READY = POS_DIR.exists()
    if not RAW_READY:
        print("Download the POS dataset from IEEE Dataport and place the .h5 files under:", POS_DIR)
else:
    print("No download helper for task:", TASK)
    print("See dataset docs in the WavesFM site and update RAW_ROOT accordingly.")

print("Raw data ready:", RAW_READY)


## 4. Preprocess raw data into a .h5 cache
This converts raw files into a single cache that WavesFM can load quickly.


In [None]:
import sys
import subprocess

if not PREPROCESS:
    print("Skipping preprocessing.")
elif CACHE_PATH.exists():
    print("Using existing cache at:", CACHE_PATH)
else:
    if TASK.startswith("deepmimo"):
        if not DEEPMIMO_DIR or not DEEPMIMO_DIR.exists():
            raise FileNotFoundError("DeepMIMO scenarios not found. Run the download cell first.")
        cmd = [
            sys.executable, str(REPO_DIR / "preprocessing/preprocess_deepmimo.py"),
            "--output", str(CACHE_PATH),
            "--dataset-folder", str(DEEPMIMO_DIR),
            "--scenario-idxs", DEEP_MIMO_SCENARIO_IDXS,
            "--n-beams", str(DEEP_MIMO_N_BEAMS),
            "--n-beams-list", str(DEEP_MIMO_N_BEAMS),
        ]
        subprocess.run(cmd, check=True)
    elif TASK == "rml":
        if not RML_DATA_FILE.exists():
            raise FileNotFoundError(f"Missing RML data file at {RML_DATA_FILE}")
        cmd = [
            sys.executable, str(REPO_DIR / "preprocessing/preprocess_rml.py"),
            "--data-file", str(RML_DATA_FILE),
            "--version", RML_VERSION,
            "--output", str(CACHE_PATH),
        ]
        subprocess.run(cmd, check=True)
    elif TASK == "pos":
        POS_DIR = RAW_ROOT / "pos"
        if not POS_DIR.exists():
            raise FileNotFoundError(f"Missing POS directory at {POS_DIR}")
        cmd = [
            sys.executable, str(REPO_DIR / "preprocessing/preprocess_nr_positioning.py"),
            "--data-path", str(POS_DIR),
            "--output", str(CACHE_PATH),
            "--scene", POS_SCENE,
        ]
        subprocess.run(cmd, check=True)
    else:
        raise ValueError(f"No preprocessing recipe for task {TASK}")

print("Cache ready:", CACHE_PATH.exists())


## 5. Download a pretrained checkpoint
WavesFM weights are hosted on Hugging Face. If you do not want to use a pretrained checkpoint, set `DOWNLOAD_PRETRAINED = False`.


In [None]:
from pathlib import Path

HF_REPO = "ahmedaboulfo/wavesfm"
HF_FILE = None  # set to a specific checkpoint filename if you want to pin it

PRETRAINED_PATH = None
if DOWNLOAD_PRETRAINED:
    from huggingface_hub import list_repo_files, hf_hub_download

    CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
    if HF_FILE is None:
        files = [f for f in list_repo_files(HF_REPO) if f.endswith(".pth")]
        if not files:
            raise RuntimeError("No .pth checkpoints found in the Hugging Face repo.")
        HF_FILE = files[0]
        print("Using checkpoint:", HF_FILE)

    PRETRAINED_PATH = Path(
        hf_hub_download(repo_id=HF_REPO, filename=HF_FILE, local_dir=str(CHECKPOINT_DIR))
    )
    print("Downloaded to:", PRETRAINED_PATH)
else:
    print("Skipping pretrained download.")


## 6. Finetune
This runs the WavesFM CLI on your cache. Increase epochs and batch size for real training.


In [None]:
import sys
import subprocess
import torch

if not CACHE_PATH.exists():
    raise FileNotFoundError(f"Missing cache at {CACHE_PATH}. Run preprocessing first.")

device = "cuda" if torch.cuda.is_available() else "cpu"

train_cmd = [
    sys.executable, str(REPO_DIR / "main_finetune.py"),
    "--task", TASK,
    "--train-data", str(CACHE_PATH),
    "--val-split", str(VAL_SPLIT),
    "--output-dir", str(OUTPUT_DIR),
    "--model", MODEL_NAME,
    "--epochs", str(EPOCHS),
    "--batch-size", str(BATCH_SIZE),
    "--num-workers", str(NUM_WORKERS),
    "--device", device,
]

if USE_STRATIFIED_SPLIT:
    train_cmd.append("--stratified-split")
if FULL_FINETUNE:
    train_cmd.append("--sl-baseline")
if PRETRAINED_PATH:
    train_cmd += ["--finetune", str(PRETRAINED_PATH)]
if TASK == "deepmimo-beam":
    train_cmd += ["--deepmimo-n-beams", str(DEEP_MIMO_N_BEAMS)]

if RUN_TRAINING:
    subprocess.run(train_cmd, check=True)
else:
    print("Skipping training. Command:")
    print(" ".join(train_cmd))


## 7. Evaluate and visualize
Classification tasks produce a confusion matrix. Positioning tasks produce an error density plot.


In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from types import SimpleNamespace

from data import build_datasets
from main_finetune import build_model


def get_base_dataset(ds):
    base = ds
    while hasattr(base, "dataset"):
        base = base.dataset
    return base


device = "cuda" if torch.cuda.is_available() else "cpu"

train_ds, val_ds, task_info = build_datasets(
    TASK,
    CACHE_PATH,
    val_path=None,
    val_split=VAL_SPLIT,
    stratified_split=USE_STRATIFIED_SPLIT,
    seed=42,
    deepmimo_n_beams=DEEP_MIMO_N_BEAMS if TASK == "deepmimo-beam" else None,
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

args = SimpleNamespace(
    model=MODEL_NAME,
    global_pool="token",
    vis_img_size=224,
    vis_patch=16,
    iq_segment_len=16,
    iq_downsample=None,
    iq_target_len=256,
    use_conditional_ln=False,
    lora=False,
    lora_rank=8,
    lora_alpha=1.0,
)

model = build_model(args, task_info)

ckpt_path = OUTPUT_DIR / "best.pth"
if ckpt_path.exists():
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
    model.load_state_dict(state, strict=False)
    print("Loaded finetuned checkpoint:", ckpt_path)
elif PRETRAINED_PATH:
    ckpt = torch.load(PRETRAINED_PATH, map_location="cpu")
    state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
    model.load_state_dict(state, strict=False)
    print("Loaded pretrained checkpoint:", PRETRAINED_PATH)
else:
    print("No checkpoint found; evaluating random init.")

model.to(device)
model.eval()

if task_info.target_type == "classification":
    num_classes = task_info.num_outputs
    cm = np.zeros((num_classes, num_classes), dtype=np.int64)
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in val_loader:
            samples, targets = batch[0], batch[1]
            outputs = model(samples.to(device, non_blocking=True))
            preds = outputs.argmax(dim=1).cpu().numpy()
            t = targets.cpu().numpy().astype(int)
            for ti, pi in zip(t, preds):
                cm[ti, pi] += 1
            correct += (preds == t).sum()
            total += len(t)

    acc = correct / max(1, total)
    print(f"Val accuracy: {acc:.4f}")

    labels = getattr(get_base_dataset(val_ds), "labels", None)

    import matplotlib.pyplot as plt

    fig, ax = plt.subplots(figsize=(6, 6))
    im = ax.imshow(cm, cmap="Blues")
    ax.set_title("Confusion matrix")
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")

    if labels and len(labels) == num_classes and num_classes <= 30:
        ax.set_xticks(range(num_classes))
        ax.set_yticks(range(num_classes))
        ax.set_xticklabels(labels, rotation=90, fontsize=8)
        ax.set_yticklabels(labels, fontsize=8)

    fig.colorbar(im, ax=ax)
    plt.tight_layout()
    plt.show()

elif task_info.target_type == "position":
    coord_min = task_info.coord_min.to(device)
    coord_max = task_info.coord_max.to(device)

    def denorm(x):
        return (x + 1) * 0.5 * (coord_max - coord_min) + coord_min

    errors = []
    with torch.no_grad():
        for batch in val_loader:
            samples, targets = batch[0], batch[1]
            preds = model(samples.to(device, non_blocking=True))
            pred_m = denorm(preds)
            true_m = denorm(targets.to(device, non_blocking=True))
            dist = torch.linalg.norm(pred_m - true_m, dim=-1)
            errors.append(dist.cpu().numpy())

    errors = np.concatenate(errors)
    print(f"Mean error (m): {errors.mean():.3f}")
    print(f"Median error (m): {np.median(errors):.3f}")

    import matplotlib.pyplot as plt

    try:
        from scipy.stats import gaussian_kde

        x = np.linspace(0, np.percentile(errors, 99), 200)
        kde = gaussian_kde(errors)
        y = kde(x)

        fig, ax = plt.subplots(figsize=(6, 4))
        ax.plot(x, y, label="KDE")
        ax.hist(errors, bins=40, density=True, alpha=0.3, label="Histogram")
        ax.set_xlabel("Positioning error (m)")
        ax.set_ylabel("Density")
        ax.set_title("Positioning error density")
        ax.legend()
        plt.tight_layout()
        plt.show()
    except Exception as exc:
        print("KDE failed:", exc)
        plt.hist(errors, bins=40, density=True)
        plt.xlabel("Positioning error (m)")
        plt.ylabel("Density")
        plt.show()

else:
    errors = []
    with torch.no_grad():
        for batch in val_loader:
            samples, targets = batch[0], batch[1]
            preds = model(samples.to(device, non_blocking=True))
            err = (preds.squeeze() - targets.to(device, non_blocking=True).squeeze()).abs()
            errors.append(err.cpu().numpy())

    errors = np.concatenate(errors)
    print(f"MAE: {errors.mean():.4f}")


## Next steps
- Increase `EPOCHS` and `BATCH_SIZE` for real runs.
- Swap `TASK` to other datasets and update the download/preprocess settings.
- Use `OUTPUT_DIR` to track checkpoints and logs.
