# YoloSAM: Medical Scar Detection & Segmentation Tutorial

Welcome to YoloSAM! This notebook will guide you through:
- Installation and setup
- Data preparation
- Training YOLO and SAM models
- Making predictions
- Evaluating results


In [None]:
# Install YoloSAM
!git clone https://github.com/Danialmoa/YoloSAM
%cd YoloSAM
!pip install -e . 

# Data preparation
Data should be in the following format:

```
data/
├── train/
│   ├── images/
│   │   ├── image1.jpg
│   │   ├── image2.jpg
│   │   └── ...
│   ├── masks/
│   │   ├── image1.jpg
│   │   ├── image2.jpg
│   │   └── ...
├── val/
│   ├── images/
│   │   ├── image1.jpg
│   │   ├── image2.jpg
│   │   └── ...
│   ├── masks/
│   │   ├── image1.jpg
│   │   ├── image2.jpg
│   │   └── ...
```

In [None]:
# Finetune model
# Training YOLO model (YOLO11n)

# Convert masks to YOLO labels in the same folder structure.
from utils.mask_to_yolo import MaskToYOLOConverter
from utils.config import YOLOConfig
from scripts.train_yolo import YOLOTrainer


converter = MaskToYOLOConverter(class_id=0)  # 0 for 'scar' class

# Convert your dataset in place
converter.convert_dataset_inplace(
    base_path='../sample_data',
    min_area=5,  # Minimum area threshold (adjust as needed)
    splits=['train', 'val']
)

config = YOLOConfig(
    # Model settings
    model_type="yolo11n",
    device="cpu",
    pretrained_path="../checkpoints",
    
    # Dataset paths
    dataset_path="../sample_data/train",
    val_dataset_path="../sample_data/val",
    class_names=['scar'],
    
    # Training parameters
    epochs=10,
    batch_size=16,
    image_size=640,
    patience=50,
    
    # Augmentation (optimized for medical scars)
    mosaic=0.9,
    mixup=0.1,
    copy_paste=0.4,
    degrees=15.0,
    hsv_v=0.3,
    
    # Detection parameters
    iou_threshold=0.2,
    conf_threshold=0.15,
    max_detections=2,
    
    # Project settings
    project_name="yolo_scar_detection",
    experiment_name="enhanced_scar_detection",
    
    # Wandb settings
    wandb_project="YOLO-scar-detection",
    wandb_name="scar_detection_v1",
    wandb_mode="disabled"  # Set to "online" to enable wandb logging
)
    
# Create trainer
trainer = YOLOTrainer(config)

# Train the model
results = trainer.train()


## FineTune SAM model 
Two ways to finetune SAM model

1. Using Yolo prompt
2. Using ground truth mask


In [None]:
# FineTune SAM model With Yolo prompt

from scripts.train_sam import TrainSAM
from utils.dataset import SAMDataset
from utils.config import SAMFinetuneConfig, SAMDatasetConfig

finetune_config = SAMFinetuneConfig(
        device='cpu',
        wandb_project='SAM_finetune',
        wandb_name='test_run',
        model_type='vit_b',
        sam_path='../checkpoints/sam_vit_b_01ec64.pth',
        num_epochs=1,
        batch_size=2,
        learning_rate=1e-5,
        weight_decay=1e-4,
        lambda_bce=0.2, 
        lambda_kl=0.2,
        sigma=1,
        wandb_mode='disabled',
        num_workers=0
    )

train_dataset_config = SAMDatasetConfig(
    dataset_path='../sample_data/train/',
    remove_nonscar=True,
    sample_size=2,
    point_prompt=False, # -> If True, Random generation of points base on the mask
    box_prompt=False, # -> If True, box prompt is generated based on the mask
    enable_direction_aug=False, # -> If True, direction augmentation is enabled
    enable_size_aug=False, # -> If True, size augmentation is enabled
    yolo_prompt=True, # -> If True, yolo prompt is generated based on the mask
    yolo_model_path='../checkpoints/yolo11n.pt', # -> Path to the yolo model
    yolo_conf_threshold=0.25, # -> Confidence threshold for yolo
    yolo_iou_threshold=0.45, # -> IoU threshold for yolo
    yolo_imgsz=640, # -> Image size for yolo
    image_size=1024, 
    train=True
)

val_dataset_config = SAMDatasetConfig(
    dataset_path='../sample_data/val/',
    remove_nonscar=True,
    sample_size=2,
    point_prompt=False,
    box_prompt=False,
    yolo_prompt=True,
    yolo_model_path='../checkpoints/yolo11n.pt',
    yolo_conf_threshold=0.25,
    yolo_iou_threshold=0.45,
    yolo_imgsz=640,
    image_size=1024,
    train=False
)

train_dataset = SAMDataset(train_dataset_config)
val_dataset = SAMDataset(val_dataset_config)

trainer = TrainSAM(finetune_config, train_dataset, val_dataset)
trainer.train(finetune_config.num_epochs)

In [None]:
# FineTune SAM model with ground truth mask

from scripts.train_sam import TrainSAM
from utils.dataset import SAMDataset
from utils.config import SAMFinetuneConfig, SAMDatasetConfig

finetune_config = SAMFinetuneConfig(
        device='cpu',
        wandb_project='SAM_finetune',
        wandb_name='test_run',
        model_type='vit_b',
        sam_path='../checkpoints/sam_vit_b_01ec64.pth',
        num_epochs=1,
        batch_size=2,
        learning_rate=1e-5,
        weight_decay=1e-4,
        lambda_bce=0.2, 
        lambda_kl=0.2,
        sigma=1,
        wandb_mode='disabled',
        num_workers=0
    )

train_dataset_config = SAMDatasetConfig(
    dataset_path='../sample_data/train/',
    remove_nonscar=True, # -> If True, remove non-scar images
    sample_size=2,
    point_prompt=True, # -> If True, Random generation of points base on the mask
    point_prompt_types=['positive'], # -> Types of points to generate (Negative, Positive)
    num_points=3, # -> Number of points to generate
    box_prompt=True, # -> If True, box prompt is generated based on the mask
    enable_direction_aug=True, # -> If True, direction augmentation is enabled
    enable_size_aug=True, # -> If True, size augmentation is enabled
    yolo_prompt=False, # -> If True, yolo prompt is generated based on the mask
    image_size=1024, 
    train=True
)

val_dataset_config = SAMDatasetConfig(
    dataset_path='../sample_data/val/',
    remove_nonscar=True,
    sample_size=2,
    point_prompt=True,
    point_prompt_types=['positive'],
    num_points=3,
    box_prompt=True,
    enable_direction_aug=False,
    enable_size_aug=False,
    image_size=1024,
    train=False
)

train_dataset = SAMDataset(train_dataset_config)
val_dataset = SAMDataset(val_dataset_config)

trainer = TrainSAM(finetune_config, train_dataset, val_dataset)
trainer.train(finetune_config.num_epochs)

# Inference Model

In [None]:
from scripts.inference import YoloSAMInference
from utils.config import YoloSAMInferenceConfig
import matplotlib.pyplot as plt

config = YoloSAMInferenceConfig(
    yolo_checkpoint_path="../runs/yolo_scar_detection2/weights/best.pt",
    sam_checkpoint_path="../checkpoints/sam_vit_b_01ec64.pth",
    device="cpu"
)

inference_pipeline = YoloSAMInference(config=config)

image_path = "../sample_data/val/images/Case_P004_slice_01.png"

results = inference_pipeline.predict(image_path)
output_image = inference_pipeline.visualize_results(results)

plt.figure(figsize=(12, 8))
plt.imshow(output_image)
plt.title("YoloSAM Inference Results")
plt.axis('off')
plt.show()