# Fine tune a checkpoint of MedSAM on point prompted Data

Given a path to a MedSAM checkpoint, we want to fine tune it on pre-processed data
(subject to modifications specified by the paper and the transformation script). This will
be done initially on an anatomy-specific level.

In [164]:
# Imports
import re
import os
import cv2
import sys
import torch
import random
import argparse
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


# Add the setup_data_vars function as we will need it to find the directory for the training data.
dir1 = os.path.abspath(os.path.join(os.path.abspath(''), '..', '..'))
if not dir1 in sys.path: sys.path.append(dir1)

from utils.environment import setup_data_vars
setup_data_vars()

In [54]:
parser = argparse.ArgumentParser()

# Inspired by orginal code from the MedSAM/extensions/point_prompt

# 1. Add the anatomy on which we will fine-tune
parser.add_argument(
    '--anatomy',
    type=str,
    help='Anatomy on which to fine-tune the model. Note: this is case sensitive, please capitalize the first letter and accronyms such as CTVn or CTVp.',
    required=True
)

# 2. Path to the MedSAM checkpoint
parser.add_argument(
    '--checkpoint',
    type=str,
    help='Path to the checkpoint of the model to fine-tune',
    required=True
)

# 3. Path where we will be saving the checkpoints of the fine-tuned model
parser.add_argument(
    '--save_dir',
    type=str,
    help='Directory where the fine-tuned model will be saved',
    required=True
)

# 4. Add the source directory for the data
parser.add_argument(
    '--img_dir',
    type=str,
    help='Directory containing the images for the slices of the anatomy',
    required=False,
)

# 5. Add the source directory for the gts
parser.add_argument(
    '--gt_dir',
    type=str,
    help='Directory containing the ground truth masks for the slices of the anatomy',
    required=False
)

# 6. Number of epochs for the fine-tuning
parser.add_argument(
    '--epochs',
    type=int,
    help='Number of epochs for the fine-tuning',
    required=False,
    default=300
)

# 7. Batch size for the fine-tuning
parser.add_argument(
    '--batch_size',
    type=int,
    help='Batch size for the fine-tuning',
    required=False,
    default=4
)

# 8. Learning rate for the fine-tuning
parser.add_argument(
    '--lr',
    type=float,
    help='Learning rate for the fine-tuning',
    required=False,
    default=0.00005
)

# 9. Number of workers for the data loader
parser.add_argument(
    '--num_workers',
    type=int,
    help='Number of workers for the data loader',
    required=False,
    default=4
)

# 10. Resume checkpoint
parser.add_argument(
    '--resume',
    type=bool,
    help='Whether to resume training using the latest checkpoint in the save_dir',
    required=False,
    default=True
)

_StoreAction(option_strings=['--resume'], dest='resume', nargs=None, const=None, default=True, type=<class 'bool'>, choices=None, required=False, help='Whether to resume training using the latest checkpoint in the save_dir', metavar=None)

In [55]:
# args = parser.parse_args()
# Suppose for now we get the following set of required arguments:
args = parser.parse_args([
    '--anatomy', 'CTVn',
    '--checkpoint', os.path.join(os.environ['PROJECT_DIR'], 'models', 'work_dir', 'medsam_vit_b.pth'),
    '--save_dir', os.path.join(os.environ['MedSAM_finetuned'], 'CTVn')
])

## Set up the vars

In [56]:
anatomy = args.anatomy
checkpoint_path = args.checkpoint
save_dir = args.save_dir
img_dir = args.img_dir
gt_dir = args.gt_dir
epochs = args.epochs
batch_size = args.batch_size
lr = args.lr
num_workers = args.num_workers
resume = args.resume

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [57]:
assert anatomy in ['CTVn', 'CTVp', 'Bladder', 'Anorectum', 'Uterus', 'Vagina']

if img_dir is None:
    img_dir = os.path.join(os.environ['MedSAM_preprocessed'], 'imgs')
if gt_dir is None:
    gt_dir = os.path.join(os.environ['MedSAM_preprocessed'], 'gts', anatomy)

assert os.path.exists(img_dir), 'Image Directory doesn\'t exist.'
assert os.path.exists(gt_dir), 'Ground Truth Directory doesn\'t exist for the requested anatomy.'

In [58]:
seed = 42

torch.cuda.empty_cache()
os.environ['PYTHONHASHSEED']=str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

## Set up Dataset and (Train, Validation) split

In [91]:
# get the image ids that have been processed. Use gt dir as reference
image_id_from_file_name_regex = r'.*_(\d+).*'
slice_id_from_file_name_regex = r'.*-(\d+).*'

axis0_slices = set(map(lambda x : int(re.search(image_id_from_file_name_regex, x).group(1)), os.listdir(os.path.join(gt_dir, 'axis0'))))
axis1_slices = set(map(lambda x : int(re.search(image_id_from_file_name_regex, x).group(1)), os.listdir(os.path.join(gt_dir, 'axis1'))))
axis2_slices = set(map(lambda x : int(re.search(image_id_from_file_name_regex, x).group(1)), os.listdir(os.path.join(gt_dir, 'axis2'))))

if not axis0_slices == axis1_slices == axis2_slices:
    print('[WARNING]: The slices for the anatomy are not consistent across the three axes. Some axese are missing data, please check')

training_split = 0.8
validation_split = 1 - training_split

# Split the data into training and validation
training_image_ids = random.sample(list(axis0_slices), int(len(axis0_slices) * training_split))
validation_image_ids = list(set(axis0_slices) - set(training_image_ids))
assert set.intersection(set(training_image_ids), set(validation_image_ids)).__len__() == 0, 'Training and Validation sets are not disjoint'

In [153]:
# Adapted Dataset class from ../2_no_finetuning/MEDSAM_helper_functions.py
class SAM_Dataset(Dataset):
    """A torch dataset for delivering slices of any axis to a medsam model."""

    def __init__(self, img_path, gt_path, id_split, data_aug=False):
        """
        Args:
            img_path (string): Path to the directory containing the images
            gt_path (string): Path to the directory containing the ground truth masks
            id_split (list): List of image ids to include in the dataset
        """

        self.root_img_path = img_path
        self.root_gt_path = gt_path
        self.id_split = id_split
        self.data_aug = data_aug
        
        # Assume that axese 0 1 and 2 have been processed.
        filter_fn = lambda x : x.endswith('.npy') and int(re.search(image_id_from_file_name_regex, x).group(1)) in id_split
        self.axis0_imgs = list(filter(filter_fn, os.listdir(os.path.join(gt_path, 'axis0'))))
        self.axis1_imgs = list(filter(filter_fn, os.listdir(os.path.join(gt_path, 'axis1'))))
        self.axis2_imgs = list(filter(filter_fn, os.listdir(os.path.join(gt_path, 'axis2'))))

    def __len__(self):
        return len(self.axis0_imgs) + len(self.axis1_imgs) + len(self.axis2_imgs)

    def __getitem__(self, idx):
        assert 0 <= idx < self.__len__(), f"Index {idx} is out of range for dataset of size {self.__len__()}"

        # Fetch the image and ground truth mask. For safety, we index the items around the
        # ground truth masks, so that if for some reason the images are misaligned we will
        # guarantee that we will fetch the correct image

        if idx < len(self.axis0_imgs):
            axis, gt_name = 0, self.axis0_imgs[idx]
        elif idx < len(self.axis0_imgs) + len(self.axis1_imgs):
            axis, gt_name = 1, self.axis1_imgs[idx - len(self.axis0_imgs)]
        else:
            axis, gt_name = 2, self.axis2_imgs[idx - len(self.axis0_imgs) - len(self.axis1_imgs)]

        image_id = int(re.search(image_id_from_file_name_regex, gt_name).group(1))
        slice_id = int(re.search(slice_id_from_file_name_regex, gt_name).group(1))

        img_name = f'CT_zzAMLART_{image_id:03d}-{slice_id:03d}.npy'
        
        # Load the image and ground truth mask

        img_path = os.path.join(self.root_img_path, f'axis{axis}', img_name)
        gt_path = os.path.join(self.root_gt_path, f'axis{axis}', gt_name)

        img = np.load(img_path, 'r', allow_pickle=True) # (H, W, C)
        gt = np.load(gt_path, 'r', allow_pickle=True) # (H, W, C)

        # Pre-process where necessary

        img = np.transpose(img, (2, 0, 1)) # (C, H, W)
        assert np.max(img) <= 1. and np.min(img) >= 0., 'image should be normalized to [0, 1]'

        # add data augmentation: random fliplr and random flipud
        if self.data_aug:
            if random.random() > 0.5:
                img = np.ascontiguousarray(np.flip(img, axis=-1))
                gt = np.ascontiguousarray(np.flip(gt, axis=-1))
            if random.random() > 0.5:
                img = np.ascontiguousarray(np.flip(img, axis=-2))
                gt = np.ascontiguousarray(np.flip(gt, axis=-2))
        
        gt = np.uint8(gt > 0)
        y_indices, x_indices = np.where(gt > 0)
        x_point = np.random.choice(x_indices)
        y_point = np.random.choice(y_indices)
        coords = np.array([x_point, y_point])

        return {
            "image": torch.tensor(img).float(),
            "gt2D": torch.tensor(gt[None, :,:]).long(),
            "coords": torch.tensor(coords[None, ...]).float(),
            "image_name": img_name
        }

In [154]:
training_dataset = SAM_Dataset(img_dir, gt_dir, training_image_ids)
validation_dataset = SAM_Dataset(img_dir, gt_dir, validation_image_ids)

In [155]:
# Quick check
assert set(map(lambda x : int(re.search(image_id_from_file_name_regex, x).group(1)), validation_dataset.axis0_imgs)) == set(validation_image_ids), 'DataSet incorrectly loaded image ids that don\'t match supplied validation set image ids'
assert set(map(lambda x : int(re.search(image_id_from_file_name_regex, x).group(1)), training_dataset.axis0_imgs)) == set(training_image_ids), 'DataSet incorrectly loaded image ids that don\'t match supplied validation set image ids'

## Set up Fine-Tuning nn Module

In [165]:
class MedSAM(nn.Module):
    def __init__(self, 
                image_encoder, 
                mask_decoder,
                prompt_encoder,
                freeze_image_encoder=False,
                ):
        super().__init__()
        self.image_encoder = image_encoder
        self.mask_decoder = mask_decoder
        self.prompt_encoder = prompt_encoder

        # freeze prompt encoder
        for param in self.prompt_encoder.parameters():
            param.requires_grad = False
        
        self.freeze_image_encoder = freeze_image_encoder
        if self.freeze_image_encoder:
            for param in self.image_encoder.parameters():
                param.requires_grad = False

    def forward(self, image, point_prompt):

        # do not compute gradients for pretrained img encoder and prompt encoder
        with torch.no_grad():
            image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
            # not need to convert box to 1024x1024 grid
            # bbox is already in 1024x1024
            sparse_embeddings, dense_embeddings = self.prompt_encoder(
                points=point_prompt,
                boxes=None,
                masks=None,
            )
        low_res_masks, iou_predictions = self.mask_decoder(
            image_embeddings=image_embedding, # (B, 256, 64, 64)
            image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
            multimask_output=False,
          ) # (B, 1, 256, 256)

        return low_res_masks