In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
from albumentations.pytorch import ToTensorV2
import albumentations as A


In [65]:
%%bash

filelist=("sam_vit_b_01ec64.pth" "sample_data")

pat=$(printf "^%s$" "${filelist[@]}")
pat=${pat:1}


ls | grep -Ev "$pat" | xargs rm -rf

In [66]:
!ls -R

.:


In [67]:
%%bash

# Variables
REPO_URL="https://github.com/LIMAMMohamedlimam/sammed-lite.git"
CLONE_DIR="temp_repo"
TARGET_DIR="./"
git clone "$REPO_URL" "$CLONE_DIR"

# Create target directory if it doesn't exist
mkdir -p "$TARGET_DIR"

# Copy all contents (including hidden files)
cp -r "$CLONE_DIR"/. "$TARGET_DIR"/

# Delete cloned repo directory
rm -rf "$CLONE_DIR"

echo "Done: copied repo content into $TARGET_DIR"

Done: copied repo content into ./


Cloning into 'temp_repo'...


In [68]:
!ls -R

.:
lite-sammed2d.py  SAMMed2D-lite.ipynb  segment_anything

./segment_anything:
automatic_mask_generator.py  __init__.py  predictor.py
build_sam.py		     modeling	  utils

./segment_anything/modeling:
common.py	  __init__.py	   prompt_encoder.py  transformer.py
image_encoder.py  mask_decoder.py  sam.py

./segment_anything/utils:
amg.py	__init__.py  onnx.py  transforms.py


In [69]:
from segment_anything.modeling import (
    Sam,
    ImageEncoderViT,
    MaskDecoder,
    PromptEncoder,
    TwoWayTransformer,
)

In [70]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


## Adapter layer implementation


### Custom Transformer Block  (Adapter_layer injected)

## Medical Image Dataset

In [None]:
import json 
import os 
import random
from skimage.measure import label, regionprops

def train_transforms(img_size, ori_h, ori_w):
    transforms = []
    if ori_h < img_size and ori_w < img_size:
        transforms.append(A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0)))
    else:
        transforms.append(A.Resize(int(img_size), int(img_size), interpolation=cv2.INTER_NEAREST))
    transforms.append(ToTensorV2(p=1.0))
    return A.Compose(transforms, p=1.)


def get_boxes_from_mask(mask, box_num=1, std = 0.1, max_pixel = 5):
    """
    Args:
        mask: Mask, can be a torch.Tensor or a numpy array of binary mask.
        box_num: Number of bounding boxes, default is 1.
        std: Standard deviation of the noise, default is 0.1.
        max_pixel: Maximum noise pixel value, default is 5.
    Returns:
        noise_boxes: Bounding boxes after noise perturbation, returned as a torch.Tensor.
    """
    if isinstance(mask, torch.Tensor):
        mask = mask.numpy()
        
    label_img = label(mask)
    regions = regionprops(label_img)

    # Iterate through all regions and get the bounding box coordinates
    boxes = [tuple(region.bbox) for region in regions]

    # If the generated number of boxes is greater than the number of categories,
    # sort them by region area and select the top n regions
    if len(boxes) >= box_num:
        sorted_regions = sorted(regions, key=lambda x: x.area, reverse=True)[:box_num]
        boxes = [tuple(region.bbox) for region in sorted_regions]

    # If the generated number of boxes is less than the number of categories,
    # duplicate the existing boxes
    elif len(boxes) < box_num:
        num_duplicates = box_num - len(boxes)
        boxes += [boxes[i % len(boxes)] for i in range(num_duplicates)]

    # Perturb each bounding box with noise
    noise_boxes = []
    for box in boxes:
        y0, x0,  y1, x1  = box
        width, height = abs(x1 - x0), abs(y1 - y0)
        # Calculate the standard deviation and maximum noise value
        noise_std = min(width, height) * std
        max_noise = min(max_pixel, int(noise_std * 5))
         # Add random noise to each coordinate
        try:
            noise_x = np.random.randint(-max_noise, max_noise)
        except:
            noise_x = 0
        try:
            noise_y = np.random.randint(-max_noise, max_noise)
        except:
            noise_y = 0
        x0, y0 = x0 + noise_x, y0 + noise_y
        x1, y1 = x1 + noise_x, y1 + noise_y
        noise_boxes.append((x0, y0, x1, y1))
    return torch.as_tensor(noise_boxes, dtype=torch.float)

class TrainingDataset(Dataset):
    def __init__(self, data_dir, image_size=256, mode='train', requires_name=True, point_num=1, mask_num=5):
        """
        Initializes a training dataset.
        Args:
            data_dir (str): Directory containing the dataset.
            image_size (int, optional): Desired size for the input images. Defaults to 256.
            mode (str, optional): Mode of the dataset. Defaults to 'train'.
            requires_name (bool, optional): Indicates whether to include image names in the output. Defaults to True.
            num_points (int, optional): Number of points to sample. Defaults to 1.
            num_masks (int, optional): Number of masks to sample. Defaults to 5.
        """
        self.image_size = image_size
        self.requires_name = requires_name
        self.point_num = point_num
        self.mask_num = mask_num
        self.pixel_mean = [123.675, 116.28, 103.53]
        self.pixel_std = [58.395, 57.12, 57.375]

        dataset = json.load(open(os.path.join(data_dir, f'image2label_{mode}.json'), "r"))
        for i in list(dataset.keys()):
            self.image_paths.append(i.replace("data_demo/images/" , "images_dir/")) 
        self.label_paths = list(dataset.values())
    
    def __getitem__(self, index):
        """
        Returns a sample from the dataset.
        Args:
            index (int): Index of the sample.
        Returns:
            dict: A dictionary containing the sample data.
        """

        image_input = {}
        try:
            image = cv2.imread(self.image_paths[index])
            image = (image - self.pixel_mean) / self.pixel_std
            print("ok")
        except ():
            print("not ok")
            print(self.image_paths[index])

        h, w, _ = image.shape
        transforms = train_transforms(self.image_size, h, w)
    
        masks_list = []
        boxes_list = []
        # point_coords_list, point_labels_list = [], []
        mask_path = random.choices(self.label_paths[index], k=self.mask_num)
        for m in mask_path:
            pre_mask = cv2.imread(m, 0)
            if pre_mask.max() == 255:
                pre_mask = pre_mask / 255

            augments = transforms(image=image, mask=pre_mask)
            image_tensor, mask_tensor = augments['image'], augments['mask'].to(torch.int64)

            boxes = get_boxes_from_mask(mask_tensor)
            # point_coords, point_label = init_point_sampling(mask_tensor, self.point_num)

            masks_list.append(mask_tensor)
            boxes_list.append(boxes)
            # point_coords_list.append(point_coords)
            # point_labels_list.append(point_label)

        mask = torch.stack(masks_list, dim=0)
        boxes = torch.stack(boxes_list, dim=0)
        # point_coords = torch.stack(point_coords_list, dim=0)
        # point_labels = torch.stack(point_labels_list, dim=0)

        image_input["image"] = image_tensor.unsqueeze(0)
        image_input["label"] = mask.unsqueeze(1)
        image_input["boxes"] = boxes
        # image_input["point_coords"] = point_coords
        # image_input["point_labels"] = point_labels

        image_name = self.image_paths[index].split('/')[-1]
        if self.requires_name:
            image_input["name"] = image_name
            return image_input
        else:
            return image_input
    def __len__(self):
        return len(self.image_paths)

In [None]:
# test for medicatdatasetclass

dataset = MedicalDataset("./dataset/image_dir","./dataset/mask_dir")

for i in range (dataset.__len__) :
    print(dataset.__getitem__(i))

# later

In [71]:
# ! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
! [ ! -f sam_vit_b_* ] &&  wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
!ls -R

--2025-11-23 13:45:34--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.171.22.13, 3.171.22.118, 3.171.22.33, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.171.22.13|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 375042383 (358M) [binary/octet-stream]
Saving to: ‘sam_vit_b_01ec64.pth’


2025-11-23 13:45:36 (267 MB/s) - ‘sam_vit_b_01ec64.pth’ saved [375042383/375042383]

.:
lite-sammed2d.py  SAMMed2D-lite.ipynb  sam_vit_b_01ec64.pth  segment_anything

./segment_anything:
automatic_mask_generator.py  __init__.py  predictor.py
build_sam.py		     modeling	  utils

./segment_anything/modeling:
common.py	  __init__.py	   prompt_encoder.py  transformer.py
image_encoder.py  mask_decoder.py  sam.py

./segment_anything/utils:
amg.py	__init__.py  onnx.py  transforms.py


In [72]:
sam_checkpoint = "sam_vit_b_01ec64.pth"

In [73]:
! cat segment_anything/build_sam.py

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch

from functools import partial

from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
from .modeling.image_encoder import Adapted_Block 
from torch.nn import functional as F

def build_sam_vit_h(checkpoint=None):
    return _build_sam(
        encoder_embed_dim=1280,
        encoder_depth=32,
        encoder_num_heads=16,
        encoder_global_attn_indexes=[7, 15, 23, 31],
        checkpoint=checkpoint,
    )


build_sam = build_sam_vit_h


def build_sam_vit_l(checkpoint=None):
    return _build_sam(
        encoder_embed_dim=1024,
        encoder_depth=24,
        encoder_num_heads=16,
        encoder_global_attn_indexes=[5, 11, 17, 23],
        checkpoint=checkpoint,
    )


def build_sam_vit_b(checkpoint=None):
    return _build_sam(
     

In [74]:
# test
from segment_anything.build_sam import sam_model_registry
sam = sam_model_registry["vit_b"](checkpoint=sam_checkpoint)



TypeError: _build_sam() missing 1 required positional argument: 'transformer_block'

In [None]:
for name,parm in sam.image_encoder.named_parameters() : 
    # if "Adapter" in name : 
        print(f"name : {name}")


## Model implementation SAM-Med2D-Lite

In [None]:
class SAMMed2DLite(nn.Module):
    def __init__(
        self,
        sam_model : Sam,
        embed_dim=768,
    ):
        
        """Lite version of SAM-Med2D with adapter layers for medical image segmentation."""
        super().__init__()
        self.sam = sam_model

        #Freezing image encoder params except the adapter layer
        for name , params in self.sam.image_encoder.named_parameters() :
            if "adapter" in name.lower():
                params.requires_grad = True
            else : 
                params.requires_grad = False
            
        # Fine-tune prompt encoder and mask decoder
        for param in self.sam.prompt_encoder.parameters():
            param.requires_grad = True
        for param in self.sam.mask_decoder.parameters():
            param.requires_grad = True

    def forward (self,images,boxes) :
        """
        Args : 
            images : [B, 3, H, W]
            boxes : [B, 4] in xyxy format
        """

        batch_size = images.shape[0]

        #image encoding 
        image_embeddings = self._encode_with_adapters(images)

        # Prepare prompts (boxes)
        sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(
            points=None,
            boxes=boxes,
            masks=None,
        )
        
        # Decode masks
        low_res_masks, iou_predictions = self.sam.mask_decoder(
            image_embeddings=image_embeddings,
            image_pe=self.sam.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
        )
        
        # Upscale masks
        masks = F.interpolate(
            low_res_masks,
            size=(images.shape[2], images.shape[3]),
            mode='bilinear',
            align_corners=False
        )
        
        return masks, iou_predictions

    def _encode_with_adapters(self, x):
        """Image encoding with adapter injection"""
        x = self.sam.image_encoder.patch_embed(x)
        
        # Add positional encoding
        if self.sam.image_encoder.pos_embed is not None:
            x = x + self.sam.image_encoder.pos_embed
        
        # Pass through transformer blocks with adapters
        for i, block in enumerate(self.sam.image_encoder.blocks):
            x = block(x)
            x = self.adapters[i](x)  # Apply adapter
        
        x = self.sam.image_encoder.neck(x.permute(0, 3, 1, 2))
        return x



Loss Function

In [None]:
def dice_loss(pred, target, smooth=1.0):
    """
    Dice loss for binary segmentation.
    Args:
        pred: logits (B, 1, H, W)
        target: binary mask (B, 1, H, W)
    """
    pred = torch.sigmoid(pred)
    intersection = (pred * target).sum(dim=(2, 3))
    union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))

    dice_score = (2.0 * intersection + smooth) / (union + smooth)
    return 1.0 - dice_score.mean()


def focal_loss(pred, target, alpha=0.25, gamma=2.0):
    """
    Focal loss for handling class imbalance.
    Args:
        pred: logits (B, 1, H, W)
        target: binary mask (B, 1, H, W)
    """
    bce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
    pt = torch.exp(-bce)
    loss = alpha * (1 - pt) ** gamma * bce
    return loss.mean()


def iou_mse_loss(pred_mask, gt_mask, pred_iou):
    """
    MSE loss between predicted IoU score and true IoU.
    Args:
        pred_mask: logits (B, 1, H, W)
        gt_mask: ground-truth mask (B, 1, H, W)
        pred_iou: predicted IoU head output (B,)
    """
    pred_mask = torch.sigmoid(pred_mask)

    intersection = (pred_mask * gt_mask).sum(dim=(2, 3))
    union = pred_mask.sum(dim=(2, 3)) + gt_mask.sum(dim=(2, 3)) - intersection

    true_iou = (intersection + 1e-6) / (union + 1e-6)
    return F.mse_loss(pred_iou.squeeze(), true_iou)


def combined_loss(pred, target, focal_w=20.0, dice_w=1.0):
    """
    Hybrid loss combining Dice and Focal loss.
    Matches the paper ratio: 20 (Focal) : 1 (Dice)
    """
    d = dice_loss(pred, target)
    f = focal_loss(pred, target)
    return dice_w * d + focal_w * f


def total_loss_fn(pred_mask, gt_mask, pred_iou):
    """
    Full loss = Mask Loss (Focal+Dice) + IoU MSE loss.
    """
    mask_loss = combined_loss(pred_mask, gt_mask)
    iou_loss = iou_mse_loss(pred_mask, gt_mask, pred_iou)
    return mask_loss + iou_loss


### Evaluation 

In [None]:

def compute_dice_coefficient(pred, target, threshold=0.5):
    """Compute Dice coefficient"""
    pred = (torch.sigmoid(pred) > threshold).float()
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    
    if union == 0:
        return 1.0 if intersection == 0 else 0.0
    
    dice = (2.0 * intersection) / union
    return dice.item()

def compute_iou(pred, target, threshold=0.5):
    """Compute Intersection over Union (IoU)"""
    pred = (torch.sigmoid(pred) > threshold).float()
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    
    if union == 0:
        return 1.0 if intersection == 0 else 0.0
    
    iou = intersection / union
    return iou.item()

def evaluate_batch(model, dataloader, device):
    """Evaluate model on a dataset"""
    model.eval()
    total_dice = 0.0
    total_iou = 0.0
    num_samples = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            images = batch['image'].to(device)
            masks_gt = batch['mask'].to(device)
            boxes = batch['bbox'].to(device)
            
            # Forward pass
            masks_pred, _ = model(images, boxes)
            
            # Compute metrics
            for i in range(images.shape[0]):
                dice = compute_dice_coefficient(masks_pred[i], masks_gt[i])
                iou = compute_iou(masks_pred[i], masks_gt[i])
                total_dice += dice
                total_iou += iou
                num_samples += 1
    
    avg_dice = total_dice / num_samples
    avg_iou = total_iou / num_samples
    
    return {'dice': avg_dice, 'iou': avg_iou}


# Training

In [None]:

def train_epoch(model, dataloader, optimizer, device, epoch):
    """Train for one epoch"""
    model.train()
    total_loss = 0.0
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
    for batch_idx, batch in enumerate(pbar):
        images = batch['image'].to(device)
        masks_gt = batch['mask'].to(device)
        boxes = batch['bbox'].to(device)
        
        # Forward pass
        masks_pred, iou_pred = model(images, boxes)
        
        # Compute loss
        loss = combined_loss(masks_pred, masks_gt)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update progress
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / len(dataloader)
    return avg_loss

def train_model(
    model, 
    train_loader, 
    val_loader, 
    num_epochs=50,
    learning_rate=1e-4,
    save_dir='checkpoints'
):
    """Complete training pipeline"""
    Path(save_dir).mkdir(exist_ok=True)
    
    # Optimizer (only trainable parameters)
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=learning_rate,
        weight_decay=0.01
    )
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs
    )
    
    best_dice = 0.0
    history = {'train_loss': [], 'val_dice': [], 'val_iou': []}
    
    for epoch in range(1, num_epochs + 1):
        # Train
        train_loss = train_epoch(model, train_loader, optimizer, device, epoch)
        history['train_loss'].append(train_loss)
        
        # Validate
        val_metrics = evaluate_batch(model, val_loader, device)
        history['val_dice'].append(val_metrics['dice'])
        history['val_iou'].append(val_metrics['iou'])
        
        # Learning rate step
        scheduler.step()
        
        # Print epoch summary
        print(f"\nEpoch {epoch}/{num_epochs}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Dice: {val_metrics['dice']:.4f}")
        print(f"Val IoU: {val_metrics['iou']:.4f}")
        
        # Save best model
        if val_metrics['dice'] > best_dice:
            best_dice = val_metrics['dice']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'dice': best_dice,
            }, f"{save_dir}/best_model.pth")
            print(f"Saved best model with Dice: {best_dice:.4f}")
    
    return history

In [None]:

from typing import List, Optional

class InferencePredictor:
    """Simple inference wrapper"""
    def __init__(self, model, device, image_size=256):
        self.model = model
        self.device = device
        self.image_size = image_size
        
        self.transform = A.Compose([
            A.Resize(image_size, image_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    
    def predict(self, image_path: str, bbox: List[int], threshold=0.5):
        """
        Predict segmentation mask
        
        Args:
            image_path: Path to input image
            bbox: Bounding box [x1, y1, x2, y2]
            threshold: Prediction threshold
        """
        self.model.eval()
        
        # Load and preprocess image
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        original_size = image.shape[:2]
        
        transformed = self.transform(image=image)
        image_tensor = transformed['image'].unsqueeze(0).to(self.device)
        bbox_tensor = torch.tensor([bbox], dtype=torch.float32).to(self.device)
        
        # Predict
        with torch.no_grad():
            mask_pred, iou_pred = self.model(image_tensor, bbox_tensor)
            mask_pred = torch.sigmoid(mask_pred[0, 0])
        
        # Post-process
        mask = (mask_pred > threshold).cpu().numpy().astype(np.uint8)
        mask = cv2.resize(mask, (original_size[1], original_size[0]), 
                         interpolation=cv2.INTER_NEAREST)
        
        return mask, iou_pred.item()
    
    def visualize_prediction(self, image_path: str, bbox: List[int], 
                            save_path: Optional[str] = None):
        """Visualize prediction result"""
        # Predict
        mask, iou = self.predict(image_path, bbox)
        
        # Load original image
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Create visualization
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Original image with bbox
        axes[0].imshow(image)
        rect = plt.Rectangle((bbox[0], bbox[1]), bbox[2]-bbox[0], bbox[3]-bbox[1],
                             fill=False, color='red', linewidth=2)
        axes[0].add_patch(rect)
        axes[0].set_title('Input + BBox')
        axes[0].axis('off')
        
        # Predicted mask
        axes[1].imshow(mask, cmap='gray')
        axes[1].set_title(f'Predicted Mask (IoU: {iou:.3f})')
        axes[1].axis('off')
        
        # Overlay
        overlay = image.copy()
        overlay[mask > 0] = [255, 0, 0]
        blended = cv2.addWeighted(image, 0.7, overlay, 0.3, 0)
        axes[2].imshow(blended)
        axes[2].set_title('Overlay')
        axes[2].axis('off')
        
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.show()
