# リポジトリクローン

In [None]:
!git clone https://github.com/baaivision/Painter
%cd Painter/SegGPT/SegGPT_inference

# パッケージインストール

In [None]:
!pip install --upgrade -q timm
!pip install -q fvcore
!pip install -q fairscale
!pip install -q 'git+https://github.com/facebookresearch/detectron2.git'

# 重みダウンロード

In [3]:
!wget https://huggingface.co/BAAI/SegGPT/resolve/main/seggpt_vit_large.pth -q

# モデル読み込み

In [None]:
import torch
from seggpt_inference import prepare_model

torch.manual_seed(42)
device = torch.device('cuda')  # or  'cpu'

model = prepare_model(
    'seggpt_vit_large.pth',
    'seggpt_vit_large_patch16_input896x448',
    'instance',
).to(device)

In [9]:
import cv2
import numpy as np
import torch.nn.functional as F

from seggpt_engine import run_one_image

# 推論用関数
def run_inference(model, input_image, prompt_image, prompt_mask, input_shape=(448, 448)):
    resize_width, resize_height = input_shape[0], input_shape[1]
    original_size = (input_image.shape[1], input_image.shape[0])

    # 前処理
    input_image_ = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
    input_image_ = cv2.resize(input_image_, (resize_width, resize_height)) / 255.0

    prompt_image_ = cv2.cvtColor(prompt_image, cv2.COLOR_BGR2RGB)
    prompt_image_ = cv2.resize(prompt_image_, (resize_width, resize_height)) / 255.0

    prompt_mask_ = cv2.cvtColor(prompt_mask, cv2.COLOR_BGR2RGB)
    prompt_mask_ = cv2.resize(prompt_mask_, (resize_width, resize_height), interpolation=cv2.INTER_NEAREST) / 255.0

    # プロンプトマスクと入力マスクを結合
    combined_mask = np.concatenate((prompt_mask_, prompt_mask_), axis=0)
    # プロンプト画像と入力画像を結合
    combined_image = np.concatenate((prompt_image_, input_image_), axis=0)

    # 画像をImageNetの平均と標準偏差で正規化
    imagenet_mean = np.array([0.485, 0.456, 0.406])
    imagenet_std = np.array([0.229, 0.224, 0.225])
    normalized_image = (combined_image - imagenet_mean) / imagenet_std
    normalized_mask = (combined_mask - imagenet_mean) / imagenet_std

    # SegGPTで推論を実行
    output = run_one_image(
        np.stack([normalized_image], axis=0),
        np.stack([normalized_mask], axis=0),
        model,
        device,
    )

    # 出力を元の画像サイズにリサイズ
    output = F.interpolate(
        output[None, ...].permute(0, 3, 1, 2),
        size=[original_size[1], original_size[0]],
        mode='nearest',
    ).permute(0, 2, 3, 1)[0].numpy()

    return output

# サンプル画像での確認

In [None]:
import cv2
from google.colab.patches import cv2_imshow

# 入力画像
input_image_path = 'examples/hmbb_2.jpg'
input_image = cv2.imread(input_image_path)

cv2_imshow(input_image)

In [None]:
# 入力プロンプト画像とマスク指定
input_prompt_image_path = 'examples/hmbb_1.jpg'
input_prompt_mask_path = 'examples/hmbb_1_target.png'

prompt_image = cv2.imread(input_prompt_image_path)
prompt_mask = cv2.imread(input_prompt_mask_path)

cv2_imshow(cv2.hconcat([prompt_image, prompt_mask]))

In [None]:
%%time

# 推論
result = run_inference(model, input_image, prompt_image, prompt_mask)

In [None]:
import copy

# 結果確認
debug_image = copy.deepcopy(input_image)
debug_image = (debug_image * (0.6 * result / 255.0 + 0.4)).astype(np.uint8)
cv2_imshow(debug_image)

# 別画像でテスト

In [15]:
!wget https://user0514.cdnw.net/shared/img/thumb/sakiphotoPAR541761180_TP_V4.jpg -O sample01.jpg -q
!wget https://user0514.cdnw.net/shared/img/thumb/sakiphotoPAR541661179_TP_V4.jpg -O sample02.jpg -q
!wget https://user0514.cdnw.net/shared/img/thumb/sakiphotoPAR542051187_TP_V4.jpg -O sample03.jpg -q

In [17]:
!wget https://github.com/Kazuhito00/simple-annotation-on-colab/raw/main/colab_utils.py -q

In [None]:
sample_image01 = cv2.imread('sample01.jpg')
sample_image02 = cv2.imread('sample02.jpg')
sample_image03 = cv2.imread('sample03.jpg')

In [None]:
import colab_utils

polygons = []
colab_utils.annotate_polygon([cv2.cvtColor(sample_image01, cv2.COLOR_BGR2RGB)], polygon_storage_pointer=polygons)

In [72]:
mask_height, mask_width = sample_image01.shape[:2]
sample_image01_mask = np.zeros((mask_height, mask_width), dtype=np.uint8)

absolute_coords = (np.array(polygons) * [mask_width, mask_height]).astype(int)
absolute_coords = absolute_coords.reshape((-1, 1, 2))
cv2.fillPoly(sample_image01_mask, [absolute_coords], 255)

sample_image01_mask = cv2.cvtColor(sample_image01_mask, cv2.COLOR_GRAY2BGR)

In [None]:
# マスク生成結果確認
debug_image = copy.deepcopy(sample_image01)
debug_image = (debug_image * (0.6 * sample_image01_mask / 255.0 + 0.4)).astype(np.uint8)
cv2_imshow(debug_image)

In [None]:
%%time

# 推論
result = run_inference(model, sample_image02, sample_image01, sample_image01_mask)

In [None]:
import copy

# 結果確認
debug_image = copy.deepcopy(sample_image02)
debug_image = (debug_image * (0.6 * result / 255.0 + 0.4)).astype(np.uint8)
cv2_imshow(debug_image)

In [None]:
%%time

# 推論
result = run_inference(model, sample_image03, sample_image01, sample_image01_mask)

In [None]:
import copy

# 結果確認
debug_image = copy.deepcopy(sample_image03)
debug_image = (debug_image * (0.6 * result / 255.0 + 0.4)).astype(np.uint8)
cv2_imshow(debug_image)