In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import segmentation_models_pytorch as smp
import optuna

# --- Import from your project's src files ---
# This assumes your notebook is in the 'notebooks' folder
import sys
sys.path.append('..')
from src import config
from src.dataset import HubmapWsiDataset # Using the openslide dataset
from src.engine import train_epoch, validate_epoch

# --- Optuna Objective Function ---
def objective(trial):
    # Hyperparameters to Tune
    learning_rate = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    dice_weight = trial.suggest_float("dice_weight", 0.5, 0.9)
    
    # Data & Model Setup
    df = pd.read_csv(config.METADATA_FILE)
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
    train_dataset = HubmapWsiDataset(train_df, augmentations=None) # No augs for quick tuning
    val_dataset = HubmapWsiDataset(val_df, augmentations=None)
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = smp.Unet(config.MODEL_ENCODER, encoder_weights="imagenet", in_channels=3, classes=1).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    dice_loss = smp.losses.DiceLoss(mode='binary')
    focal_loss = smp.losses.FocalLoss(mode='binary')
    def combined_loss(y_pred, y_true):
        return dice_weight * dice_loss(y_pred, y_true) + (1 - dice_weight) * focal_loss(y_pred, y_true)

    # Simplified Training Loop for one quick epoch
    train_epoch(train_loader, model, optimizer, combined_loss, device)
    avg_val_loss, avg_val_dice = validate_epoch(val_loader, model, combined_loss, device)
    
    return avg_val_dice

# --- Run the Study ---
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=20)

print("\n--- Optimization Complete ---")
print("Best trial:")
trial = study.best_trial
print(f"  Value (Dice Score): {trial.value:.4f}")
print("  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")