
# 🧪 Validate Models from W&B CSV (Reconstruct + Evaluate on Collections)

This notebook:
1. Reads your **W&B CSV export** with run metadata/hparams.
2. Reconstructs each model's **architecture config** from the CSV (JSON-like columns or flattened keys).
3. Locates the corresponding **`.ckpt`** (by `name` → `<name>.ckpt`).
4. Writes a **sidecar** `<name>.arch.json` next to each checkpoint (future-proof).
5. **Validates** each reconstructed model over your chosen collections/subsets.

> Works entirely with the files you already have (CSV + checkpoints + your project code).


## 🔧 Parameters

In [None]:

# Path to your W&B CSV export
CSV_PATH = "/mnt/data/wandb_export_2025-08-22T14_23_07.979+01_00.csv"  # change if needed

# Where to search for checkpoints (add more roots as needed)
SEARCH_ROOTS = [
    "/home/admindi/sbenites/WirelessLocation",
    "/mnt/data",
]

# Optional: if your CSV already has a 'ckpt_path' column, set this True to trust it first
TRUST_CKPT_PATH_COLUMN = True

# Validation subsets
SUBSETS = ["garage", "outdoor", "indoor"]  # or ["all"] or include "collections"
# Only used if you include "collections" in SUBSETS
COLLECTIONS = [
    # "reto_grande_outdoor", "equilatero_grande_outdoor"
]

# MongoDB / data loading
DB_NAME = "wifi_fingerprinting_data"
BATCH_SIZE = 4096

# Project import roots (so the notebook can import your modules)
PROJECT_PATHS = [
    "/home/admindi/sbenites/WirelessLocation",
    "/mnt/data",
]


## 📦 Imports & Environment

In [None]:

import os, glob, json, ast, math, re
from typing import Dict, Any, List, Tuple, Optional

import numpy as np
import pandas as pd
import torch

import sys
for p in PROJECT_PATHS:
    if p not in sys.path:
        sys.path.append(p)

from data_processing import get_dataset, combine_arrays, shuffle_array, split_combined_data
from model_generation import GeneratedModel
from gpu_fucntion import LightningWrapper

print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())


## 🗂️ Collections & Subsets

In [None]:

ALL_COLLECTIONS = [
    "equilatero_grande_garage",
    "equilatero_grande_outdoor",
    "equilatero_medio_garage",
    "equilatero_medio_outdoor",
    "isosceles_grande_indoor",
    "isosceles_grande_outdoor",
    "isosceles_medio_outdoor",
    "obtusangulo_grande_outdoor",
    "obtusangulo_pequeno_outdoor",
    "reto_grande_garage",
    "reto_grande_indoor",
    "reto_grande_outdoor",
    "reto_medio_garage",
    "reto_medio_outdoor",
    "reto_n_quadrado_grande_indoor",
    "reto_n_quadrado_grande_outdoor",
    "reto_n_quadrado_pequeno_outdoor",
    "reto_pequeno_garage",
    "reto_pequeno_outdoor",
]

def group_by_location(collections: List[str], locations: List[str]) -> List[str]:
    return [name for name in collections if any(loc in name for loc in locations)]

SUBSET_MAP = {
    "garage": group_by_location(ALL_COLLECTIONS, ["garage"]),
    "outdoor": group_by_location(ALL_COLLECTIONS, ["outdoor"]),
    "indoor": group_by_location(ALL_COLLECTIONS, ["indoor"]),
    "all": ALL_COLLECTIONS,
}


## 🧩 Parse Architecture Config From CSV Rows

In [None]:

def parse_jsonish(text: str) -> Optional[Dict[str, Any]]:
    """Try to parse a JSON/dict-like string to a dict."""
    if not isinstance(text, str) or not text.strip():
        return None
    s = text.strip()
    # Common cleanup: single quotes -> double quotes, True/False/None -> JSON
    candidates = [s]
    candidates.append(s.replace("'", '"'))
    repl = (('None','null'), ('True','true'), ('False','false'))
    s2 = s
    for a,b in repl:
        s2 = re.sub(r'\b'+a+r'\b', b, s2)
    candidates.append(s2)
    candidates.append(s2.replace("'", '"'))
    for c in candidates:
        try:
            return json.loads(c)
        except Exception:
            pass
        try:
            obj = ast.literal_eval(c)
            if isinstance(obj, dict):
                return obj
        except Exception:
            pass
    return None

def set_in_nested(d: Dict[str, Any], path_parts: List[Any], value: Any):
    cur = d
    for i, key in enumerate(path_parts):
        is_last = (i == len(path_parts) - 1)
        # numeric indices indicate lists
        if isinstance(key, int):
            if not isinstance(cur, list):
                # convert current level to list
                cur_list = []
                # if cur was empty dict, replace; if dict with content, we can't merge safely
                # so just replace (best effort)
                cur.clear()
                cur_list = []
                cur = cur_list
            # extend list to index
            while len(cur) <= key:
                cur.append({})
            if is_last:
                cur[key] = value
            else:
                if not isinstance(cur[key], (dict, list)):
                    cur[key] = {}
                cur = cur[key]
        else:
            if is_last:
                cur[key] = value
            else:
                if key not in cur or not isinstance(cur[key], (dict, list)):
                    cur[key] = {}
                cur = cur[key]

def parse_flattened_to_config(row: pd.Series) -> Optional[Dict[str, Any]]:
    """Reconstruct a nested config dict from columns like 'config.layers.0.filters' = 64."""
    colnames = list(row.index)
    # Candidate prefixes that indicate config-ish fields
    prefixes = ["config.", "architecture_config.", "arch_config.", "model_config.", "architecture.", "cnn_config."]

    candidates = [c for c in colnames if any(c.startswith(p) for p in prefixes)]
    if not candidates:
        return None

    cfg: Dict[str, Any] = {}
    for c in candidates:
        val = row[c]
        if pd.isna(val):
            continue
        # Try to coerce to python types
        if isinstance(val, str):
            if val.lower() in ("true","false"):
                val = (val.lower() == "true")
            else:
                # try number
                try:
                    if "." in val:
                        val = float(val)
                        if val.is_integer():
                            val = int(val)
                    else:
                        val = int(val)
                except Exception:
                    # try jsonish for lists/dicts
                    parsed = parse_jsonish(val)
                    val = parsed if parsed is not None else val

        # Build path parts: split by dot, and convert '[i]' to index
        parts = []
        for part in c.split("."):
            m = re.match(r"^(.*?)(\[(\d+)\])?$", part)
            if not m:
                parts.append(part)
                continue
            name, _, idx = m.groups()
            if name:
                parts.append(name)
            if idx is not None:
                parts.append(int(idx))

        # Remove leading prefix token like 'config'
        if parts and parts[0] in ("config","architecture_config","arch_config","model_config","architecture","cnn_config"):
            parts = parts[1:]
        if not parts:
            continue

        set_in_nested(cfg, parts, val)
    return cfg if cfg else None

def row_to_arch_config(row: pd.Series) -> Optional[Dict[str, Any]]:
    # 1) Direct JSON-like columns
    direct_cols = [
        "architecture_config", "arch_config", "config", "model_config",
        "architecture", "cnn_config", "config_json", "hparams", "hyperparameters",
    ]
    for col in direct_cols:
        if col in row and isinstance(row[col], str):
            cfg = parse_jsonish(row[col])
            if isinstance(cfg, dict):
                return cfg

    # 2) Flattened columns
    cfg = parse_flattened_to_config(row)
    if isinstance(cfg, dict) and cfg:
        return cfg

    return None

def pick_model_name(row: pd.Series) -> Optional[str]:
    """Try common columns for checkpoint stem (model name)."""
    for key in ["name","model_name","ckpt_stem","run_name","id","slug"]:
        if key in row and isinstance(row[key], str) and row[key].strip():
            # sanitize forbidden path chars
            base = os.path.basename(row[key].strip())
            stem = os.path.splitext(base)[0]
            return stem
    return None


## 📥 Load CSV and Build (name → config, ckpt) list

In [None]:

df_raw = pd.read_csv(CSV_PATH)
print("CSV rows:", len(df_raw))
print("CSV columns:", list(df_raw.columns)[:30], "...")

entries = []  # list of dicts: {name, arch_config, ckpt_path}

def find_ckpt_for_name(name: str) -> Optional[str]:
    # If CSV contains an explicit ckpt path, prefer it
    if TRUST_CKPT_PATH_COLUMN:
        for col in ["ckpt","ckpt_path","checkpoint","checkpoint_path"]:
            if col in df_raw.columns:
                # try to filter row by name; else brute force across column values
                candidates = df_raw[df_raw.get("name", "").astype(str) == name][col].dropna().unique() if "name" in df_raw.columns else df_raw[col].dropna().unique()
                for c in candidates:
                    c = str(c)
                    if c.endswith(f"{name}.ckpt") and os.path.isfile(c):
                        return c
    # Search filesystem
    for root in SEARCH_ROOTS:
        pattern = os.path.join(root, "**", f"{name}.ckpt")
        found = glob.glob(pattern, recursive=True)
        if found:
            # pick most recently modified
            found.sort(key=lambda p: os.path.getmtime(p), reverse=True)
            return found[0]
    return None

for idx, row in df_raw.iterrows():
    name = pick_model_name(row)
    cfg = row_to_arch_config(row)
    if not name or not isinstance(cfg, dict):
        continue
    ckpt = find_ckpt_for_name(name)
    entries.append({"name": name, "arch_config": cfg, "ckpt_path": ckpt})

len(entries), entries[:3]


## 🧷 Write Sidecar `.arch.json` and Report Coverage

In [None]:

written = 0
missing_ckpt = []
for e in entries:
    name, cfg, ckpt = e["name"], e["arch_config"], e["ckpt_path"]
    if not ckpt:
        missing_ckpt.append(name)
        continue
    sidecar = ckpt.replace(".ckpt", ".arch.json")
    try:
        with open(sidecar, "w") as f:
            json.dump(cfg, f)
        written += 1
        e["sidecar"] = sidecar
    except Exception as ex:
        print(f"⚠️ Failed sidecar for {name}: {ex}")

print(f"Sidecars written: {written}")
if missing_ckpt:
    print(f"❗ {len(missing_ckpt)} models missing .ckpt file (showing up to 20):")
    for n in missing_ckpt[:20]:
        print(" -", n)

# Keep only entries that have ckpt and sidecar for validation
entries_ready = [e for e in entries if e.get("ckpt_path") and e.get("sidecar")]
print(f"\nEntries ready for validation: {len(entries_ready)}")


## 📏 Metrics & Data Loading

In [None]:

def mse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return float(np.mean((y_true - y_pred) ** 2))

def rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return float(np.sqrt(mse(y_true, y_pred)))

def mae(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return float(np.mean(np.abs(y_true - y_pred)))

def r2_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    ss_res = np.sum((y_true - y_pred) ** 2)
    ss_tot = np.sum((y_true - np.mean(y_true, axis=0)) ** 2)
    return float(1 - ss_res / ss_tot) if ss_tot != 0 else float("nan")

def load_val_data(selected_collections: List[str], db_name: str) -> Tuple[np.ndarray, np.ndarray]:
    print(f"📡 Loading validation datasets: {selected_collections}")
    datasets = [get_dataset(name, db_name) for name in selected_collections]
    combined = combine_arrays(datasets)
    shuffled = shuffle_array(combined)
    X, y = split_combined_data(shuffled)
    return X, y


## 🧠 Load Model & Validate

In [None]:

def load_model_from_arch_and_ckpt(ckpt_path: str, arch_config: Dict[str, Any], input_size: int, output_size: int, device: torch.device) -> LightningWrapper:
    model = GeneratedModel(input_size=input_size, output_size=output_size, architecture_config=arch_config)
    wrapper = LightningWrapper(
        model=model,
        train_data=(torch.empty(1, input_size), torch.empty(1, output_size)),
        val_data=(torch.empty(1, input_size), torch.empty(1, output_size)),
        learning_rate=arch_config.get("learning_rate", 1e-3),
        weight_decay=arch_config.get("weight_decay", 0.0),
        optimizer_name=arch_config.get("optimizer", "adam"),
    )
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state_dict = ckpt.get("state_dict", ckpt)
    missing, unexpected = wrapper.load_state_dict(state_dict, strict=False)
    if missing:
        print(f"⚠️ Missing keys when loading: {missing[:5]}{'...' if len(missing) > 5 else ''}")
    if unexpected:
        print(f"⚠️ Unexpected keys when loading: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
    wrapper.to(device)
    wrapper.eval()
    return wrapper

@torch.inference_mode()
def evaluate(wrapper: LightningWrapper, X: np.ndarray, y: np.ndarray, device: torch.device, batch_size: int = 4096):
    X_t = torch.as_tensor(X, dtype=torch.float32, device=device)
    preds = []
    for i in range(0, X_t.size(0), batch_size):
        xb = X_t[i:i+batch_size]
        yb_pred = wrapper.model(xb)
        preds.append(yb_pred.detach().cpu().numpy())
    y_pred = np.vstack(preds)
    metrics = {
        "mse": mse(y, y_pred),
        "rmse": rmse(y, y_pred),
        "mae": mae(y, y_pred),
        "r2": r2_score(y, y_pred),
    }
    return metrics


## 🚀 Run Validation

In [None]:

subset_to_collections = {}
for sub in SUBSETS:
    if sub == "collections":
        if not COLLECTIONS:
            raise SystemExit('You included "collections" but COLLECTIONS is empty.')
        subset_to_collections[sub] = COLLECTIONS
    else:
        subset_to_collections[sub] = SUBSET_MAP[sub]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

results = []
for e in entries_ready:
    name, ckpt = e["name"], e["ckpt_path"]
    arch_config = e["arch_config"]
    print("\n" + "="*100)
    print("Model:", name)
    for subset_name, collections in subset_to_collections.items():
        X_val, y_val = load_val_data(collections, DB_NAME)
        input_size = X_val.shape[1]
        output_size = y_val.shape[1]
        wrapper = load_model_from_arch_and_ckpt(ckpt, arch_config, input_size, output_size, device)
        metrics = evaluate(wrapper, X_val, y_val, device, batch_size=BATCH_SIZE)
        row = {
            "name": name,
            "ckpt": ckpt,
            "subset": subset_name,
            "collections": ",".join(collections),
            **metrics
        }
        print(row)
        results.append(row)

df_results = pd.DataFrame(results).sort_values(["name","subset"]).reset_index(drop=True)
df_results


## 💾 Save Results

In [None]:

out_csv = "/mnt/data/validation_results_from_wandb.csv"
df_results.to_csv(out_csv, index=False)
print("Saved:", out_csv)
