In [1]:
import yaml
import sys
import os

from segment_anything import build_sam, SamAutomaticMaskGenerator
from segment_anything import SamPredictor, sam_model_registry
from easydict import EasyDict as edict
from matplotlib import pyplot as plt
from omegaconf import OmegaConf
import torch.nn.functional as F
from pathlib import Path
from PIL import Image
import numpy as np
import torch
import cv2

from IA.lama_inpaint import inpaint_img_with_lama


Detectron v2 is not installed


In [2]:
ROOT_PATH  = os.getcwd()
DATA_PATH  = f'{ROOT_PATH}/data/dove.jpg'
MODEL_PATH = f'{ROOT_PATH}/checkpoint/sam_vit_h.pth'

In [3]:
device        = 'cuda' if torch.cuda.is_available() else 'cpu'
image         = np.array(Image.open(DATA_PATH).convert('RGB'))

plt.imshow(image)
plt.axis(False)
plt.show()



![원본 이미지](./data/dove.jpg)

In [4]:
model        = build_sam(checkpoint = MODEL_PATH).to(device)
generator    = SamAutomaticMaskGenerator(model)

In [5]:
masks        = generator.generate(image)

In [6]:
image_cp     = image.copy()
random_color = lambda : np.random.randint(0, 256)


for idx, mask in enumerate(masks):
    
    x, y         = list(map(int, mask['point_coords'][0]))
    seg          = mask['segmentation']
    seg          = seg + np.zeros(seg.shape, np.uint8)
    _, bin_image = cv2.threshold(seg, 0, 127, cv2.THRESH_BINARY)
    
    
    color        = (random_color(), random_color(), random_color())
    masked_image = np.where(seg[..., None], color, image_cp)
    image_cp     = cv2.addWeighted(image_cp, 0.8, masked_image, 0.2, 0, dtype = cv2.CV_32F)
    
    conts, _     = cv2.findContours(bin_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    for cont in conts: cv2.drawContours(image_cp, [cont], -1, color, 2)
    
    cv2.putText(image_cp, f'[{idx}]. x : {x}, y : {y}', (x, y), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
plt.imshow(image_cp.astype(np.uint32))
plt.axis(False)

cv2.imwrite('output/dove/masked.png', image_cp)

True

![마스크 씌워진 이미지](./data/output/dove/masked.png)

In [7]:
args = {}
args['input_image']  = DATA_PATH
args['coords_type']  = 'key_in'
args['point_coords'] = masks[19]['point_coords']
args['point_label']  = [1]
args['output_dir']   = f'{DATA_PATH}/output'
args['lama_config']  = f'{ROOT_PATH}/IA/lama/configs/prediction/default.yaml'
args['lama_ckpt']    = f'{ROOT_PATH}/checkpoint/'

args = edict(args)

In [8]:
point_coords = np.array(args.point_coords)
point_label  = np.array(args.point_label)

sam          = sam_model_registry['vit_h'](checkpoint = MODEL_PATH).to(device)
predictor    = SamPredictor(sam)

In [9]:
point_coords, point_label

(array([[421.875  , 613.59375]]), array([1]))

In [10]:
coords = torch.as_tensor(point_coords)
coords.size()

torch.Size([1, 2])

In [11]:
predictor.set_image(image)
masks_, _, _ = predictor.predict(
                point_coords     = point_coords,
                point_labels     = point_label,
                multimask_output = False
            )

In [12]:
masks_   = masks_.astype(np.uint8) * 255
img_stem = Path(DATA_PATH).stem
out_dir  = Path(args.output_dir) / img_stem
out_dir.mkdir(parents = True, exist_ok = True)

In [13]:
for idx, mask in enumerate(masks_):
    
    mask_path         = out_dir / f'mask_{idx}.png'
    image_points_path = out_dir / f'with_points.png'
    image_mask_path   = out_dir / f'with_{Path(mask_path).name}'
    
    Image.fromarray(mask.astype(np.uint8)).save(mask_path)
    dpi  = plt.rcParams['figure.dpi']
    H, W = image.shape[:2] 
    
    plt.figure(figsize = (W * 0.77 / dpi, H * 0.77 / dpi))
    plt.imshow(image)
    plt.axis(False)
    
    coords      = np.array(args.point_coords)
    labels      = np.array(args.point_label)
    color_table = {0 : 'red', 1 : 'green'}
    
    for label_value, color in color_table.items():
        
        points = coords[labels == label_value]
        plt.gca().scatter(points[:, 0], points[:, 1], color = color,
                          marker = '*', s = (W*0.04)**2, edgecolor = 'white',
                         linewidth = 1.25)
        
    plt.savefig(image_points_path, bbox_inches = 'tight', pad_inches = 0)
    mask = mask.astype(np.uint8)
    
    if np.max(mask) == 255: mask = mask / 255
    color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w  = mask.shape[-2:]
    
    mask_img = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    plt.gca().imshow(mask_img)
    
    plt.savefig(image_mask_path, bbox_inches = 'tight', pad_inches = 0)

![마스크 이미지](./data/output/dove/mask_0.png)
![마스크 씌워진 이미지](./data/output/dove/with_mask_0.png)

In [14]:
for idx, mask in enumerate(masks_):
        
        mask_path        = out_dir / f'mask{idx}.png'
        img_inpaint_path = out_dir / f'inp_with_{Path(mask_path).name}'
        img_inpainted    = inpaint_img_with_lama(
                                image, mask, args.lama_config,
                                args.lama_ckpt, device = device
                            )
        
        Image.fromarray(img_inpainted.astype(np.uint8)).save(img_inpaint_path)
        

In [15]:
plt.imshow(img_inpainted)

<matplotlib.image.AxesImage at 0x7f03b0284fd0>

![lama 이미지](./data/output/dove/inp_with_mask0.png)