In [7]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import ConcatDataset, DataLoader, SubsetRandomSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
from skopt import gp_minimize, load
from skopt.space import Real
from skopt.callbacks import CheckpointSaver
from sklearn.model_selection import KFold
import time
import re
import wandb

wandb.login()

from tqdm import tqdm




In [None]:

# ==========================================
# Checkpoint Configuration Variables
# ==========================================
CHECKPOINT_BASE_NAME = '3d_cv_optimization'
USE_CHECKPOINT = True   # Set to True to resume from a checkpoint, False to start new
DESIRED_CHECKPOINT_ID = None  # Set to None for latest, or an integer for a specific checkpoint ID

# Local checkpoint directory
DRIVE_DIR = r"c:\Users\JMN\Documents\Privat\Uddannelse\ActiveML\mini-projekt\BO_Checkpoints"
os.makedirs(DRIVE_DIR, exist_ok=True)

# ==========================================
# 1. Data Preprocessing & Loading
# ==========================================
# Local dataset path
dataset_path = r"c:\Users\JMN\Documents\Privat\Uddannelse\ActiveML\mini-projekt\dataset"
print(f"Contents of {dataset_path}: {os.listdir(dataset_path)}")

# --- Compute (or load cached) dataset-specific normalization statistics ---
import json
NORM_STATS_FILE = os.path.join(DRIVE_DIR, "dataset_norm_stats.json")

if os.path.exists(NORM_STATS_FILE):
    with open(NORM_STATS_FILE, "r") as f:
        _stats = json.load(f)
    DATASET_MEAN = _stats["mean"]
    DATASET_STD  = _stats["std"]
    print(f"Loaded cached normalization stats from {NORM_STATS_FILE}")
else:
    print("Computing dataset-specific normalization statistics (first run)...")
    _tmp_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])
    _tmp_train = ImageFolder(os.path.join(dataset_path, "Training"), transform=_tmp_transform)
    _tmp_test  = ImageFolder(os.path.join(dataset_path, "Testing"),  transform=_tmp_transform)
    _tmp_all   = ConcatDataset([_tmp_train, _tmp_test])
    _tmp_loader = DataLoader(_tmp_all, batch_size=256, shuffle=False, num_workers=0)

    _mean = torch.zeros(3)
    _std  = torch.zeros(3)
    _n_pixels = 0
    for imgs, _ in tqdm(_tmp_loader, desc="Norm stats", leave=False):
        b, c, h, w = imgs.shape
        _n_pixels += b * h * w
        _mean += imgs.sum(dim=[0, 2, 3])
        _std  += (imgs ** 2).sum(dim=[0, 2, 3])

    DATASET_MEAN = (_mean / _n_pixels).tolist()
    DATASET_STD  = ((_std / _n_pixels - torch.tensor(DATASET_MEAN) ** 2).sqrt()).tolist()
    del _tmp_transform, _tmp_train, _tmp_test, _tmp_all, _tmp_loader, _mean, _std, _n_pixels

    # Save for future runs
    with open(NORM_STATS_FILE, "w") as f:
        json.dump({"mean": DATASET_MEAN, "std": DATASET_STD}, f, indent=2)
    print(f"Saved normalization stats to {NORM_STATS_FILE}")

print(f"Dataset mean: {DATASET_MEAN}")
print(f"Dataset std:  {DATASET_STD}")

# --- Final transform with computed statistics ---
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=DATASET_MEAN, std=DATASET_STD)
])

training_dataset = ImageFolder(os.path.join(dataset_path, "Training"), transform=transform)
testing_dataset  = ImageFolder(os.path.join(dataset_path, "Testing"),  transform=transform)
dataset = ConcatDataset([training_dataset, testing_dataset])
print(f"Total dataset size: {len(dataset)} images")

# ==========================================
# 2. Model Definition — SimpleTumorCNN
# ==========================================
class SimpleTumorCNN(nn.Module):
    """
    Lightweight custom CNN (~24k parameters).
    3 conv blocks with BatchNorm, AdaptiveAvgPool, and a single FC head.
    """
    def __init__(self, num_classes=4, dropout_rate=0.1):
        super(SimpleTumorCNN, self).__init__()
        self.features = nn.Sequential(
            # Block 1: 3 -> 16 channels
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            # Block 2: 16 -> 32 channels
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            # Block 3: 32 -> 64 channels
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            # Global pooling
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Verify parameter count
_tmp_model = SimpleTumorCNN(num_classes=4, dropout_rate=0.1)
_param_count = sum(p.numel() for p in _tmp_model.parameters())
print(f"SimpleTumorCNN parameter count: {_param_count:,}")
del _tmp_model

criterion = nn.CrossEntropyLoss()

# ==========================================
# 3. Training Params & BO Configuration
# ==========================================
CALLS = 100        # Total BO trials
EPOCHS = 50        # Epochs per trial per fold
BATCH_SIZE = 32
NUM_WORKERS = 3
N_FOLDS = 3        # 3-Fold Cross-Validation
SEED = 42

# 3D Search Space
search_space = [
    Real(1e-4, 1e-1, prior='log-uniform', name='learning_rate'),
    Real(1e-5, 1e-2, prior='log-uniform', name='weight_decay'),
    Real(0.0,  0.5,  prior='uniform',     name='dropout'),
]

# Global state for trial numbering and WandB grouping
current_call = 0
checkpoint_id_for_this_run = 0  # Will be set by main block; used as WandB group

def get_checkpoint_id(base_name, find_latest=False):
    """
    Generates a new unique ID for new runs or finds the latest existing ID for resuming.
    """
    existing_ids = []
    for f_name in os.listdir(DRIVE_DIR):
        match = re.match(rf'^{re.escape(base_name)}_(\d+)\.pkl$', f_name)
        if match:
            existing_ids.append(int(match.group(1)))

    if find_latest:
        return max(existing_ids) if existing_ids else None
    else:
        if not existing_ids:
            return 0
        else:
            existing_ids.sort()
            for i, _id in enumerate(existing_ids):
                if i != _id:
                    return i
            return len(existing_ids)


# ==========================================
# 4. Objective Function (3-Fold CV)
# ==========================================
def train_model(params):
    """
    Objective function for Bayesian Optimization.
    Trains SimpleTumorCNN with 3-Fold CV and returns mean validation loss.
    """
    global current_call, checkpoint_id_for_this_run
    current_call += 1

    learning_rate = params[0]
    weight_decay  = params[1]
    dropout       = params[2]

    # Clear GPU memory from previous trial
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Initialize WandB for this trial
    run = wandb.init(
        entity="2121jmmn-danmarks-tekniske-universitet-dtu",
        project="3d_cv_simpleTumorCNN",
        group=f"{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}",
        name=f"trial_{current_call}",
        reinit=True,
        resume="never",
        config={
            "learning_rate": learning_rate,
            "weight_decay": weight_decay,
            "dropout": dropout,
            "batch_size": BATCH_SIZE,
            "epochs": EPOCHS,
            "n_folds": N_FOLDS,
            "optimizer": "AdamW",
            "trial": current_call,
        }
    )

    print(f"\n{'='*60}")
    print(f"  Trial {current_call}/{CALLS}")
    print(f"  lr={learning_rate:.6f}  wd={weight_decay:.6f}  dropout={dropout:.4f}")
    print(f"{'='*60}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- 3-Fold Cross-Validation ---
    kfold = KFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
    fold_losses = []
    fold_accuracies = []

    for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(range(len(dataset)))):
        print(f"\n  --- Fold {fold_idx + 1}/{N_FOLDS} ---")

        # Samplers for this fold
        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler   = SubsetRandomSampler(val_idx)

        workers = NUM_WORKERS
        train_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                                  sampler=train_sampler,
                                  num_workers=workers, persistent_workers=True)
        val_loader   = DataLoader(dataset, batch_size=BATCH_SIZE,
                                  sampler=val_sampler,
                                  num_workers=workers, persistent_workers=True)

        # Fresh model & optimizer per fold
        model = SimpleTumorCNN(num_classes=4, dropout_rate=dropout).to(device)
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

        # --- Training loop ---
        for epoch in range(EPOCHS):
            model.train()
            running_loss = 0.0
            data_time = 0.0
            compute_time = 0.0

            pbar = tqdm(train_loader, desc=f"  Fold {fold_idx+1} Epoch {epoch+1}/{EPOCHS}", leave=False)
            end = time.time()

            for _batch_idx, (inputs, labels) in enumerate(pbar):
                data_time += time.time() - end

                comp_start = time.time()
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                compute_time += time.time() - comp_start

                total_time = data_time + compute_time
                data_pct = 100 * data_time / total_time if total_time > 0 else 0

                elapsed = pbar.format_dict.get('elapsed', 0)
                remaining = (pbar.format_dict.get('total', 1) - pbar.format_dict.get('n', 0)) \
                            * pbar.format_dict.get('elapsed', 0) \
                            / max(pbar.format_dict.get('n', 1), 1)
                epoch_total = elapsed + remaining
                et_min, et_sec = divmod(int(epoch_total), 60)

                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'epoch_est': f'{et_min:02d}:{et_sec:02d}',
                    'data%': f'{data_pct:.0f}%'
                })
                end = time.time()

            avg_train_loss = running_loss / len(train_loader)
            wandb.log({
                "fold": fold_idx + 1,
                "epoch": epoch + 1,
                "train_loss": avg_train_loss,
                "data_loading_pct": data_pct,
            })

        # --- Validation for this fold ---
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        avg_fold_val_loss = val_loss / len(val_loader)
        fold_accuracy = 100 * correct / total
        fold_losses.append(avg_fold_val_loss)
        fold_accuracies.append(fold_accuracy)

        wandb.log({
            "fold": fold_idx + 1,
            "fold_val_loss": avg_fold_val_loss,
            "fold_val_accuracy": fold_accuracy,
        })
        print(f"  Fold {fold_idx+1} — Val Loss: {avg_fold_val_loss:.4f}, Accuracy: {fold_accuracy:.2f}%")

        # Cleanup per fold
        del model, optimizer, train_loader, val_loader
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # --- Average across folds ---
    mean_val_loss = float(np.mean(fold_losses))
    mean_accuracy = float(np.mean(fold_accuracies))

    wandb.log({
        "mean_cv_val_loss": mean_val_loss,
        "mean_cv_val_accuracy": mean_accuracy,
    })

    print(f"\n  Trial {current_call} finished — Mean CV Loss: {mean_val_loss:.4f}, Mean Accuracy: {mean_accuracy:.2f}%")
    run.finish()

    return mean_val_loss


# ==========================================
# 5. Checkpoint Logic & Bayesian Optimization
# ==========================================
if __name__ == '__main__':
    x0 = None
    y0 = None
    current_call = 0
    checkpoint_id_for_this_run = None
    checkpoint_file = None

    if USE_CHECKPOINT:
        if DESIRED_CHECKPOINT_ID is not None:
            checkpoint_id_for_this_run = DESIRED_CHECKPOINT_ID
            checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'

            if os.path.exists(checkpoint_file):
                print(f"Attempting to load specific checkpoint from {checkpoint_file}...")
                try:
                    res_loaded = load(checkpoint_file)
                    x0 = [list(xi) for xi in res_loaded.x_iters]  # Ensure list of lists
                    y0 = list(res_loaded.func_vals)                # Ensure plain list
                    current_call = len(x0)
                    best_so_far = min(y0)
                    print(f"Resuming from {current_call} previous calls from ID {checkpoint_id_for_this_run}.")
                    print(f"  Best loss so far: {best_so_far:.4f}")
                except Exception as e:
                    print(f"WARNING: Could not load checkpoint {checkpoint_file}: {e}. Starting new.")
                    checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
                    checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                    print(f"Starting new optimization with checkpoint ID {checkpoint_id_for_this_run}.")
            else:
                print(f"ERROR: Checkpoint file {checkpoint_file} not found. Starting new optimization.")
                checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
                checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                print(f"Starting new optimization with checkpoint ID {checkpoint_id_for_this_run}.")

        else:
            latest_id = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=True)
            if latest_id is not None:
                checkpoint_id_for_this_run = latest_id
                checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                print(f"Attempting to load latest checkpoint from {checkpoint_file}...")
                try:
                    res_loaded = load(checkpoint_file)
                    x0 = [list(xi) for xi in res_loaded.x_iters]
                    y0 = list(res_loaded.func_vals)
                    current_call = len(x0)
                    best_so_far = min(y0)
                    print(f"Resuming from {current_call} previous calls from latest ID {checkpoint_id_for_this_run}.")
                    print(f"  Best loss so far: {best_so_far:.4f}")
                except Exception as e:
                    print(f"WARNING: Could not load checkpoint {checkpoint_file}: {e}. Starting new.")
                    checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
                    checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                    print(f"Starting new optimization with checkpoint ID {checkpoint_id_for_this_run}.")
            else:
                print("No existing checkpoints found. Starting new optimization.")
                checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
                checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                print(f"Starting new optimization with checkpoint ID {checkpoint_id_for_this_run}.")

    else:
        print("USE_CHECKPOINT is False. Starting a brand new optimization.")
        checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
        checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
        print(f"New optimization will use checkpoint ID {checkpoint_id_for_this_run}.")

    if checkpoint_file is None:
        checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
        checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'

    checkpoint_callback = CheckpointSaver(checkpoint_file)
    remaining_calls = max(0, CALLS - current_call)

    print(f"Starting optimization with {remaining_calls} remaining calls (Total CALLS: {CALLS})...")
    start_time = time.time()

    if remaining_calls > 0:
        # Resume-aware initial random points: 20 total, minus already-evaluated points
        required_random = max(0, 20 - len(x0 if x0 is not None else []))

        res = gp_minimize(
            train_model,
            search_space,                      # 3D: [lr, weight_decay, dropout]
            acq_func="EI",                     # Expected Improvement
            xi=0.1,                           # Exploration bias
            n_calls=remaining_calls,
            n_initial_points=required_random,
            noise="gaussian",
            random_state=SEED,
            callback=[checkpoint_callback],
            x0=x0,
            y0=y0,
        )
    else:
        print(f"All {CALLS} calls already completed based on loaded checkpoint.")
        if x0 is not None and y0 is not None:
            best_idx = np.argmin(y0)
            best_lr      = x0[best_idx][0]
            best_wd      = x0[best_idx][1]
            best_dropout = x0[best_idx][2]
            best_loss    = y0[best_idx]

            class MockResult:
                def __init__(self, x, fun):
                    self.x = x
                    self.fun = fun

            res = MockResult([best_lr, best_wd, best_dropout], best_loss)
            print(f"Best from checkpoint — LR: {res.x[0]:.6f}, WD: {res.x[1]:.6f}, "
                  f"Dropout: {res.x[2]:.4f}, Loss: {res.fun:.4f}")
        else:
            print("No results to display as no checkpoint was loaded and no new calls were made.")

    end_time = time.time()
    print(f"\nOptimization finished in {(end_time - start_time)/60:.2f} minutes.")
    if 'res' in locals():
        print(f"Best LR: {res.x[0]:.6f}, Best Weight Decay: {res.x[1]:.6f}, "
              f"Best Dropout: {res.x[2]:.4f}, Best Loss: {res.fun:.4f}")



Contents of c:\Users\JMN\Documents\Privat\Uddannelse\ActiveML\mini-projekt\dataset: ['Testing', 'Training']
Loaded cached normalization stats from c:\Users\JMN\Documents\Privat\Uddannelse\ActiveML\mini-projekt\BO_Checkpoints\dataset_norm_stats.json
Dataset mean: [0.18654859066009521, 0.18655261397361755, 0.18659797310829163]
Dataset std:  [0.19559581577777863, 0.19559480249881744, 0.1956312358379364]
Total dataset size: 7200 images
SimpleTumorCNN parameter count: 24,068
Attempting to load latest checkpoint from c:\Users\JMN\Documents\Privat\Uddannelse\ActiveML\mini-projekt\BO_Checkpoints/3d_cv_optimization_0.pkl...
Resuming from 26 previous calls from latest ID 0.
  Best loss so far: 0.3018
Starting optimization with 74 remaining calls (Total CALLS: 100)...


0,1
data_loading_pct,▁
epoch,▁
fold,▁
train_loss,▁

0,1
data_loading_pct,43.35639
epoch,1.0
fold,1.0
train_loss,0.90513



  Trial 27/100
  lr=0.002074  wd=0.000205  dropout=0.0901
Using device: cpu

  --- Fold 1/3 ---


  Fold 1 Epoch 22/50:  63%|██████▎   | 95/150 [00:08<00:04, 12.28it/s, loss=0.4309, epoch_est=00:12, data%=5%] 