# Fine Tuning the MedSAM

We provide advanced tutorials to show 
   - a. How was MedSAM trained? or How to fine-tune SAM on customized datasets? [training](https://github.com/bowang-lab/MedSAM/tree/main#model-training)
   - b. How to fine-tune the model with text-based prompts: [training](https://github.com/bowang-lab/MedSAM/tree/main/extensions/text_prompt) and [inference colab](https://colab.research.google.com/drive/1wexPLewVMI-9EMiplfyoEtGGayYDH3tt?usp=sharing) 
   - c. How to fine-tune the model with point-based prompts: [training](https://github.com/bowang-lab/MedSAM/tree/main/extensions/point_prompt) and [inference colab](https://colab.research.google.com/drive/1cCBw_IhdPiWE4sN7QwqKJPgAFlWsKgkm?usp=sharing)

## Data preprocessing

Download [SAM checkpoint](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) and place it at `work_dir/SAM/sam_vit_b_01ec64.pth` .

Download the demo [dataset](https://zenodo.org/record/7860267) and unzip it to `data/FLARE22Train/`.

This dataset contains 50 abdomen CT scans and each scan contains an annotation mask with 13 organs. The names of the organ label are available at [MICCAI FLARE2022](https://flare22.grand-challenge.org/).

Run pre-processing

Install `cc3d`: `pip install connected-components-3d`

```bash
python pre_CT_MR.py
```

- split dataset: 80% for training and 20% for testing
- adjust CT scans to [soft tissue](https://radiopaedia.org/articles/windowing-ct) window level (40) and width (400)
- max-min normalization
- resample image size to `1024x2014`
- save the pre-processed images and labels as `npy` files

In [2]:
!pip install connected-components-3d
!wget -P /vol/bitbucket/az620/radiotherapy/models/MedSAM/checkpoints/ https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

--2024-05-02 13:27:08--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 2600:9000:2684:f800:13:6e38:acc0:93a1, 2600:9000:2684:8600:13:6e38:acc0:93a1, 2600:9000:2684:4e00:13:6e38:acc0:93a1, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|2600:9000:2684:f800:13:6e38:acc0:93a1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 375042383 (358M) [binary/octet-stream]
Saving to: ‘/vol/bitbucket/az620/radiotherapy/models/MedSAM/checkpoints/sam_vit_b_01ec64.pth’


2024-05-02 13:27:14 (109 MB/s) - ‘/vol/bitbucket/az620/radiotherapy/models/MedSAM/checkpoints/sam_vit_b_01ec64.pth’ saved [375042383/375042383]



In [2]:
!wget -P data/ https://zenodo.org/records/7860267/files/FLARE22Train.zip

--2024-05-02 13:29:57--  https://zenodo.org/records/7860267/files/FLARE22Train.zip?download=1
Resolving zenodo.org (zenodo.org)... 2001:1458:d00:3b::100:200, 2001:1458:d00:9::100:195, 2001:1458:d00:3a::100:33a, ...
Connecting to zenodo.org (zenodo.org)|2001:1458:d00:3b::100:200|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1480326483 (1.4G) [application/octet-stream]
Saving to: ‘data/FLARE22Train/FLARE22Train.zip?download=1’


2024-05-02 13:31:29 (16.2 MB/s) - ‘data/FLARE22Train/FLARE22Train.zip?download=1’ saved [1480326483/1480326483]



In [2]:
!unzip data/FLARE22Train.zip -d data/FLARE22Train

Archive:  data/FLARE22Train/FLARE22Train.zip
   creating: data/FLARE22Train/FLARE22Train/
   creating: data/FLARE22Train/FLARE22Train/images/
  inflating: data/FLARE22Train/FLARE22Train/images/FLARE22_Tr_0001_0000.nii.gz  
  inflating: data/FLARE22Train/FLARE22Train/images/FLARE22_Tr_0002_0000.nii.gz  
  inflating: data/FLARE22Train/FLARE22Train/images/FLARE22_Tr_0003_0000.nii.gz  
  inflating: data/FLARE22Train/FLARE22Train/images/FLARE22_Tr_0004_0000.nii.gz  
  inflating: data/FLARE22Train/FLARE22Train/images/FLARE22_Tr_0005_0000.nii.gz  
  inflating: data/FLARE22Train/FLARE22Train/images/FLARE22_Tr_0006_0000.nii.gz  
  inflating: data/FLARE22Train/FLARE22Train/images/FLARE22_Tr_0007_0000.nii.gz  
  inflating: data/FLARE22Train/FLARE22Train/images/FLARE22_Tr_0008_0000.nii.gz  
  inflating: data/FLARE22Train/FLARE22Train/images/FLARE22_Tr_0009_0000.nii.gz  
  inflating: data/FLARE22Train/FLARE22Train/images/FLARE22_Tr_0010_0000.nii.gz  
  inflating: data/FLARE22Train/FLARE22Train/imag

In [1]:
%run /vol/bitbucket/az620/radiotherapy/models/MedSAM/pre_CT_MR.py

ori \# files len(names)=50
after sanity check \# files len(names)=50


  0%|          | 0/40 [00:12<?, ?it/s]


KeyboardInterrupt: 

## Training on multiple GPUs (Recommend)

The model was trained on five A100 nodes and each node has four GPUs (80G) (20 A100 GPUs in total). Please use the slurm script to start the training process.

```bash
sbatch train_multi_gpus.sh
```

When the training process is done, please convert the checkpoint to SAM's format for convenient inference.

```bash
python utils/ckpt_convert.py # Please set the corresponding checkpoint path first
```

## Training on one GPU

```bash
python train_one_gpu.py
```

If you only want to train the mask decoder, please check the tutorial on the [0.1 branch](https://github.com/bowang-lab/MedSAM/tree/0.1).