# Introduction
<div>
<img src="../assets/gtMsk.png" width="250"/>
<img src="../assets/samAutoMsk.png" width="250"/>
<img src="../assets/samPointPrompt.png" width="250"/>
<img src="../assets/samBoxPrompt1.png" width="250"/>
</div>



**<summary>What is the Segment Anything Model (SAM)?</summary>**
The Segment Anything Model (SAM) is a segmentation model developed by Meta AI. It is considered the first foundational model for Computer Vision. SAM was trained on a huge corpus of data containing millions of images and billions of masks, making it extremely powerful. As its name suggests, SAM is able to produce accurate segmentation masks for a wide variety of images. SAM’s design allows it to take human prompts into account, making it particularly powerful for Human In The Loop annotation. These prompts can be multi-modal: they can be points on the area to be segmented, a bounding box around the object to be segmented or a text prompt about what should be segmented.


**<summary>Fine tuning SAM with Low-Rank Adaptation (LoRA)?</summary>**
LoRA is an adapter that is using 2 matrices B and A. The 2 matrices have specific dimensions (input_size, r) and (r, input_size) . By specifying a rank r < input_size, we reduce the parameters size and try to capture the task with a small enough rank. The matrix product B*A gives a matrix of shape (input_size, input_size) so no information is lost but the model will have learned a new representation through training.

For any application, we only need to initialize the matrices, freeze SAM and train the adapter so that the frozen model + LoRA learns to segment anythings that you need.


## Data PreProcessing 
BCSS dataset masks_orig has 22 different classes of cells 
```
outside_roi             0
tumor	                1
stroma	                2
lymphocytic_infiltrate  3
necrosis_or_debris      4
glandular_secretions    5
blood                   6
exclude                 7
metaplasia_NOS          8
fat                     9
plasma_cells            10
other_immune_infiltrate	11
mucoid_material	        12
normal_acinus_or_duct	13
lymphatics              14
undetermined	        15
nerve	                16
skin_adnexa             17
blood_vessel	        18
angioinvasion	        19
dcis	                20
other	                21
```
Convert it to 2 classes
```
no tumor                0
tumor                   1
```

In [3]:
from torchvision import transforms
from PIL import Image
import os
import torch

in_folder = "../data/BCSS_small/train/masks_orig/"
out_folder = "../data/BCSS_small/train/masks/"
overwrite = False

_mask_transformer = transforms.Compose([
    transforms.PILToTensor()
])
_image_transformer = transforms.Compose([
    transforms.ToPILImage()
])

def process_and_save_mask(image_file):
    mask = Image.open(os.path.join(in_folder,image_file))
    mask_tensor = _mask_transformer(mask)
    mask_tensor = (mask_tensor == 1).to(torch.uint8) # convert to 0/1 mask
    mask_image = _image_transformer(mask_tensor)
    os.makedirs(out_folder, exist_ok=True)
    mask_image.save(os.path.join(out_folder, image_file))

in_masks = [file for file in os.listdir(in_folder) if file.endswith('.png')]
for image_file in in_masks:
    if os.path.isfile(os.path.join(out_folder, image_file)) and not overwrite:
        continue
    process_and_save_mask(image_file)

#### Original Mask

In [4]:
mask = Image.open(os.path.join(in_folder,in_masks[0]))
mask_tensor = _mask_transformer(mask)
print(mask_tensor.unique(return_counts=True))
mask_tensor

(tensor([1, 2], dtype=torch.uint8), tensor([31460, 18716]))


tensor([[[2, 2, 2,  ..., 2, 2, 2],
         [2, 2, 2,  ..., 1, 1, 1],
         [2, 2, 2,  ..., 1, 1, 1],
         ...,
         [2, 2, 2,  ..., 2, 2, 2],
         [2, 2, 2,  ..., 2, 2, 2],
         [2, 2, 2,  ..., 2, 2, 2]]], dtype=torch.uint8)

#### Processed Mask

In [6]:
mask = Image.open(os.path.join(out_folder,in_masks[0]))
mask_tensor = _mask_transformer(mask)
print(mask_tensor.unique(return_counts=True))
mask_tensor

(tensor([0, 1], dtype=torch.uint8), tensor([18716, 31460]))


tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]], dtype=torch.uint8)

## Code Repo and Setup
https://github.com/PiyushBhadauriya26/Semantic_Segmentation

In [12]:
from IPython.display import Markdown, display
display(Markdown("../README.md"))

#### Setup
- Use Python 3.12
- Use virtual env
- Install pytorch https://pytorch.org/get-started/locally/#windows-python
- Use requirements.txt to install the required python libs

#### Server
- Start the server `python LitServe_SAM.py`
- Health check `http://localhost:8000/health`

#### Client 
Use client code to call /predict API when the server is up.

```
usage: client.py [-h] --image IMAGE [--p1 P1] [--p2 P2] [--model MODEL] [--alpha ALPHA]

Send text & image to server and receive a response.

optional arguments:  
  -h, --help     show this help message and exit  
  --image IMAGE  URL for the image file.  
  --p1 P1        Single Point input in '(x,y)' format.
  --p2 P2        Point2 '(x1,y1)' for box input.
  --model MODEL  Name of the model [sam-vit_l, sam-vit_h, med_sam-vit_b, sam-vit_b-lora512]
  --alpha ALPHA  Transparency mask between 0-1.
```
response contains segmented image with identified region with no mask and green mask for background region.
###### Example
- `python client.py --image .\data\test1.png --p1 "(60,40)" --p2 "(180,120)" --alpha 0.8`# For best results provide box input with Region of interest.
- `python client.py --image .\data\test1.png --alpha 0.5 --model "med_sam-vit_b"` # Use med_sam-vit_b model to segmentation whole image  
- `python client.py --image .\data\test1.png --p1 "(130,80)" --alpha 0.5 --model "med_sam-vit_b"` # Point input 

#### Batch Inference 
Use `batch_inference.py` script to run inference on multiple images
```
usage: batch_inference.py [-h] [-i DATA_PATH] [-o SEG_PATH] [--device DEVICE] [--overwrite OVERWRITE] [--model MODEL]

run inference on testing set based on MedSAM

options:
  -h, --help                               show this help message and exit
  -i DATA_PATH, --data_path DATA_PATH      path to the data folder
  -o SEG_PATH, --seg_path SEG_PATH         path to the segmentation folder
  --device DEVICE                          Device cuda or cpu
  --overwrite OVERWRITE                    Overwrite existing results with a new mask.
  --model MODEL                            Name of the model [sam-vit_l, sam-vit_h, med_sam-vit_b, sam-vit_b-lora512]

```
###### Example
- `python .\batch_inference.py -i "data/BCSS_small/test/images" --overwrite True` # Run inference and save predicted masks in data/Results folder

#### Train
- Source: https://github.com/WangRongsheng/SAM-fine-tune
- Update `config.yaml` for DATASET paths, CHECKPOINT for base sam model and TRAIN setting
- `python train.py` # Start training.
- After training lora weights are saved as safetensors file in model_checkpoint folder

#### References
1. https://github.com/facebookresearch/segment-anything
2. https://github.com/facebookresearch/sam2
3. https://github.com/bowang-lab/MedSAM
4. https://github.com/mazurowski-lab/finetune-SAM
5. https://github.com/WangRongsheng/SAM-fine-tune
6. https://github.com/Lightning-AI/LitServe


## Metrics And Comparison