In [None]:
#!/usr/bin/env python3
"""
NAS for INS calibration (PyTorch-only, CSV or .pth inputs)

Features:
- Train on one dataset and test on another (CSV or .pth supported).
- Multi-output regression (predict bias, scale factors, etc.).
- Simple evolutionary / random-search over MLP architectures.
- Prints training / validation / final test metrics (MSE, MAE, R2).
- Saves best model and scalers to a .pth checkpoint.

Usage examples:
# CSV inputs:
python nas_ins_calibration_full.py \
  --train_csv ./data/train_ins.csv --test_csv ./data/test_ins.csv \
  --target_cols bias_x,bias_y,scale_x,scale_y \
  --input_cols ax,ay,az,gx,gy,gz,temperature \
  --pop_size 6 --generations 4 --search_epochs 6 --final_epochs 40 \
  --device cuda

# .pth inputs (saved as {'X': np_or_tensor, 'Y': ...} or tuple (X,Y)):
python nas_ins_calibration_full.py \
  --train_pth ./data/train_ins.pth --test_pth ./data/test_ins.pth \
  --target_cols bias_x,bias_y --device cuda
"""
import argparse
import random
import os
from copy import deepcopy
from typing import List, Tuple, Any

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# --------------------- Utilities ---------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class Scaler:
    def __init__(self, eps: float = 1e-12):
        self.mean = None
        self.std = None
        self.eps = eps

    def fit(self, arr: np.ndarray):
        arr = np.asarray(arr, dtype=np.float64)
        self.mean = arr.mean(axis=0, keepdims=True)
        self.std = arr.std(axis=0, keepdims=True)
        self.std[self.std < self.eps] = 1.0

    def transform(self, arr: np.ndarray) -> np.ndarray:
        return (np.asarray(arr, dtype=np.float64) - self.mean) / self.std

    def inverse_transform(self, scaled: np.ndarray) -> np.ndarray:
        return np.asarray(scaled, dtype=np.float64) * self.std + self.mean

# --------------------- Metrics ---------------------
def mse_per_output(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    return np.mean((y_true - y_pred) ** 2, axis=0)

def mae_per_output(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    return np.mean(np.abs(y_true - y_pred), axis=0)

def r2_per_output(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    ss_res = np.sum((y_true - y_pred) ** 2, axis=0)
    ss_tot = np.sum((y_true - np.mean(y_true, axis=0, keepdims=True)) ** 2, axis=0)
    with np.errstate(divide='ignore', invalid='ignore'):
        r2 = 1.0 - ss_res / np.where(ss_tot == 0, 1.0, ss_tot)
    r2 = np.nan_to_num(r2, nan=0.0, posinf=0.0, neginf=0.0)
    return r2

def print_metrics(prefix: str, y_true: np.ndarray, y_pred: np.ndarray, target_names: List[str]):
    mse = mse_per_output(y_true, y_pred)
    mae = mae_per_output(y_true, y_pred)
    r2 = r2_per_output(y_true, y_pred)
    print(f"\n{prefix} metrics (per-target):")
    for i, name in enumerate(target_names):
        print(f"  {name:20s}  MSE={mse[i]:.6e}  MAE={mae[i]:.6e}  R2={r2[i]:.4f}")
    print(f"  -> Average: MSE={mse.mean():.6e}  MAE={mae.mean():.6e}  R2={r2.mean():.4f}\n")

# --------------------- Data loading helpers ---------------------
def load_csv_as_arrays(csv_path: str, input_cols: List[str], target_cols: List[str], sep: str = ','):
    df = pd.read_csv(csv_path, sep=sep)
    if input_cols is None:
        input_cols = [c for c in df.columns if c not in target_cols]
    X = df[input_cols].values.astype(np.float64)
    Y = df[target_cols].values.astype(np.float64)
    return X, Y, input_cols, target_cols

def load_pth_as_arrays(pth_path: str):
    data = torch.load(pth_path, map_location='cpu')
    X = None
    Y = None
    # dict-like with keys
    if isinstance(data, dict):
        for k in ('X', 'x', 'inputs', 'features', 'data'):
            if k in data:
                X = data[k]
                break
        for k in ('Y', 'y', 'targets', 'labels'):
            if k in data:
                Y = data[k]
                break
        # fallback: maybe tuple saved under some key
        if X is None and 'dataset' in data and isinstance(data['dataset'], (list, tuple)) and len(data['dataset']) >= 2:
            X, Y = data['dataset'][0], data['dataset'][1]
    # tuple/list directly saved
    if X is None and isinstance(data, (list, tuple)) and len(data) >= 2:
        X, Y = data[0], data[1]
    if X is None or Y is None:
        raise ValueError(f"Could not find X and Y in pth file: {pth_path}. Expected keys: X/Y, inputs/targets, or saved tuple (X,Y).")
    if torch.is_tensor(X):
        X = X.cpu().numpy()
    if torch.is_tensor(Y):
        Y = Y.cpu().numpy()
    X = np.asarray(X, dtype=np.float64)
    Y = np.asarray(Y, dtype=np.float64)
    return X, Y

def make_dataloader(X: np.ndarray, Y: np.ndarray, batch_size: int, shuffle: bool):
    tX = torch.from_numpy(X).float()
    tY = torch.from_numpy(Y).float()
    ds = TensorDataset(tX, tY)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle)

# --------------------- Search-space & model ---------------------
class ArchSpec:
    def __init__(self, num_layers: int, hidden_dims: List[int], activation: str, dropout: float, use_bn: bool):
        self.num_layers = num_layers
        self.hidden_dims = hidden_dims
        self.activation = activation
        self.dropout = float(dropout)
        self.use_bn = use_bn

    def mutate(self):
        child = deepcopy(self)
        if random.random() < 0.4:
            child.num_layers = max(1, min(6, child.num_layers + random.choice([-1, 1])))
        if random.random() < 0.6:
            idx = random.randrange(len(child.hidden_dims))
            factor = random.choice([1, 2]) if random.random() < 0.8 else random.choice([0.5, 1])
            child.hidden_dims[idx] = int(max(8, min(2048, int(child.hidden_dims[idx] * factor))))
        if random.random() < 0.2:
            child.activation = random.choice(['relu', 'selu', 'tanh', 'gelu'])
        if random.random() < 0.3:
            child.dropout = float(max(0.0, min(0.6, child.dropout + random.uniform(-0.15, 0.15))))
        if random.random() < 0.2:
            child.use_bn = not child.use_bn
        if len(child.hidden_dims) < child.num_layers:
            child.hidden_dims += [child.hidden_dims[-1]] * (child.num_layers - len(child.hidden_dims))
        child.hidden_dims = child.hidden_dims[:child.num_layers]
        return child

    def __repr__(self):
        return f"ArchSpec(layers={self.num_layers}, dims={self.hidden_dims}, act={self.activation}, drop={self.dropout:.2f}, bn={self.use_bn})"

class MLP(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, spec: ArchSpec):
        super().__init__()
        layers = []
        in_dim = input_dim
        for i in range(spec.num_layers):
            out_dim = spec.hidden_dims[i]
            layers.append(nn.Linear(in_dim, out_dim))
            if spec.use_bn:
                layers.append(nn.BatchNorm1d(out_dim))
            layers.append(self._get_activation(spec.activation))
            if spec.dropout > 0:
                layers.append(nn.Dropout(spec.dropout))
            in_dim = out_dim
        layers.append(nn.Linear(in_dim, output_dim))
        self.net = nn.Sequential(*layers)

    @staticmethod
    def _get_activation(name: str):
        if name == 'relu':
            return nn.ReLU(inplace=True)
        if name == 'selu':
            return nn.SELU(inplace=True)
        if name == 'tanh':
            return nn.Tanh()
        if name == 'gelu':
            return nn.GELU()
        return nn.ReLU(inplace=True)

    def forward(self, x):
        return self.net(x)

# --------------------- Training/Eval loops ---------------------
def train_one_epoch(model: nn.Module, loader: DataLoader, criterion, optimizer, device: torch.device):
    model.train()
    running_loss = 0.0
    n = 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        optimizer.zero_grad()
        pred = model(xb)
        loss = criterion(pred, yb)
        loss.backward()
        optimizer.step()
        bs = xb.shape[0]
        running_loss += loss.item() * bs
        n += bs
    return running_loss / max(1, n)

@torch.no_grad()
def evaluate_model(model: nn.Module, loader: DataLoader, device: torch.device):
    model.eval()
    preds = []
    trues = []
    for xb, yb in loader:
        xb = xb.to(device)
        out = model(xb).cpu().numpy()
        preds.append(out)
        trues.append(yb.numpy())
    if len(preds) == 0:
        return np.zeros((0,)), np.zeros((0,))
    preds = np.vstack(preds)
    trues = np.vstack(trues)
    return trues, preds

# --------------------- NAS search ---------------------
def random_arch_spec():
    num_layers = random.choice([1, 2, 3, 4])
    hidden_dims = [random.choice([32, 64, 128, 256, 512]) for _ in range(num_layers)]
    activation = random.choice(['relu', 'selu', 'tanh', 'gelu'])
    dropout = random.choice([0.0, 0.1, 0.2, 0.3])
    use_bn = random.choice([True, False])
    return ArchSpec(num_layers=num_layers, hidden_dims=hidden_dims, activation=activation, dropout=dropout, use_bn=use_bn)

def search_architectures(train_loader: DataLoader, val_loader: DataLoader, input_dim: int, output_dim: int,
                         device: torch.device, pop_size: int = 6, generations: int = 3, search_epochs: int = 5,
                         lr: float = 1e-3):
    population = [random_arch_spec() for _ in range(pop_size)]
    best_spec = None
    best_score = float('inf')  # lower val MSE better
    criterion = nn.MSELoss()

    for gen in range(generations):
        print(f"\n=== Generation {gen+1}/{generations} ===")
        scored = []
        for i, spec in enumerate(population):
            print(f" Candidate {i+1}/{len(population)}: {spec}")
            model = MLP(input_dim, output_dim, spec).to(device)
            optimizer = optim.Adam(model.parameters(), lr=lr)
            for ep in range(search_epochs):
                tr_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
                trues, preds = evaluate_model(model, val_loader, device)
                val_mse_avg = float(np.mean(mse_per_output(trues, preds))) if trues.size else float('inf')
                print(f"  ep {ep+1}/{search_epochs}  train_loss={tr_loss:.6e}  val_mse_avg={val_mse_avg:.6e}")
            trues, preds = evaluate_model(model, val_loader, device)
            val_mse_avg = float(np.mean(mse_per_output(trues, preds))) if trues.size else float('inf')
            scored.append((val_mse_avg, spec))
            if val_mse_avg < best_score:
                best_score = val_mse_avg
                best_spec = deepcopy(spec)
                print(f"  -> New best (val_mse_avg={best_score:.6e}): {best_spec}")
        scored.sort(key=lambda x: x[0])
        keep = [s for _, s in scored[:max(1, len(scored)//2)]]
        new_pop = deepcopy(keep)
        while len(new_pop) < pop_size:
            parent = random.choice(keep)
            new_pop.append(parent.mutate())
        population = new_pop

    print(f"\nSearch finished. Best spec: {best_spec} (val_mse_avg={best_score:.6e})")
    return best_spec

# --------------------- Main ---------------------
def main():
    parser = argparse.ArgumentParser(description="NAS for INS calibration (PyTorch only). Supports CSV or .pth inputs.")
    parser.add_argument('--train_csv', type=str, default=None, help='path to train CSV (optional)')
    parser.add_argument('--test_csv', type=str, default=None, help='path to test CSV (optional)')
    parser.add_argument('--train_pth', type=str, default=None, help='path to train .pth (optional; overrides CSV if given)')
    parser.add_argument('--test_pth', type=str, default=None, help='path to test .pth (optional; overrides CSV if given)')
    parser.add_argument('--input_cols', type=str, default=None, help='comma-separated input columns (optional)')
    parser.add_argument('--target_cols', type=str, required=True, help='comma-separated target columns (e.g., bias_x,bias_y)')
    parser.add_argument('--sep', type=str, default=',', help='CSV separator')
    parser.add_argument('--pop_size', type=int, default=6)
    parser.add_argument('--generations', type=int, default=3)
    parser.add_argument('--search_epochs', type=int, default=6)
    parser.add_argument('--final_epochs', type=int, default=40)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--val_fraction', type=float, default=0.1)
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--save', type=str, default='nas_ins_best.pth')
    args = parser.parse_args()

    set_seed(args.seed)
    device = torch.device(args.device)

    # Load train data
    if args.train_pth:
        X_train_all, Y_train_all = load_pth_as_arrays(args.train_pth)
        input_cols = None
        inferred_input_cols = None
    elif args.train_csv:
        input_cols = None if args.input_cols is None else [c.strip() for c in args.input_cols.split(',') if c.strip()]
        target_cols = [c.strip() for c in args.target_cols.split(',') if c.strip()]
        X_train_all, Y_train_all, inferred_input_cols, _ = load_csv_as_arrays(args.train_csv, input_cols, target_cols, sep=args.sep)
        if input_cols is None:
            input_cols = inferred_input_cols
    else:
        raise ValueError("Provide either --train_pth or --train_csv")

    # Load test data
    if args.test_pth:
        X_test, Y_test = load_pth_as_arrays(args.test_pth)
    elif args.test_csv:
        # If input_cols was None earlier and inferred, use same input list for test
        target_cols = [c.strip() for c in args.target_cols.split(',') if c.strip()]
        X_test, Y_test, _, _ = load_csv_as_arrays(args.test_csv, input_cols, target_cols, sep=args.sep)
    else:
        raise ValueError("Provide either --test_pth or --test_csv")

    if X_train_all.shape[0] != Y_train_all.shape[0]:
        raise ValueError("Train X and Y first dimensions disagree")
    if X_test.shape[0] != Y_test.shape[0]:
        raise ValueError("Test X and Y first dimensions disagree")

    input_dim = X_train_all.shape[1]
    output_dim = Y_train_all.shape[1]
    print(f"Input dim: {input_dim}, Output dim: {output_dim}")
    print(f"Input cols (if known): {input_cols}")
    print(f"Target cols: {args.target_cols}")

    # prepare train/val split for search
    n_total = X_train_all.shape[0]
    n_val = max(1, int(args.val_fraction * n_total))
    n_train = n_total - n_val
    perm = np.random.permutation(n_total)
    X_sh = X_train_all[perm]
    Y_sh = Y_train_all[perm]
    X_train = X_sh[:n_train]
    Y_train = Y_sh[:n_train]
    X_val = X_sh[n_train:]
    Y_val = Y_sh[n_train:]

    # Fit scalers on train (search) split
    x_scaler = Scaler()
    x_scaler.fit(X_train)
    X_train_s = x_scaler.transform(X_train)
    X_val_s = x_scaler.transform(X_val)
    X_test_s = x_scaler.transform(X_test)

    y_scaler = Scaler()
    y_scaler.fit(Y_train)
    Y_train_s = y_scaler.transform(Y_train)
    Y_val_s = y_scaler.transform(Y_val)
    Y_test_s = y_scaler.transform(Y_test)  # used only as scaled reference if needed

    train_loader = make_dataloader(X_train_s.astype(np.float32), Y_train_s.astype(np.float32), batch_size=args.batch_size, shuffle=True)
    val_loader = make_dataloader(X_val_s.astype(np.float32), Y_val_s.astype(np.float32), batch_size=args.batch_size, shuffle=False)
    test_loader_scaled = make_dataloader(X_test_s.astype(np.float32), Y_test_s.astype(np.float32), batch_size=args.batch_size, shuffle=False)
    # loader for final evaluation with original Y
    test_loader_orig = make_dataloader(X_test_s.astype(np.float32), Y_test.astype(np.float32), batch_size=args.batch_size, shuffle=False)

    # Search
    best_spec = search_architectures(train_loader, val_loader, input_dim, output_dim, device,
                                     pop_size=args.pop_size, generations=args.generations,
                                     search_epochs=args.search_epochs, lr=args.lr)

    # Retrain best on full training data (train + val)
    X_full = np.vstack([X_train, X_val])
    Y_full = np.vstack([Y_train, Y_val])
    # fit scalers on full training
    x_scaler_full = Scaler(); x_scaler_full.fit(X_full)
    y_scaler_full = Scaler(); y_scaler_full.fit(Y_full)
    X_full_s = x_scaler_full.transform(X_full)
    Y_full_s = y_scaler_full.transform(Y_full)
    X_test_s_full = x_scaler_full.transform(X_test)  # re-scale test using full-train scalers

    full_loader = make_dataloader(X_full_s.astype(np.float32), Y_full_s.astype(np.float32), batch_size=args.batch_size, shuffle=True)
    test_loader_final_scaled = make_dataloader(X_test_s_full.astype(np.float32), Y_test.astype(np.float32), batch_size=args.batch_size, shuffle=False)

    model = MLP(input_dim, output_dim, best_spec).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=max(10, args.final_epochs//3), gamma=0.5)

    print("\nRetraining best architecture on FULL training set...")
    for ep in range(1, args.final_epochs + 1):
        tr_loss = train_one_epoch(model, full_loader, criterion, optimizer, device)
        scheduler.step()
        if ep % 5 == 0 or ep == 1 or ep == args.final_epochs:
            # compute quick val on test scaled just for monitoring (not the final evaluation)
            trues, preds_scaled = evaluate_model(model, test_loader_final_scaled, device)
            # preds_scaled here are in original Y space because test_loader_final_scaled contains original Y (we used Y_test in loader), so we must be careful:
            # Instead we'll compute scaled predictions by passing scaled inputs and then inverse-transform using y_scaler_full.
            with torch.no_grad():
                preds_s = []
                for xb_s, _ in make_dataloader(X_test_s_full.astype(np.float32), Y_test_s.astype(np.float32), batch_size=args.batch_size, shuffle=False):
                    preds_s.append(model(xb_s.to(device)).cpu().numpy())
                preds_s = np.vstack(preds_s)
            preds_unscaled = y_scaler_full.inverse_transform(preds_s)
            val_mse = float(np.mean(mse_per_output(Y_test, preds_unscaled)))
            print(f" Final Ep {ep}/{args.final_epochs}  train_loss(scaled)={tr_loss:.6e}  test_val_mse={val_mse:.6e}")

    # Save checkpoint (model state + spec + scalers + columns)
    save_obj = {
        'model_state_dict': model.state_dict(),
        'spec': best_spec.__dict__,
        'x_scaler_mean': x_scaler_full.mean,
        'x_scaler_std': x_scaler_full.std,
        'y_scaler_mean': y_scaler_full.mean,
        'y_scaler_std': y_scaler_full.std,
        'input_cols': input_cols,
        'target_cols': [c.strip() for c in args.target_cols.split(',') if c.strip()]
    }
    torch.save(save_obj, args.save)
    print(f"\nSaved checkpoint to {args.save}")

    # Final evaluation on test dataset (original scale)
    print("\nFinal evaluation on TEST dataset (original scale):")
    model.eval()
    preds = []
    trues = []
    with torch.no_grad():
        for xb_s, y_orig in test_loader_orig:
            xb_s = xb_s.to(device)
            out_scaled = model(xb_s).cpu().numpy()
            out_unscaled = y_scaler_full.inverse_transform(out_scaled)
            preds.append(out_unscaled)
            trues.append(y_orig.numpy())
    preds = np.vstack(preds)
    trues = np.vstack(trues)
    print_metrics("TEST", trues, preds, save_obj['target_cols'])

    # show first few predictions
    n_show = min(8, preds.shape[0])
    print(f"First {n_show} predictions vs truth (per-target):")
    for i in range(n_show):
        pred_str = ", ".join(f"{v:.6e}" for v in preds[i])
        true_str = ", ".join(f"{v:.6e}" for v in trues[i])
        print(f"  row {i+1:03d}: pred = [{pred_str}]  |  true = [{true_str}]")

    print("\nDone.")

if __name__ == "__main__":
    main()
