In [1]:
import cv2
import pandas as pd
import numpy as np
from glob import glob
from tqdm import tqdm
import os
from natsort import natsorted
from matplotlib import pyplot as plt
from skimage.restoration import inpaint
import shutil

노트북 설명

목적  : 포토샵으로 세그멘테이션한 이미지를 불러와, 원본의 이미지에서 세그멘테이션한 부분을 추출하는 코드.

2023.03.10

이미지 상에서 나타나는 결함은 GrayScale값의 변화로 나타난다. 결함 부분에서 결함이 없을 때의 GrayScale값과 결함이 있을 때의 GrayScale값의 차이를 이용하여 결함을 추출하려 한다.

inpaint 관련 여러 알고리즘 또는 모델을 알아보려 함.

In [None]:
def extract_flaw(ground_truth, origin, padding=0, file_name=None):
    """
    Extract flaw from origin
    :param ground_truth: ground truth image
    :param origin: origin image
    :param padding: padding size
    :return: flaw image
    """
    image_name = os.path.basename(ground_truth).split('.')[0]
    # Read ground truth
    mask = cv2.imread(ground_truth, cv2.IMREAD_GRAYSCALE) # 우리가 만든 마스킹
    # Read origin
    image_orig = cv2.imread(origin, cv2.IMREAD_GRAYSCALE) # 원본 이미지(결함 있음)
    mask2 = mask / 255
    image_defect = image_orig * (1 - mask2)
    image_result = inpaint.inpaint_biharmonic(image_defect, mask2)
    image_result = image_result.astype(np.uint8)
    image_diff = cv2.subtract(image_orig, image_result)
    
    #결과 이미지 : image_result
    #마스크 이미지 : image_defect
    #원본 이미지 : image_orig
    
    #마스크 확장. 
    _, gt_bin = cv2.threshold(mask, 0, 255, cv2.THRESH_OTSU)  # 0보다 크면 255, 0보다 작으면 0 
    dilatation = cv2.getStructuringElement(cv2.MORPH_RECT, (5,5))
    gt_bin_dil = cv2.dilate(gt_bin, dilatation, iterations=1) # 팽창. 마스킹을 좀 더 키워서 결함을 더 잘 감싸도록 함.
    
    #인페인팅
    
    
    cnt, labels, stats, centroids = cv2.connectedComponentsWithStats(gt_bin_dil) #윤곽선 추출해서 사각형으로 보여줌.
    count = 0
    for i in range(1, cnt):
        x, y, w, h, s = stats[i]
        x = x - padding 
        y = y - padding
        w = w + 2 * padding
        h = h + 2 * padding
        if x < 0:
            x = 0
        if y < 0:
            y = 0
        if x + w > mask.shape[1]:
            w = mask.shape[1] - x
        if y + h > mask.shape[0]:
            h = mask.shape[0] - y
        SD = image_diff[y:y + h, x:x + w]
        # numpy array로 저장
        np.save(f'/home/VirtualFlaw/Dataset_Unet/Extracted/{image_name}_{count}.npy', SD)
        
        count += 1
        


In [None]:
ground_truths = "/home/VirtualFlaw/Dataset_Unet/Ground_truth/"
origins = glob('/home/VirtualFlaw/Dataset_Unet/origin/Leftover/*.png')
origins = natsorted(origins)

for i in range(len(origins)):
    file_name = os.path.basename(origins[i]).split('.')[0]
    ground_truth = ground_truths + file_name + '.png'
    extract_flaw(ground_truth, origins[i], padding=0, file_name=file_name)

In [None]:
REJ_image = glob("/home/VirtualFlaw/Data/IMAGE/Raw_jpg/202211/REJ/**/*.jpg", recursive=True) + glob("/home/VirtualFlaw/Data/IMAGE/Raw_jpg/202212/REJ/**/*.jpg", recursive=True)
Seg_image = glob("/home/VirtualFlaw/Data/Hyun/data/Segmented_PO/Origin_normalized/*.jpg")
REJ_image_name = REJ_image.copy()
print(len(REJ_image), len(Seg_image))
#file name
for i in range(len(Seg_image)):
    Seg_image[i] = os.path.basename(Seg_image[i]).split('.')[0]

for i in range(len(REJ_image_name)):
    REJ_image_name[i] = os.path.basename(REJ_image_name[i]).split('.')[0]

image_list = []

for i in range(len(REJ_image_name)):
    if REJ_image_name[i] in Seg_image:
        print(REJ_image[i])
        image_list.append(REJ_image[i])
        
for i in image_list:
    shutil.copy(i, '/home/VirtualFlaw/Data/Hyun/data/Segmented_PO/Origin')


In [None]:
origin_folder = "/home/VirtualFlaw/RT_Project/src/Study/PO_Inpainted"
removed_flaw_folder = "/home/VirtualFlaw/RT_Project/src/Study/PO_Origin"
origin_images = glob(origin_folder + "/*.jpg")
removed_flaw_images = glob(removed_flaw_folder + "/*.jpg")
origin_images = natsorted(origin_images)
removed_flaw_images = natsorted(removed_flaw_images)

for i in range(len(origin_images)):
    print(origin_images[i])
    extract_flaw(origin_images[i], removed_flaw_images[i], padding=0, file_name=os.path.basename(origin_images[i]).split('.')[0])
    


In [None]:
file_name = glob("/home/VirtualFlaw/RT_Project/src/Study/PO/*.jpg")
file_name = [os.path.basename(i).split('.')[0] for i in file_name]  


origin_file = glob("/home/dais01/HyundaiRB/Data/Raw_Data/REJ/**/*.jpg", recursive=True)
for file in file_name:
    for origin in origin_file:
        if file in origin:
            shutil.copy(origin, '/home/VirtualFlaw/RT_Project/src/Study/PO_origin')
            break

In [None]:
origin_images = glob('/home/VirtualFlaw/RT_Project/src/Study/PO_Origin/*.jpg')
mask_images = glob('/home/VirtualFlaw/RT_Project/src/Study/Mask/*.jpg')

origin_images = natsorted(origin_images)
mask_images = natsorted(mask_images)

print(len(origin_images), len(mask_images))


for i in range(len(origin_images)):
    print(origin_images[i], mask_images[i])
    
    cv2.imwrite('/home/VirtualFlaw/RT_Project/src/Study/PO_Inpainted/' + os.path.basename(origin_images[i]), image_result, [int(cv2.IMWRITE_JPEG_QUALITY), 100])

In [None]:
ground_truths = "/home/VirtualFlaw/Dataset_Unet/Ground_truth/"
origins = glob('/home/VirtualFlaw/Dataset_Unet/origin/Leftover/*.png') + glob('/home/VirtualFlaw/Dataset_Unet/origin/Scratch/*.png')

origins = natsorted(origins)

for i in range(len(origins)):
    file_name = os.path.basename(origins[i]).split('.')[0]
    ground_truth = ground_truths + file_name + '.png'
    img = cv2.imread(origins[i])
    img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)
    cv2.imwrite('/home/VirtualFlaw/Dataset_Unet/inpainting/origin/' + file_name + '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, 0])
    cv2.imwrite('/home/VirtualFlaw/Dataset_Unet/inpainting/Ground_truth/' + file_name + '.png', cv2.imread(ground_truth), [cv2.IMWRITE_PNG_COMPRESSION, 0])