In [0]:
mount_path = "/mnt/datamount"
dbutils.fs.ls(mount_path)

# Process Images

In [0]:
from pyspark.sql.functions import regexp_replace
import os
import matplotlib.pyplot as plt
from PIL import Image
import pyspark.sql.functions as F

In [0]:
def create_file_info_df(imgs_dir: str):
    full_path = f"{mount_path}/{imgs_dir}"
    try:
        files = dbutils.fs.ls(full_path)
    except Exception:
        raise Exception(f"Directory {full_path} does not exist")
        return None

    file_info_list = [{"path": file.path, "name": file.name} for file in files]
    return spark.createDataFrame(file_info_list)

file_info_df = create_file_info_df("images")

# file_info_df: remove extension from 'name' column
file_info_df = file_info_df.withColumn(
    "name",
    regexp_replace("name", "\\.[^.]+$", "")
)
display(file_info_df)

In [0]:
label_files = dbutils.fs.ls(f"{mount_path}/labels") # list of all files in the directory
label_path = label_files[0].path
label_df = spark.read.csv(label_path, header = True, inferSchema = True)
display(label_df)

In [0]:
# rename columns
label_df = label_df.withColumnRenamed("img_name", "name")
label_df = label_df.withColumnRenamed("sports_activity", "label")
display(label_df)

In [0]:
# join the two dataframes with the 'name' column
joined_df = file_info_df.join(
    label_df,
    on = "name",
    how = "left"
)

display(joined_df)

In [0]:
def display_image(df, sample_size: int = 5, keep_order = True):
    """
    Display sample images from a Spark DataFrame containing image file paths.

    Args:
        df (DataFrame): Spark DataFrame with at least 'path' and 'name' columns.
        sample_size (int): Number of images to display.
        keep_order (bool): If True, display the first N images in order.
                           If False, display N random images.

    Returns:
        None. Displays images using matplotlib.
    """
    if keep_order:
        images = df.take(sample_size) # Get the first N rows as a list.
    else:
        # Get N random rows as a list.
        # df.orderBy(F.rand()).limit(sample_size) : dataframe -> not callable
        images = df.orderBy(F.rand()).limit(sample_size).collect()

    plt.figure(figsize=(10, 10))
    for i, row in enumerate(images):
        # Replace "dbfs:/" with "/dbfs/" for local file access in Databricks.
        # why? FileNotFoundError: [Errno 2] No such file or directory: '/Workspace/Users/ark0723@gmail.com/anomaly_detection/notebook/dbfs:/mnt/datamount/images/frame_0.jpg' 
        img_path = row.path.replace("dbfs:/", "/dbfs/")
        img = Image.open(img_path)
        plt.subplot(1, sample_size, i+1)
        plt.title(row.name)
        plt.axis("off")
        plt.imshow(img)


In [0]:
display_image(joined_df, sample_size = 3)

In [0]:
for row in joined_df.collect():
    img_dir = row.path.replace("dbfs:/", "/dbfs/")
    img = Image.open(img_dir)

    w, h = img.size
    new_size = min(w, h)
    # center crop (x1, y1, x2, y2)
    img = img.crop(((w-new_size)//2, (h-new_size)//2, (w+new_size)//2, (h+new_size)//2))
    print(f"{img.size = }")

    # resize : 256 x 256, if image enlarged, nearest neighbor will be applied)
    img = img.resize((256, 256), Image.NEAREST)
    print(f"{img.size = }")
    plt.imshow(img)
    break

In [0]:
import io
import pandas as pd
from pyspark.sql.functions import pandas_udf, col, regexp_replace, lit
from pyspark.sql.types import BinaryType

IMG_SIZE = 256

# return 값은 BinaryType
@pandas_udf(BinaryType())
def resize_image_udf(img_paths: pd.Series) -> pd.Series:
    def preprocess(path):
        "resize image and serialize back as jpg"
        # load image
        img = Image.open(path.replace("dbfs:/", "/dbfs/"))
        w, h = img.size
        new_size = min(w, h)
        # crop
        img = img.crop(((w-new_size)//2, (h-new_size)//2, (w+new_size)//2, (h+new_size)//2))
        # resize
        img = img.resize((IMG_SIZE, IMG_SIZE), Image.NEAREST)
        output = io.BytesIO()
        img.save(output, format="JPEG")
        return output.getvalue()
    return img_paths.apply(preprocess)

# add the metadata to enable the image preview
img_meta = {
    "spark.contentAnnotation":'{"mimeType":"image/jpeg"}'
}

df = (
    joined_df.withColumn("image", resize_image_udf("path")) # make 'image' column with binary image
    .withColumn('image', col('image').alias('image', metadata = img_meta)) # binary image -> thumbnail(preview)
    )

display(df)



In [0]:
# save df as parquet file in datalake storage
df.write.mode("overwrite").parquet(f"{mount_path}/resized")

In [0]:
@pandas_udf(BinaryType())
def flip_image_horizontally_udf(df_series: pd.Series) -> pd.Series:
    def flip(binary_image):
        img = Image.open(io.BytesIO(binary_image))
        img = img.transpose(Image.FLIP_LEFT_RIGHT)
        output = io.BytesIO()
        # save back as jpeg
        img.save(output, format="JPEG")
        return output.getvalue()
    return df_series.apply(flip)

df_flipped = df.withColumn("image", flip_image_horizontally_udf("image").alias("image", metadata = img_meta)).withColumn("name", regexp_replace("name", r"(\d+)", r"$1_flipped")).withColumn("path", lit('N/A'))
display(df_flipped)

In [0]:
df_flipped.write.mode("append").parquet(f"{mount_path}/resized")

In [0]:
new_df = spark.read.format('parquet').load(f'{mount_path}/resized')
display(new_df)

In [0]:
noise_df = create_file_info_df("noisy_images")
noise_df = noise_df.withColumn("label", lit("noise"))
noise_df = noise_df.withColumn("image", resize_image_udf("path")).withColumn('image', col('image').alias('image', metadata = img_meta))
display(noise_df)

In [0]:
flipped_noise_df = noise_df.withColumn("image", flip_image_horizontally_udf("image").alias("image", metadata = img_meta)).withColumn("name", regexp_replace("name", r"(\d+)", r"$1_noisy_flipped")).withColumn("path", lit('N/A'))
display(flipped_noise_df)

In [0]:
from functools import reduce
from pyspark.sql.functions import DataFrame

# final_df = new_df.union(noise_df).union(flipped_noise_df)
final_df = reduce(DataFrame.unionAll, [new_df, noise_df, flipped_noise_df])
display(final_df)

In [0]:
final_df.write.mode("overwrite").parquet(f"{mount_path}/images_final")

In [0]:
display(final_df.groupBy('label').count())