In [0]:
%pip install "albumentations" "numpy<2.0"
dbutils.library.restartPython()

In [0]:
import numpy as np
import pandas as pd
import io
from PIL import Image

import pyspark.sql.functions as F
from pyspark.sql.types import StructType, StructField, StringType, BinaryType

import albumentations as A

In [0]:
mount_path = "/mnt/datamount"
dbutils.fs.ls(mount_path)

In [0]:
# 1. Define Augmentation Pipelines
TARGET_RESIZE = 256

def custom_solarize(image, threshold=128, **kwargs):
    """Solarize: Invert only pixels brighter than the threshold."""
    # np.where(condition, value_if_true, value_if_false)
    img_out = np.where(image > threshold, 255 - image, image)
    return img_out.astype(np.uint8)

def custom_invert_img(image, **kwargs):
    """InvertImg: Invert all pixels."""
    return (255 - image).astype(np.uint8)


# Strategy 1: Combination of weak transformations for 'normal' data
normal_transform = A.Compose([
    A.Resize(TARGET_RESIZE, TARGET_RESIZE),
    A.Rotate(limit=10, p=0.7), # -10 to +10 degree rotation
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0, p=0.7), # ~10% shift/scale
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.8), # Adjust brightness/contrast
    A.HueSaturationValue(hue_shift_limit=15, sat_shift_limit=20, val_shift_limit=15, p=0.5),
    # A.GaussianBlur(blur_limit=(3, 7), p=0.5), # Weak blur
    # A.MotionBlur(blur_limit=(3, 9), p=0.5) 
])

# Strategy 2A: Geometric transformations for existing 'anomaly' data
anomaly_transform_A = A.Compose([
    A.Resize(TARGET_RESIZE, TARGET_RESIZE),
    A.Rotate(limit=15, p=0.8),
    A.ShiftScaleRotate(shift_limit=0.15, scale_limit=0.15, rotate_limit=0, p=0.8),
])

# Strategy 2B: Strong transformations to turn 'normal' data into 'new anomalies'
anomaly_transform_B = A.Compose([
    A.Resize(TARGET_RESIZE, TARGET_RESIZE),
    
    # 1. Create 'blemishes' (OneOf ensures at least one is applied)
    A.OneOf([
        # 'Black spots/peeling' (patch-style)
        A.CoarseDropout(
            max_holes=8, max_height=32, max_width=32,
            min_holes=1, min_height=16, min_width=16,
            fill_value=0, p=0.8
        ),
        # 'Cracks' or 'long scratches'
        A.RandomRain(slant_lower=-10, slant_upper=10, drop_length=20, drop_width=1, p=0.7),
        
        # 'Stains' or 'sensor noise'
        A.GaussNoise(var_limit=(100.0, 200.0), p=0.7),
        
        # 'Compression artifacts'
        A.ImageCompression(quality_lower=40, quality_upper=70, p=0.5),

        # --- Salt & Pepper ---
        # 'Pepper' (black dots) simulation
        A.PixelDropout(dropout_prob=0.03, p=0.7, fill_value=0),
        
        # 'Salt' (white dots) simulation
        A.PixelDropout(dropout_prob=0.03, p=0.7, fill_value=255)

    ], p=1.0),

    # 2. Add 'disturbances' (optional)
    A.OneOf([
        A.GaussianBlur(blur_limit=(3, 7), p=0.5),
        A.MotionBlur(blur_limit=(3, 9), p=0.5)
    ], p=0.5),
    
    # 3. Geometric transformation
    A.Rotate(limit=10, p=0.5)
])



# 2. Define Augmentation Quantities (Goal: ~800 normal, ~200 anomaly)
# (800 target - 18 original) / 18 original = 43.44... -> Round up to 44
N_PER_NORMAL = 44 # 18 original normal -> 18(original) + 18*44(augmented) = 810 (normal)

# 2 original anomalies -> ~50 (Track A)
# (50 target - 2 original) / 2 original = 24 -> Round up to 25 (to get >= 50 including originals)
N_PER_ANOMALY_A = 25 # 2 original anomalies -> 2(original) + 2*25(augmented) = 52 (anomaly Track A)

# 18 original normals -> ~150 (Track B)
# 150 target / 18 original = 8.33... -> Round up to 9
N_PER_NEW_ANOMALY = 9 # 18 original normal -> 18*9(augmented) = 162 (anomaly Track B)

# Total anomalies: 52 (Track A) + 162 (Track B) = 214

In [0]:
# 3. Image Transformation Helper Functions (PIL & Numpy)

def bin_to_np(img_bin):
    """
    Converts Spark BinaryType to a Numpy array (RGB).
    - This is a very CPU-expensive operation: minimizing its execution count is key to efficiency.
    """
    try:
        # Use PIL to open the binary data and convert to RGB
        image = Image.open(io.BytesIO(img_bin)).convert('RGB')
        # Convert to Numpy array
        return np.array(image)
    except Exception as e:
        print(f"Error decoding image: {e}")
        return None

def np_to_bin(img_np):
    """
    Converts a Numpy array (RGB) back to Spark BinaryType (PNG).
    """
    try:
        # Albumentations returns an RGB numpy array
        # Image.fromarray accepts the RGB order directly
        image = Image.fromarray(img_np)
        
        # Create an in-memory byte buffer
        byte_io = io.BytesIO()
        
        # Save the image to the byte buffer in PNG format
        # Using lossless compression (PNG) to preserve original features.
        image.save(byte_io, format='PNG')
        
        # Return the full byte value from the buffer
        return byte_io.getvalue()
    except Exception as e:
        print(f"Failed to encode image with PIL: {e}")
        return None


# 4. Pandas UDF (Generator) Definitions

# Define the schema that the UDFs will return
output_schema = StructType([
    StructField("name", StringType(), True),
    StructField("path", StringType(), True),
    StructField("label", StringType(), False),
    StructField("image", BinaryType(), False)
])

# Column order for the output DataFrame (must match the schema)
SCHEMA_COLUMNS = ["name", "path", "label", "image"]

# UDF 1: For processing 'normal' images
def augment_normal_generator(pdf_iter):
    """
    A Pandas UDF generator that takes an iterator of 'normal' image partitions.
    For each normal image, it yields:
    1. The original image.
    2. N 'normal' augmentations (Strategy 1).
    3. M 'new anomaly' augmentations (Strategy 2B).
    """
    for pdf in pdf_iter:
        # A list to store all rows that will be generated from this partition
        output_rows = []

        for _, row in pdf.iterrows():
            name, path, label = row['name'], row['path'], row['label']
            
            img_np = bin_to_np(row['image'])
            if img_np is None:
                continue

            # 1. Add the original image
            output_rows.append((name, path, label, row['image']))

            # 2. 'Normal' augmentation (Strategy 1)
            for i in range(N_PER_NORMAL):
                augmented_np = normal_transform(image=img_np)['image']
                aug_bin = np_to_bin(augmented_np)
                if aug_bin:
                    new_name = f"{name}_aug_normal_{i}"
                    output_rows.append((new_name, "N/A", label, aug_bin))

            # 3. 'New Anomaly' generation (Strategy 2B)
            for i in range(N_PER_NEW_ANOMALY):
                augmented_np = anomaly_transform_B(image=img_np)['image']
                aug_bin = np_to_bin(augmented_np)
                if aug_bin:
                    new_name = f"{name}_aug_new_anomaly_{i}"
                    output_rows.append((new_name, "N/A", "noise", aug_bin))
                    
        # After processing all rows in the partition, yield the list as a DataFrame
        if output_rows:
            yield pd.DataFrame(output_rows, columns=SCHEMA_COLUMNS)

# UDF 2: For processing 'anomaly' images
def augment_anomaly_generator(pdf_iter):
    """
    A Pandas UDF generator that takes an iterator of 'anomaly' image partitions.
    'pdf_iter' is an iterator that yields Spark partitions (data chunks) one by one.
    'pdf' is a single partition, represented as a pandas DataFrame.
    
    For each anomaly image, it yields:
    1. The original anomaly image.
    2. K 'anomaly' augmentations (Strategy 2A).
    """
    for pdf in pdf_iter:
        # A list to store all rows that will be generated from this partition
        output_rows = []

        for _, row in pdf.iterrows():
            name, path, label = row['name'], row['path'], row['label']
            
            img_np = bin_to_np(row['image'])
            if img_np is None:
                continue

            # 1. Return the original anomaly image
            output_rows.append((name, path, label, row['image']))

            # 2. 'Anomaly' augmentation (Strategy 2A)
            for i in range(N_PER_ANOMALY_A):
                augmented_np = anomaly_transform_A(image=img_np)['image']
                aug_bin = np_to_bin(augmented_np)
                if aug_bin:
                    new_name = f"{name}_aug_anomaly_{i}"
                    output_rows.append((new_name, "N/A", label, aug_bin))
                    
        # After processing all rows in the partition, yield the list as a DataFrame
        if output_rows:
            yield pd.DataFrame(output_rows, columns=SCHEMA_COLUMNS)



In [0]:
# 5. Execute Spark Job
print("Loading original DataFrame...")
df = spark.read.format('parquet').load(f'{mount_path}/images_final')


# Separate data and cache
df_normal = df.filter(F.col('label') == 'surfing').cache()
df_anomaly = df.filter(F.col('label') == 'noise').cache()

print(f"Original Normal Count: {df_normal.count()}, Original Anomaly Count: {df_anomaly.count()}")

print("Starting Augmentation (UDF 1: Normal -> Normal + New Anomaly)...")
df_aug_1 = df_normal.mapInPandas(augment_normal_generator, schema=output_schema)

print("Starting Augmentation (UDF 2: Anomaly -> Anomaly)...")
df_aug_2 = df_anomaly.mapInPandas(augment_anomaly_generator, schema=output_schema)

# Combine the two results
df_augmented = df_aug_1.unionByName(df_aug_2)

# Check the results
print("Augmentation complete. Final result aggregation:")
df_augmented.groupBy('label').count().show()

# (Optional) Unpersist caches
df_normal.unpersist()
df_anomaly.unpersist()

In [0]:
# define image metadata
img_meta = {
    "spark.contentAnnotation":'{"mimeType":"image/jpeg"}'
}

df_augmented = df_augmented.withColumn('image', F.col('image').alias('image', metadata = img_meta))
display(df_augmented, limit = 10)

In [0]:
df_augmented.write.mode("overwrite").format('parquet').save(f'{mount_path}/images_augmented')