## Setup & Installation

**Note**: Ultralytics is already installed in requirements.txt.

SAM models will auto-download when first used:
- **SAM 2 Base** (`sam_b.pt`): ~375MB - Recommended
- **SAM 2 Large** (`sam_l.pt`): ~1.2GB - Higher quality
- **SAM 3** (`sam3.pt`): Requires manual download from HuggingFace

For this experiment, we'll use SAM 2 Base which auto-downloads.

In [1]:
import os
import sys
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image

sys.path.insert(0, os.path.abspath('../'))

from src.segmentation import (
    load_sam_model,
    segment_image,
    get_largest_mask,
    process_dataset,
    create_segmented_dataset,
    verify_dataset_structure
)
from src.segmentation.background_removal import (
    process_image_with_sam,
    visualize_background_removal
)
from src.dataset.loaders import get_dataloaders
from src.utils.metrics import calculate_metrics

SEED = 21
DEVICE = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 32
IMG_SIZE = 224

# Directories
NOTEBOOK_DIR = Path(os.getcwd())
RESULTS_DIR = NOTEBOOK_DIR / 'notebooks'/ 'results' / 'sam_background_removal'
DATA_DIR = NOTEBOOK_DIR / 'data'
MODELS_DIR = NOTEBOOK_DIR/ 'models'
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
MODELS_DIR.mkdir(parents=True, exist_ok=True)

torch.manual_seed(SEED)
np.random.seed(SEED)

print(f"Device: {DEVICE}")
print(f"Results directory: {RESULTS_DIR}")
print(f"Data directory: {DATA_DIR}")
print(f"Using Ultralytics SAM - weights will auto-download if needed")

Device: mps
Results directory: /Users/stahlma/Desktop/01_Studium/11_Thesis/soybean/thesis_poc/notebooks/results/sam_background_removal
Data directory: /Users/stahlma/Desktop/01_Studium/11_Thesis/soybean/thesis_poc/data
Using Ultralytics SAM - weights will auto-download if needed


## Step 1: Load SAM Model

Using Ultralytics SAM 2 - weights will auto-download (~375MB for Base model).

In [6]:
print("Loading SAM model using Ultralytics...")
print("This will auto-download weights if not cached (~375MB).\n")

# Load SAM model - will auto-download if needed
model = load_sam_model('sam2.1_b.pt')

print("\n‚úÖ SAM model loaded successfully!")
print(f"   Model: SAM 2 Base")
print(f"   Backend: Ultralytics")
print(f"   Ready for segmentation")

Loading SAM model using Ultralytics...
This will auto-download weights if not cached (~375MB).

[KDownloading https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2.1_b.pt to 'sam2.1_b.pt': 100% ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ 154.4MB 6.6MB/s 23.5s 23.5s<0.1s
‚úÖ Loaded SAM model: sam2.1_b.pt

‚úÖ SAM model loaded successfully!
   Model: SAM 2 Base
   Backend: Ultralytics
   Ready for segmentation


## Step 2: Test SAM on Sample Images

Visualize segmentation on a few sample images to verify it's working correctly.

In [7]:
# Get sample images from MH
mh_rust_dir = DATA_DIR / 'MH-SoyaHealthVision' / 'Soyabean_Leaf_Image_Dataset' / 'Soyabean_Rust'
sample_images = list(mh_rust_dir.glob('*.[jJ][pP][gG]'))[:3]

print(f"Testing SAM on {len(sample_images)} sample images...\n")

for i, img_path in enumerate(sample_images):
    print(f"Processing: {img_path.name}")
    
    # Process with SAM (using center point heuristic)
    cleaned_image, mask, metadata = process_image_with_sam(
        str(img_path),
        model,
        background_color=(0, 0, 0)
    )
    
    print(f"  Masks found: {metadata['num_masks_found']}")
    print(f"  Mask coverage: {metadata['mask_coverage']:.1%}")
    print(f"  Mask area: {metadata['mask_area']} pixels\n")

    print(metadata)
    
    # Visualize
    image = Image.open(img_path).convert('RGB')
    vis = visualize_background_removal(
        image, mask, cleaned_image,
        save_path=RESULTS_DIR / f'sample_{i+1}_comparison.png'
    )
    if vis is not None:
        plt.show()

print("\nüìä Sample visualizations saved!")

Testing SAM on 3 sample images...

Processing: 20241006_181712.jpg



image 1/1 /Users/stahlma/Desktop/01_Studium/11_Thesis/soybean/thesis_poc/data/MH-SoyaHealthVision/Soyabean_Leaf_Image_Dataset/Soyabean_Rust/20241006_181712.jpg: 1024x1024 1 0, 1250.0ms
Speed: 27.9ms preprocess, 1250.0ms inference, 8.5ms postprocess per image at shape (1, 3, 1024, 1024)
  Masks found: 1
  Mask coverage: 41.5%
  Mask area: 3879069 pixels

{'num_masks_found': 1, 'num_masks_merged': 1, 'mask_area': 3879069, 'mask_coverage': 0.4153563027671939}
Processing: 20241006_084656.jpg

image 1/1 /Users/stahlma/Desktop/01_Studium/11_Thesis/soybean/thesis_poc/data/MH-SoyaHealthVision/Soyabean_Leaf_Image_Dataset/Soyabean_Rust/20241006_084656.jpg: 1024x1024 1 0, 1141.6ms
Speed: 7.9ms preprocess, 1141.6ms inference, 17.0ms postprocess per image at shape (1, 3, 1024, 1024)
  Masks found: 1
  Mask coverage: 20.8%
  Mask area: 1938807 pixels

{'num_masks_found': 1, 'num_masks_merged': 1, 'mask_area': 1938807, 'mask_coverage': 0.20760025338532387}
Processing: 20240928_170413.jpg

image 1/1 /

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


ValueError: zero-size array to reduction operation minimum which has no identity

## üî• Two-Stage Detection: YOLO + SAM

Now let's test the **two-stage approach** (YOLO detection + SAM segmentation):

1. **YOLO** detects all leaves in the image
2. **SAM** segments each detected leaf
3. All masks are merged for complete foreground extraction

This approach handles:
- ‚úÖ Multiple leaves in one image
- ‚úÖ Partially visible leaves at edges
- ‚úÖ Diseased leaves with color variations

In [None]:
# Get sample images with multiple leaves
mh_rust_dir = DATA_DIR / 'MH-SoyaHealthVision' / 'Soyabean_Leaf_Image_Dataset' / 'Soyabean_Rust'
sample_images = list(mh_rust_dir.glob('*.[jJ][pP][gG]'))[:3]

print(f"Testing Two-Stage Detection on {len(sample_images)} images...\n")
print("=" * 70)

for i, img_path in enumerate(sample_images):
    print(f"\nüì∏ Image {i+1}: {img_path.name}")
    print("-" * 70)
    
    # Two-stage detection: YOLO + SAM
    cleaned_image, mask, metadata = process_image_with_sam(
        str(img_path),
        model,
        background_color=(0, 0, 0),
        use_yolo_detection=True  # Enable two-stage detection
    )
    
    print(f"  üîç Stage 1 (YOLO): {metadata.get('num_leaves_detected', 'N/A')} leaves detected")
    print(f"  üéØ Stage 2 (SAM): {metadata['num_masks_found']} masks found")
    print(f"  üîó Stage 3 (Merge): {metadata['num_masks_merged']} masks merged")
    print(f"  üìä Coverage: {metadata['mask_coverage']:.1%}")
    print(f"  üìê Mask area: {metadata['mask_area']:,} pixels")
    print(f"  üõ†Ô∏è  Method: {metadata.get('method', 'default')}")
    
    # Visualize
    image = Image.open(img_path).convert('RGB')
    vis = visualize_background_removal(
        image, mask, cleaned_image,
        save_path=RESULTS_DIR / f'two_stage_sample_{i+1}.png'
    )
    if vis is not None:
        display(vis)
        plt.close()

print("\n" + "=" * 70)
print("‚úÖ Two-stage detection test complete!")
print(f"üíæ Results saved to: {RESULTS_DIR}")
print("=" * 70)

### üìä Comparison: Single-Stage vs Two-Stage

Let's compare both approaches on the same image:

In [None]:
# Pick one test image
test_img = sample_images[0]
print(f"Comparing approaches on: {test_img.name}\n")

# 1. Single-stage (center point)
print("1Ô∏è‚É£ Single-Stage SAM (center point)")
print("-" * 50)
cleaned_single, mask_single, meta_single = process_image_with_sam(
    str(test_img),
    model,
    use_yolo_detection=False
)
print(f"  ‚Ä¢ Masks found: {meta_single['num_masks_found']}")
print(f"  ‚Ä¢ Masks merged: {meta_single['num_masks_merged']}")
print(f"  ‚Ä¢ Coverage: {meta_single['mask_coverage']:.1%}\n")

# 2. Two-stage (YOLO + SAM)
print("2Ô∏è‚É£ Two-Stage YOLO + SAM")
print("-" * 50)
cleaned_two, mask_two, meta_two = process_image_with_sam(
    str(test_img),
    model,
    use_yolo_detection=True
)
print(f"  ‚Ä¢ Leaves detected: {meta_two.get('num_leaves_detected', 'N/A')}")
print(f"  ‚Ä¢ Masks found: {meta_two['num_masks_found']}")
print(f"  ‚Ä¢ Masks merged: {meta_two['num_masks_merged']}")
print(f"  ‚Ä¢ Coverage: {meta_two['mask_coverage']:.1%}\n")

# Visualize comparison
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
original = Image.open(test_img).convert('RGB')

# Row 1: Single-stage
axes[0, 0].imshow(original)
axes[0, 0].set_title('Original Image', fontsize=12, fontweight='bold')
axes[0, 0].axis('off')

axes[0, 1].imshow(mask_single, cmap='gray')
axes[0, 1].set_title(f'Single-Stage Mask\n({meta_single["num_masks_merged"]} masks)', 
                     fontsize=12, fontweight='bold')
axes[0, 1].axis('off')

axes[0, 2].imshow(cleaned_single)
axes[0, 2].set_title(f'Single-Stage Result\n({meta_single["mask_coverage"]*100:.1f}% coverage)', 
                     fontsize=12, fontweight='bold')
axes[0, 2].axis('off')

# Row 2: Two-stage
axes[1, 0].imshow(original)
axes[1, 0].set_title('Original Image', fontsize=12, fontweight='bold')
axes[1, 0].axis('off')

axes[1, 1].imshow(mask_two, cmap='gray')
axes[1, 1].set_title(f'Two-Stage Mask\n({meta_two["num_masks_merged"]} masks, ' +
                     f'{meta_two.get("num_leaves_detected", 0)} leaves)', 
                     fontsize=12, fontweight='bold')
axes[1, 1].axis('off')

axes[1, 2].imshow(cleaned_two)
axes[1, 2].set_title(f'Two-Stage Result\n({meta_two["mask_coverage"]*100:.1f}% coverage)', 
                     fontsize=12, fontweight='bold')
axes[1, 2].axis('off')

# Add row labels
fig.text(0.02, 0.75, '1. Single-Stage', rotation=90, fontsize=14, 
         fontweight='bold', va='center')
fig.text(0.02, 0.25, '2. Two-Stage', rotation=90, fontsize=14, 
         fontweight='bold', va='center')

plt.tight_layout(rect=[0.03, 0, 1, 1])
plt.savefig(RESULTS_DIR / 'single_vs_two_stage_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úÖ Comparison saved to: {RESULTS_DIR / 'single_vs_two_stage_comparison.png'}")

## Step 3: Create Segmented MH Dataset

Process the entire MH dataset and create `MH_Segmented` with backgrounds removed.

In [None]:
# Define paths
mh_source = DATA_DIR / 'MH-SoyaHealthVision' / 'Soyabean_Leaf_Image_Dataset'
mh_output = DATA_DIR / 'MH_Segmented'

# Class folders to process
class_folders = [
    'Healthy_Soyabean',
    'Soyabean_Rust',
    'Soyabean_Frog_Leaf_Eye'
]

print("="*70)
print("CREATING SEGMENTED MH DATASET")
print("="*70)
print(f"Source: {mh_source}")
print(f"Output: {mh_output}")
print(f"Classes: {class_folders}")
print("="*70)
print("\n‚ö†Ô∏è This will take several minutes...\n")

# Process dataset
stats = create_segmented_dataset(
    dataset_name='MH',
    data_root=str(mh_source),
    output_root=str(mh_output),
    model=model,
    background_color=(0, 0, 0),
    class_folders=class_folders
)

print("\n‚úÖ Segmented dataset created!")

## Step 4: Verify Segmented Dataset

In [None]:
# Verify structure
print("Verifying segmented dataset structure...\n")
is_valid = verify_dataset_structure(str(mh_output))

if is_valid:
    print("\n‚úÖ Dataset structure is valid!")
else:
    print("\n‚ùå Dataset structure is invalid!")

In [None]:
# Visualize some comparisons
from src.segmentation.batch_processing import compare_original_vs_segmented

print("\nCreating comparison visualizations...\n")

for class_name in class_folders:
    print(f"Comparing: {class_name}")
    compare_original_vs_segmented(
        original_dir=str(mh_source / class_name),
        segmented_dir=str(mh_output / class_name),
        num_samples=3,
        save_path=RESULTS_DIR / f'{class_name}_comparison.png'
    )

## Step 5: Load Baseline ResNet50

Load the pre-trained model from ASDID.

In [None]:
# Path to pre-trained model
pretrained_model_path = NOTEBOOK_DIR.parent / 'notebooks' / 'results' / 'best_resnet50.pth'

if not pretrained_model_path.exists():
    raise FileNotFoundError(
        f"Pre-trained ResNet50 not found at {pretrained_model_path}. "
        "Please run notebook 01_cnn_baseline.ipynb first."
    )

print(f"‚úÖ Found pre-trained model: {pretrained_model_path}")

# Load model
model_resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
model_resnet.fc = nn.Linear(model_resnet.fc.in_features, 3)
model_resnet = model_resnet.to(DEVICE)

# Load pre-trained weights
checkpoint = torch.load(pretrained_model_path, map_location=DEVICE)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    model_resnet.load_state_dict(checkpoint['model_state_dict'])
else:
    model_resnet.load_state_dict(checkpoint)

print(f"‚úÖ Loaded source-trained ResNet50")
print(f"Model: ResNet50 (25.6M params)")
print(f"Source domain: ASDID")

## Step 6: Evaluate on Original MH (Baseline)

In [None]:
# Standard evaluation transforms
eval_transforms = T.Compose([
    T.Resize(int(IMG_SIZE * 1.14)),
    T.CenterCrop(IMG_SIZE),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load original MH dataset
print("Loading original MH dataset...")
_, _, mh_test_loader, mh_dataset, _, _, _ = get_dataloaders(
    dataset_name='MH',
    data_root=str(mh_source),
    batch_size=BATCH_SIZE,
    train_transform=eval_transforms,
    test_transform=eval_transforms,
    seed=SEED
)

print(f"MH Test Set: {len(mh_test_loader.dataset)} images")
print(f"Classes: {list(mh_dataset.class_to_idx.keys())}")

In [None]:
def evaluate_model(model, test_loader, device=DEVICE):
    """Evaluate model and return detailed results"""
    model.eval()
    all_labels = []
    all_preds = []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    
    # Compute metrics
    from sklearn.metrics import confusion_matrix
    metrics = calculate_metrics(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    
    return metrics, all_labels, all_preds, cm

print("="*70)
print("Evaluating on ORIGINAL MH (No Background Removal)")
print("="*70)

original_metrics, original_labels, original_preds, original_cm = evaluate_model(
    model_resnet, mh_test_loader
)

print(f"\nOriginal MH Performance:")
print(f"  Accuracy:  {original_metrics['accuracy']:.4f}")
print(f"  Precision: {original_metrics['precision']:.4f}")
print(f"  Recall:    {original_metrics['recall']:.4f}")
print(f"  F1 Score:  {original_metrics['f1']:.4f}")

print(f"\nConfusion Matrix:")
print(original_cm)

# Store baseline for comparison
baseline_f1 = original_metrics['f1']

## Step 7: Evaluate on Segmented MH

In [None]:
# Load segmented MH dataset
print("Loading segmented MH dataset...")
_, _, mh_seg_test_loader, mh_seg_dataset, _, _, _ = get_dataloaders(
    dataset_name='MH',
    data_root=str(mh_output),
    batch_size=BATCH_SIZE,
    train_transform=eval_transforms,
    test_transform=eval_transforms,
    seed=SEED
)

print(f"MH Segmented Test Set: {len(mh_seg_test_loader.dataset)} images")
print(f"Classes: {list(mh_seg_dataset.class_to_idx.keys())}")

In [None]:
print("="*70)
print("Evaluating on SEGMENTED MH (Background Removed)")
print("="*70)

segmented_metrics, segmented_labels, segmented_preds, segmented_cm = evaluate_model(
    model_resnet, mh_seg_test_loader
)

print(f"\nSegmented MH Performance:")
print(f"  Accuracy:  {segmented_metrics['accuracy']:.4f}")
print(f"  Precision: {segmented_metrics['precision']:.4f}")
print(f"  Recall:    {segmented_metrics['recall']:.4f}")
print(f"  F1 Score:  {segmented_metrics['f1']:.4f}")

print(f"\nConfusion Matrix:")
print(segmented_cm)

# Compare to baseline
f1_improvement = segmented_metrics['f1'] - baseline_f1
print(f"\nImprovement over baseline: {f1_improvement:+.4f} ({f1_improvement/baseline_f1*100:+.1f}%)")

## Step 8: Comparison & Analysis

In [None]:
import seaborn as sns

# Create summary table
summary_data = {
    'Method': ['Original MH', 'Segmented MH (SAM)'],
    'F1 Score': [original_metrics['f1'], segmented_metrics['f1']],
    'Accuracy': [original_metrics['accuracy'], segmented_metrics['accuracy']],
    'Precision': [original_metrics['precision'], segmented_metrics['precision']],
    'Recall': [original_metrics['recall'], segmented_metrics['recall']],
    'F1 Improvement': [0.0, f1_improvement]
}

df_summary = pd.DataFrame(summary_data)

print("\n" + "="*80)
print("SAM BACKGROUND REMOVAL - FINAL RESULTS")
print("="*80)
print(df_summary.to_string(index=False))
print("="*80)

# Save summary
df_summary.to_csv(RESULTS_DIR / 'sam_segmentation_summary.csv', index=False)
print(f"\nüíæ Summary saved to: {RESULTS_DIR / 'sam_segmentation_summary.csv'}")

In [None]:
# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 1. F1 Score Comparison
methods = ['Original\nMH', 'Segmented\nMH (SAM)']
f1_scores = [original_metrics['f1'], segmented_metrics['f1']]
colors = ['#6C757D', '#28A745' if f1_improvement > 0 else '#DC3545']

bars = axes[0].bar(methods, f1_scores, color=colors, alpha=0.8, edgecolor='black', linewidth=2)
axes[0].set_ylabel('F1 Score', fontsize=12, fontweight='bold')
axes[0].set_title('F1 Score Comparison', fontsize=13, fontweight='bold')
axes[0].set_ylim([0, 1])
axes[0].grid(axis='y', alpha=0.3)

# Add values on bars
for bar, score in zip(bars, f1_scores):
    height = bar.get_height()
    axes[0].text(bar.get_x() + bar.get_width()/2., height + 0.02,
                f'{score:.4f}', ha='center', fontweight='bold', fontsize=11)

# 2. Confusion Matrix - Original
class_names = ['Frogeye', 'Healthy', 'Rust']  # Adjust based on actual order
sns.heatmap(original_cm, annot=True, fmt='d', cmap='Blues', ax=axes[1],
           xticklabels=class_names, yticklabels=class_names, cbar=True)
axes[1].set_title(f'Original MH\nF1: {original_metrics["f1"]:.4f}', fontsize=12, fontweight='bold')
axes[1].set_ylabel('True Label')
axes[1].set_xlabel('Predicted Label')

# 3. Confusion Matrix - Segmented
sns.heatmap(segmented_cm, annot=True, fmt='d', cmap='Greens', ax=axes[2],
           xticklabels=class_names, yticklabels=class_names, cbar=True)
axes[2].set_title(f'Segmented MH (SAM)\nF1: {segmented_metrics["f1"]:.4f}', fontsize=12, fontweight='bold')
axes[2].set_ylabel('True Label')
axes[2].set_xlabel('Predicted Label')

plt.tight_layout()
plt.savefig(RESULTS_DIR / 'sam_segmentation_results.png', dpi=300, bbox_inches='tight')
plt.show()

## Step 9: Analysis & Conclusions

In [None]:
print("\n" + "="*80)
print("KEY FINDINGS")
print("="*80)

print(f"\nüìä Performance Comparison:")
print(f"   Original MH:    F1 = {original_metrics['f1']:.4f}")
print(f"   Segmented MH:   F1 = {segmented_metrics['f1']:.4f}")
print(f"   Improvement:    {f1_improvement:+.4f} ({f1_improvement/baseline_f1*100:+.1f}%)")

print(f"\nüéØ Hypothesis Testing:")
if f1_improvement > 0.10:
    print("   ‚úÖ HYPOTHESIS STRONGLY CONFIRMED")
    print("   ‚Üí Background is a MAJOR source of domain shift")
    print("   ‚Üí SAM-based segmentation provides dramatic improvement")
    print("   ‚Üí Background removal should be a standard preprocessing step")
elif f1_improvement > 0.05:
    print("   ‚úÖ HYPOTHESIS CONFIRMED")
    print("   ‚Üí Background contributes significantly to domain shift")
    print("   ‚Üí SAM segmentation provides measurable benefit")
    print("   ‚Üí Consider combining with other techniques")
elif f1_improvement > 0:
    print("   ‚ö†Ô∏è HYPOTHESIS PARTIALLY CONFIRMED")
    print("   ‚Üí Background has modest impact on domain shift")
    print("   ‚Üí Other factors (lesion appearance, lighting) may dominate")
    print("   ‚Üí Consider hybrid approaches")
else:
    print("   ‚ùå HYPOTHESIS NOT CONFIRMED")
    print("   ‚Üí Background is NOT the primary confounder")
    print("   ‚Üí Domain shift is driven by other factors")
    print("   ‚Üí Focus on feature-level or semantic adaptation")

# Analyze confusion matrix changes
print(f"\nüìà Confusion Matrix Analysis:")
for i in range(len(class_names)):
    orig_correct = original_cm[i, i]
    seg_correct = segmented_cm[i, i]
    orig_total = original_cm[i, :].sum()
    seg_total = segmented_cm[i, :].sum()
    
    orig_acc = orig_correct / orig_total if orig_total > 0 else 0
    seg_acc = seg_correct / seg_total if seg_total > 0 else 0
    
    improvement = seg_acc - orig_acc
    
    print(f"   {class_names[i]}:")
    print(f"      Original: {orig_acc:.1%} ({orig_correct}/{orig_total})")
    print(f"      Segmented: {seg_acc:.1%} ({seg_correct}/{seg_total})")
    print(f"      Change: {improvement:+.1%}")

print(f"\nüí° Thesis Implications:")
if f1_improvement > 0.05:
    print("   ‚Ä¢ SAM-based preprocessing is effective for cross-domain transfer")
    print("   ‚Ä¢ Background removal should be standard in deployment pipelines")
    print("   ‚Ä¢ Consider combining with few-shot learning or input alignment")
    print(f"   ‚Ä¢ ROI: {f1_improvement:.4f} F1 gain for one-time segmentation cost")
else:
    print("   ‚Ä¢ Background is not the primary domain shift factor")
    print("   ‚Ä¢ Focus on disease-specific features and lesion appearance")
    print("   ‚Ä¢ Consider fine-tuning with labeled target data")

print(f"\nüî¨ Next Steps:")
if f1_improvement > 0.05:
    print("   1. Combine SAM segmentation with few-shot learning (notebook 09)")
    print("   2. Test SAM + Input alignment (notebook 08) hybrid")
    print("   3. Deploy SAM preprocessing in production pipeline")
    print("   4. Test on other domain pairs (UAV data, different locations)")
else:
    print("   1. Analyze failure cases: which images still fail after segmentation?")
    print("   2. Investigate lesion appearance differences (texture, color)")
    print("   3. Consider few-shot learning as primary solution")
    print("   4. Test domain-specific fine-tuning")

print("\n" + "="*80)

## Optional: Failure Case Analysis

Analyze images where segmentation didn't help or made things worse.

In [None]:
# Find cases where prediction changed from correct to incorrect
orig_correct = np.array(original_labels) == np.array(original_preds)
seg_correct = np.array(segmented_labels) == np.array(segmented_preds)

# Cases that got worse
worse_indices = np.where(orig_correct & ~seg_correct)[0]
print(f"\nFound {len(worse_indices)} cases where segmentation made predictions worse")

# Cases that improved
better_indices = np.where(~orig_correct & seg_correct)[0]
print(f"Found {len(better_indices)} cases where segmentation improved predictions")

print(f"\nNet improvement: {len(better_indices) - len(worse_indices)} images")