In [0]:
mount_path = "/mnt/datamount"
dbutils.fs.ls(mount_path)

# Image Augmentation
- problem: class imblance between 'abnomal' and 'normal'

In [0]:
df = spark.read.format('parquet').load(f'{mount_path}/images_final')
display(df)

In [0]:
display(df.groupBy('label').count())

In [0]:
import io
import random
from PIL import Image, ImageDraw
import numpy as np
from pyspark.sql.functions import pandas_udf, col, lit, regexp_replace
from pyspark.sql.types import BinaryType


from functools import reduce
from pyspark.sql.functions import DataFrame




@pandas_udf(BinaryType())
def transpose_img_udf(df_series):
    def transpose(img_bytes):
        '''Transpose image and serialize back as jpeg.'''
        img = Image.open(io.BytesIO(img_bytes))
        transpose_types = ['horizontal', 'vertical', 'rotate_90', 'rotate_180','rotate_270', 'squash&skew']

        # squach & skew matrix
        w, h = img.size
        ss_matrix = (1, 0.3, -w*0.15, 0.3, 1, -h*0.15)

        # # transpose will be applied randomly from 1 to 3 times for an image.
        selected = random.sample(transpose_types, random.randint(1, 3))
        # selected = ['squash&skew']
                                 
        for selected_tr in selected:
            match selected_tr:
                case 'horizontal':
                    img = img.transpose(Image.FLIP_LEFT_RIGHT)
                case 'vertical':
                    img = img.transpose(Image.FLIP_TOP_BOTTOM)
                case 'rotate_90':
                    img = img.transpose(Image.ROTATE_90)    
                case 'rotate_180':
                    img = img.transpose(Image.ROTATE_180)   
                case 'rotate_270':
                    img = img.transpose(Image.ROTATE_270)  
                case 'squash&skew':
                    img = img.transform((w,h), Image.AFFINE, ss_matrix)
        # save back as jpeg
        output = io.BytesIO()
        img.save(output, format='JPEG')
        return output.getvalue()
    return df_series.apply(transpose)

# define image metadata
img_meta = {
    "spark.contentAnnotation":'{"mimeType":"image/jpeg"}'
}
# apply the UDF to transpose images randomly with multiple transpose types
noisy_transposed_df = df.filter(col('label') == 'noise').withColumn('image', transpose_img_udf('image').alias('image', metadata = img_meta)).withColumn('name', regexp_replace(col('name'), '.jpg', '_transposed')).withColumn('path', lit('N/A'))
display(noisy_transposed_df)
display(df.filter(col('label') == 'noise'))

In [0]:
# generate multiple df_transposed DataFrames
num_to_apply = 5
transposed_dfs = []

for i in range(num_to_apply):
    transposed_df = df.filter(col('label') == 'noise').withColumn('image', transpose_img_udf('image').alias('image', metadata = img_meta)).withColumn('name', regexp_replace(col('name'), '.jpg', '_tr')).withColumn('path', lit('N/A'))
    transposed_dfs.append(transposed_df)

# combine all transposed_dfs
transposed_df = reduce(DataFrame.unionAll, transposed_dfs)
display(transposed_df)

In [0]:
@pandas_udf(BinaryType())
def add_salt_pepper_patches_udf(df_series):
    def add_salt_pepper_patches(img_bytes):
        """Add salt and pepper patches to image and serialize back as jpeg."""
        img = Image.open(io.BytesIO(img_bytes))
        patch_pixels = 500
        noise_value = 255 # white
        draw = ImageDraw.Draw(img)
        w, h = img.size

        # radius r
        r = int(np.sqrt(patch_pixels) / np.pi)
        r = max(r, 5)

        # random center point for the noise patch
        cx = np.random.randint(r, w - r)
        cy = np.random.randint(r, h - r)

        num_points = np.random.randint(5,10) # polygon points between 5 and 10
        angles = np.linspace(0, 2*np.pi, num_points, endpoint=False)
        angles += np.random.uniform(0, 2*np.pi/num_points, size = num_points)
        radii = np.random.uniform(0.5*r, 1.5*r, size = num_points)
        points = [
            (int(cx + r*np.cos(angle)), int(cy + r*np.sin(angle)))
            for angle, randius in zip(angles, radii)
        ]

        fill_color = (noise_value,) * 3 if img.mode == 'RGB' else noise_value
        draw.polygon(points, fill_color)

        output = io.BytesIO()
        img.save(output, format='JPEG')

        return output.getvalue()
    return df_series.apply(add_salt_pepper_patches)

salt_pepper_df = df.filter(col('label') == 'surfing').withColumn('image', add_salt_pepper_patches_udf('image').alias('image', metadata = img_meta)).withColumn('name', regexp_replace(col('name'),r"(\d+)", r"$1_add_noise")).withColumn('path', lit('N/A'))
display(salt_pepper_df)