
# 📘 Validate Saved CNN Checkpoints on Dataset Subsets

This notebook loads your previously trained CNN checkpoints (`.ckpt` saved by `Trainer.save_checkpoint`), 
reconstructs each model architecture from your saved configs (either a sidecar `*.arch.json` next to the checkpoint 
or a matching entry in one of your `best_models_*.pkl` files), and evaluates on selected dataset subsets 
(**garage**, **outdoor**, **indoor**, or explicit collection names).

**What you need in the working directory:**
- Your project files: `data_processing.py`, `model_generation.py`, `gpu_fucntion.py` (already here)
- Your model checkpoints in a folder (e.g., `model_storage/`)
- Your `best_models_*.pkl` files, if you didn't create `.arch.json` sidecars during training

> Tip: you can create a sidecar JSON config next to each checkpoint from your training loop:
> ```python
> with open(model_save_path.replace(".ckpt", ".arch.json"), "w") as f:
>     json.dump(config_dict["config"], f)
> ```


In [51]:

import os, glob, pickle, json
from typing import Dict, Any, List, Tuple

import numpy as np
import torch
import pandas as pd

# ensure local modules are importable
import sys
sys.path.append("/mnt/data")

from data_processing import get_dataset, combine_arrays, shuffle_array, split_combined_data
from model_generation import GeneratedModel
from gpu_fucntion import LightningWrapper

print(f"PyTorch: {torch.__version__}")
print("CUDA available:", torch.cuda.is_available())


PyTorch: 2.7.0
CUDA available: True


## 🔧 Parameters

In [52]:

# Choose ONE of the two: either a single checkpoint path, or a directory to iterate all checkpoints.
CKPT = ""  # e.g., "/mnt/data/model_storage/outdoor_only_run0_depth3_model5.ckpt"
CKPT_DIR = "/home/admindi/sbenites/WirelessLocation/validation/model_per_dataset_validation/model_storage"  # folder of .ckpt files, or "" to ignore

# Validation subsets: choose any of ["garage", "outdoor", "indoor", "all"] and/or "collections".
SUBSETS = ["garage", "outdoor", "indoor"]

# If you include "collections" in SUBSETS, list the exact collection names here:
COLLECTIONS = [
    # "reto_grande_outdoor", "equilatero_grande_outdoor"
]

# Database name
DB_NAME = "wifi_fingerprinting_data"

# Where to find best model pickles if you didn't save sidecar .arch.json files
PKL_GLOB = "/home/admindi/sbenites/WirelessLocation/**/best_models_*.pkl"

# Forward-pass batch size
BATCH_SIZE = 4096


## 🗂️ Collections & Subsets

In [53]:

ALL_COLLECTIONS = [
    "equilatero_grande_garage",
    "equilatero_grande_outdoor",
    "equilatero_medio_garage",
    "equilatero_medio_outdoor",
    "isosceles_grande_indoor",
    "isosceles_grande_outdoor",
    "isosceles_medio_outdoor",
    "obtusangulo_grande_outdoor",
    "obtusangulo_pequeno_outdoor",
    "reto_grande_garage",
    "reto_grande_indoor",
    "reto_grande_outdoor",
    "reto_medio_garage",
    "reto_medio_outdoor",
    "reto_n_quadrado_grande_indoor",
    "reto_n_quadrado_grande_outdoor",
    "reto_n_quadrado_pequeno_outdoor",
    "reto_pequeno_garage",
    "reto_pequeno_outdoor",
]

def group_by_location(collections: List[str], locations: List[str]) -> List[str]:
    return [name for name in collections if any(loc in name for loc in locations)]

SUBSET_MAP = {
    "garage": group_by_location(ALL_COLLECTIONS, ["garage"]),
    "outdoor": group_by_location(ALL_COLLECTIONS, ["outdoor"]),
    "indoor": group_by_location(ALL_COLLECTIONS, ["indoor"]),
    "all": ALL_COLLECTIONS,
}


## 📏 Metrics & Data Loading

In [54]:

def mse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return float(np.mean((y_true - y_pred) ** 2))

def rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return float(np.sqrt(mse(y_true, y_pred)))

def mae(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return float(np.mean(np.abs(y_true - y_pred)))

def r2_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    ss_res = np.sum((y_true - y_pred) ** 2)
    ss_tot = np.sum((y_true - np.mean(y_true, axis=0)) ** 2)
    return float(1 - ss_res / ss_tot) if ss_tot != 0 else float("nan")

def load_val_data(selected_collections: List[str], db_name: str) -> Tuple[np.ndarray, np.ndarray]:
    print(f"📡 Loading validation datasets: {selected_collections}")
    datasets = [get_dataset(name, db_name) for name in selected_collections]
    combined = combine_arrays(datasets)
    shuffled = shuffle_array(combined)
    X, y = split_combined_data(shuffled)
    return X, y


## 🔍 Find Architecture Config for a Checkpoint

In [55]:
# 🔁 Patch: also try to read the architecture from checkpoint hyperparameters
import torch

def _maybe_arch_from_hp(hp):
    if not isinstance(hp, dict):
        return None
    # Obvious candidates first
    for cand in ["config", "architecture_config", "arch_config", "model_config", "architecture", "cnn_config"]:
        if cand in hp and isinstance(hp[cand], dict):
            return hp[cand]
    # Heuristic fallback: if hp itself looks like an arch dict, use it
    likely_keys = {"layers","conv","conv_blocks","fc_layers","activation","dropout","hidden_sizes","kernel_size"}
    if any(k in hp for k in likely_keys):
        return hp
    return None

# Re-define find_arch_config_for_ckpt with hyperparams fallback
def find_arch_config_for_ckpt(ckpt_path: str, pkl_glob: str):
    import os, glob, json, pickle
    ckpt_stem = os.path.splitext(os.path.basename(ckpt_path))[0]

    # 1) Sidecar JSON
    sidecar = os.path.join(os.path.dirname(ckpt_path), ckpt_stem + ".arch.json")
    if os.path.isfile(sidecar):
        try:
            with open(sidecar, "r") as f:
                cfg = json.load(f)
            print(f"🧩 Found config sidecar: {sidecar}")
            return cfg
        except Exception as e:
            print(f"⚠️ Failed to read sidecar {sidecar}: {e}")

    # 2) best_models_*.pkl
    candidates = glob.glob(pkl_glob, recursive=True)
    for pkl_path in candidates:
        try:
            with open(pkl_path, "rb") as f:
                items = pickle.load(f)
            for it in items:
                if it.get("name") == ckpt_stem and "config" in it:
                    print(f"🧩 Found config in {pkl_path}")
                    return it["config"]
        except Exception as e:
            print(f"⚠️ Could not read {pkl_path}: {e}")

    # 3) Fallback: try the checkpoint's hyperparameters
    try:
        ckpt = torch.load(ckpt_path, map_location="cpu")
        hp = ckpt.get("hyper_parameters") or ckpt.get("hparams") or ckpt.get("hyperparams")
        cfg = _maybe_arch_from_hp(hp)
        if cfg is not None:
            print("🧩 Recovered config from checkpoint hyperparameters")
            return cfg
    except Exception as e:
        print(f"⚠️ Failed to inspect checkpoint hyperparameters: {e}")

    raise RuntimeError(
        "Could not find architecture config for:\n"
        f"  {ckpt_path}\n"
        "Provide a sidecar JSON or ensure the best_models_*.pkl for this run includes this exact model name."
    )


## 🧠 Load Model from Checkpoint & Evaluate

In [56]:

import torch
from typing import Dict, Any

def load_model_from_ckpt(ckpt_path: str, arch_config: Dict[str, Any], input_size: int, output_size: int, device: torch.device) -> LightningWrapper:
    model = GeneratedModel(input_size=input_size, output_size=output_size, architecture_config=arch_config)
    wrapper = LightningWrapper(
        model=model,
        train_data=(torch.empty(1, input_size), torch.empty(1, output_size)),  # dummy placeholders
        val_data=(torch.empty(1, input_size), torch.empty(1, output_size)),
        learning_rate=arch_config.get("learning_rate", 1e-3),
        weight_decay=arch_config.get("weight_decay", 0.0),
        optimizer_name=arch_config.get("optimizer", "adam"),
    )
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state_dict = ckpt.get("state_dict", ckpt)
    missing, unexpected = wrapper.load_state_dict(state_dict, strict=False)
    if missing:
        print(f"⚠️ Missing keys when loading: {missing[:5]}{'...' if len(missing) > 5 else ''}")
    if unexpected:
        print(f"⚠️ Unexpected keys when loading: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
    wrapper.to(device)
    wrapper.eval()
    return wrapper

@torch.inference_mode()
def evaluate(wrapper: LightningWrapper, X: np.ndarray, y: np.ndarray, device: torch.device, batch_size: int = 4096):
    X_t = torch.as_tensor(X, dtype=torch.float32, device=device)
    preds = []
    for i in range(0, X_t.size(0), batch_size):
        xb = X_t[i:i+batch_size]
        yb_pred = wrapper.model(xb)
        preds.append(yb_pred.detach().cpu().numpy())
    y_pred = np.vstack(preds)
    metrics = {
        "mse": mse(y, y_pred),
        "rmse": rmse(y, y_pred),
        "mae": mae(y, y_pred),
        "r2": r2_score(y, y_pred),
    }
    return metrics


## 📂 Choose Which Checkpoints to Validate

In [57]:

if CKPT and CKPT_DIR:
    raise SystemExit("Please set either CKPT or CKPT_DIR (not both).")

if not CKPT and not CKPT_DIR:
    raise SystemExit("Please set CKPT or CKPT_DIR to proceed.")

if CKPT:
    ckpt_paths = [CKPT]
else:
    if not os.path.isdir(CKPT_DIR):
        raise SystemExit(f"CKPT_DIR does not exist: {CKPT_DIR}")
    ckpt_paths = sorted([os.path.join(CKPT_DIR, f) for f in os.listdir(CKPT_DIR) if f.endswith(".ckpt")])
    if not ckpt_paths:
        raise SystemExit(f"No .ckpt files found in {CKPT_DIR}")

print(f"Found {len(ckpt_paths)} checkpoint(s).")
for p in ckpt_paths[:5]:
    print(' -', p)


Found 2240 checkpoint(s).
 - /home/admindi/sbenites/WirelessLocation/validation/model_per_dataset_validation/model_storage/all_data_run0_depth0_model0.ckpt
 - /home/admindi/sbenites/WirelessLocation/validation/model_per_dataset_validation/model_storage/all_data_run0_depth0_model1.ckpt
 - /home/admindi/sbenites/WirelessLocation/validation/model_per_dataset_validation/model_storage/all_data_run0_depth0_model2.ckpt
 - /home/admindi/sbenites/WirelessLocation/validation/model_per_dataset_validation/model_storage/all_data_run0_depth0_model3.ckpt
 - /home/admindi/sbenites/WirelessLocation/validation/model_per_dataset_validation/model_storage/all_data_run0_depth0_model4.ckpt


## 🧪 Run Validation

In [58]:

# Build subsets selected for this run
subset_to_collections = {}
for sub in SUBSETS:
    if sub == "collections":
        if not COLLECTIONS:
            raise SystemExit('-- You included "collections" but did not specify any COLLECTIONS')
        subset_to_collections[sub] = COLLECTIONS
    else:
        subset_to_collections[sub] = SUBSET_MAP[sub]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

results = []
for ckpt_path in ckpt_paths:
    print("\n" + "="*100)
    print("Checkpoint:", ckpt_path)
    arch_config = find_arch_config_for_ckpt(ckpt_path, PKL_GLOB)

    for subset_name, collections in subset_to_collections.items():
        X_val, y_val = load_val_data(collections, DB_NAME)
        input_size = X_val.shape[1]
        output_size = y_val.shape[1]

        wrapper = load_model_from_ckpt(ckpt_path, arch_config, input_size, output_size, device)
        metrics = evaluate(wrapper, X_val, y_val, device, batch_size=BATCH_SIZE)

        row = {
            "checkpoint": os.path.basename(ckpt_path),
            "subset": subset_name,
            "collections": ",".join(collections) if isinstance(collections, list) else str(collections),
            **metrics
        }
        print(row)
        results.append(row)

df = pd.DataFrame(results).sort_values(["checkpoint", "subset"]).reset_index(drop=True)
df


Using device: cuda

Checkpoint: /home/admindi/sbenites/WirelessLocation/validation/model_per_dataset_validation/model_storage/all_data_run0_depth0_model0.ckpt


RuntimeError: Could not find architecture config for:
  /home/admindi/sbenites/WirelessLocation/validation/model_per_dataset_validation/model_storage/all_data_run0_depth0_model0.ckpt
Provide a sidecar JSON or ensure the best_models_*.pkl for this run includes this exact model name.