<a href="https://colab.research.google.com/github/jungsh210/AI-Project/blob/main/SAM/SAM_Tutorial_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SAM Tutorial
# #1.자동 분할

## SAM 환경 세팅

In [None]:
import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
import sys
!{sys.executable} -m pip install opencv-python matplotlib
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

!mkdir images
!wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/dog.jpg

!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import io
from PIL import Image
from google.colab import files

def show_anns(anns):
  if len(anns) == 0:
    return
  sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse = True)
  ax = plt.gca()
  ax.set_autoscale_on(False)
  polygons = []
  color = []
  for ann in sorted_anns:
    m = ann['segmentation']
    img = np.ones((m.shape[0], m.shape[1], 3))
    color_mask = np.random.random((1,3)).tolist()[0]
    for i in range(3):
      img[:,:,i] = color_mask[i]
    ax.imshow(np.dstack((img, m*0.35)))

## 이미지 업로드
- 다운받은 샘플 이미지를 사용하는 경우
   - image = cv2.imread("images/dog.jpg") \
     image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- 이미지를 업로드해서 사용하는 경우
   - image_file = files.upload() \
     image = io.BytesIO(image_file[list(image_file.keys())[0]])\
     image = np.array(Image.open(image))

In [None]:
image = cv2.imread("images/dog.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [None]:
plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()

## 자동 mask 생성

In [None]:
import sys
sys.path.append('..')
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = 'vit_h'

device = 'cuda'

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

## run **generate** with defalut parameters
- 이미지 입력만으로 객체 분할
- Default 파라미터 값으로 실행

In [None]:
masks = mask_generator.generate(image)

In [None]:
print(len(masks))
print(masks[0].keys())

In [None]:
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()

## run generate parameter 설정
- 포인트가 얼마나 조밀하게 샘플링 되는지
- 품질이 낮거나 중복된 마스크를 제거하기 위한 임계값 제어

## mask generator의 파라미터를 변경하여 객체 분할

In [None]:
mask_generator_2 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100,  # Requires open-cv to run post-processing
)

In [None]:
masks2 = mask_generator_2.generate(image)

In [None]:
len(masks2)

In [None]:
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show()