In [None]:
#| default_exp medSAM finetune

# Finetune medSAM
> Finetune medSAM with own data

In [42]:
#| hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [43]:
#| export
# cv_tools
from cv_tools.core import *
from cv_tools.imports import *
from cv_tools.data_processing.smb_tools import *


In [44]:
#| export

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, AdamW


In [45]:
#| export
from transformers import SamModel, SamProcessor
import monai
from monai.losses import DiceCELoss, DiceLoss, FocalLoss

In [46]:
#| export
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union, Any
from dataclasses import dataclass, field
import logging
from tqdm.auto import tqdm
import json
import warnings
from statistics import mean
from fastcore.basics import *

In [47]:

import albumentations as A
from albumentations.pytorch import ToTensorV2


In [48]:
#| export
@dataclass
class SAMFineTuneConfig:
    """Configuration class for SAM fine-tuning with production-ready defaults."""
    
    # Model configuration
    model_name: str = "facebook/sam-vit-base"
    freeze_encoder: bool = True
    freeze_prompt_encoder: bool = True
    
    # Training configuration
    batch_size: int = 4
    learning_rate: float = 1e-5
    weight_decay: float = 0.01
    num_epochs: int = 100
    warmup_steps: int = 100
    gradient_clip_norm: float = 1.0
    
    # Loss configuration
    loss_type: str = "dice_ce"  # "dice_ce", "dice", "focal", "bce"
    dice_weight: float = 1.0
    ce_weight: float = 1.0
    focal_alpha: float = 0.25
    focal_gamma: float = 2.0
    
    # Data configuration
    image_size: int = 1024
    mask_threshold: float = 0.5
    bbox_perturbation_range: int = 20
    
    # Augmentation configuration
    use_augmentation: bool = True
    augmentation_prob: float = 0.5
    
    # Training configuration
    device: str = "auto"  # "auto", "cuda", "cpu"
    mixed_precision: bool = True
    save_best_model: bool = True
    early_stopping_patience: int = 10
    
    # Logging and checkpointing
    log_interval: int = 10
    save_interval: int = 50
    output_dir: str = "./sam_finetune_output"
    experiment_name: str = "sam_finetune"
    
    def __post_init__(self):
        """Post-initialization validation and setup."""
        if self.device == "auto":
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Create output directory
        Path(self.output_dir).mkdir(parents=True, exist_ok=True)
        
        # Setup logging
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(Path(self.output_dir) / "training.log"),
                logging.StreamHandler()
            ]
        )


In [11]:
#| export
class BoundingBoxGenerator:
    """bounding box generator. """
    
    def __init__(
            self, 
            perturbation_range: int = 20, # max perturbation range
            min_box_size: int = 10, # min box size
            ):
        self.perturbation_range = perturbation_range
        self.min_box_size = min_box_size
    
    def get_bounding_box(
            self, 
            mask: np.ndarray,  # binary mask
            add_perturbation: bool = True, # whether to add random perturbation to bbox
            ) -> List[int]: # bounding box coordinates [x_min, y_min, x_max, y_max]
        """
        Generate bounding box from segmentation mask with robust error handling.
        """
        if not isinstance(mask, np.ndarray):
            raise TypeError(f"Expected numpy array, got {type(mask)}")
        
        if mask.ndim != 2:
            raise ValueError(f"Expected 2D mask, got {mask.ndim}D")
        
        # Find non-zero pixels
        y_indices, x_indices = np.where(mask > 0)
        
        if len(y_indices) == 0:
            raise ValueError("Empty mask - no positive pixels found")
        
        # Get bounding box coordinates
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)
        
        # Ensure minimum box size
        if (x_max - x_min) < self.min_box_size:
            center_x = (x_min + x_max) // 2
            x_min = max(0, center_x - self.min_box_size // 2)
            x_max = min(mask.shape[1], center_x + self.min_box_size // 2)
        
        if (y_max - y_min) < self.min_box_size:
            center_y = (y_min + y_max) // 2
            y_min = max(0, center_y - self.min_box_size // 2)
            y_max = min(mask.shape[0], center_y + self.min_box_size // 2)
        
        # Add perturbation if requested
        if add_perturbation and self.perturbation_range > 0:
            H, W = mask.shape
            x_min = max(0, x_min - np.random.randint(0, self.perturbation_range + 1))
            x_max = min(W, x_max + np.random.randint(0, self.perturbation_range + 1))
            y_min = max(0, y_min - np.random.randint(0, self.perturbation_range + 1))
            y_max = min(H, y_max + np.random.randint(0, self.perturbation_range + 1))
        
        return [x_min, y_min, x_max, y_max]


In [12]:
#| export
class SAMDataset(Dataset):
    """
    SAM dataset with comprehensive preprocessing and augmentation.
    
    Supports multiple data formats:
    - HuggingFace datasets
    - Directory structure (images/ and masks/ folders)
    - Custom data loaders
    """
    
    def __init__(
        self,
        data_source: Union[str, Path, Any],
        processor: SamProcessor,
        config: SAMFineTuneConfig,
        bbox_generator: Optional[BoundingBoxGenerator] = None,
        transform: Optional[A.Compose] = None,
        is_training: bool = True
    ):
        self.processor = processor
        self.config = config
        self.bbox_generator = bbox_generator or BoundingBoxGenerator(
            perturbation_range=config.bbox_perturbation_range
        )
        self.transform = transform
        self.is_training = is_training
        
        # Load data based on source type
        self.data = self._load_data(data_source)
        
        logging.info(f"Loaded {len(self.data)} samples for {'training' if is_training else 'validation'}")


In [15]:
#| export
@patch
def _load_data(
    self: SAMDataset, 
    data_source: Union[str, Path, Any]
    ) -> List[Dict[str, Any]]:
    """Load data from various sources with robust error handling."""
        
    if hasattr(data_source, '__getitem__') and hasattr(data_source, '__len__'):
        # HuggingFace dataset or similar
        return [data_source[i] for i in range(len(data_source))]
        
    elif isinstance(data_source, (str, Path)):
        # Directory structure
        return self._load_from_directory(Path(data_source))
        
    else:
        raise ValueError(f"Unsupported data source type: {type(data_source)}")


In [None]:
#| export
@patch    
def _load_from_directory(
    self: SAMDataset,
    data_dir: Path # path to data directory
    ) -> List[Dict[str, Any]]:
    """Load data from directory structure: data_dir/{images,masks}/filename.ext"""
        
    images_dir = data_dir / "images"
    masks_dir = data_dir / "masks"
        
    if not images_dir.exists():
        raise FileNotFoundError(f"Images directory not found: {images_dir}")
    if not masks_dir.exists():
        raise FileNotFoundError(f"Masks directory not found: {masks_dir}")
        
    # Find matching image-mask pairs
    image_files = sorted(list(images_dir.glob("*")))
    data = []
        
    for img_path in image_files:
        # Look for corresponding mask
        mask_path = masks_dir / img_path.name
            
        # Try different extensions if exact match not found
        if not mask_path.exists():
            stem = img_path.stem
            for ext in ['.png', '.jpg', '.jpeg', '.tif', '.tiff']:
                mask_path = masks_dir / f"{stem}{ext}"
                if mask_path.exists():
                    break
            else:
                logging.warning(f"No mask found for image: {img_path}")
                continue
            
        data.append({
            "image_path": str(img_path),
            "mask_path": str(mask_path)
        })
        
    return data
   

In [18]:
#| export
@patch
def __len__(self: SAMDataset) -> int:
    return len(self.data)


In [19]:
#| export
@patch
def _get_dummy_sample(
    self: SAMDataset,
    ) -> Dict[str, torch.Tensor]:
    """Create a dummy sample for error recovery."""
    dummy_image = Image.new('RGB', (256, 256), color='black')
    dummy_mask = np.zeros((256, 256), dtype=np.float32)
    dummy_bbox = [0, 0, 10, 10]
        
    inputs = self.processor(
        dummy_image,
        input_boxes=[[dummy_bbox]],
        return_tensors="pt"
    )
    inputs = {k: v.squeeze(0) for k, v in inputs.items()}
    inputs["ground_truth_mask"] = torch.from_numpy(dummy_mask)
        
    return inputs

In [21]:
#| export
@patch
def __getitem__(
    self: SAMDataset, 
    idx: int
    ) -> Dict[str, torch.Tensor]:
    """Get a single training example with comprehensive preprocessing."""
        
    try:
        item = self.data[idx]
            
        # Load image and mask
        if "image" in item and "label" in item:
            # HuggingFace dataset format
            image = item["image"]
            if hasattr(image, 'convert'):
                image = image.convert('RGB')
            mask = np.array(item["label"])
        else:
            # File path format
            image = Image.open(item["image_path"]).convert('RGB')
            mask = np.array(Image.open(item["mask_path"]))
            
        # Convert to numpy for processing
        image_np = np.array(image)
            
        # Ensure mask is binary
        if mask.max() > 1:
            mask = (mask > 0).astype(np.uint8)
            
        # Apply augmentations if training
        if self.is_training and self.transform is not None:
            augmented = self.transform(image=image_np, mask=mask)
            image_np = augmented['image']
            mask = augmented['mask']
            
        # Convert back to PIL for processor
        if isinstance(image_np, np.ndarray):
            image = Image.fromarray(image_np)
            
        # Generate bounding box prompt
        bbox = self.bbox_generator.get_bounding_box(
            mask, add_perturbation=self.is_training
        )
            
        # Prepare inputs for model
        inputs = self.processor(
            image, 
            input_boxes=[[bbox]], 
            return_tensors="pt"
        )
            
        # Remove batch dimension added by processor
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
            
        # Add ground truth mask
        inputs["ground_truth_mask"] = torch.from_numpy(mask.astype(np.float32))
            
        return inputs
            
    except Exception as e:
        logging.error(f"Error processing item {idx}: {str(e)}")
        # Return a dummy sample to prevent training interruption
        return self._get_dummy_sample()


In [None]:
#| export
class AugmentationFactory:
    """Factory for creating medical image augmentation pipelines."""
    
    @staticmethod
    def create_training_transform(config: SAMFineTuneConfig) -> A.Compose:
        """Create training augmentation pipeline optimized for medical images."""
        
        transforms = []
        
        if config.use_augmentation:
            # Geometric transformations (preserve anatomical structure)
            transforms.extend([
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.3),
                A.RandomRotate90(p=0.5),
                A.ShiftScaleRotate(
                    shift_limit=0.1,
                    scale_limit=0.1,
                    rotate_limit=15,
                    p=0.5,
                    border_mode=0
                ),
            ])
            
            # Intensity transformations (medical imaging specific)
            transforms.extend([
                A.RandomBrightnessContrast(
                    brightness_limit=0.2,
                    contrast_limit=0.2,
                    p=0.5
                ),
                A.RandomGamma(gamma_limit=(80, 120), p=0.3),
                A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.3),
            ])
            
            # Noise and blur (simulate acquisition artifacts)
            transforms.extend([
                A.GaussNoise(var_limit=(10, 50), p=0.2),
                A.GaussianBlur(blur_limit=(3, 5), p=0.2),
            ])
        
        return A.Compose(transforms)
    
    @staticmethod
    def create_validation_transform() -> A.Compose:
        """Create validation transform (no augmentation)."""
        return A.Compose([])  # No augmentation for validation


In [22]:
#| export
class LossFactory:
    """Factory for creating different loss functions for segmentation."""
    
    @staticmethod
    def create_loss(config: SAMFineTuneConfig) -> nn.Module:
        """Create loss function based on configuration."""
        
        if config.loss_type == "dice_ce":
            return DiceCELoss(
                sigmoid=True,  # Apply sigmoid activation to convert logits to probabilities
                squared_pred=True,  # Square predicted values in Dice calculation for smoother gradients
                reduction='mean',  # Average loss across batch dimension for stable training
                ce_weight=config.ce_weight,  # Weight for cross-entropy component in combined loss
                dice_weight=config.dice_weight  # Weight for Dice component in combined loss
            )
        
        elif config.loss_type == "dice":
            return DiceLoss(
                sigmoid=True,
                squared_pred=True,
                reduction='mean'
            )
        
        elif config.loss_type == "focal":
            return FocalLoss(
                alpha=config.focal_alpha, # Weight for positive class
                gamma=config.focal_gamma, # Focusing parameter
                reduction='mean' # Average loss across batch dimension for stable training
            )
        
        elif config.loss_type == "bce":
            return nn.BCEWithLogitsLoss(reduction='mean')
        
        else:
            raise ValueError(f"Unsupported loss type: {config.loss_type}")


In [23]:
#| export

class MetricsCalculator:
    """Calculate segmentation metrics for evaluation."""
    
    @staticmethod
    def dice_coefficient(
        pred: torch.Tensor, 
        target: torch.Tensor, 
        smooth: float = 1e-6 # small constant to avoid division by zero
        ) -> float:
        """Calculate Dice coefficient."""
        pred = (pred > 0.5).float()
        target = target.float()
        
        intersection = (pred * target).sum()
        union = pred.sum() + target.sum()
        
        dice = (2.0 * intersection + smooth) / (union + smooth)
        return dice.item()
    
    @staticmethod
    def iou_score(pred: torch.Tensor, target: torch.Tensor, smooth: float = 1e-6) -> float:
        """Calculate IoU (Jaccard) score."""
        pred = (pred > 0.5).float()
        target = target.float()
        
        intersection = (pred * target).sum()
        union = pred.sum() + target.sum() - intersection
        
        iou = (intersection + smooth) / (union + smooth)
        return iou.item()
    
    @staticmethod
    def pixel_accuracy(pred: torch.Tensor, target: torch.Tensor) -> float:
        """Calculate pixel-wise accuracy."""
        pred = (pred > 0.5).float()
        target = target.float()
        
        correct = (pred == target).float().sum()
        total = target.numel()
        
        return (correct / total).item()


In [24]:
#| export
class SAMTrainer:
    """
    SAM trainer with comprehensive features:
    - Mixed precision training
    - Gradient clipping
    - Learning rate scheduling
    - Early stopping
    - Model checkpointing
    - Comprehensive logging
    """
    
    def __init__(self, config: SAMFineTuneConfig):
        self.config = config
        self.device = torch.device(config.device)
        
        # Initialize model and processor
        self.processor = SamProcessor.from_pretrained(config.model_name)
        self.model = self._setup_model()
        
        # Initialize training components
        self.optimizer = self._setup_optimizer()
        self.scheduler = self._setup_scheduler()
        self.loss_fn = LossFactory.create_loss(config)
        self.scaler = torch.cuda.amp.GradScaler() if config.mixed_precision else None
        
        # Training state
        self.current_epoch = 0
        self.best_loss = float('inf')
        self.patience_counter = 0
        self.training_history = []
        
        # Metrics
        self.metrics_calc = MetricsCalculator()
        
        logging.info(f"SAMTrainer initialized with device: {self.device}")
        logging.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters() if p.requires_grad):,}")


In [25]:
#| export
@patch
def _setup_model(
    self: SAMTrainer,
    ) -> SamModel:
        """Setup and configure the SAM model."""
        model = SamModel.from_pretrained(self.config.model_name)
        
        # Freeze components based on configuration
        if self.config.freeze_encoder:
            for param in model.vision_encoder.parameters():
                param.requires_grad_(False)
            logging.info("Vision encoder frozen")
        
        if self.config.freeze_prompt_encoder:
            for param in model.prompt_encoder.parameters():
                param.requires_grad_(False)
            logging.info("Prompt encoder frozen")
        
        # Only train mask decoder by default
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        logging.info(f"Trainable parameters: {trainable_params:,}")
        
        return model.to(self.device)


In [26]:
#| export
@patch
def _setup_optimizer(
    self: SAMTrainer,
    ) -> torch.optim.Optimizer:
        """Setup optimizer with proper parameter filtering."""
        trainable_params = [p for p in self.model.parameters() if p.requires_grad]
        
        return AdamW(
            trainable_params,
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay,
            betas=(0.9, 0.999),
            eps=1e-8
        )


In [27]:
#| export
@patch
def _setup_scheduler(
    self: SAMTrainer,
    ) -> torch.optim.lr_scheduler._LRScheduler:
    """Setup learning rate scheduler."""
    return torch.optim.lr_scheduler.CosineAnnealingLR(
        self.optimizer,
        T_max=self.config.num_epochs,
        eta_min=self.config.learning_rate * 0.01
    )


In [28]:
#| export
@patch
def _train_one_batch(
    self: SAMTrainer,
    batch: Dict[str, torch.Tensor],
    batch_idx: int
    ) -> Dict[str, float]:
        """Train on a single batch and return metrics."""
        # Move batch to device
        batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                for k, v in batch.items()}
        
        # Forward pass with mixed precision
        with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
            outputs = self.model(
                pixel_values=batch["pixel_values"],
                input_boxes=batch["input_boxes"],
                multimask_output=False
            )
            
            predicted_masks = outputs.pred_masks.squeeze(1)
            ground_truth_masks = batch["ground_truth_mask"]
            
            # Calculate loss
            loss = self.loss_fn(predicted_masks, ground_truth_masks.unsqueeze(1))
        
        # Backward pass
        # Clear gradients from previous iteration
        self.optimizer.zero_grad()
        
        # Handle mixed precision training with gradient scaler
        if self.scaler is not None:
            # Scale loss to prevent gradient underflow in fp16
            self.scaler.scale(loss).backward()
            # Unscale gradients before clipping to get true gradient norms
            self.scaler.unscale_(self.optimizer)
            # Clip gradients to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(), 
                self.config.gradient_clip_norm
            )
            # Update parameters with scaled gradients
            self.scaler.step(self.optimizer)
            # Update scaler for next iteration
            self.scaler.update()
        else:
            # Standard fp32 training without gradient scaling
            loss.backward()
            # Clip gradients to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(), 
                self.config.gradient_clip_norm
            )
            # Update model parameters
            self.optimizer.step()
        
        # Calculate metrics without gradient computation
        with torch.no_grad():
            # Apply sigmoid to get probabilities from logits
            pred_sigmoid = torch.sigmoid(predicted_masks)
            # Calculate Dice coefficient for segmentation overlap
            dice = self.metrics_calc.dice_coefficient(pred_sigmoid, ground_truth_masks)
            # Calculate Intersection over Union score
            iou = self.metrics_calc.iou_score(pred_sigmoid, ground_truth_masks)
            # Calculate pixel-wise classification accuracy
            accuracy = self.metrics_calc.pixel_accuracy(pred_sigmoid, ground_truth_masks)
        
        return {
            'loss': loss.item(),
            'dice': dice,
            'iou': iou,
            'accuracy': accuracy
        }


In [29]:
#| export
@patch
def train_epoch(
    self: SAMTrainer,
    train_loader: DataLoader
    ) -> Dict[str, float]:
        """Train for one epoch with comprehensive logging."""
        self.model.train()
        epoch_losses = []
        epoch_metrics = {'dice': [], 'iou': [], 'accuracy': []}
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {self.current_epoch}")
        
        for batch_idx, batch in enumerate(progress_bar):
            batch_stats = self._train_one_batch(batch, batch_idx)
            
            # Accumulate metrics
            epoch_losses.append(batch_stats['loss'])
            epoch_metrics['dice'].append(batch_stats['dice'])
            epoch_metrics['iou'].append(batch_stats['iou'])
            epoch_metrics['accuracy'].append(batch_stats['accuracy'])
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{batch_stats['loss']:.4f}",
                'dice': f"{batch_stats['dice']:.4f}",
                'lr': f"{self.optimizer.param_groups[0]['lr']:.2e}"
            })
        
        # Calculate epoch averages
        epoch_stats = {
            'loss': mean(epoch_losses),
            'dice': mean(epoch_metrics['dice']),
            'iou': mean(epoch_metrics['iou']),
            'accuracy': mean(epoch_metrics['accuracy']),
            'lr': self.optimizer.param_groups[0]['lr']
        }
        
        return epoch_stats


In [31]:
#| export
@patch
def validate(
    self: SAMTrainer,
    val_loader: DataLoader
    ) -> Dict[str, float]:
    """Validate the model."""
    self.model.eval()
    val_losses = []
    val_metrics = {'dice': [], 'iou': [], 'accuracy': []}
        
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                    for k, v in batch.items()}
                
            outputs = self.model(
                pixel_values=batch["pixel_values"],
                input_boxes=batch["input_boxes"],
                multimask_output=False
            )
                
            predicted_masks = outputs.pred_masks.squeeze(1)
            ground_truth_masks = batch["ground_truth_mask"]
                
            loss = self.loss_fn(predicted_masks, ground_truth_masks.unsqueeze(1))
                
            # Calculate metrics
            pred_sigmoid = torch.sigmoid(predicted_masks)
            dice = self.metrics_calc.dice_coefficient(pred_sigmoid, ground_truth_masks)
            iou = self.metrics_calc.iou_score(pred_sigmoid, ground_truth_masks)
            accuracy = self.metrics_calc.pixel_accuracy(pred_sigmoid, ground_truth_masks)
                
            val_losses.append(loss.item())
            val_metrics['dice'].append(dice)
            val_metrics['iou'].append(iou)
            val_metrics['accuracy'].append(accuracy)
        
    return {
        'val_loss': mean(val_losses),
        'val_dice': mean(val_metrics['dice']),
        'val_iou': mean(val_metrics['iou']),
        'val_accuracy': mean(val_metrics['accuracy'])
    }
	


In [32]:
#| export
@patch
def save_checkpoint(
    self: SAMTrainer,
    epoch: int,
    is_best: bool = False
    ):
    """Save model checkpoint."""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': self.model.state_dict(),
        'optimizer_state_dict': self.optimizer.state_dict(),
        'scheduler_state_dict': self.scheduler.state_dict(),
        'best_loss': self.best_loss,
        'config': self.config,
        'training_history': self.training_history
    }
        
    # Save regular checkpoint
    checkpoint_path = Path(self.config.output_dir) / f"checkpoint_epoch_{epoch}.pt"
    torch.save(checkpoint, checkpoint_path)
        
    # Save best model
    if is_best:
        best_path = Path(self.config.output_dir) / "best_model.pt"
        torch.save(checkpoint, best_path)
        logging.info(f"New best model saved with loss: {self.best_loss:.6f}")
 

In [33]:
#| export
@patch
def train(
    self: SAMTrainer,
    train_loader: DataLoader,
    val_loader: Optional[DataLoader] = None
    ):
    """Main training loop with all production features."""
    logging.info("Starting training...")
        
    for epoch in range(self.config.num_epochs):
        self.current_epoch = epoch
            
        # Training
        train_stats = self.train_epoch(train_loader)
            
        # Validation
        val_stats = {}
        if val_loader is not None:
            val_stats = self.validate(val_loader)
            
        # Learning rate scheduling
        self.scheduler.step()
            
        # Combine stats
        epoch_stats = {**train_stats, **val_stats}
        self.training_history.append(epoch_stats)
            
        # Logging
        log_msg = f"Epoch {epoch}: "
        log_msg += f"Loss: {train_stats['loss']:.6f}, "
        log_msg += f"Dice: {train_stats['dice']:.4f}, "
        log_msg += f"IoU: {train_stats['iou']:.4f}"
            
        if val_stats:
            log_msg += f", Val Loss: {val_stats['val_loss']:.6f}"
            log_msg += f", Val Dice: {val_stats['val_dice']:.4f}"
            
        logging.info(log_msg)
            
        # Model checkpointing and early stopping
        current_loss = val_stats.get('val_loss', train_stats['loss'])
        is_best = current_loss < self.best_loss
            
        if is_best:
            self.best_loss = current_loss
            self.patience_counter = 0
        else:
            self.patience_counter += 1
            
        # Save checkpoint
        if epoch % self.config.save_interval == 0 or is_best:
            self.save_checkpoint(epoch, is_best)
            
        # Early stopping
        if self.patience_counter >= self.config.early_stopping_patience:
            logging.info(f"Early stopping triggered after {epoch} epochs")
            break
        
    logging.info("Training completed!")
        
    # Save final model
    self.save_checkpoint(self.current_epoch, False)
        
    return self.training_history


In [34]:
#| export
class SAMInference:
    """Production-ready inference class for fine-tuned SAM models."""
    
    def __init__(self, model_path: str, config: SAMFineTuneConfig):
        self.config = config # configuration object
        self.device = torch.device(config.device) # device to use for inference
        
        # Load model and processor
        self.processor = SamProcessor.from_pretrained(config.model_name) # processor for image preprocessing
        self.model = self._load_model(model_path) # load model from checkpoint
        self.bbox_generator = BoundingBoxGenerator(perturbation_range=0)  # No perturbation for inference
        
        logging.info(f"SAMInference initialized with model from: {model_path}")


    def _load_model(
            self, 
            model_path: str) -> SamModel:
        """Load the fine-tuned model."""
        # Load checkpoint
        checkpoint = torch.load(
            model_path, 
            map_location=self.device)
        
        # Initialize model
        model = SamModel.from_pretrained(
            self.config.model_name
            )
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(self.device)
        model.eval()
        
        return model
    


In [35]:
#| export
def predict_from_bbox(
    self: SAMInference, 
    image: Union[Image.Image, np.ndarray], # input image (PIL Image or numpy array)
    bbox: List[int], # bounding box [x_min, y_min, x_max, y_max]
    return_logits: bool = False # whether to return raw logits or sigmoid probabilities
) -> np.ndarray: # segmentation mask as numpy array (H, W)
    """
    Predict segmentation mask from image and bounding box.
            
    """
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
        
    # Prepare inputs
    inputs = self.processor(
        image,
        input_boxes=[[bbox]],
        return_tensors="pt"
    ).to(self.device)
        
    # Inference
    with torch.no_grad():
        outputs = self.model(**inputs, multimask_output=False)
        predicted_masks = outputs.pred_masks.squeeze(1)
            
        if return_logits:
            return predicted_masks.cpu().numpy().squeeze()
        else:
            # Apply sigmoid and threshold
            mask_prob = torch.sigmoid(predicted_masks)
            mask = (mask_prob > self.config.mask_threshold).float()
            return mask.cpu().numpy().squeeze()
    



In [36]:
#| export
def predict_from_mask(
    self: SAMInference,
    image: Union[Image.Image, np.ndarray], # input image (PIL Image or numpy array)
    rough_mask: np.ndarray, # rough segmentation mask (H, W)
    return_logits: bool = False # whether to return raw logits or sigmoid probabilities
) -> np.ndarray: # refined segmentation mask as numpy array (H, W)
    """
    Predict segmentation from image and rough mask (generates bbox automatically).
    """
        # Generate bounding box from rough mask
    bbox = self.bbox_generator.get_bounding_box(rough_mask, add_perturbation=False)
        
    return self.predict_from_bbox(image, bbox, return_logits)


In [38]:
#| export
@patch
def predict_batch(
    self: SAMInference,
    images: List[Union[Image.Image, np.ndarray]], # list of input images (PIL Image or numpy array)
    bboxes: List[List[int]], # list of bounding boxes [x_min, y_min, x_max, y_max]
    batch_size: int = 4 # batch size for processing
    ) -> List[np.ndarray]: # list of segmentation masks as numpy arrays (H, W)
    """Batch prediction for multiple images."""
    results = [] # list to store segmentation masks
        
    for i in range(0, len(images), batch_size): # process in batches
        batch_images = images[i:i+batch_size] # get batch of images
        batch_bboxes = bboxes[i:i+batch_size] # get batch of bounding boxes
            
        batch_results = [] # list to store segmentation masks for batch
        for img, bbox in zip(batch_images, batch_bboxes): # predict for each image in batch
            mask = self.predict_from_bbox(img, bbox) # predict segmentation mask
            batch_results.append(mask) # add to batch results
            
        results.extend(batch_results) # add batch results to final results
        
    return results # return list of segmentation masks



In [39]:
#| export
class VisualizationUtils:
    """Utilities for visualizing segmentation results."""
    
    @staticmethod
    def show_mask(mask: np.ndarray, ax: plt.Axes, random_color: bool = False, alpha: float = 0.6):
        """Display segmentation mask overlay."""
        if random_color:
            color = np.concatenate([np.random.random(3), np.array([alpha])], axis=0)
        else:
            color = np.array([30/255, 144/255, 255/255, alpha])
        
        h, w = mask.shape[-2:]
        mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        ax.imshow(mask_image)
    
    @staticmethod
    def show_bbox(
        bbox: List[int], # bounding box [x_min, y_min, x_max, y_max]
        ax: plt.Axes, # matplotlib axes object
        color: str = 'red', # color of bounding box
        linewidth: int = 2 # width of bounding box
    ):
        """Display bounding box."""
        x_min, y_min, x_max, y_max = bbox
        rect = plt.Rectangle(
            (x_min, y_min), 
            x_max - x_min, 
            y_max - y_min,
            fill=False, 
            color=color, 
            linewidth=linewidth
        )
        ax.add_patch(rect)
    
    @staticmethod
    def compare_predictions(
        image: Union[Image.Image, np.ndarray], # input image (PIL Image or numpy array)
        ground_truth: np.ndarray, # ground truth mask (H, W)
        prediction: np.ndarray, # predicted mask (H, W)
        bbox: Optional[List[int]] = None, # bounding box [x_min, y_min, x_max, y_max]
        title: str = "Comparison"
    ):
        """Create comparison visualization."""
        if isinstance(image, Image.Image):
            image = np.array(image)
        
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Original image
        axes[0].imshow(image)
        if bbox is not None:
            VisualizationUtils.show_bbox(bbox, axes[0])
        axes[0].set_title("Original Image + BBox")
        axes[0].axis("off")
        
        # Ground truth
        axes[1].imshow(image)
        VisualizationUtils.show_mask(ground_truth, axes[1], alpha=0.7)
        axes[1].set_title("Ground Truth")
        axes[1].axis("off")
        
        # Prediction
        axes[2].imshow(image)
        VisualizationUtils.show_mask(prediction, axes[2], random_color=True, alpha=0.7)
        axes[2].set_title("Prediction")
        axes[2].axis("off")
        
        plt.suptitle(title)
        plt.tight_layout()
        plt.show()
    
    @staticmethod
    def plot_training_history(
        history: List[Dict[str, float]] # training history
    ):
        """Plot training history."""
        if not history:
            return
        
        epochs = range(len(history))
        
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        
        # Loss
        train_loss = [h['loss'] for h in history]
        val_loss = [h.get('val_loss', None) for h in history]
        
        axes[0, 0].plot(epochs, train_loss, label='Train Loss')
        if any(v is not None for v in val_loss):
            val_loss_clean = [v for v in val_loss if v is not None]
            val_epochs = [i for i, v in enumerate(val_loss) if v is not None]
            axes[0, 0].plot(val_epochs, val_loss_clean, label='Val Loss')
        axes[0, 0].set_title('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        # Dice Score
        train_dice = [h['dice'] for h in history]
        val_dice = [h.get('val_dice', None) for h in history]
        
        axes[0, 1].plot(epochs, train_dice, label='Train Dice')
        if any(v is not None for v in val_dice):
            val_dice_clean = [v for v in val_dice if v is not None]
            val_epochs = [i for i, v in enumerate(val_dice) if v is not None]
            axes[0, 1].plot(val_epochs, val_dice_clean, label='Val Dice')
        axes[0, 1].set_title('Dice Score')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        # IoU Score
        train_iou = [h['iou'] for h in history]
        val_iou = [h.get('val_iou', None) for h in history]
        
        axes[1, 0].plot(epochs, train_iou, label='Train IoU')
        if any(v is not None for v in val_iou):
            val_iou_clean = [v for v in val_iou if v is not None]
            val_epochs = [i for i, v in enumerate(val_iou) if v is not None]
            axes[1, 0].plot(val_epochs, val_iou_clean, label='Val IoU')
        axes[1, 0].set_title('IoU Score')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
        
        # Learning Rate
        lr = [h['lr'] for h in history]
        axes[1, 1].plot(epochs, lr)
        axes[1, 1].set_title('Learning Rate')
        axes[1, 1].set_yscale('log')
        axes[1, 1].grid(True)
        
        plt.tight_layout()
        plt.show()


In [40]:
#| export
def create_sam_finetune_pipeline(
    train_data_source: Union[str, Path, Any], # training data source
    val_data_source: Optional[Union[str, Path, Any]] = None, # validation data source
    config: Optional[SAMFineTuneConfig] = None, # configuration object
    **config_kwargs
) -> Tuple[SAMTrainer, DataLoader, Optional[DataLoader]]:
    """
    Create a complete SAM fine-tuning pipeline with sensible defaults.
    
    This is the main entry point for users who want to quickly set up training.
    
    Args:
        train_data_source: Training data (HuggingFace dataset, directory path, etc.)
        val_data_source: Validation data (optional)
        config: Configuration object (optional, will create default if None)
        **config_kwargs: Additional configuration parameters
        
    Returns:
        Tuple of (trainer, train_loader, val_loader)
        
    Example:
        ```python
        # Using HuggingFace dataset
        from datasets import load_dataset
        dataset = load_dataset("nielsr/breast-cancer", split="train")
        
        trainer, train_loader, val_loader = create_sam_finetune_pipeline(
            train_data_source=dataset,
            batch_size=4,
            num_epochs=50,
            learning_rate=1e-5
        )
        
        # Start training
        history = trainer.train(train_loader, val_loader)
        ```
        
        ```python
        # Using directory structure
        trainer, train_loader, val_loader = create_sam_finetune_pipeline(
            train_data_source="./data/train",
            val_data_source="./data/val",
            use_augmentation=True,
            mixed_precision=True
        )
        
        history = trainer.train(train_loader, val_loader)
        ```
    """
    
    # Create configuration
    if config is None:
        config = SAMFineTuneConfig(**config_kwargs)
    else:
        # Update config with any additional kwargs
        for key, value in config_kwargs.items():
            if hasattr(config, key):
                setattr(config, key, value)
    
    # Initialize trainer (this also initializes processor)
    trainer = SAMTrainer(config)
    
    # Create augmentation transforms
    train_transform = AugmentationFactory.create_training_transform(config)
    val_transform = AugmentationFactory.create_validation_transform()
    
    # Create datasets
    train_dataset = SAMDataset(
        data_source=train_data_source,
        processor=trainer.processor,
        config=config,
        transform=train_transform,
        is_training=True
    )
    
    val_dataset = None
    if val_data_source is not None:
        val_dataset = SAMDataset(
            data_source=val_data_source,
            processor=trainer.processor,
            config=config,
            transform=val_transform,
            is_training=False
        )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True if config.device == "cuda" else False,
        drop_last=True
    )
    
    val_loader = None
    if val_dataset is not None:
        val_loader = DataLoader(
            val_dataset,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True if config.device == "cuda" else False,
            drop_last=False
        )
    
    logging.info(f"Pipeline created successfully!")
    logging.info(f"Training samples: {len(train_dataset)}")
    if val_dataset:
        logging.info(f"Validation samples: {len(val_dataset)}")
    
    return trainer, train_loader, val_loader



In [None]:
#| export
def quick_train_sam(
    train_data_source: Union[str, Path, Any],
    val_data_source: Optional[Union[str, Path, Any]] = None,
    output_dir: str = "./sam_finetune_output",
    num_epochs: int = 50,
    batch_size: int = 4,
    learning_rate: float = 1e-5,
    **kwargs
) -> Dict[str, Any]:
    """
    Quick training function for SAM fine-tuning with minimal setup.
    
    This function provides a one-liner solution for training SAM on custom data.
    
    Args:
        train_data_source: Training data source
        val_data_source: Validation data source (optional)
        output_dir: Directory to save outputs
        num_epochs: Number of training epochs
        batch_size: Training batch size
        learning_rate: Learning rate
        **kwargs: Additional configuration parameters
        
    Returns:
        Dictionary containing training history and model paths
        
    Example:
        ```python
        # Train SAM on your data with one line
        results = quick_train_sam(
            train_data_source="./my_data/train",
            val_data_source="./my_data/val",
            num_epochs=100,
            batch_size=8
        )
        
        # Access training history
        history = results['history']
        best_model_path = results['best_model_path']
        ```
    """
    
    # Create configuration
    config = SAMFineTuneConfig(
        output_dir=output_dir,
        num_epochs=num_epochs,
        batch_size=batch_size,
        learning_rate=learning_rate,
        **kwargs
    )
    
    # Create pipeline
    trainer, train_loader, val_loader = create_sam_finetune_pipeline(
        train_data_source=train_data_source,
        val_data_source=val_data_source,
        config=config
    )
    
    # Train model
    history = trainer.train(train_loader, val_loader)
    
    # Plot training history
    VisualizationUtils.plot_training_history(history)
    
    # Return results
    results = {
        'history': history,
        'config': config,
        'best_model_path': str(Path(config.output_dir) / "best_model.pt"),
        'final_model_path': str(Path(config.output_dir) / f"checkpoint_epoch_{trainer.current_epoch}.pt"),
        'trainer': trainer
    }
    
    return results


In [None]:
# Example 1: Quick training with HuggingFace dataset
from datasets import load_dataset

# Load a sample dataset
dataset = load_dataset("nielsr/breast-cancer", split="train")

# Quick training with minimal setup
results = quick_train_sam(
    train_data_source=dataset,
    output_dir="./sam_breast_cancer",
    num_epochs=10,  # Reduced for demo
    batch_size=2,   # Reduced for demo
    learning_rate=1e-5,
    use_augmentation=True,
    mixed_precision=True
)

print(f"Training completed! Best model saved at: {results['best_model_path']}")
print(f"Final validation dice score: {results['history'][-1].get('val_dice', 'N/A')}")


In [None]:
# Example 2: Advanced training with custom configuration
config = SAMFineTuneConfig(
    model_name="facebook/sam-vit-base",
    batch_size=4,
    learning_rate=1e-5,
    num_epochs=100,
    loss_type="dice_ce",
    use_augmentation=True,
    mixed_precision=True,
    early_stopping_patience=15,
    output_dir="./advanced_sam_training"
)

# Create pipeline with custom config
trainer, train_loader, val_loader = create_sam_finetune_pipeline(
    train_data_source=dataset,
    config=config
)

# Train with custom monitoring
history = trainer.train(train_loader, val_loader)

# Plot results
VisualizationUtils.plot_training_history(history)


In [None]:
# Example 3: Using your own data from directories
# Assuming you have data organized as:
# my_data/
# ├── train/
# │   ├── images/
# │   │   ├── img1.jpg
# │   │   └── img2.jpg
# │   └── masks/
# │       ├── img1.png
# │       └── img2.png
# └── val/
#     ├── images/
#     └── masks/

# Uncomment and modify paths for your data:
"""
results = quick_train_sam(
    train_data_source="./my_data/train",
    val_data_source="./my_data/val",
    output_dir="./my_sam_model",
    num_epochs=50,
    batch_size=4,
    learning_rate=1e-5
)
"""


In [None]:
# Example 4: Inference with trained model
# Load trained model for inference
config = SAMFineTuneConfig(device="cuda" if torch.cuda.is_available() else "cpu")
inference_model = SAMInference(
    model_path="./sam_breast_cancer/best_model.pt",
    config=config
)

# Test inference on a sample
sample_idx = 0
sample = dataset[sample_idx]
image = sample["image"]
ground_truth = np.array(sample["label"])

# Generate bounding box from ground truth (in practice, you'd provide this)
bbox_gen = BoundingBoxGenerator(perturbation_range=0)
bbox = bbox_gen.get_bounding_box(ground_truth, add_perturbation=False)

# Predict
prediction = inference_model.predict_from_bbox(image, bbox)

# Visualize results
VisualizationUtils.compare_predictions(
    image=image,
    ground_truth=ground_truth,
    prediction=prediction,
    bbox=bbox,
    title="SAM Fine-tuning Results"
)


In [None]:
#| hide
import nbdev; nbdev.nbdev_export('22_medSAM_finetune.ipynb')


In [41]:
#| hide
import nbdev; nbdev.nbdev_export('22_medSAM_finetune.ipynb')

TypeError: ExportModuleProc._default_exp_() takes 3 positional arguments but 4 were given