In [0]:
mount_path = "/mnt/datamount"
dbutils.fs.ls(mount_path)

In [0]:
%run /Workspace/Users/ark0723@gmail.com/anomaly_detection/notebook/00_utils

# Image Augmentation
- problem: class imblance between 'abnomal' and 'normal'

In [0]:
df = spark.read.format('parquet').load(f'{mount_path}/images_final')
display(df)

In [0]:
display(df.groupBy('label').count())

In [0]:
import io
import random
from PIL import Image, ImageDraw
import numpy as np
from pyspark.sql.functions import pandas_udf, col, lit, regexp_replace
from pyspark.sql.types import BinaryType

from functools import reduce
from pyspark.sql.functions import DataFrame


def create_transform_udf(transforms: list):
    """
    적용할 변환(transform) 타입 리스트를 인자로 받아,
    해당 리스트에서 무작위로 변환을 선택하여 적용하는 Pandas UDF를 생성합니다.
    (Geometric transforms + Salt&Pepper Noise)
    """
    @pandas_udf(BinaryType())
    def dynamic_transform_udf(df_series):
        def apply_transform(img_bytes):
            '''Apply transform to images and serialize back as jpeg.'''
            try:
                img = Image.open(io.BytesIO(img_bytes))
            except:
                return None
            
            # randomly select transform
            num_to_apply = random.randint(1, len(transforms))
            seleted_list = random.sample(transforms, num_to_apply)

            for selected_tr in seleted_list:
                # *** 중요 ***
                # 변환(예: 회전)으로 인해 이미지 크기가 바뀔 수 있으므로,
                # 매번 루프가 돌 때마다 현재 크기를 다시 가져옵니다.
                w, h = img.size

                match selected_tr:
                    case 'horizontal':
                        img = img.transpose(Image.FLIP_LEFT_RIGHT)
                    case 'vertical':
                        img = img.transpose(Image.FLIP_TOP_BOTTOM)
                    case 'rotate_90':
                        img = img.transpose(Image.ROTATE_90)    
                    case 'rotate_180':
                        img = img.transpose(Image.ROTATE_180)   
                    case 'rotate_270':
                        img = img.transpose(Image.ROTATE_270)  
                    case 'squash&skew':
                        # 현재 w, h를 기준으로 매트릭스 계산
                        ss_matrix = (1, 0.3, -w*0.15, 0.3, 1, -h*0.15)
                        img = img.transform((w,h), Image.AFFINE, ss_matrix)
                    case 'salt&pepper':
                        # draw 객체는 'img'를 직접 수정
                        draw = ImageDraw.Draw(img)
                        num_patches = np.random.randint(1, 5)

                        for _ in range(num_patches):
                            patch_pixels = np.random.randint(100, 500)
                            noise_value = np.random.choice([0, 255]) # white or black
                            # radius
                            r = int(np.sqrt(patch_pixels / np.pi))
                            r = max(r, 5) 

                            # 현재 w, h 기준으로 패치 생성 (이미지가 너무 작으면 skip)
                            if w <=2*r or h<=2*r:
                                continue

                            cx = np.random.randint(r, w -r)
                            cy = np.random.randint(r, h -r)

                            # polygon
                            num_points = np.random.randint(5,10) 
                            angles = np.linspace(0, 2*np.pi, num_points, endpoint=False)
                            angles += np.random.uniform(0, 2*np.pi/num_points, size = num_points)
                            radii = np.random.uniform(0.5*r, 1.5*r, size = num_points)
                            points = [
                                (int(cx + rad*np.cos(angle)), int(cy + rad*np.sin(angle)))
                                for angle, rad in zip(angles, radii)
                            ]

                            fill_color = (noise_value,)*3 if img.mode == 'RGB' else noise_value
                            draw.polygon(points, fill_color)
            # save back as jpeg
            output = io.BytesIO()
            img.save(output, format='JPEG')
            return output.getvalue()
        return df_series.apply(apply_transform)
    # return udf function
    return dynamic_transform_udf

# define image metadata
img_meta = {
    "spark.contentAnnotation":'{"mimeType":"image/jpeg"}'
}

In [0]:
# 1. 적용할 모든 변환 유형을 리스트로 정의
transform_to_anomaly = [
    'horizontal', 
    'vertical', 
    'rotate_90', 
    'rotate_180', 
    'rotate_270', 
    'squash&skew',
    'salt&pepper'
]

transform_to_normal = ['horizontal', 'vertical', 'rotate_90', 'rotate_180', 'rotate_270']

### refactor: optimize spark performance
- Filter Once: call df.filter(...) only one time, not in a loop.
- posexplode: This is the standard Spark function to multiply rows. It takes an array and creates a new row for each element, giving you an index (named pos) for free.
- Single Transformation: The UDF is applied to the entire set of duplicated rows at once. Spark handles the parallel execution across the whole cluster efficiently.
- Unique Naming: Using the pos column from posexplode allows you to easily create a unique name for every new image.
- One Write Operation: You trigger one single, large, and efficient Spark job, not hundreds of small ones. This is significantly faster and easier for Spark to optimize.

In [0]:
label_to_augment = 'surfing'
target_normal = 800

# 3. Get the count of base images to augment
num_normal = df.filter(col('label') == label_to_augment).count()

print(f"Found {num_normal} base images with label '{label_to_augment}'.")

# 4. Calculate augmentations needed and handle edge cases
if num_normal == 0:
    print("No base images found. Skipping augmentation.")
else:
    # Calculate how many new images to generate for *each* base image
    n_augmentations_per_image = int(target_normal / num_normal) + 1 
    
    if n_augmentations_per_image == 0:
        print(f"Target ({target_normal}) is not greater than current count ({num_normal}). No augmentation needed.")
    else:
        print(f"Generating {n_augmentations_per_image} augmentations per image...")

        # 5. Create the UDF instance
        normal_transformer_udf = create_transform_udf(transform_to_normal)

        # 6. Define the Spark-native transformation
        
        # Filter the DataFrame *once*
        base_df = df.filter(col('label') == label_to_augment)

        # Create an array of N dummy elements, where N is the number of augmentations.
        # This will be used to duplicate each row N times.
        explode_array = F.array([F.lit(1)] * n_augmentations_per_image)

        # Define the full transformation
        augmented_df = base_df \
            .withColumn("explode_col", explode_array) \
            .select(
                "*", 
                # --- CHANGED LINE ---
                # Call posexplode in its own 'select' clause.
                # This creates two new columns: 'pos' (the index) and 'val' (the value)
                F.posexplode("explode_col").alias("pos", "val") 
            ) \
            .withColumn(
                "image",
                # Apply the random UDF. It will run once for each exploded row.
                normal_transformer_udf("image").alias("image", metadata=img_meta)
            ) \
            .withColumn(
                "name",
                # Create a new unique name using the augmentation index ('pos')
                F.concat(
                    F.col("name"), 
                    F.lit("_augmented_"), 
                    F.col("pos").cast("string")
                )
            ) \
            .withColumn("path", F.lit("N/A")) \
            .select("name", "path", "label", "image") # Re-order and drop temp columns

# 7. Execute the *single* write operation

# This one write operation executes the entire plan (filter, explode, transform)
# in one optimized Spark job.
augmented_df.write.mode('overwrite').format('parquet').save(f'{mount_path}/images_normal_final')

print(f"Successfully generated and saved {n_augmentations_per_image * num_normal} augmented images.")

In [0]:
display(augmented_df, limit = 10)

In [0]:
label_to_augment = 'noise'
target_anormaly = 200

# 3. Get the count of base images to augment
num_anormaly = df.filter(col('label') == label_to_augment).count()

print(f"Found {num_anormaly} base images with label '{label_to_augment}'.")

# 4. Calculate augmentations needed and handle edge cases
if num_anormaly == 0:
    print("No base images found. Skipping augmentation.")
else:
    # Calculate how many new images to generate for *each* base image
    n_augmentations_per_image = int(target_anormaly / num_anomaly) + 1 
    
    if n_augmentations_per_image == 0:
        print(f"Target ({target_anormaly}) is not greater than current count ({num_anormaly}). No augmentation needed.")
    else:
        print(f"Generating {n_augmentations_per_image} augmentations per image...")

        # 5. Create the UDF instance
        abnormal_transformer_udf = create_transform_udf(transform_to_anomaly)

        # 6. Define the Spark-native transformation
        
        # Filter the DataFrame *once*
        base_df = df.filter(col('label') == label_to_augment)

        # Create an array of N dummy elements, where N is the number of augmentations.
        # This will be used to duplicate each row N times.
        explode_array = F.array([F.lit(1)] * n_augmentations_per_image)

        # Define the full transformation
        augmented_df = base_df \
            .withColumn("explode_col", explode_array) \
            .select(
                "*", 
                # --- CHANGED LINE ---
                # Call posexplode in its own 'select' clause.
                # This creates two new columns: 'pos' (the index) and 'val' (the value)
                F.posexplode("explode_col").alias("pos", "val") 
            ) \
            .withColumn(
                "image",
                # Apply the random UDF. It will run once for each exploded row.
                abnormal_transformer_udf("image").alias("image", metadata=img_meta)
            ) \
            .withColumn(
                "name",
                # Create a new unique name using the augmentation index ('pos')
                F.concat(
                    F.col("name"), 
                    F.lit("_augmented_"), 
                    F.col("pos").cast("string")
                )
            ) \
            .withColumn("path", F.lit("N/A")) \
            .select("name", "path", "label", "image") # Re-order and drop temp columns

# 7. Execute the *single* write operation

# This one write operation executes the entire plan (filter, explode, transform)
# in one optimized Spark job.
augmented_df.write.mode('overwrite').format('parquet').save(f'{mount_path}/images_anormaly_final')

print(f"Successfully generated and saved {n_augmentations_per_image * num_anormaly} augmented images.")

In [0]:
display(augmented_df, limit = 10)

In [0]:
display_bytes_image_in_df(transposed_df)

In [0]:
noise_df_final = salt_pepper_df.union(transposed_df)