In [1]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [62]:
class Image_Augmentation():

    # defect image save할 때 제품 타입 담는 용
    product_types = []
    
    def __init__(self):
        pass

    def augment_image(self, img):
        # img : img를 담은 리스트

        augmented_images = []

        # 원본 이미지 추가
        augmented_images.append(img)
    
        # Flip
        flipped_img = cv2.flip(img, 1)
        augmented_images.append(flipped_img)
    
        # Rotation
        rows, cols = img.shape[:2]
        M1 = cv2.getRotationMatrix2D((cols / 2, rows / 2), 15, 1)
        rotated_img1 = cv2.warpAffine(img, M1, (cols, rows))
        augmented_images.append(rotated_img1)
    
        M2 = cv2.getRotationMatrix2D((cols / 2, rows / 2), -15, 1)
        rotated_img2 = cv2.warpAffine(img, M2, (cols, rows))
        augmented_images.append(rotated_img2)
    
        # Blur
        blurred_img = cv2.GaussianBlur(img, (5, 5), 0)
        augmented_images.append(blurred_img)
    
        return augmented_images
    def preprocess_image(self, img_path):
        # print(img_path)
        img = cv2.imread(img_path) 
        
        # 회색으로 저장
        
        original_height, original_width = img.shape[:2]
        scale = 0.1
        # 새로운 크기 계산
        new_width = int(original_width * scale)
        new_height = int(original_height * scale)
        
        # print('height:', new_height, 'width:', new_width)
        
        # 이미지 크기 조절
        resized_img = cv2.resize(img, (new_width, new_height))
        
        # resized_img = resized_img / 255.0 # 정규화
        
        # plt.imshow(resized_img)
        return resized_img
    
    def load_augmented_images(self, base_path, categories):
        train_images = []
        self.product_types = []
        for category in categories:
            dir_path = os.path.join(base_path, category)
            category_images = []

            print(dir_path)
            
            for img_name in os.listdir(dir_path):
                img_path = os.path.join(dir_path, img_name)

                # 불량품 제품 타입 담는용
                self.product_types.append(img_name.split('_')[0])
                
                img = self.preprocess_image(img_path)  # 이미지 전처리 함수 사용
                
                augmented_images = self.augment_image(img)  # 이미지 증강 함수 사용
                category_images.extend(augmented_images)
            
            train_images.append(category_images)
        
        return train_images

    def fair_save_images(self, train_images, categories, target_dir):            
        
        if not os.path.exists(target_dir):
            os.makedirs(target_dir)
        
        for idx, category_images in enumerate(train_images):
            category_name = categories[idx]
            
            for i, img in enumerate(category_images, start=1):
                img_name = f"{category_name}_PASS_{i:02d}.jpg"
                img_dir = os.path.join(target_dir, img_name)
                
                # cv2.imwrite(img_dir, img)
                print(f"Saved: {img_dir}")
    
    def defect_save_images(self, train_images, categories, target_dir):            

        j = 0
        
        if not os.path.exists(target_dir):
            os.makedirs(target_dir)
        
        for idx, category_images in enumerate(train_images):
            category_name = categories[idx]

            print(len(category_images))
            
            for i, img in enumerate(category_images, start=0):
                prod_type = self.product_types[j]
                img_name = f"{prod_type}_{category_name}_{(i+1):02d}.jpg"
                
                img_dir = os.path.join(target_dir, img_name)
                
                cv2.imwrite(img_dir, img)
                print(f"Saved: {img_dir}")

                if i % 5 == 4:
                    j = j + 1

        print('='*20)
        print(i)


# 양품 사용 예시

In [None]:
base_path = '../PCB_DATASET/PCB_USED'
categories = ['01', '04', '05', '06', '07', '08', '09', '10', '11', '12']  # 예시 카테고리
target_dir = '../PCB_DATASET/PCB_PASS'

IM_instance = Image_Augmentation()

train_images = IM_instance.load_augmented_images(base_path, categories)

np.shape(train_images)

IM_instance.fair_save_images(train_images, categories, target_dir)

# 불량품 사용 예시

In [None]:
base_path = '../PCB_DATASET/images'
categories = ['Missing_hole', 'Mouse_bite', 'Open_circuit', 'Short', 'Spur', 'Spurious_copper']
target_dir = '../PCB_DATASET/PCB_UNPASS'

IM_instance = Image_Augmentation()

train_images = IM_instance.load_augmented_images(base_path, categories)

IM_instance.defect_save_images(train_images, categories, target_dir)