In [1]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt



In [2]:
%%bash

filelist=("sam_vit_b_01ec64.pth" "sample_data")

pat=$(printf "^%s$" "${filelist[@]}")
pat=${pat:1}


ls | grep -Ev "$pat" | xargs rm -rf

In [3]:
!ls -R

.:


In [4]:
%%bash

# Variables
REPO_URL="https://github.com/LIMAMMohamedlimam/sammed-lite.git"
CLONE_DIR="temp_repo"
TARGET_DIR="./"
git clone "$REPO_URL" "$CLONE_DIR"

# Create target directory if it doesn't exist
mkdir -p "$TARGET_DIR"

# Copy all contents (including hidden files)
cp -r "$CLONE_DIR"/. "$TARGET_DIR"/

# Delete cloned repo directory
rm -rf "$CLONE_DIR"

echo "Done: copied repo content into $TARGET_DIR"

Done: copied repo content into ./


Cloning into 'temp_repo'...


In [5]:
!ls -R

.:
lite-sammed2d.py  SAMMed2D-lite.ipynb  segment_anything

./segment_anything:
automatic_mask_generator.py  __init__.py  predictor.py
build_sam.py		     modeling	  utils

./segment_anything/modeling:
common.py	  __init__.py	   prompt_encoder.py  transformer.py
image_encoder.py  mask_decoder.py  sam.py

./segment_anything/utils:
amg.py	__init__.py  onnx.py  transforms.py


In [7]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


## Loading SAM model

In [None]:
from SAMMed2DLite import SAMMed2DLite

In [None]:
! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

In [None]:
from segment_anything import sam_model_registry

sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"

# Load base SAM
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device)

In [None]:
model = SAMMed2DLite(sam_model=sam).to(device)
print(f"Model loaded with {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters")

## Datasets and Dataloaders

In [None]:
from DataLoader import DatasetLoader

In [None]:
data_dir = "drive/MyDrive/SAM-Med2D-Mini/data"
train_dataset = DatasetLoader(data_dir=data_dir)
test_dataset = DatasetLoader(data_dir=data_dir , mode=0)

In [None]:
train_loader = DataLoader(
    train_dataset, 
    batch_size=4,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)


In [None]:
from train import train_model
#Train

history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=test_loader,
    num_epochs=50,
    learning_rate=1e-4,
    save_dir='checkpoints'
)

## Evaluation

In [None]:
from utils import evaluate_batch
# Load best checkpoint
checkpoint = torch.load('checkpoints/best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# Evaluate on test set
test_dataset = DatasetLoader(
    image_dir='data/test/images',
    mask_dir='data/test/masks',
    image_size=256,
    train=False
)

test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# Run evaluation
test_metrics = evaluate_batch(model, test_loader, device)

print("\n=== Test Results ===")
print(f"Dice Coefficient: {test_metrics['dice']:.4f}")
print(f"IoU: {test_metrics['iou']:.4f}")


## training history plot

In [None]:

def plot_training_history(history, save_path='training_curves.png'):
    """Visualize training progress"""
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Loss curve
    axes[0].plot(history['train_loss'], label='Train Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # Metrics curves
    axes[1].plot(history['val_dice'], label='Dice', marker='o')
    axes[1].plot(history['val_iou'], label='IoU', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Score')
    axes[1].set_title('Validation Metrics')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
plot_training_history(history=history)