<a href="https://colab.research.google.com/github/Somvit09/Test-LLMS/blob/master/SAM2_TEST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
!pip install -r requirements.txt

Collecting segment_anything@ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 (from -r requirements.txt (line 22))
  Using cached segment_anything-1.0-py3-none-any.whl


In [18]:
!pip install numpy opencv-python torch torchvision python-dotenv fastapi uvicorn



In [19]:
!pip install requests



In [20]:
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
import numpy as np
import cv2, os, torch, requests
from dotenv import load_dotenv
from enum import Enum

In [21]:
load_dotenv()
BASE_DIR = os.getcwd()
print(BASE_DIR)
CONFIG_PATH = os.path.join(BASE_DIR, "configs")
os.makedirs(CONFIG_PATH, exist_ok=True)
print(CONFIG_PATH)

/content
/content/configs


In [22]:
class ModelType(Enum):
    VIT_H = "vit_h"
    VIT_B = "vit_b"
    VIT = "vit_l"

In [32]:
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
import numpy as np
import cv2, os, torch, requests
from fastapi import status, HTTPException
from dotenv import load_dotenv
from enum import Enum

load_dotenv()


BASE_DIR = os.getcwd()
CONFIG_PATH = os.path.join(BASE_DIR, "configs")
os.makedirs(CONFIG_PATH, exist_ok=True)

class ModelType(Enum):
    VIT_H = "vit_h"
    VIT_L = "vit_l"
    VIT_B = "vit_b"



class SAM2Annotator:
    def __init__(self):
        self.VIT_H_PTH_URL: str = os.getenv("VIT_H_PTH")
        self.VIT_L_PTH_URL: str = os.getenv("VIT_L_PTH")
        self.VIT_B_PTH_URL: str = os.getenv("VIT_B_PTH")
        self.__CHECKPOINTS = ["sam_vit_h_4b8939.pth", "sam_vit_l_0b3195.pth", "sam_vit_b_01ec64.pth"]
        # empty cache
        torch.cuda.empty_cache()

        # download pth file
        self.download_configs()

        self.models = dict(
            vit_h="sam_vit_h_4b8939.pth",
            vit_l="sam_vit_l_0b3195.pth",
            vit_b="sam_vit_b_01ec64.pth"
        )


    def download_configs(self):
        try:
            # make the downloadable dirs
            self.download_dir = os.path.join(CONFIG_PATH, "SAM2")
            os.makedirs(self.download_dir, exist_ok=True)

            # download the models
            self.MODELS = [m.value for m in ModelType]
            for index in range(len(self.MODELS)):
                # main download path
                downloaded_model_path = os.path.join(self.download_dir, self.__CHECKPOINTS[index])

                # determine the models accroding to the path
                if not os.path.exists(downloaded_model_path):
                    if self.MODELS[index] == ModelType.VIT_H.value:
                        response = requests.get(self.VIT_H_PTH_URL)
                    elif self.MODELS[index] == ModelType.VIT_L.value:
                        response = requests.get(self.VIT_L_PTH_URL)
                    else:
                        response = requests.get(self.VIT_B_PTH_URL)

                    # finally save the model in the directory
                    if response.status_code == 200:
                        with open(downloaded_model_path, 'wb') as file:
                            file.write(response.content)
                        print(f"File downloaded successfully as {self.__CHECKPOINTS[index]}")
                    else:
                        print(f"Failed to download file. HTTP Status code: {response.status_code}")

        except (FileExistsError, FileNotFoundError) as fe:
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(fe))
        except IOError as i:
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(i))
        except TypeError as t:
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(t))
        except (ValueError, KeyError) as kv:
            raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(kv))
        except Exception as e:
            raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))


    def load_model(self, model_type: str):
        """
        Load the SAM model based on the specified model type.
        Args:
            model_type (str): The type of model to load (e.g., 'vit_h', 'vit_l', 'vit_b').
        """
        try:
            if model_type not in self.MODELS:
                raise ValueError(f"Invalid model type: {model_type}. Allowed types are: {', '.join(self.MODELS)}")

            # Check if the checkpoint file exists
            checkpoint_path = os.path.join(self.download_dir, self.models[model_type])
            if not os.path.exists(checkpoint_path):
                raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}. Ensure it is downloaded.")

            # Initialize the model
            self.__DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
            self.__model = sam_model_registry[model_type](checkpoint=checkpoint_path)
            self.__model.to(device=self.__DEVICE)

            # Log successful initialization
            print(f"SAM model of type '{model_type}' has been initialized on {self.__DEVICE}.")

            # Initialize mask generator and predictor
            self.mask_generator = SamAutomaticMaskGenerator(self.__model)
            self.predictor = SamPredictor(self.__model)

        except ValueError as ve:
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
        except FileNotFoundError as fe:
            raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(fe))
        except torch.cuda.OutOfMemoryError as ome:
            raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="CUDA out of memory. Try reducing model size or using a machine with more GPU memory.")
        except Exception as e:
            raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An unexpected error occurred while loading the model: {str(e)}")

    def convert_image_to_npArray(self, file_object: bytes) -> np.ndarray:
        """
        Automatically generate np array ob the binary image object
        Args:
            file_object (binary obj): file
        Returns:
            list: List of NumPy arrays.
        """
        np_array = np.asarray(bytearray(file_object), dtype="uint8")
        self.image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
        if self.image is None:
            raise ValueError("Failed to decode the image. Please check the input file format.")
        return self.image

    def auto_annotate(self) -> list:
        """
        Automatically generate segmentation masks for an image.
        Args:
            image (np.ndarray): Input image as a NumPy array (H, W, C).
        Returns:
            list: List of segmentation masks as NumPy arrays.
        """
        masks = self.mask_generator.generate(self.image)
        return masks


    def annotate_with_box(self, box: list) -> np.ndarray:
        """
        Generate a segmentation mask for a given bounding box.
        Args:
            image (np.ndarray): Input image as a NumPy array (H, W, C).
            box (list): Bounding box [x1, y1, x2, y2].
        Returns:
            np.ndarray: Segmentation mask as a NumPy array.
        """
        self.predictor.set_image(self.image)
        mask, _, _ = self.predictor.predict(box=box)
        return mask

    def annotate_with_points(self, points: list, labels: list) -> np.ndarray:
        """
        Generate a segmentation mask for given points and labels.
        Args:
            image (np.ndarray): Input image as a NumPy array (H, W, C).
            points (list): List of points as (x, y) coordinates.
            labels (list): List of labels (1 for foreground, 0 for background).
        Returns:
            np.ndarray: Segmentation mask as a NumPy array.
        """
        self.predictor.set_image(self.image)
        mask, _, _ = self.predictor.predict(points=points, labels=labels)
        return mask

    def save_mask(self, mask: np.ndarray, filepath: str):
        """
        Save a segmentation mask to a file.
        Args:
            mask (np.ndarray): Segmentation mask as a NumPy array.
            filepath (str): File path to save the mask.
        """
        cv2.imwrite(filepath, mask.astype(np.uint8) * 255)


In [36]:
sam = SAM2Annotator()
sam.load_model(model_type="vit_h")

  state_dict = torch.load(f)


SAM model of type 'vit_h' has been initialized on cuda:0.


In [37]:
first_image = "Road621.jpg"
image_path = os.path.join(BASE_DIR, first_image)
img_arr = None

In [38]:
with open(image_path, 'rb') as file:
    file_object = file.read()
    print(type(file_object))
    img_arr = sam.convert_image_to_npArray(file_object=file_object)
    print(img_arr)
    print(sam.auto_annotate())

<class 'bytes'>
[[[236 237 228]
  [235 236 227]
  [232 235 226]
  ...
  [ 99 110 108]
  [121 132 130]
  [156 168 168]]

 [[173 170 162]
  [171 169 161]
  [169 167 159]
  ...
  [ 75  86  84]
  [ 92 103 101]
  [ 76  88  88]]

 [[155 144 140]
  [154 145 141]
  [155 146 142]
  ...
  [116 127 125]
  [144 155 153]
  [ 95 107 107]]

 ...

 [[ 88  76  70]
  [ 86  74  68]
  [ 85  73  67]
  ...
  [144 203 219]
  [144 204 220]
  [144 204 220]]

 [[ 93  81  75]
  [ 91  79  73]
  [ 90  78  72]
  ...
  [141 203 221]
  [140 204 222]
  [140 204 222]]

 [[ 99  87  81]
  [ 98  86  80]
  [ 96  84  78]
  ...
  [139 203 221]
  [138 204 222]
  [136 205 222]]]
[{'segmentation': array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, Fal