In [4]:
import os
import numpy as np
import torch.nn as nn
import torch
import torch.optim as optim
from torch.utils.data import Dataset, ConcatDataset, DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from skopt import gp_minimize, load # Import load for resuming
from skopt.callbacks import CheckpointSaver # Improved saving
import time # To track time
import re # For checkpoint ID parsing
import wandb
# Local WandB authentication (will use cached key or prompt once)
wandb.login()

from tqdm import tqdm





In [6]:


# ==========================================
# Checkpoint Configuration Variables
# ==========================================
CHECKPOINT_BASE_NAME = 'dropout_optimization'
USE_CHECKPOINT = True  # Set to True to resume from a checkpoint, False to start new
DESIRED_CHECKPOINT_ID = None  # Set to None for latest, or an integer for a specific checkpoint ID

# Local checkpoint directory
DRIVE_DIR = r"c:\Users\JMN\Documents\Privat\Uddannelse\ActiveML\mini-projekt\BO_Checkpoints"
os.makedirs(DRIVE_DIR, exist_ok=True)

# ==========================================
# 1. Data Preprocessing & Loading
# ==========================================
# Define transforms required for pre-trained EfficientNet-B0
# We normalize using ImageNet statistics because we use pre-trained weights
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.425])
])

# Local dataset path
dataset_path = r"c:\Users\JMN\Documents\Privat\Uddannelse\ActiveML\mini-projekt\dataset"

# Verify directory contents from Python's perspective
print(f"Contents of {dataset_path}: {os.listdir(dataset_path)}")

training_dataset = ImageFolder(os.path.join(dataset_path, "Training"), transform=transform)
testing_dataset = ImageFolder(os.path.join(dataset_path, "Testing"), transform=transform)
dataset = ConcatDataset([training_dataset, testing_dataset])

# ==========================================
# 2. Model Definition
# ==========================================
class BrainTumorModel(nn.Module):
    def __init__(self, num_classes=4, dropout=.1):
        super(BrainTumorModel, self).__init__()
        # Load pre-trained EfficientNet-B0
        self.base_model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
        # Remove the original classifier
        self.base_model.classifier = nn.Identity()

        # Add a custom classifier with variable dropout for optimization
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(dropout),
            nn.Linear(1280, 128),  # EfficientNet-B0 outputs 1280 features
            nn.ReLU(),
            nn.Dropout(0.25),      # Fixed dropout for stability
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.base_model(x)
        x = self.classifier(x)
        return x

criterion = nn.CrossEntropyLoss()

# ==========================================
# 3. Training Params & Hardware Tuning
# ==========================================
CALLS = 50     # Total BO trials
EPOCHS = 10     # Epochs per trial
BATCH_SIZE = 32
NUM_WORKERS = 2
SEED = 123

# Global counter for customization
current_call = 0 # Initialize globally, will be updated if resuming

def get_checkpoint_id(base_name, find_latest=False):
    """
    Generates a new unique ID for new runs or finds the latest existing ID for resuming.
    """
    existing_ids = []
    for f_name in os.listdir(DRIVE_DIR):
        match = re.match(rf'^{re.escape(base_name)}_(\d+)\.pkl$', f_name)
        if match:
            existing_ids.append(int(match.group(1)))

    if find_latest:
        return max(existing_ids) if existing_ids else None
    else:
        if not existing_ids:
            return 0
        else:
            # Find the smallest non-negative integer not present in existing_ids
            existing_ids.sort()
            for i, _id in enumerate(existing_ids):
                if i != _id:
                    return i
            return len(existing_ids)

def train_model(params):
    """
    Objective function for Bayesian Optimization.
    Trains the model with specific hyperparameters and returns validation loss.
    """
    global current_call
    current_call += 1

    dropout = params[0]

    # Clear GPU memory from previous trial to prevent OOM crashes
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Initialize WandB for this specific trial
    run = wandb.init(
        entity="2121jmmn-danmarks-tekniske-universitet-dtu",
        project="brain-tumor-bo-optimization",
        name=f"trial_{current_call}",
        config={
            "dropout": dropout,
            "batch_size": BATCH_SIZE,
            "epochs": EPOCHS,
            "optimizer": "Adamax",
            "k_fold": current_call
        }
    )

    print(f"\n-------------------- Round {current_call}/{CALLS} ----------------------------")
    print(f"Testing with dropout: {dropout:.4f}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Optional: Print device once to be sure
    print(f"Using device: {device}")

    model = BrainTumorModel(num_classes=4, dropout=dropout).to(device)
    optimizer = optim.Adamax(model.parameters(), lr=0.001)

    # Split dataset for this trial
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    # Fixed seed ensures identical train/val split every trial
    # This isolates dropout as the only variable, making BO's effect clearly visible
    generator = torch.Generator().manual_seed(SEED)
    train_subset, val_subset = random_split(dataset, [train_size, val_size], generator=generator)

    # OPTIMIZATION FOR RYZEN 7 PRO 7840U (CPU Mode)
    # Your CPU has 16 threads. Leaving 2-4 for Windows/Chrome is safe.
    # Setting workers to 8-12 allows data preparation to happen in parallel.
    workers = NUM_WORKERS

    # persistent_workers=True keeps the RAM allocated between epochs, speeding up training
    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True,
                            num_workers=workers, persistent_workers=True)
    val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE,
                            num_workers=workers, persistent_workers=True)

    # ------------------------------------------------------------------------------------------------------

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        data_time = 0.0
        compute_time = 0.0

        # Progress bar for the batches in the current epoch
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
        end = time.time()

        for batch_idx, (inputs, labels) in enumerate(pbar):
            data_time += time.time() - end  # Time spent waiting for data

            comp_start = time.time()
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            compute_time += time.time() - comp_start

            # data% diagnostic: <10% = compute-bound (good), >30% = data-starved
            total_time = data_time + compute_time
            data_pct = 100 * data_time / total_time if total_time > 0 else 0

            # Use tqdm's own elapsed + remaining for a more accurate epoch total estimate
            elapsed = pbar.format_dict.get('elapsed', 0)
            remaining = (pbar.format_dict.get('total', 1) - pbar.format_dict.get('n', 0)) * pbar.format_dict.get('elapsed', 0) / max(pbar.format_dict.get('n', 1), 1)
            epoch_total = elapsed + remaining
            et_min, et_sec = divmod(int(epoch_total), 60)

            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'epoch_est': f'{et_min:02d}:{et_sec:02d}',
                'data%': f'{data_pct:.0f}%'
            })

            end = time.time()

        # Calculate average epoch loss
        avg_train_loss = running_loss / len(train_loader)

        # Log epoch metrics to WandB
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": avg_train_loss,
            "data_loading_pct": data_pct,
        })

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs) # Get raw logits
            loss = criterion(outputs, labels) # Calculate loss
            val_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_val_loss = val_loss / len(val_loader)
    accuracy = 100 * correct / total

    # Log final results for this trial
    wandb.log({
        "val_loss": avg_val_loss,
        "val_accuracy": accuracy
    })

    print(f"Trial {current_call} finished. Validation Loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.2f}%")

    # Finish the WandB run so the next trial starts fresh
    run.finish()

    # Free GPU memory before next trial
    del model, optimizer, train_loader, val_loader
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return avg_val_loss # Return metric to minimize

SEED = 123

if __name__ == '__main__':
    x0 = None
    y0 = None
    current_call = 0 # Reset for main execution block
    checkpoint_id_for_this_run = None
    checkpoint_file = None

    if USE_CHECKPOINT:
        if DESIRED_CHECKPOINT_ID is not None: # Specific checkpoint requested
            checkpoint_id_for_this_run = DESIRED_CHECKPOINT_ID
            checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'

            if os.path.exists(checkpoint_file):
                print(f"Attempting to load specific checkpoint from {checkpoint_file}...")
                try:
                    res_loaded = load(checkpoint_file)
                    x0 = res_loaded.x_iters
                    y0 = res_loaded.func_vals
                    current_call = len(x0)
                    print(f"Resuming from {current_call} previous calls from ID {checkpoint_id_for_this_run}.")
                except Exception as e:
                    print(f"WARNING: Could not load specific checkpoint {checkpoint_file} (corrupted or invalid format): {e}. Starting new optimization.")
                    # Fallback to new run logic
                    checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
                    checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                    print(f"Starting new optimization with checkpoint ID {checkpoint_id_for_this_run}.")
            else:
                print(f"ERROR: Specific checkpoint file {checkpoint_file} not found. Starting new optimization.")
                # Fallback to new run logic
                checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
                checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                print(f"Starting new optimization with checkpoint ID {checkpoint_id_for_this_run}.")

        else: # Latest checkpoint requested
            latest_id = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=True)
            if latest_id is not None:
                checkpoint_id_for_this_run = latest_id
                checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                print(f"Attempting to load latest checkpoint from {checkpoint_file}...")
                try:
                    res_loaded = load(checkpoint_file)
                    x0 = res_loaded.x_iters
                    y0 = res_loaded.func_vals
                    current_call = len(x0)
                    print(f"Resuming from {current_call} previous calls from latest ID {checkpoint_id_for_this_run}.")
                except Exception as e:
                    print(f"WARNING: Could not load latest checkpoint {checkpoint_file} (corrupted or invalid format): {e}. Starting new optimization.")
                    # Fallback to new run logic
                    checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
                    checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                    print(f"Starting new optimization with checkpoint ID {checkpoint_id_for_this_run}.")
            else:
                print("No existing checkpoints found. Starting new optimization.")
                # Fallback to new run logic
                checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
                checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
                print(f"Starting new optimization with checkpoint ID {checkpoint_id_for_this_run}.")

    else: # New run requested
        print("USE_CHECKPOINT is False. Starting a brand new optimization.")
        checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
        checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'
        print(f"New optimization will use checkpoint ID {checkpoint_id_for_this_run}.")

    # Ensure checkpoint_file is set even if no checkpoints exist and USE_CHECKPOINT is True initially
    if checkpoint_file is None:
         checkpoint_id_for_this_run = get_checkpoint_id(CHECKPOINT_BASE_NAME, find_latest=False)
         checkpoint_file = f'{DRIVE_DIR}/{CHECKPOINT_BASE_NAME}_{checkpoint_id_for_this_run}.pkl'

    checkpoint_callback = CheckpointSaver(checkpoint_file)

    remaining_calls = max(0, CALLS - current_call)

    print(f"Starting optimization with {remaining_calls} remaining calls (Total CALLS: {CALLS})...")
    start_time = time.time()

    if remaining_calls > 0:
        # Beregn hvor mange random starts der reelt mangler.
        # Hvis vi allerede har evalueret nok punkter (len(x0)), sættes den til 0.
        required_random = max(0, (CALLS // 6) - len(x0 if x0 is not None else []))

        # Run Bayesian Optimization
        res = gp_minimize(train_model,
                          [(0.0, 0.5)],       # Search space for dropout
                          acq_func="EI",      # Expected Improvement
                          n_calls=remaining_calls,
                          n_initial_points=required_random, # Ret fra n_random_starts
                          noise="gaussian",   # Perfekt valg til Neural Networks!
                          random_state=SEED,
                          callback=[checkpoint_callback],
                          x0=x0,
                          y0=y0)              # Pass previous results to resume

    else:
        print(f"All {CALLS} calls already completed based on loaded checkpoint.")
        if x0 is not None and y0 is not None:
            # Simulate 'res' object for printing best results if optimization is skipped
            best_idx = np.argmin(y0)
            best_dropout = x0[best_idx][0] # Assuming dropout is the only parameter
            best_loss = y0[best_idx]
            class MockResult:
                def __init__(self, x, fun):
                    self.x = x
                    self.fun = fun
            res = MockResult([best_dropout], best_loss)
            print(f"Best Dropout from loaded checkpoint: {res.x[0]:.4f}, Best Loss: {res.fun:.4f}")
        else:
            print("No results to display as no checkpoint was loaded and no new calls were made.")

    end_time = time.time()
    print(f"Optimization finished in {(end_time - start_time)/60:.2f} minutes.")
    if 'res' in locals(): # Check if 'res' was defined either by gp_minimize or MockResult
        print(f"Best Dropout: {res.x[0]:.4f}, Best Loss: {res.fun:.4f}")

Contents of c:\Users\JMN\Documents\Privat\Uddannelse\ActiveML\mini-projekt\dataset: ['Testing', 'Training']
Attempting to load latest checkpoint from c:\Users\JMN\Documents\Privat\Uddannelse\ActiveML\mini-projekt\BO_Checkpoints/dropout_optimization_0.pkl...
Starting new optimization with checkpoint ID 1.
Starting optimization with 50 remaining calls (Total CALLS: 50)...


0,1
data_loading_pct,▁
epoch,▁
train_loss,▁

0,1
data_loading_pct,2.82973
epoch,1.0
train_loss,0.27664



-------------------- Round 1/50 ----------------------------
Testing with dropout: 0.3565
Using device: cpu


                                                                                                     

KeyboardInterrupt: 