<a href="https://colab.research.google.com/github/aliyassine1/LiteMedSAM_quant/blob/main/Lite_MedSam_size_reduction_quantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
pip install timm



In [None]:
from os import listdir, makedirs
from os.path import join, isfile, basename
from glob import glob
from tqdm import tqdm
from time import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from segment_anything.modeling import MaskDecoder, PromptEncoder, TwoWayTransformer
from tiny_vit_sam import TinyViT
from matplotlib import pyplot as plt
import cv2
import argparse
from collections import OrderedDict
import pandas as pd

# medsamlite
class MedSAM_Lite(nn.Module):
    def __init__(
            self,
            image_encoder,
            mask_decoder,
            prompt_encoder
        ):
        super().__init__()
        self.image_encoder = image_encoder
        self.mask_decoder = mask_decoder
        self.prompt_encoder = prompt_encoder

    def forward(self, image, box_np):
        image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
        # do not compute gradients for prompt encoder
        with torch.no_grad():
            box_torch = torch.as_tensor(box_np, dtype=torch.float32, device=image.device)
            if len(box_torch.shape) == 2:
                box_torch = box_torch[:, None, :] # (B, 1, 4)

        sparse_embeddings, dense_embeddings = self.prompt_encoder(
            points=None,
            boxes=box_np,
            masks=None,
        )
        low_res_masks, iou_predictions = self.mask_decoder(
            image_embeddings=image_embedding, # (B, 256, 64, 64)
            image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
            multimask_output=False,
          ) # (B, 1, 256, 256)

        return low_res_masks

    @torch.no_grad()
    def postprocess_masks(self, masks, new_size, original_size):
        """
        Do cropping and resizing

        Parameters
        ----------
        masks : torch.Tensor
            masks predicted by the model
        new_size : tuple
            the shape of the image after resizing to the longest side of 256
        original_size : tuple
            the original shape of the image

        Returns
        -------
        torch.Tensor
            the upsampled mask to the original size
        """
        # Crop
        masks = masks[..., :new_size[0], :new_size[1]]
        # Resize
        masks = F.interpolate(
            masks,
            size=(original_size[0], original_size[1]),
            mode="bilinear",
            align_corners=False,
        )

        return masks


In [None]:
import time
import numpy as np
from skimage import transform, io

from torch.nn import functional as F
from PIL import Image
from segment_anything import sam_model_registry
from os import listdir, makedirs
from os.path import join, isfile, basename
from glob import glob
from tqdm import tqdm
# wrap it up as a function
import base64

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from tiny_vit_sam import TinyViT
from segment_anything.modeling import MaskDecoder, PromptEncoder, TwoWayTransformer
#from tiny_vit_sam import TinyViT
from matplotlib import pyplot as plt
import cv2
import argparse
from collections import OrderedDict
import pandas as pd
import torch.nn.functional as F


# freeze seeds
torch.manual_seed(2023)
torch.cuda.empty_cache()
torch.cuda.manual_seed(2023)
np.random.seed(2023)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def set_up_model():
  # set up model
      medsam_lite_image_encoder = TinyViT(
          img_size=256,
          in_chans=3,
          embed_dims=[
              64, ## (64, 256, 256)
              128, ## (128, 128, 128)
              160, ## (160, 64, 64)
              320 ## (320, 64, 64)
          ],
          depths=[2, 2, 6, 2],
          num_heads=[2, 4, 5, 10],
          window_sizes=[7, 7, 14, 7],
          mlp_ratio=4.,
          drop_rate=0.,
          drop_path_rate=0.0,
          use_checkpoint=False,
          mbconv_expand_ratio=4.0,
          local_conv_size=3,
          layer_lr_decay=0.8
      )

      medsam_lite_prompt_encoder = PromptEncoder(
          embed_dim=256,
          image_embedding_size=(64, 64),
          input_image_size=(256, 256),
          mask_in_chans=16
      )

      medsam_lite_mask_decoder = MaskDecoder(
          num_multimask_outputs=3,
              transformer=TwoWayTransformer(
                  depth=2,
                  embedding_dim=256,
                  mlp_dim=2048,
                  num_heads=8,
              ),
              transformer_dim=256,
              iou_head_depth=3,
              iou_head_hidden_dim=256,
      )

      medsam_model = MedSAM_Lite(
          image_encoder = medsam_lite_image_encoder,
          mask_decoder = medsam_lite_mask_decoder,
          prompt_encoder = medsam_lite_prompt_encoder
      )
      return medsam_model

def postprocess_masks(self, masks, new_size, original_size):
    """
    Do cropping and resizing

    Parameters
    ----------
    masks : torch.Tensor
        masks predicted by the model
    new_size : tuple
        the shape of the image after resizing to the longest side of 256
    original_size : tuple
        the original shape of the image

    Returns
    -------
    torch.Tensor
        the upsampled mask to the original size
    """
    # Crop
    masks = masks[..., :new_size[0], :new_size[1]]
    # Resize
    masks = F.interpolate(
        masks,
        size=(original_size[0], original_size[1]),
        mode="bilinear",
        align_corners=False,
    )

    return masks

@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_256, new_size, original_size):
    box_torch = torch.as_tensor(box_256, dtype=torch.float, device=img_embed.device)
    if len(box_torch.shape) == 2:
        box_torch = box_torch[:, None, :] # (B, 1, 4)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points = None,
        boxes = box_torch,
        masks = None,
    )
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed, # (B, 256, 64, 64)
        image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
        sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
        dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
        multimask_output=False
    )

    low_res_pred = medsam_model.postprocess_masks(low_res_logits, new_size, original_size)
    low_res_pred = torch.sigmoid(low_res_pred)
    low_res_pred = low_res_pred.squeeze().cpu().numpy()
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)

    return medsam_seg



def resize_longest_side(image, target_length=256):
    """
    Resize image to target_length while keeping the aspect ratio
    Expects a numpy array with shape HxWxC in uint8 format.
    """
    oldh, oldw = image.shape[0], image.shape[1]
    scale = target_length * 1.0 / max(oldh, oldw)
    newh, neww = oldh * scale, oldw * scale
    neww, newh = int(neww + 0.5), int(newh + 0.5)
    target_size = (neww, newh)

    return cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)

def pad_image(image, target_size=256):
    """
    Pad image to target_size
    Expects a numpy array with shape HxWxC in uint8 format.
    """
    # Pad
    h, w = image.shape[0], image.shape[1]
    padh = target_size - h
    padw = target_size - w
    if len(image.shape) == 3: ## Pad image
        image_padded = np.pad(image, ((0, padh), (0, padw), (0, 0)))
    else: ## Pad gt mask
        image_padded = np.pad(image, ((0, padh), (0, padw)))

    return image_padded

def load_image(img_np,medsam_model):


    if len(img_np.shape) == 2:
        img_3c = np.repeat(img_np[:, :, None], 3, axis=-1)
    else:
        img_3c = img_np
    assert np.max(img_3c)<256, f'input data should be in range [0, 255], but got {np.unique(img_3c)}'
    H, W, _ = img_3c.shape

    #segs = np.zeros(img_3c.shape[:2], dtype=np.uint8)

    segs = np.zeros(img_3c.shape[:2], dtype=np.uint8)

    ## MedSAM Lite preprocessing
    img_256 = resize_longest_side(img_3c, 256)
    newh, neww = img_256.shape[:2]
    img_256_norm = (img_256 - img_256.min()) / np.clip(
        img_256.max() - img_256.min(), a_min=1e-8, a_max=None
    )
    ## preprocessing
    #img_256 = resize_longest_side(img_3c, 256)
    #newh, neww = img_256.shape[:2]
    img_256_padded = pad_image(img_256_norm, 256)
    img_256_tensor = torch.tensor(img_256_padded).float().permute(2, 0, 1).unsqueeze(0).to(device)

    with torch.no_grad():
        image_embedding = medsam_model.image_encoder(img_256_tensor)

    return img_3c,image_embedding,H, W,newh, neww



def save_mask(self):
    out_path = f"{self.image_path.split('.')[0]}_mask.png"
    io.imsave(out_path, self.mask_c)

Quantizing the whole linear layers of the models form 32 to 8 bits.

In [None]:
lite_medsam_checkpoint_path="/content/drive/MyDrive/challenge_medsam/MedSAM_fast/work_dir/LiteMedSAM/lite_medsam.pth"
medsam_model=set_up_model()
print("Loading MedSAM model, a sec.")
tic = time.perf_counter()
lite_medsam_checkpoint = torch.load(lite_medsam_checkpoint_path, map_location=device)
medsam_model.load_state_dict(lite_medsam_checkpoint)
medsam_model.to(device)
medsam_model.eval

print(f"Done, took {time.perf_counter() - tic} to load the model")

Loading MedSAM model, a sec.
Done, took 0.34889948300042306 to load the model


In [None]:
print('Size of the model before quantization')
print_size_of_model(medsam_model)

Size of the model before quantization
Size (KB): 39376.561


**size reduction by 65.79 % !**

In [None]:
import torch
from torch.ao.quantization import quantize_dynamic

# Assuming img_encoder is your model
medsam_model.eval()  # Ensure the model is in evaluation mode

# Perform dynamic quantization, focusing on supported layers
quantized_medsam_model = quantize_dynamic(
    model=medsam_model,
    dtype=torch.qint8,
    qconfig_spec={
        torch.nn.Linear,  # Add Linear as it's commonly quantized
        # Embedding layers are not explicitly included here due to their special requirements
    }
)

print("Dynamic quantization complete.")
print('Size of the model after quantization')
print_size_of_model(quantized_medsam_model)

Dynamic quantization complete.
Size of the model after quantization
Size (KB): 13469.215


Testing

non quantized model

In [None]:
!python CVPR24_LiteMedSAM_infer_8bit_quantized.py.py -i /content/drive/MyDrive/challenge_medsam/test_demo/imgs/ -o /content/drive/MyDrive/challenge_medsam/test_demo/T4_env/Final_quant/Original_cpumap

  0% 0/10 [00:00<?, ?it/s]2DBox_CXR_demo.npz, box: [ 84  55 203 193], predicted iou: 0.9303
 10% 1/10 [00:03<00:32,  3.65s/it]2DBox_Dermoscopy_demo.npz, box: [124 244 815 843], predicted iou: 0.959
 20% 2/10 [00:09<00:38,  4.87s/it]2DBox_Endoscopy_demo.npz, box: [ 564  207 1879 1080], predicted iou: 0.9652
 30% 3/10 [00:14<00:34,  4.86s/it]2DBox_Fundus_demo.npz, box: [1150  978 1696 1516], predicted iou: 0.9566
 40% 4/10 [00:21<00:33,  5.66s/it]2DBox_Mammography_demo.npz, box: [ 131 1532  475 1895], predicted iou: 0.8332
2DBox_Mammography_demo.npz, box: [ 701 1952  816 2096], predicted iou: 0.8386
2DBox_Mammography_demo.npz, box: [ 982 1991 1088 2139], predicted iou: 0.8355
2DBox_Mammography_demo.npz, box: [ 431 2075  594 2238], predicted iou: 0.7594
 50% 5/10 [00:30<00:34,  6.84s/it]2DBox_Microscope_demo.npz, box: [  0  97  66 236], predicted iou: 0.9024
2DBox_Microscope_demo.npz, box: [  0 267  84 386], predicted iou: 0.9121
2DBox_Microscope_demo.npz, box: [  0 387  34 522], predicte

quantized model

In [None]:
!python CVPR24_LiteMedSAM_infer.py -i /content/drive/MyDrive/challenge_medsam/test_demo/imgs/ -o /content/drive/MyDrive/challenge_medsam/test_demo/Final_quant/Quantization_of_whole_linear

  0% 0/10 [00:00<?, ?it/s]2DBox_CXR_demo.npz, box: [ 84  55 203 193], predicted iou: 0.9353
 10% 1/10 [00:05<00:46,  5.21s/it]2DBox_Dermoscopy_demo.npz, box: [124 244 815 843], predicted iou: 0.9578
 20% 2/10 [00:12<00:52,  6.57s/it]2DBox_Endoscopy_demo.npz, box: [ 564  207 1879 1080], predicted iou: 0.9687
 30% 3/10 [00:17<00:39,  5.61s/it]2DBox_Fundus_demo.npz, box: [1150  978 1696 1516], predicted iou: 0.9498
 40% 4/10 [00:24<00:36,  6.09s/it]2DBox_Mammography_demo.npz, box: [ 131 1532  475 1895], predicted iou: 0.8508
2DBox_Mammography_demo.npz, box: [ 701 1952  816 2096], predicted iou: 0.832
2DBox_Mammography_demo.npz, box: [ 982 1991 1088 2139], predicted iou: 0.8436
2DBox_Mammography_demo.npz, box: [ 431 2075  594 2238], predicted iou: 0.7804
 50% 5/10 [00:33<00:36,  7.21s/it]2DBox_Microscope_demo.npz, box: [  0  97  66 236], predicted iou: 0.9119
2DBox_Microscope_demo.npz, box: [  0 267  84 386], predicted iou: 0.9206
2DBox_Microscope_demo.npz, box: [  0 387  34 522], predicte

Both have similar iou pred but the quantized is faster