In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import ConcatDataset, DataLoader, SubsetRandomSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
from skopt import gp_minimize, load
from skopt.space import Real
from skopt.callbacks import CheckpointSaver
from sklearn.model_selection import KFold
import time
import re
import wandb

wandb.login()

from tqdm import tqdm


In [None]:

# ==========================================
# Checkpoint Configuration Variables
# ==========================================
CHECKPOINT_BASE_NAME = '3d_cv_optimization'
USE_CHECKPOINT = False   # Set to True to resume from a checkpoint, False to start new
DESIRED_CHECKPOINT_ID = None  # Set to None for latest, or an integer for a specific checkpoint ID
USE_NARROWED_SPACE = True   # Set to True to use narrowed_search_space from the analysis cell (5b)

# Local checkpoint directory
DRIVE_DIR = r"c:\Users\JMN\Documents\Privat\Uddannelse\ActiveML\mini-projekt\BO_Checkpoints"
os.makedirs(DRIVE_DIR, exist_ok=True)

# ==========================================
# 1. Data Preprocessing & Loading
# ==========================================
# Local dataset path
dataset_path = r"c:\Users\JMN\Documents\Privat\Uddannelse\ActiveML\mini-projekt\dataset"
print(f"Contents of {dataset_path}: {os.listdir(dataset_path)}")

# --- Compute (or load cached) dataset-specific normalization statistics ---
import json
NORM_STATS_FILE = os.path.join(DRIVE_DIR, "dataset_norm_stats.json")

if os.path.exists(NORM_STATS_FILE):
    with open(NORM_STATS_FILE, "r") as f:
        _stats = json.load(f)
    DATASET_MEAN = _stats["mean"]
    DATASET_STD  = _stats["std"]
    print(f"Loaded cached normalization stats from {NORM_STATS_FILE}")
else:
    print("Computing dataset-specific normalization statistics (first run)...")
    _tmp_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])
    _tmp_train = ImageFolder(os.path.join(dataset_path, "Training"), transform=_tmp_transform)
    _tmp_test  = ImageFolder(os.path.join(dataset_path, "Testing"),  transform=_tmp_transform)
    _tmp_all   = ConcatDataset([_tmp_train, _tmp_test])
    _tmp_loader = DataLoader(_tmp_all, batch_size=256, shuffle=False, num_workers=0)

    _mean = torch.zeros(3)
    _std  = torch.zeros(3)
    _n_pixels = 0
    for imgs, _ in tqdm(_tmp_loader, desc="Norm stats", leave=False):
        b, c, h, w = imgs.shape
        _n_pixels += b * h * w
        _mean += imgs.sum(dim=[0, 2, 3])
        _std  += (imgs ** 2).sum(dim=[0, 2, 3])

    DATASET_MEAN = (_mean / _n_pixels).tolist()
    DATASET_STD  = ((_std / _n_pixels - torch.tensor(DATASET_MEAN) ** 2).sqrt()).tolist()
    del _tmp_transform, _tmp_train, _tmp_test, _tmp_all, _tmp_loader, _mean, _std, _n_pixels

    # Save for future runs
    with open(NORM_STATS_FILE, "w") as f:
        json.dump({"mean": DATASET_MEAN, "std": DATASET_STD}, f, indent=2)
    print(f"Saved normalization stats to {NORM_STATS_FILE}")

print(f"Dataset mean: {DATASET_MEAN}")
print(f"Dataset std:  {DATASET_STD}")

# --- Final transform with computed statistics ---
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=DATASET_MEAN, std=DATASET_STD)
])

training_dataset = ImageFolder(os.path.join(dataset_path, "Training"), transform=transform)
testing_dataset  = ImageFolder(os.path.join(dataset_path, "Testing"),  transform=transform)
dataset = ConcatDataset([training_dataset, testing_dataset])
print(f"Total dataset size: {len(dataset)} images")

# ==========================================
# 2. Model Definition — SimpleTumorCNN
# ==========================================
class SimpleTumorCNN(nn.Module):
    """
    Lightweight custom CNN (~24k parameters).
    3 conv blocks with BatchNorm, AdaptiveAvgPool, and a single FC head.
    """
    def __init__(self, num_classes=4, dropout_rate=0.1):
        super(SimpleTumorCNN, self).__init__()
        self.features = nn.Sequential(
            # Block 1: 3 -> 16 channels
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            # Block 2: 16 -> 32 channels
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            # Block 3: 32 -> 64 channels
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            # Global pooling
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Verify parameter count
_tmp_model = SimpleTumorCNN(num_classes=4, dropout_rate=0.1)
_param_count = sum(p.numel() for p in _tmp_model.parameters())
print(f"SimpleTumorCNN parameter count: {_param_count:,}")
del _tmp_model

criterion = nn.CrossEntropyLoss()

# ==========================================
# 3. Training Params & BO Configuration
# ==========================================
CALLS = 18         # New BO trials (matches grid search budget; warm-start points are free)
EPOCHS = 50        # Epochs per trial per fold
BATCH_SIZE = 32
NUM_WORKERS = 3
N_FOLDS = 3        # 3-Fold Cross-Validation
SEED = 42

# 3D Search Space
search_space = [
    Real(1e-4, 1e-1, prior='log-uniform', name='learning_rate'),
    Real(1e-5, 1e-2, prior='log-uniform', name='weight_decay'),
    Real(0.0,  0.5,  prior='uniform',     name='dropout'),
]

# Global state for trial numbering and WandB grouping
current_call = 0
checkpoint_id_for_this_run = 0  # Will be set by main block; used as WandB group

def get_checkpoint_id(base_name, find_latest=False):
    """
    Generates a new unique ID for new runs or finds the latest existing ID for resuming.
    """
    existing_ids = []
    for f_name in os.listdir(DRIVE_DIR):
        match = re.match(rf'^{re.escape(base_name)}_(\d+)\.pkl$', f_name)
        if match:
            existing_ids.append(int(match.group(1)))

    if find_latest:
        return max(existing_ids) if existing_ids else None
    else:
        if not existing_ids:
            return 0
        else:
            existing_ids.sort()
            for i, _id in enumerate(existing_ids):
                if i != _id:
                    return i
            return len(existing_ids)


# ==========================================
# 4. Objective Function (3-Fold CV)
# ==========================================
def train_model(params):
    """
    Objective function for Bayesian Optimization.
    Trains SimpleTumorCNN with 3-Fold CV and returns mean validation loss.
    """
    global current_call, checkpoint_id_for_this_run
    current_call += 1

    learning_rate = params[0]
    weight_decay  = params[1]
    dropout       = params[2]

    # Clear GPU memory from previous trial
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Initialize WandB for this trial
    run = wandb.init(
        entity="2121jmmn-danmarks-tekniske-universitet-dtu",
        project="3d_cv_simpleTumorCNN",
        group=f"{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}",
        name=f"trial_{current_call}",
        reinit=True,
        resume="never",
        config={
            "learning_rate": learning_rate,
            "weight_decay": weight_decay,
            "dropout": dropout,
            "batch_size": BATCH_SIZE,
            "epochs": EPOCHS,
            "n_folds": N_FOLDS,
            "optimizer": "AdamW",
            "trial": current_call,
        }
    )

    print(f"\n{'='*60}")
    print(f"  Trial {current_call}/{CALLS}")
    print(f"  lr={learning_rate:.6f}  wd={weight_decay:.6f}  dropout={dropout:.4f}")
    print(f"{'='*60}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- 3-Fold Cross-Validation ---
    kfold = KFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
    fold_losses = []
    fold_accuracies = []

    for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(range(len(dataset)))):
        print(f"\n  --- Fold {fold_idx + 1}/{N_FOLDS} ---")

        # Samplers for this fold
        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler   = SubsetRandomSampler(val_idx)

        workers = NUM_WORKERS
        train_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                                  sampler=train_sampler,
                                  num_workers=workers, persistent_workers=True)
        val_loader   = DataLoader(dataset, batch_size=BATCH_SIZE,
                                  sampler=val_sampler,
                                  num_workers=workers, persistent_workers=True)

        # Fresh model & optimizer per fold
        model = SimpleTumorCNN(num_classes=4, dropout_rate=dropout).to(device)
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

        # --- Training loop ---
        for epoch in range(EPOCHS):
            model.train()
            running_loss = 0.0
            data_time = 0.0
            compute_time = 0.0

            pbar = tqdm(train_loader, desc=f"  Fold {fold_idx+1} Epoch {epoch+1}/{EPOCHS}", leave=False)
            end = time.time()

            for _batch_idx, (inputs, labels) in enumerate(pbar):
                data_time += time.time() - end

                comp_start = time.time()
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                compute_time += time.time() - comp_start

                total_time = data_time + compute_time
                data_pct = 100 * data_time / total_time if total_time > 0 else 0

                elapsed = pbar.format_dict.get('elapsed', 0)
                remaining = (pbar.format_dict.get('total', 1) - pbar.format_dict.get('n', 0)) \
                            * pbar.format_dict.get('elapsed', 0) \
                            / max(pbar.format_dict.get('n', 1), 1)
                epoch_total = elapsed + remaining
                et_min, et_sec = divmod(int(epoch_total), 60)

                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'epoch_est': f'{et_min:02d}:{et_sec:02d}',
                    'data%': f'{data_pct:.0f}%'
                })
                end = time.time()

            avg_train_loss = running_loss / len(train_loader)
            wandb.log({
                "fold": fold_idx + 1,
                "epoch": epoch + 1,
                "train_loss": avg_train_loss,
                "data_loading_pct": data_pct,
            })

        # --- Validation for this fold ---
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        avg_fold_val_loss = val_loss / len(val_loader)
        fold_accuracy = 100 * correct / total
        fold_losses.append(avg_fold_val_loss)
        fold_accuracies.append(fold_accuracy)

        wandb.log({
            "fold": fold_idx + 1,
            "fold_val_loss": avg_fold_val_loss,
            "fold_val_accuracy": fold_accuracy,
        })
        print(f"  Fold {fold_idx+1} — Val Loss: {avg_fold_val_loss:.4f}, Accuracy: {fold_accuracy:.2f}%")

        # Cleanup per fold
        del model, optimizer, train_loader, val_loader
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # --- Average across folds ---
    mean_val_loss = float(np.mean(fold_losses))
    mean_accuracy = float(np.mean(fold_accuracies))

    wandb.log({
        "mean_cv_val_loss": mean_val_loss,
        "mean_cv_val_accuracy": mean_accuracy,
    })

    print(f"\n  Trial {current_call} finished — Mean CV Loss: {mean_val_loss:.4f}, Mean Accuracy: {mean_accuracy:.2f}%")
    run.finish()

    return mean_val_loss


# ==========================================
# 5. Checkpoint Logic & Bayesian Optimization
# ==========================================
if __name__ == '__main__':
    x0 = None
    y0 = None
    current_call = 0
    checkpoint_id_for_this_run = None
    checkpoint_file = None

    if USE_CHECKPOINT:
        if DESIRED_CHECKPOINT_ID is not None:
            checkpoint_id_for_this_run = DESIRED_CHECKPOINT_ID
            checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'

            if os.path.exists(checkpoint_file):
                print(f"Attempting to load specific checkpoint from {checkpoint_file}...")
                try:
                    res_loaded = load(checkpoint_file)
                    x0 = [list(xi) for xi in res_loaded.x_iters]  # Ensure list of lists
                    y0 = list(res_loaded.func_vals)                # Ensure plain list
                    current_call = len(x0)
                    best_so_far = min(y0)
                    print(f"Resuming from {current_call} previous calls from ID {checkpoint_id_for_this_run}.")
                    print(f"  Best loss so far: {best_so_far:.4f}")
                except Exception as e:
                    print(f"WARNING: Could not load checkpoint {checkpoint_file}: {e}. Starting new.")
                    checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
                    checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                    print(f"Starting new optimization with checkpoint ID {checkpoint_id_for_this_run}.")
            else:
                print(f"ERROR: Checkpoint file {checkpoint_file} not found. Starting new optimization.")
                checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
                checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                print(f"Starting new optimization with checkpoint ID {checkpoint_id_for_this_run}.")

        else:
            latest_id = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=True)
            if latest_id is not None:
                checkpoint_id_for_this_run = latest_id
                checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                print(f"Attempting to load latest checkpoint from {checkpoint_file}...")
                try:
                    res_loaded = load(checkpoint_file)
                    x0 = [list(xi) for xi in res_loaded.x_iters]
                    y0 = list(res_loaded.func_vals)
                    current_call = len(x0)
                    best_so_far = min(y0)
                    print(f"Resuming from {current_call} previous calls from latest ID {checkpoint_id_for_this_run}.")
                    print(f"  Best loss so far: {best_so_far:.4f}")
                except Exception as e:
                    print(f"WARNING: Could not load checkpoint {checkpoint_file}: {e}. Starting new.")
                    checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
                    checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                    print(f"Starting new optimization with checkpoint ID {checkpoint_id_for_this_run}.")
            else:
                print("No existing checkpoints found. Starting new optimization.")
                checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
                checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                print(f"Starting new optimization with checkpoint ID {checkpoint_id_for_this_run}.")

    else:
        print("USE_CHECKPOINT is False. Starting a brand new optimization.")
        checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
        checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
        print(f"New optimization will use checkpoint ID {checkpoint_id_for_this_run}.")

    if checkpoint_file is None:
        checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
        checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'

    checkpoint_callback = CheckpointSaver(checkpoint_file)
    remaining_calls = max(0, CALLS - current_call)

    print(f"Starting optimization with {remaining_calls} remaining calls (Total CALLS: {CALLS})...")
    start_time = time.time()

    if remaining_calls > 0:
        # Determine which search space to use
        if USE_NARROWED_SPACE and 'narrowed_search_space' in dir() and narrowed_search_space is not None:
            _active_space = narrowed_search_space
            # Warm-start with points from narrowed analysis
            # These are FREE data for the GP — they don't count against CALLS
            if x0 is None and narrowed_x0:
                x0 = narrowed_x0
                y0 = narrowed_y0
                # Do NOT update current_call or remaining_calls;
                # warm-start points are prior knowledge, not new evaluations
            print(f"Using NARROWED search space with {len(x0) if x0 else 0} warm-start points.")
            print(f"Will run {remaining_calls} NEW trials (warm-start points are free).")
        else:
            _active_space = search_space
            if USE_NARROWED_SPACE:
                print("WARNING: USE_NARROWED_SPACE=True but narrowed_search_space not found. Using original.")
            print(f"Using ORIGINAL search space.")

        # For narrowed space with warm-start: no random points needed, GP has data
        # For original space: use some initial random exploration
        if USE_NARROWED_SPACE and x0 is not None and len(x0) > 0:
            required_random = 0
        else:
            required_random = max(0, 20 - len(x0 if x0 is not None else []))

        res = gp_minimize(
            train_model,
            _active_space,                     # 3D: [lr, weight_decay, dropout]
            acq_func="EI",                     # Expected Improvement
            xi=0.01,                          # Exploit within narrowed space
            n_calls=remaining_calls,
            n_initial_points=required_random,
            noise="gaussian",
            random_state=SEED,
            callback=[checkpoint_callback],
            x0=x0,
            y0=y0,
        )
    else:
        print(f"All {CALLS} calls already completed based on loaded checkpoint.")
        if x0 is not None and y0 is not None:
            best_idx = np.argmin(y0)
            best_lr      = x0[best_idx][0]
            best_wd      = x0[best_idx][1]
            best_dropout = x0[best_idx][2]
            best_loss    = y0[best_idx]

            class MockResult:
                def __init__(self, x, fun):
                    self.x = x
                    self.fun = fun

            res = MockResult([best_lr, best_wd, best_dropout], best_loss)
            print(f"Best from checkpoint — LR: {res.x[0]:.6f}, WD: {res.x[1]:.6f}, "
                  f"Dropout: {res.x[2]:.4f}, Loss: {res.fun:.4f}")
        else:
            print("No results to display as no checkpoint was loaded and no new calls were made.")

    end_time = time.time()
    print(f"\nOptimization finished in {(end_time - start_time)/60:.2f} minutes.")
    if 'res' in locals():
        print(f"Best LR: {res.x[0]:.6f}, Best Weight Decay: {res.x[1]:.6f}, "
              f"Best Dropout: {res.x[2]:.4f}, Best Loss: {res.fun:.4f}")

In [None]:

# ==========================================
# 5b. Analyse BO Checkpoint → Narrowed Search Space
# ==========================================
# Load the BO checkpoint, extract the first N trials, find the top X
# by mean_val_loss, compute their min/max per dimension, and define
# a narrowed search space with a configurable margin.
# The results are stored in `narrowed_search_space`, `narrowed_x0`,
# `narrowed_y0` for use by the BO cell (set USE_NARROWED_SPACE = True).

# --- Configuration ---
ANALYSE_FIRST_N_TRIALS = 36   # Only look at the first N trials from the BO checkpoint
TOP_X = 10                     # Number of best trials to base the narrowed space on
MARGIN_PCT = 5                # % margin added above max and below min of each dimension

# Original search space bounds (for clamping)
ORIG_BOUNDS = {
    'learning_rate': (1e-4, 1e-1),
    'weight_decay':  (1e-5, 1e-2),
    'dropout':       (0.0,  0.5),
}

# --- Load BO checkpoint ---
_analyse_ckpt_file = None
_analyse_latest_id = get_checkpoint_id('3d_cv_optimization', find_latest=True)

if _analyse_latest_id is not None:
    _analyse_ckpt_file = os.path.join(DRIVE_DIR, f"3d_cv_optimization_{_analyse_latest_id}.pkl")
    print(f"Loading BO checkpoint: {_analyse_ckpt_file}")
else:
    print("ERROR: No BO checkpoint found. Run the BO cell first.")

if _analyse_ckpt_file and os.path.exists(_analyse_ckpt_file):
    _res = load(_analyse_ckpt_file)
    _all_x = [list(xi) for xi in _res.x_iters]
    _all_y = list(_res.func_vals)
    print(f"Total trials in checkpoint: {len(_all_x)}")

    # Restrict to first N trials
    _n = min(ANALYSE_FIRST_N_TRIALS, len(_all_x))
    _x_subset = _all_x[:_n]
    _y_subset = _all_y[:_n]
    print(f"Analysing first {_n} trials.")

    # Rank by loss and take top X
    _ranked = sorted(zip(_y_subset, _x_subset, range(1, _n + 1)), key=lambda t: t[0])
    _top = _ranked[:TOP_X]

    print(f"\n{'='*74}")
    print(f"  Top {len(_top)} trials (out of first {_n}) by mean CV val loss")
    print(f"{'='*74}")
    print(f"  {'Rank':<5} {'Trial#':<8} {'Loss':<10} {'LR':<14} {'WD':<14} {'Dropout':<10}")
    print(f"  {'-'*5} {'-'*8} {'-'*10} {'-'*14} {'-'*14} {'-'*10}")
    for i, (loss, params, trial_num) in enumerate(_top):
        print(f"  {i+1:<5} {trial_num:<8} {loss:<10.4f} {params[0]:<14.6e} {params[1]:<14.6e} {params[2]:<10.4f}")

    # Extract per-dimension min/max from top X
    _top_params = [p for _, p, _ in _top]
    _top_lrs  = [p[0] for p in _top_params]
    _top_wds  = [p[1] for p in _top_params]
    _top_dos  = [p[2] for p in _top_params]

    _raw_bounds = {
        'learning_rate': (min(_top_lrs), max(_top_lrs)),
        'weight_decay':  (min(_top_wds), max(_top_wds)),
        'dropout':       (min(_top_dos), max(_top_dos)),
    }

    print(f"\n  Raw bounds from top {TOP_X}:")
    for dim, (lo, hi) in _raw_bounds.items():
        scale = "log" if dim != "dropout" else "lin"
        print(f"    {dim:<16} [{lo:.6e}, {hi:.6e}]  ({scale})")

    # Apply margin — for log-uniform dims, margin is applied in log-space
    margin_frac = MARGIN_PCT / 100.0

    def _apply_margin(lo, hi, orig_lo, orig_hi, is_log=False):
        """Expand [lo, hi] by margin_frac, clamped to original bounds."""
        if is_log:
            log_lo, log_hi = np.log10(lo), np.log10(hi)
            log_range = log_hi - log_lo
            log_range = max(log_range, 0.1)  # min 0.1 decades
            new_log_lo = log_lo - margin_frac * log_range
            new_log_hi = log_hi + margin_frac * log_range
            new_lo = max(10 ** new_log_lo, orig_lo)
            new_hi = min(10 ** new_log_hi, orig_hi)
        else:
            lin_range = hi - lo
            lin_range = max(lin_range, 0.02)  # min range 0.02
            new_lo = max(lo - margin_frac * lin_range, orig_lo)
            new_hi = min(hi + margin_frac * lin_range, orig_hi)
        return new_lo, new_hi

    _narrowed_lr = _apply_margin(*_raw_bounds['learning_rate'], *ORIG_BOUNDS['learning_rate'], is_log=True)
    _narrowed_wd = _apply_margin(*_raw_bounds['weight_decay'],  *ORIG_BOUNDS['weight_decay'],  is_log=True)
    _narrowed_do = _apply_margin(*_raw_bounds['dropout'],       *ORIG_BOUNDS['dropout'],       is_log=False)

    # --- Compare narrowed vs original bounds ---
    def _pct_of_original(new_lo, new_hi, orig_lo, orig_hi, is_log=False):
        """Compute what % of the original range the narrowed range covers."""
        if is_log:
            orig_range = np.log10(orig_hi) - np.log10(orig_lo)
            new_range  = np.log10(new_hi)  - np.log10(new_lo)
        else:
            orig_range = orig_hi - orig_lo
            new_range  = new_hi  - new_lo
        return 100.0 * new_range / orig_range if orig_range > 0 else 0.0

    _pct_lr = _pct_of_original(*_narrowed_lr, *ORIG_BOUNDS['learning_rate'], is_log=True)
    _pct_wd = _pct_of_original(*_narrowed_wd, *ORIG_BOUNDS['weight_decay'],  is_log=True)
    _pct_do = _pct_of_original(*_narrowed_do, *ORIG_BOUNDS['dropout'],       is_log=False)

    print(f"\n  Narrowed bounds (with {MARGIN_PCT}% margin, clamped to original space):")
    print(f"  {'Dimension':<16} {'Narrowed Range':<36} {'Original Range':<36} {'% of Original':<14}")
    print(f"  {'-'*16} {'-'*36} {'-'*36} {'-'*14}")
    print(f"  {'learning_rate':<16} [{_narrowed_lr[0]:.6e}, {_narrowed_lr[1]:.6e}]  (log)   "
          f"[{ORIG_BOUNDS['learning_rate'][0]:.6e}, {ORIG_BOUNDS['learning_rate'][1]:.6e}]  (log)   "
          f"{_pct_lr:>6.1f}%")
    print(f"  {'weight_decay':<16} [{_narrowed_wd[0]:.6e}, {_narrowed_wd[1]:.6e}]  (log)   "
          f"[{ORIG_BOUNDS['weight_decay'][0]:.6e}, {ORIG_BOUNDS['weight_decay'][1]:.6e}]  (log)   "
          f"{_pct_wd:>6.1f}%")
    print(f"  {'dropout':<16} [{_narrowed_do[0]:.6e}, {_narrowed_do[1]:.6e}]  (lin)   "
          f"[{ORIG_BOUNDS['dropout'][0]:.6e}, {ORIG_BOUNDS['dropout'][1]:.6e}]  (lin)   "
          f"{_pct_do:>6.1f}%")

    _total_vol_pct = _pct_lr * _pct_wd * _pct_do / (100 * 100)
    print(f"\n  Combined volume: {_total_vol_pct:.1f}% of original search space")

    # --- Build the narrowed search space (same format as the original) ---
    narrowed_search_space = [
        Real(_narrowed_lr[0], _narrowed_lr[1], prior='log-uniform', name='learning_rate'),
        Real(_narrowed_wd[0], _narrowed_wd[1], prior='log-uniform', name='weight_decay'),
        Real(_narrowed_do[0], _narrowed_do[1], prior='uniform',     name='dropout'),
    ]

    # Collect evaluated points that fall within the narrowed bounds (for warm-starting BO)
    narrowed_x0 = []
    narrowed_y0 = []
    for xi, yi in zip(_x_subset, _y_subset):
        lr_ok = _narrowed_lr[0] <= xi[0] <= _narrowed_lr[1]
        wd_ok = _narrowed_wd[0] <= xi[1] <= _narrowed_wd[1]
        do_ok = _narrowed_do[0] <= xi[2] <= _narrowed_do[1]
        if lr_ok and wd_ok and do_ok:
            narrowed_x0.append(xi)
            narrowed_y0.append(yi)

    print(f"\n  Warm-start points inside narrowed space: {len(narrowed_x0)} / {_n}")
    if narrowed_x0:
        _best_ws_idx = int(np.argmin(narrowed_y0))
        print(f"  Best warm-start point: loss={narrowed_y0[_best_ws_idx]:.4f}")

    print(f"\n  To use in the BO cell, set USE_NARROWED_SPACE = True")
    print(f"  Variables available: narrowed_search_space, narrowed_x0, narrowed_y0")

    # Cleanup
    del _res, _all_x, _all_y, _x_subset, _y_subset, _ranked, _top, _top_params
    del _top_lrs, _top_wds, _top_dos, _raw_bounds

else:
    if _analyse_ckpt_file:
        print(f"ERROR: Checkpoint file {_analyse_ckpt_file} not found.")
    narrowed_search_space = None
    narrowed_x0 = None
    narrowed_y0 = None


In [None]:

# ==========================================
# 5c. 3D Visualisation of BO Checkpoint
# ==========================================
# Interactive 3D scatter plot: LR vs WD vs Loss, coloured by loss.
# Dropout is encoded as marker size (larger = higher dropout).
# The top-X trials from the analysis cell are highlighted in red.

import plotly.graph_objects as go

# --- Load checkpoint (reuse the same logic as 5b) ---
_viz_ckpt_id = get_checkpoint_id('3d_cv_optimization', find_latest=True)
if _viz_ckpt_id is None:
    raise RuntimeError("No BO checkpoint found. Run the BO cell first.")

_viz_ckpt_file = os.path.join(DRIVE_DIR, f"3d_cv_optimization_{_viz_ckpt_id}.pkl")
_viz_res = load(_viz_ckpt_file)
_viz_x = [list(xi) for xi in _viz_res.x_iters]
_viz_y = list(_viz_res.func_vals)

# Restrict to analysed subset if the variable exists
_viz_n = min(ANALYSE_FIRST_N_TRIALS, len(_viz_x)) if 'ANALYSE_FIRST_N_TRIALS' in dir() else len(_viz_x)
_viz_x = _viz_x[:_viz_n]
_viz_y = _viz_y[:_viz_n]

_viz_lr  = np.array([p[0] for p in _viz_x])
_viz_wd  = np.array([p[1] for p in _viz_x])
_viz_do  = np.array([p[2] for p in _viz_x])
_viz_loss = np.array(_viz_y)

# Scale dropout → marker size (min 4, max 18)
_do_min, _do_max = _viz_do.min(), _viz_do.max()
if _do_max > _do_min:
    _viz_sizes = 4 + 14 * (_viz_do - _do_min) / (_do_max - _do_min)
else:
    _viz_sizes = np.full_like(_viz_do, 10.0)

# Identify top-X indices (same TOP_X as 5b)
_top_x_count = TOP_X if 'TOP_X' in dir() else 5
_sorted_idx = np.argsort(_viz_loss)
_top_idx = set(_sorted_idx[:_top_x_count])
_is_top = np.array([i in _top_idx for i in range(len(_viz_loss))])

# Hover text
_hover = [
    f"Trial {i+1}<br>LR: {_viz_lr[i]:.6e}<br>WD: {_viz_wd[i]:.6e}<br>"
    f"Dropout: {_viz_do[i]:.4f}<br>Loss: {_viz_loss[i]:.4f}"
    for i in range(len(_viz_loss))
]

# --- Build figure ---
fig = go.Figure()

# All trials
fig.add_trace(go.Scatter3d(
    x=np.log10(_viz_lr[~_is_top]),
    y=np.log10(_viz_wd[~_is_top]),
    z=_viz_loss[~_is_top],
    mode='markers',
    marker=dict(
        size=_viz_sizes[~_is_top],
        color=_viz_loss[~_is_top],
        colorscale='Viridis',
        colorbar=dict(title='Loss', x=1.05),
        opacity=0.6,
        line=dict(width=0.5, color='white'),
    ),
    text=[h for h, t in zip(_hover, _is_top) if not t],
    hoverinfo='text',
    name='Trials',
))

# Top X trials (highlighted)
fig.add_trace(go.Scatter3d(
    x=np.log10(_viz_lr[_is_top]),
    y=np.log10(_viz_wd[_is_top]),
    z=_viz_loss[_is_top],
    mode='markers',
    marker=dict(
        size=_viz_sizes[_is_top] + 4,
        color='red',
        opacity=0.9,
        symbol='diamond',
        line=dict(width=1, color='darkred'),
    ),
    text=[h for h, t in zip(_hover, _is_top) if t],
    hoverinfo='text',
    name=f'Top {_top_x_count}',
))

# If narrowed bounds exist, draw the narrowed bounding box
if 'narrowed_search_space' in dir() and narrowed_search_space is not None:
    _nb_lr = [np.log10(_narrowed_lr[0]), np.log10(_narrowed_lr[1])]
    _nb_wd = [np.log10(_narrowed_wd[0]), np.log10(_narrowed_wd[1])]
    _nb_z_lo = float(_viz_loss.min()) - 0.01
    _nb_z_hi = float(_viz_loss.max()) + 0.01

    # 12 edges of a rectangular box
    def _box_edges(x0, x1, y0, y1, z0, z1):
        edges_x, edges_y, edges_z = [], [], []
        for (xa, ya, za), (xb, yb, zb) in [
            ((x0,y0,z0),(x1,y0,z0)), ((x0,y1,z0),(x1,y1,z0)),
            ((x0,y0,z1),(x1,y0,z1)), ((x0,y1,z1),(x1,y1,z1)),
            ((x0,y0,z0),(x0,y1,z0)), ((x1,y0,z0),(x1,y1,z0)),
            ((x0,y0,z1),(x0,y1,z1)), ((x1,y0,z1),(x1,y1,z1)),
            ((x0,y0,z0),(x0,y0,z1)), ((x1,y0,z0),(x1,y0,z1)),
            ((x0,y1,z0),(x0,y1,z1)), ((x1,y1,z0),(x1,y1,z1)),
        ]:
            edges_x += [xa, xb, None]
            edges_y += [ya, yb, None]
            edges_z += [za, zb, None]
        return edges_x, edges_y, edges_z

    bx, by, bz = _box_edges(_nb_lr[0], _nb_lr[1], _nb_wd[0], _nb_wd[1], _nb_z_lo, _nb_z_hi)
    fig.add_trace(go.Scatter3d(
        x=bx, y=by, z=bz,
        mode='lines',
        line=dict(color='orange', width=3),
        name='Narrowed bounds',
        hoverinfo='skip',
    ))

fig.update_layout(
    title=f'BO Trials (first {_viz_n}) — Marker size ∝ Dropout',
    scene=dict(
        xaxis_title='log₁₀(Learning Rate)',
        yaxis_title='log₁₀(Weight Decay)',
        zaxis_title='Mean CV Val Loss',
    ),
    width=900,
    height=700,
    legend=dict(x=0.02, y=0.98),
    margin=dict(l=0, r=0, b=0, t=40),
)

fig.show()

# Cleanup
del _viz_res, _viz_x, _viz_y, _viz_lr, _viz_wd, _viz_do, _viz_loss
del _viz_sizes, _sorted_idx, _top_idx, _is_top, _hover


In [None]:

# ==========================================
# 6. Grid Search Baseline (3×3×2 = 18 trials)
# ==========================================
# This cell runs a full-factorial grid search over the same 3D
# hyperparameter space used by BO, to serve as a comparison baseline.
# LR and WD are log-spaced; dropout is linearly spaced.
# Results are checkpointed to a JSON file after each trial for resume support.
# All trials are grouped together in WandB under "grid_search_{id}".

import itertools
import json as _json  # alias to avoid shadowing the earlier import

# --- Grid Search Configuration ---
GRID_CHECKPOINT_BASE_NAME = 'grid_search'
GRID_USE_CHECKPOINT = True  # Set to False to force a fresh grid search

# Build the 3D grid (log-spaced for LR/WD, linear for dropout)
lr_grid      = np.logspace(np.log10(1e-4), np.log10(1e-1), 3).tolist()   # 3 points
wd_grid      = np.logspace(np.log10(1e-5), np.log10(1e-2), 3).tolist()   # 3 points
dropout_grid = np.linspace(0.0, 0.5, 2).tolist()                          # 2 points

GRID_TOTAL = len(lr_grid) * len(wd_grid) * len(dropout_grid)  # 18

# Generate all combinations in a deterministic order
grid_combinations = [
    [lr, wd, do]
    for lr, wd, do in itertools.product(lr_grid, wd_grid, dropout_grid)
]

print(f"Grid Search: {len(lr_grid)} LR × {len(wd_grid)} WD × {len(dropout_grid)} Dropout = {GRID_TOTAL} trials")
print(f"  LR grid  (log): {[f'{v:.6f}' for v in lr_grid]}")
print(f"  WD grid  (log): {[f'{v:.6f}' for v in wd_grid]}")
print(f"  Dropout  (lin): {[f'{v:.4f}' for v in dropout_grid]}")


# --- JSON checkpoint helpers ---
def _get_grid_checkpoint_id(base_name, find_latest=False):
    """Find existing grid search JSON checkpoint IDs."""
    existing_ids = []
    for f_name in os.listdir(DRIVE_DIR):
        match = re.match(rf'^{re.escape(base_name)}_(\d+)\.json$', f_name)
        if match:
            existing_ids.append(int(match.group(1)))
    if find_latest:
        return max(existing_ids) if existing_ids else None
    else:
        if not existing_ids:
            return 0
        existing_ids.sort()
        for i, _id in enumerate(existing_ids):
            if i != _id:
                return i
        return len(existing_ids)


def _load_grid_checkpoint(filepath):
    """Load completed results from a JSON checkpoint. Returns list of dicts."""
    if os.path.exists(filepath):
        with open(filepath, "r") as f:
            data = _json.load(f)
        return data.get("results", [])
    return []


def _save_grid_checkpoint(filepath, results, grid):
    """Save results and grid definition to a JSON checkpoint."""
    with open(filepath, "w") as f:
        _json.dump({"grid": grid, "results": results}, f, indent=2)


def _params_match(a, b, tol=1e-10):
    """Check if two param lists are the same (within float tolerance)."""
    return all(abs(x - y) < tol for x, y in zip(a, b))


# --- Resolve checkpoint ---
grid_checkpoint_id = None
grid_checkpoint_file = None
grid_completed_results = []

if GRID_USE_CHECKPOINT:
    latest_id = _get_grid_checkpoint_id(GRID_CHECKPOINT_BASE_NAME, find_latest=True)
    if latest_id is not None:
        grid_checkpoint_id = latest_id
        grid_checkpoint_file = os.path.join(DRIVE_DIR, f"{GRID_CHECKPOINT_BASE_NAME}_{grid_checkpoint_id}.json")
        grid_completed_results = _load_grid_checkpoint(grid_checkpoint_file)
        print(f"Loaded grid checkpoint ID {grid_checkpoint_id} with {len(grid_completed_results)} completed trials.")
        if grid_completed_results:
            best_prev = min(grid_completed_results, key=lambda r: r["mean_val_loss"])
            print(f"  Best so far: loss={best_prev['mean_val_loss']:.4f} "
                  f"(lr={best_prev['params'][0]:.6f}, wd={best_prev['params'][1]:.6f}, do={best_prev['params'][2]:.4f})")
    else:
        print("No existing grid search checkpoints found. Starting new.")

if grid_checkpoint_id is None:
    grid_checkpoint_id = _get_grid_checkpoint_id(GRID_CHECKPOINT_BASE_NAME, find_latest=False)
    grid_checkpoint_file = os.path.join(DRIVE_DIR, f"{GRID_CHECKPOINT_BASE_NAME}_{grid_checkpoint_id}.json")
    print(f"New grid search will use checkpoint ID {grid_checkpoint_id}.")

# Build set of already-completed param tuples for fast lookup
_completed_set = set()
for r in grid_completed_results:
    _completed_set.add(tuple(round(v, 10) for v in r["params"]))

# --- Override globals so train_model() logs to a "grid_search_X" WandB group ---
_saved_CHECKPOINT_BASE_NAME = CHECKPOINT_BASE_NAME
_saved_checkpoint_id = checkpoint_id_for_this_run
_saved_current_call = current_call
_saved_CALLS = CALLS

CHECKPOINT_BASE_NAME = GRID_CHECKPOINT_BASE_NAME
checkpoint_id_for_this_run = grid_checkpoint_id
current_call = len(grid_completed_results)
CALLS = GRID_TOTAL

# --- Main grid search loop ---
remaining_grid = [
    combo for combo in grid_combinations
    if tuple(round(v, 10) for v in combo) not in _completed_set
]

print(f"\n{'='*60}")
print(f"  Grid Search: {len(remaining_grid)} remaining / {GRID_TOTAL} total trials")
print(f"{'='*60}")

grid_start_time = time.time()
best_grid_loss = min((r["mean_val_loss"] for r in grid_completed_results), default=float("inf"))

for combo_idx, params in enumerate(remaining_grid):
    lr, wd, do = params
    trial_num = current_call + 1
    print(f"\n>>> Grid trial {trial_num}/{GRID_TOTAL}  "
          f"[lr={lr:.6f}, wd={wd:.6f}, dropout={do:.4f}]")

    mean_val_loss = train_model(params)

    # Record result
    result_entry = {
        "params": [lr, wd, do],
        "mean_val_loss": mean_val_loss,
        "trial": trial_num,
    }
    grid_completed_results.append(result_entry)

    # Track best
    if mean_val_loss < best_grid_loss:
        best_grid_loss = mean_val_loss
        print(f"  *** New best grid loss: {best_grid_loss:.4f}")

    # Save checkpoint after every trial
    _save_grid_checkpoint(grid_checkpoint_file, grid_completed_results, grid_combinations)
    print(f"  Checkpoint saved ({len(grid_completed_results)}/{GRID_TOTAL} done).")

grid_end_time = time.time()

# --- Summary ---
print(f"\n{'='*60}")
print(f"  Grid Search Complete!")
print(f"  Total time: {(grid_end_time - grid_start_time)/60:.2f} minutes")
print(f"  Trials run this session: {len(remaining_grid)}")
print(f"  Total completed: {len(grid_completed_results)}/{GRID_TOTAL}")
print(f"{'='*60}")

if grid_completed_results:
    best = min(grid_completed_results, key=lambda r: r["mean_val_loss"])
    print(f"\n  Best Grid Search Result:")
    print(f"    Learning Rate: {best['params'][0]:.6f}")
    print(f"    Weight Decay:  {best['params'][1]:.6f}")
    print(f"    Dropout:       {best['params'][2]:.4f}")
    print(f"    Mean CV Loss:  {best['mean_val_loss']:.4f}")

    # Show all results sorted by loss
    print(f"\n  All results (sorted by loss):")
    for i, r in enumerate(sorted(grid_completed_results, key=lambda r: r["mean_val_loss"])):
        print(f"    {i+1:2d}. loss={r['mean_val_loss']:.4f}  "
              f"lr={r['params'][0]:.6f}  wd={r['params'][1]:.6f}  do={r['params'][2]:.4f}")

# --- Restore globals ---
CHECKPOINT_BASE_NAME = _saved_CHECKPOINT_BASE_NAME
checkpoint_id_for_this_run = _saved_checkpoint_id
current_call = _saved_current_call
CALLS = _saved_CALLS
