Skip to content

Ly-kc/SAM-ON-BTCV

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Code Structure

├── ckpts
│   ├── sam_vit_h_4b8939.pth   #origin sam weights
│   ├── vit_h_mask_2024-01-15-05-42-30   #finetuned mask decoder and training log
│   ├── vit_h_semantic_mask_2024-01-15-22-51-30 #trained semantic mask decoder and training log
├── data
│   ├── processed         #preprocessed dataset
│   ├── sam_embedding     #image embedding generated by sam encoder
│   ├── Testing        #testing split of original dataset
│   ├── Training	   #training and validation split of original dataset
│   ├── centers.txt      
│   └── widths.txt       
├── id_to_color.txt       #mapping from organ class to color
└── sam_on_btcv
    ├── segment_anything  #add build_sem_sam.py and SemanticMaskDecoder compared with origin sam
    ├── btcv_dataset.py         #Dataset class
    ├── criterion.py            #Dice loss
    ├── grid_sam.py             #some applications with grid points prompts
    ├── myAutomaticMaskGenerator.py     #automask that supports semantic mask decoder
    ├── my_pridictor.py                 #predictor that supports semantic mask decoder
    ├── preprocess_dataset.py           #data preprocessing
    ├── finetune.py                 #finetune mask decoder
    ├── train_semantic.py        #train semantic mask decoder
    └── visualize.py             #applications for visualization

Usage

Installation

pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
pip install opencv-python,tqdm,SimpleITK,open3d,matplotlib

Data preparation

1. Download pretrained weight for sam into ckpts/
2. Download dataset into data/Testing and data/Training
3. Run preprocess_dataset.py for data preprocessing

Train semantic decoder

1. Run finetune.py to finetune mask decoder without semantic on BTCV
2. Run train_semantic.py to train semantic mask decoder (need to modify he parameter 'from_pretrain' on the bottom of the script)

Automask

After the training above, you can implement semantic automask on each image.

python grid_sam.py

This will generate semantic segmentation on each slice, and the results are saved in the format of a single channel image with pixel value ranging from 0 to 13.

Visualize

In visualize.py I support a series of applications for visualization.

vis_semantic_masks: visualize semantic segmentation result.

plot_history: draw figure of training log, including loss, dice and acc.

save_to_ply: save the semantic segmentation result of each 2D slice into 3D point cloud.

About

The repository works for the final project of course Machine Learning in 2023 fall. Course website: https://youchengli.com/teaching/machine_learning_23_fall.html

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages