# Fine tune a checkpoint of MedSAM on point prompted Data

Given a path to a MedSAM checkpoint, we want to fine tune it on pre-processed data
(subject to modifications specified by the paper and the transformation script). This will
be done initially on an anatomy-specific level.

In [1]:
# Imports
import re
import os
import cv2
import sys
import json
import torch
import monai
import random
import argparse
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
from time import time, sleep
from datetime import datetime
from matplotlib import pyplot as plt
from segment_anything import sam_model_registry
from torch.utils.data import Dataset, DataLoader

# Add the setup_data_vars function as we will need it to find the directory for the training data.
dir1 = os.path.abspath(os.path.join(os.path.abspath(''), '..', '..'))
if not dir1 in sys.path: sys.path.append(dir1)

from utils.environment import setup_data_vars
setup_data_vars()

In [2]:
parser = argparse.ArgumentParser()

# Inspired by orginal code from the MedSAM/extensions/point_prompt

# 1. Add the anatomy on which we will fine-tune
parser.add_argument(
    '--anatomy',
    type=str,
    help='Anatomy on which to fine-tune the model. Note: this is case sensitive, please capitalize the first letter and accronyms such as CTVn or CTVp.',
    required=True
)

# 2. Path to the MedSAM checkpoint
parser.add_argument(
    '--checkpoint',
    type=str,
    help='Path to the checkpoint of the model to fine-tune',
    required=True
)

# 3. Path where we will be saving the checkpoints of the fine-tuned model
parser.add_argument(
    '--save_dir',
    type=str,
    help='Directory where the fine-tuned model will be saved',
    required=True
)

# 4. Add the source directory for the data
parser.add_argument(
    '--img_dir',
    type=str,
    help='Directory containing the images for the slices of the anatomy',
    required=False,
)

# 5. Add the source directory for the gts
parser.add_argument(
    '--gt_dir',
    type=str,
    help='Directory containing the ground truth masks for the slices of the anatomy',
    required=False
)

# 6. Number of epochs for the fine-tuning
parser.add_argument(
    '--epochs',
    type=int,
    help='Number of epochs for the fine-tuning',
    required=False,
    default=300
)

# 7. Batch size for the fine-tuning
parser.add_argument(
    '--batch_size',
    type=int,
    help='Batch size for the fine-tuning',
    required=False,
    default=4
)

# 8. Learning rate for the fine-tuning
parser.add_argument(
    '--lr',
    type=float,
    help='Learning rate for the fine-tuning',
    required=False,
    default=0.00005
)

# 9. Number of workers for the data loader
parser.add_argument(
    '--num_workers',
    type=int,
    help='Number of workers for the data loader',
    required=False,
    default=4
)

parser.add_argument(
    '--weight_decay',
    type=float,
    help='Weight decay for the optimizer',
    required=False,
    default=0.01
)

# 11. Resume checkpoint
parser.add_argument(
    '--resume',
    type=bool,
    help='Whether to resume training using the latest checkpoint in the save_dir',
    required=False,
    default=True
)

_StoreAction(option_strings=['--resume'], dest='resume', nargs=None, const=None, default=True, type=<class 'bool'>, choices=None, required=False, help='Whether to resume training using the latest checkpoint in the save_dir', metavar=None)

In [3]:
# args = parser.parse_args()
# Suppose for now we get the following set of required arguments:
args = parser.parse_args([
    '--anatomy', 'CTVn',
    '--checkpoint', os.path.join(os.environ['PROJECT_DIR'], 'models', 'MedSAM', 'work_dir', 'MedSAM', 'medsam_vit_b.pth'),
    '--save_dir', os.path.join(os.environ['MedSAM_finetuned'], 'CTVn')
])

## Set up the vars

In [4]:
anatomy = args.anatomy
checkpoint_path = args.checkpoint
save_dir = args.save_dir
img_dir = args.img_dir
gt_dir = args.gt_dir
epochs = args.epochs
batch_size = args.batch_size
lr = args.lr
num_workers = args.num_workers
weight_decay = args.weight_decay
resume = args.resume

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
if img_dir is None:
    img_dir = os.path.join(os.environ['MedSAM_preprocessed'], 'imgs')
if gt_dir is None:
    gt_dir = os.path.join(os.environ['MedSAM_preprocessed'], 'gts', anatomy)

In [6]:
seed = 42

torch.cuda.empty_cache()
os.environ['PYTHONHASHSEED']=str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [7]:
image_id_from_file_name_regex = r'.*_(\d+).*'
slice_id_from_file_name_regex = r'.*-(\d+).*'

## Set up Dataset class

In [8]:
# Adapted Dataset class from ../2_no_finetuning/MEDSAM_helper_functions.py
class SAM_Dataset(Dataset):
    """A torch dataset for delivering slices of any axis to a medsam model."""

    def __init__(self, img_path, gt_path, id_split, data_aug=False):
        """
        Args:
            img_path (string): Path to the directory containing the images
            gt_path (string): Path to the directory containing the ground truth masks
            id_split (list): List of image ids to include in the dataset
        """

        self.root_img_path = img_path
        self.root_gt_path = gt_path
        self.id_split = id_split
        self.data_aug = data_aug
        
        # Assume that axese 0 1 and 2 have been processed.
        filter_fn = lambda x : x.endswith('.npy') and int(re.search(image_id_from_file_name_regex, x).group(1)) in id_split
        self.axis0_imgs = list(filter(filter_fn, os.listdir(os.path.join(gt_path, 'axis0'))))
        self.axis1_imgs = list(filter(filter_fn, os.listdir(os.path.join(gt_path, 'axis1'))))
        self.axis2_imgs = list(filter(filter_fn, os.listdir(os.path.join(gt_path, 'axis2'))))

    def __len__(self):
        return len(self.axis0_imgs) + len(self.axis1_imgs) + len(self.axis2_imgs)

    def __getitem__(self, idx):
        assert 0 <= idx < self.__len__(), f"Index {idx} is out of range for dataset of size {self.__len__()}"

        # Fetch the image and ground truth mask. For safety, we index the items around the
        # ground truth masks, so that if for some reason the images are misaligned we will
        # guarantee that we will fetch the correct image

        if idx < len(self.axis0_imgs):
            axis, gt_name = 0, self.axis0_imgs[idx]
        elif idx < len(self.axis0_imgs) + len(self.axis1_imgs):
            axis, gt_name = 1, self.axis1_imgs[idx - len(self.axis0_imgs)]
        else:
            axis, gt_name = 2, self.axis2_imgs[idx - len(self.axis0_imgs) - len(self.axis1_imgs)]

        image_id = int(re.search(image_id_from_file_name_regex, gt_name).group(1))
        slice_id = int(re.search(slice_id_from_file_name_regex, gt_name).group(1))

        img_name = f'CT_zzAMLART_{image_id:03d}-{slice_id:03d}.npy'
        
        # Load the image and ground truth mask

        img_path = os.path.join(self.root_img_path, f'axis{axis}', img_name)
        gt_path = os.path.join(self.root_gt_path, f'axis{axis}', gt_name)

        img = np.load(img_path, 'r', allow_pickle=True) # (H, W, C)
        gt = np.load(gt_path, 'r', allow_pickle=True) # (H, W, C)

        # Pre-process where necessary

        img = np.transpose(img, (2, 0, 1)) # (C, H, W)
        assert np.max(img) <= 1. and np.min(img) >= 0., 'image should be normalized to [0, 1]'

        # add data augmentation: random fliplr and random flipud
        if self.data_aug:
            if random.random() > 0.5:
                img = np.ascontiguousarray(np.flip(img, axis=-1))
                gt = np.ascontiguousarray(np.flip(gt, axis=-1))
            if random.random() > 0.5:
                img = np.ascontiguousarray(np.flip(img, axis=-2))
                gt = np.ascontiguousarray(np.flip(gt, axis=-2))
        
        gt = np.uint8(gt > 0)
        y_indices, x_indices = np.where(gt > 0)
        x_point = np.random.choice(x_indices)
        y_point = np.random.choice(y_indices)
        coords = np.array([x_point, y_point])

        return {
            "image": torch.tensor(img).float(),
            "gt2D": torch.tensor(gt[None, :,:]).long(),
            "coords": torch.tensor(coords[None, ...]).float(),
            "image_name": img_name
        }

## Set up Fine-Tuning nn Module

In [9]:
class MedSAM(nn.Module):
    def __init__(self, 
                image_encoder, 
                mask_decoder,
                prompt_encoder,
                freeze_image_encoder=False,
                ):
        super().__init__()
        self.image_encoder = image_encoder
        self.mask_decoder = mask_decoder
        self.prompt_encoder = prompt_encoder

        # freeze prompt encoder
        for param in self.prompt_encoder.parameters():
            param.requires_grad = False
        
        self.freeze_image_encoder = freeze_image_encoder
        if self.freeze_image_encoder:
            for param in self.image_encoder.parameters():
                param.requires_grad = False

    def forward(self, image, point_prompt):

        # do not compute gradients for pretrained img encoder and prompt encoder
        with torch.no_grad():
            image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
            # not need to convert box to 1024x1024 grid
            # bbox is already in 1024x1024
            sparse_embeddings, dense_embeddings = self.prompt_encoder(
                points=point_prompt,
                boxes=None,
                masks=None,
            )
        low_res_masks, iou_predictions = self.mask_decoder(
            image_embeddings=image_embedding, # (B, 256, 64, 64)
            image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
            multimask_output=False,
          ) # (B, 1, 256, 256)

        return low_res_masks

## Set up logger

In [10]:
from nnunetv2.training.logging.nnunet_logger import nnUNetLogger

class MedSAMLogger(nnUNetLogger):
    pass

# self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch)
# self.logger.log('train_losses', loss_here, self.current_epoch)
# self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch)
# self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch)
# self.logger.log('val_losses', loss_here, self.current_epoch)

## Main Training Loop

In [11]:
class MedSAMTrainer(object):
    def __init__(self
                , anatomy
                , checkpoint_path
                , save_dir
                , image_dir
                , gt_dir
                , epochs
                , batch_size
                , lr
                , num_workers
                , weight_decay
                , resume
                , device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
                , ):
        
        self.anatomy = anatomy
        assert self.anatomy in ['CTVn', 'CTVp', 'Bladder', 'Anorectum', 'Uterus', 'Vagina']
        
        self.image_dir = image_dir
        self.gt_dir = gt_dir
        assert os.path.exists(self.image_dir), 'Image Directory doesn\'t exist.'
        assert os.path.exists(self.gt_dir), 'Ground Truth Directory doesn\'t exist for the requested anatomy.'

        self.save_dir = save_dir
        self.checkpoint_path = checkpoint_path
        self.epochs = epochs
        self.batch_size = batch_size
        self.lr = lr
        self.num_workers = num_workers
        self.weight_decay = weight_decay
        self.resume = resume

    def run_training(self):
        """Main Training Loop"""

        self.on_train_start()

        for epoch in range(self.start_epoch, self.epochs):
            # <Training>
            self.on_epoch_start(epoch)

            pbar = tqdm(self.train_loader)
            for step, batch in enumerate(pbar):
                loss = self.train_step(step, batch)

                pbar.set_description(f"Epoch {epoch} at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, loss: {loss.item():.4f}")

            self.on_epoch_end()
            # </Training>

            # # <Validation>
            # with torch.no_grad():
            #     self.medsam_model.eval()
            #     for step, batch in enumerate(self.val_loader):
            #         image = batch["image"].to(self.device)
            #         gt2D = batch["gt2D"].to(self.device)
            #         coords_torch = batch["coords"].to(self.device)
            # # </Validation>

        self.on_train_end()

    def train_step(self, step, batch):
        # Get data
        image = batch["image"].to(self.device)
        gt2D = batch["gt2D"].to(self.device)
        coords_torch = batch["coords"].to(self.device) # (B, 2)

        self.optimizer.zero_grad()
        labels_torch = torch.ones(coords_torch.shape[0]).long() # (B,)
        labels_torch = labels_torch.unsqueeze(1) # (B, 1)
        coords_torch, labels_torch = coords_torch.to(device), labels_torch.to(device)
        point_prompt = (coords_torch, labels_torch)
        medsam_lite_pred = self.medsam_model(image, point_prompt)
        loss = self.seg_loss(medsam_lite_pred, gt2D) + self.ce_loss(medsam_lite_pred, gt2D.float())
        
        self.epoch_loss[step] = loss.item()
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        return loss


    def on_train_end(self):
        self.save_checkpoint(self.current_epoch, self.epoch_loss, final=True)
        torch.cuda.empty_cache()

    def on_epoch_end(self):
        # self.logger.log('epoch_end_timestamps', time(), self.current_epoch)

        self.epoch_end_time = time()
        self.epoch_time.append(self.epoch_end_time - self.epoch_start_time)

        self.save_checkpoint(self.current_epoch, self.epoch_loss)

        # Plot the progress

        # self.print_to_log_file('train_loss', np.round(self.logger.my_fantastic_logging['train_losses'][-1], decimals=4))
        # self.print_to_log_file('val_loss', np.round(self.logger.my_fantastic_logging['val_losses'][-1], decimals=4))
        # self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in
        #                                        self.logger.my_fantastic_logging['dice_per_class_or_region'][-1]])
        # self.print_to_log_file(
        #     f"Epoch time: {np.round(self.logger.my_fantastic_logging['epoch_end_timestamps'][-1] - self.logger.my_fantastic_logging['epoch_start_timestamps'][-1], decimals=2)} s")

        # self.logger.plot_progress_png(self.save_dir)

    def save_checkpoint(self, epoch, epoch_loss, final=False):
        # If the checkpoint is better than the last, save it as 'checkpoint_best.pth' otherwise, save it as checkpoint_latest.pth

        epoch_loss_reduced = epoch_loss_reduced = sum(epoch_loss) / len(epoch_loss)

        checkpoint = {
            "model": self.medsam_model.state_dict(),
            "epochs": epoch,
            "optimizer": self.optimizer.state_dict(),
            "loss": epoch_loss_reduced,
            "best_loss": self.best_loss
        }

        if epoch_loss_reduced < self.best_loss:
            self.best_loss = epoch_loss_reduced
            torch.save(checkpoint, os.path.join(self.save_dir, 'checkpoint_best.pth'))

        if final:
            torch.save(checkpoint, os.path.join(self.save_dir, 'checkpoint_final.pth'))
            os.remove(os.path.join(self.save_dir, 'checkpoint_latest.pth'))
        else:
            torch.save(checkpoint, os.path.join(self.save_dir, 'checkpoint_latest.pth'))

    def on_train_start(self):
        """Sets up the training environment"""

        # empty cuda cache
        torch.cuda.empty_cache()

        # create save_dir if it doesn't exist yet
        os.makedirs(self.save_dir, exist_ok=True)

        # set up logger to a new instance
        self.logger = MedSAMLogger()

        # load the previous checkpoint if it exists and resume is True. Otherwise, load
        # the medSAM model from which we train.
        self._first_run_setup()
        if (not self._maybe_load_checkpoint()):
            self._setup_data_splits()

        # set up the dataloaders
        self._get_dataloaders()

        # set up loss functions
        self.seg_loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, reduction='mean')
        self.ce_loss = nn.BCEWithLogitsLoss(reduction="mean")

    def on_epoch_start(self, epoch):
        self.medsam_model.train()
        self.epoch_loss = [1e10 for _ in range(len(self.training_dataset))]
        self.current_epoch = epoch
        self.epoch_start_time = datetime.now()
        
        # self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch)

    def _maybe_load_checkpoint(self):
        """
        Populates the variables IF a checkpoint has been found
        - checkpoint
        - medsam_model
        - optimizer
        - start_epoch
        - best_loss
        - training_split
        - validation_split
        """

        if self.resume:
            if len([f for f in os.listdir(save_dir) if f == 'checkpoint_latest.pth']) == 0:
                print("No checkpoint found in the save directory")
            else:
                # Load model details
                self.checkpoint = torch.load(os.path.join(save_dir, 'checkpoint_latest.pth'))

                sam_model = sam_model_registry["vit_b"](checkpoint=os.path.join(save_dir, 'checkpoint_latest.pth'))
                
                self.medsam_model = MedSAM(
                    image_encoder = sam_model.image_encoder,
                    mask_decoder = sam_model.mask_decoder,
                    prompt_encoder = sam_model.prompt_encoder,
                    freeze_image_encoder = True
                ).to(device)

                # self.medsam_model.load_state_dict(self.checkpoint["model"])
                self.optimizer.load_state_dict(self.checkpoint["optimizer"])
                self.start_epoch = self.checkpoint["epoch"] + 1
                self.best_loss = self.checkpoint["best_loss"]
                print(f"Loaded checkpoint from epoch {self.start_epoch}, best loss: {self.best_loss:.4f}")

                # Get Data Split
                try:
                    self._load_split_from_json()
                except FileNotFoundError as e:
                    print('[WARNING] a checkpoint was found, but a split file was not.')
                    print('          a new data split will be generated with potential training bias towards the validaiton split.')
                    self._setup_data_splits()

                return True

        return False

    def _first_run_setup(self):
        sam_model = sam_model_registry["vit_b"](checkpoint=self.checkpoint_path)
        self.medsam_model = MedSAM(
            image_encoder = sam_model.image_encoder,
            mask_decoder = sam_model.mask_decoder,
            prompt_encoder = sam_model.prompt_encoder,
            freeze_image_encoder = True
        )
        self.medsam_model = self.medsam_model.to(device)

        self.optimizer = optim.AdamW(
            self.medsam_model.mask_decoder.parameters(),
            lr=self.lr,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=self.weight_decay
        )

        self.start_epoch = 0
        self.best_loss = float('inf')

    def _save_splits_to_json(self, training_image_ids, validation_image_ids):
        data = {
            "training_image_ids": list(training_image_ids),
            "validation_image_ids": list(validation_image_ids)
        }
        with open(os.path.join(self.save_dir, 'data_splits.json'), 'w') as json_file:
            json.dump(data, json_file)

    def _load_split_from_json(self):
        with open(os.path.join(self.save_dir, 'data_splits.json'), 'r') as json_file:
            data = json.load(json_file)
        self.training_split = set(data["training_image_ids"])
        self.validation_split = set(data["validation_image_ids"])

    def _setup_data_splits(self, training_split=None, validation_split=None):
        """Setup the data splits either from a known source or a new setup."""
        if training_split is not None and validation_split is not None:
            self.training_split = training_split
            self.validation_split = validation_split
            return

        # get the image ids that have been processed. Use gt dir as reference
        axis0_slices = set(map(lambda x : int(re.search(image_id_from_file_name_regex, x).group(1)), os.listdir(os.path.join(self.gt_dir, 'axis0'))))
        axis1_slices = set(map(lambda x : int(re.search(image_id_from_file_name_regex, x).group(1)), os.listdir(os.path.join(self.gt_dir, 'axis1'))))
        axis2_slices = set(map(lambda x : int(re.search(image_id_from_file_name_regex, x).group(1)), os.listdir(os.path.join(self.gt_dir, 'axis2'))))

        if not axis0_slices == axis1_slices == axis2_slices:
            print('[WARNING]: The slices for the anatomy are not consistent across the three axes. Some axese are missing data, please check')

        training_split = 0.8
        validation_split = 1 - training_split

        # Split the data into training and validation
        self.training_image_ids = random.sample(list(axis0_slices), int(len(axis0_slices) * training_split))
        self.validation_image_ids = list(set(axis0_slices) - set(self.training_image_ids))
        assert set.intersection(set(self.training_image_ids), set(self.validation_image_ids)).__len__() == 0, 'Training and Validation sets are not disjoint'

        # Save the splits in a json file
        self._save_splits_to_json(self.training_image_ids, self.validation_image_ids)

    def _get_dataloaders(self):
        self.training_dataset = SAM_Dataset(self.image_dir, self.gt_dir, self.training_image_ids)
        self.validation_dataset = SAM_Dataset(self.image_dir, self.gt_dir, self.validation_image_ids)
        
        # Quick check
        assert set(map(lambda x : int(re.search(image_id_from_file_name_regex, x).group(1)), self.validation_dataset.axis0_imgs)) == set(self.validation_image_ids), 'DataSet incorrectly loaded image ids that don\'t match supplied validation set image ids'
        assert set(map(lambda x : int(re.search(image_id_from_file_name_regex, x).group(1)), self.training_dataset.axis0_imgs)) == set(self.training_image_ids), 'DataSet incorrectly loaded image ids that don\'t match supplied validation set image ids'

        self.train_loader = DataLoader(self.training_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
        self.val_loader = DataLoader(self.validation_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

In [12]:
trainer = MedSAMTrainer(
    anatomy
    , checkpoint_path
    , save_dir
    , img_dir
    , gt_dir
    , epochs
    , batch_size
    , lr
    , num_workers
    , weight_decay
    , resume
)

In [13]:
trainer.run_training()